diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_multiprocess/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_multiprocess/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ae5fd3279dfe756f292e77484cf0788c428c5a3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/_multiprocess/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/urllib3-2.0.7.dist-info/licenses/LICENSE.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/urllib3-2.0.7.dist-info/licenses/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..e6183d0276b26c5b87aecccf8d0d5bcd7b1148d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/urllib3-2.0.7.dist-info/licenses/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2008-2020 Andrey Petrov and contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dac6c72e803d023fa21a16747f182eb5b7ad3a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__main__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50cd37cdeb6a591915c84032ff50759e781f65d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__main__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__pip-runner__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__pip-runner__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d720024841440bb123df28801c1d9de1ac35b55 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/__pycache__/__pip-runner__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5b7f87f973b36af0ee6fbfb76ce38420f5f9d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/__init__.py @@ -0,0 +1,18 @@ +from typing import List, Optional + +from pip._internal.utils import _log + +# init_logging() must be called before any call to logging.getLogger() +# which happens at import of most modules. +_log.init_logging() + + +def main(args: Optional[List[str]] = None) -> int: + """This is preserved for old console scripts that may still be referencing + it. + + For additional details, see https://github.com/pypa/pip/issues/7498. + """ + from pip._internal.utils.entrypoints import _wrapper + + return _wrapper(args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/build_env.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/build_env.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d1aca0d6ae65972bd6642b27550f339e9d711f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/build_env.py @@ -0,0 +1,322 @@ +"""Build Environment used for isolation during sdist building +""" + +import logging +import os +import pathlib +import site +import sys +import textwrap +from collections import OrderedDict +from types import TracebackType +from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Type, Union + +from pip._vendor.packaging.version import Version + +from pip import __file__ as pip_location +from pip._internal.cli.spinners import open_spinner +from pip._internal.locations import get_platlib, get_purelib, get_scheme +from pip._internal.metadata import get_default_environment, get_environment +from pip._internal.utils.logging import VERBOSE +from pip._internal.utils.packaging import get_requirement +from pip._internal.utils.subprocess import call_subprocess +from pip._internal.utils.temp_dir import TempDirectory, tempdir_kinds + +if TYPE_CHECKING: + from pip._internal.index.package_finder import PackageFinder + +logger = logging.getLogger(__name__) + + +def _dedup(a: str, b: str) -> Union[Tuple[str], Tuple[str, str]]: + return (a, b) if a != b else (a,) + + +class _Prefix: + def __init__(self, path: str) -> None: + self.path = path + self.setup = False + scheme = get_scheme("", prefix=path) + self.bin_dir = scheme.scripts + self.lib_dirs = _dedup(scheme.purelib, scheme.platlib) + + +def get_runnable_pip() -> str: + """Get a file to pass to a Python executable, to run the currently-running pip. + + This is used to run a pip subprocess, for installing requirements into the build + environment. + """ + source = pathlib.Path(pip_location).resolve().parent + + if not source.is_dir(): + # This would happen if someone is using pip from inside a zip file. In that + # case, we can use that directly. + return str(source) + + return os.fsdecode(source / "__pip-runner__.py") + + +def _get_system_sitepackages() -> Set[str]: + """Get system site packages + + Usually from site.getsitepackages, + but fallback on `get_purelib()/get_platlib()` if unavailable + (e.g. in a virtualenv created by virtualenv<20) + + Returns normalized set of strings. + """ + if hasattr(site, "getsitepackages"): + system_sites = site.getsitepackages() + else: + # virtualenv < 20 overwrites site.py without getsitepackages + # fallback on get_purelib/get_platlib. + # this is known to miss things, but shouldn't in the cases + # where getsitepackages() has been removed (inside a virtualenv) + system_sites = [get_purelib(), get_platlib()] + return {os.path.normcase(path) for path in system_sites} + + +class BuildEnvironment: + """Creates and manages an isolated environment to install build deps""" + + def __init__(self) -> None: + temp_dir = TempDirectory(kind=tempdir_kinds.BUILD_ENV, globally_managed=True) + + self._prefixes = OrderedDict( + (name, _Prefix(os.path.join(temp_dir.path, name))) + for name in ("normal", "overlay") + ) + + self._bin_dirs: List[str] = [] + self._lib_dirs: List[str] = [] + for prefix in reversed(list(self._prefixes.values())): + self._bin_dirs.append(prefix.bin_dir) + self._lib_dirs.extend(prefix.lib_dirs) + + # Customize site to: + # - ensure .pth files are honored + # - prevent access to system site packages + system_sites = _get_system_sitepackages() + + self._site_dir = os.path.join(temp_dir.path, "site") + if not os.path.exists(self._site_dir): + os.mkdir(self._site_dir) + with open( + os.path.join(self._site_dir, "sitecustomize.py"), "w", encoding="utf-8" + ) as fp: + fp.write( + textwrap.dedent( + """ + import os, site, sys + + # First, drop system-sites related paths. + original_sys_path = sys.path[:] + known_paths = set() + for path in {system_sites!r}: + site.addsitedir(path, known_paths=known_paths) + system_paths = set( + os.path.normcase(path) + for path in sys.path[len(original_sys_path):] + ) + original_sys_path = [ + path for path in original_sys_path + if os.path.normcase(path) not in system_paths + ] + sys.path = original_sys_path + + # Second, add lib directories. + # ensuring .pth file are processed. + for path in {lib_dirs!r}: + assert not path in sys.path + site.addsitedir(path) + """ + ).format(system_sites=system_sites, lib_dirs=self._lib_dirs) + ) + + def __enter__(self) -> None: + self._save_env = { + name: os.environ.get(name, None) + for name in ("PATH", "PYTHONNOUSERSITE", "PYTHONPATH") + } + + path = self._bin_dirs[:] + old_path = self._save_env["PATH"] + if old_path: + path.extend(old_path.split(os.pathsep)) + + pythonpath = [self._site_dir] + + os.environ.update( + { + "PATH": os.pathsep.join(path), + "PYTHONNOUSERSITE": "1", + "PYTHONPATH": os.pathsep.join(pythonpath), + } + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + for varname, old_value in self._save_env.items(): + if old_value is None: + os.environ.pop(varname, None) + else: + os.environ[varname] = old_value + + def check_requirements( + self, reqs: Iterable[str] + ) -> Tuple[Set[Tuple[str, str]], Set[str]]: + """Return 2 sets: + - conflicting requirements: set of (installed, wanted) reqs tuples + - missing requirements: set of reqs + """ + missing = set() + conflicting = set() + if reqs: + env = ( + get_environment(self._lib_dirs) + if hasattr(self, "_lib_dirs") + else get_default_environment() + ) + for req_str in reqs: + req = get_requirement(req_str) + # We're explicitly evaluating with an empty extra value, since build + # environments are not provided any mechanism to select specific extras. + if req.marker is not None and not req.marker.evaluate({"extra": ""}): + continue + dist = env.get_distribution(req.name) + if not dist: + missing.add(req_str) + continue + if isinstance(dist.version, Version): + installed_req_str = f"{req.name}=={dist.version}" + else: + installed_req_str = f"{req.name}==={dist.version}" + if not req.specifier.contains(dist.version, prereleases=True): + conflicting.add((installed_req_str, req_str)) + # FIXME: Consider direct URL? + return conflicting, missing + + def install_requirements( + self, + finder: "PackageFinder", + requirements: Iterable[str], + prefix_as_string: str, + *, + kind: str, + ) -> None: + prefix = self._prefixes[prefix_as_string] + assert not prefix.setup + prefix.setup = True + if not requirements: + return + self._install_requirements( + get_runnable_pip(), + finder, + requirements, + prefix, + kind=kind, + ) + + @staticmethod + def _install_requirements( + pip_runnable: str, + finder: "PackageFinder", + requirements: Iterable[str], + prefix: _Prefix, + *, + kind: str, + ) -> None: + args: List[str] = [ + sys.executable, + pip_runnable, + "install", + "--ignore-installed", + "--no-user", + "--prefix", + prefix.path, + "--no-warn-script-location", + "--disable-pip-version-check", + # The prefix specified two lines above, thus + # target from config file or env var should be ignored + "--target", + "", + ] + if logger.getEffectiveLevel() <= logging.DEBUG: + args.append("-vv") + elif logger.getEffectiveLevel() <= VERBOSE: + args.append("-v") + for format_control in ("no_binary", "only_binary"): + formats = getattr(finder.format_control, format_control) + args.extend( + ( + "--" + format_control.replace("_", "-"), + ",".join(sorted(formats or {":none:"})), + ) + ) + + index_urls = finder.index_urls + if index_urls: + args.extend(["-i", index_urls[0]]) + for extra_index in index_urls[1:]: + args.extend(["--extra-index-url", extra_index]) + else: + args.append("--no-index") + for link in finder.find_links: + args.extend(["--find-links", link]) + + if finder.proxy: + args.extend(["--proxy", finder.proxy]) + for host in finder.trusted_hosts: + args.extend(["--trusted-host", host]) + if finder.custom_cert: + args.extend(["--cert", finder.custom_cert]) + if finder.client_cert: + args.extend(["--client-cert", finder.client_cert]) + if finder.allow_all_prereleases: + args.append("--pre") + if finder.prefer_binary: + args.append("--prefer-binary") + args.append("--") + args.extend(requirements) + with open_spinner(f"Installing {kind}") as spinner: + call_subprocess( + args, + command_desc=f"pip subprocess to install {kind}", + spinner=spinner, + ) + + +class NoOpBuildEnvironment(BuildEnvironment): + """A no-op drop-in replacement for BuildEnvironment""" + + def __init__(self) -> None: + pass + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + def cleanup(self) -> None: + pass + + def install_requirements( + self, + finder: "PackageFinder", + requirements: Iterable[str], + prefix_as_string: str, + *, + kind: str, + ) -> None: + raise NotImplementedError() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4512672dbad2766464405adb0193f12500e767 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/cache.py @@ -0,0 +1,290 @@ +"""Cache Management +""" + +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional + +from pip._vendor.packaging.tags import Tag, interpreter_name, interpreter_version +from pip._vendor.packaging.utils import canonicalize_name + +from pip._internal.exceptions import InvalidWheelFilename +from pip._internal.models.direct_url import DirectUrl +from pip._internal.models.link import Link +from pip._internal.models.wheel import Wheel +from pip._internal.utils.temp_dir import TempDirectory, tempdir_kinds +from pip._internal.utils.urls import path_to_url + +logger = logging.getLogger(__name__) + +ORIGIN_JSON_NAME = "origin.json" + + +def _hash_dict(d: Dict[str, str]) -> str: + """Return a stable sha224 of a dictionary.""" + s = json.dumps(d, sort_keys=True, separators=(",", ":"), ensure_ascii=True) + return hashlib.sha224(s.encode("ascii")).hexdigest() + + +class Cache: + """An abstract class - provides cache directories for data from links + + :param cache_dir: The root of the cache. + """ + + def __init__(self, cache_dir: str) -> None: + super().__init__() + assert not cache_dir or os.path.isabs(cache_dir) + self.cache_dir = cache_dir or None + + def _get_cache_path_parts(self, link: Link) -> List[str]: + """Get parts of part that must be os.path.joined with cache_dir""" + + # We want to generate an url to use as our cache key, we don't want to + # just reuse the URL because it might have other items in the fragment + # and we don't care about those. + key_parts = {"url": link.url_without_fragment} + if link.hash_name is not None and link.hash is not None: + key_parts[link.hash_name] = link.hash + if link.subdirectory_fragment: + key_parts["subdirectory"] = link.subdirectory_fragment + + # Include interpreter name, major and minor version in cache key + # to cope with ill-behaved sdists that build a different wheel + # depending on the python version their setup.py is being run on, + # and don't encode the difference in compatibility tags. + # https://github.com/pypa/pip/issues/7296 + key_parts["interpreter_name"] = interpreter_name() + key_parts["interpreter_version"] = interpreter_version() + + # Encode our key url with sha224, we'll use this because it has similar + # security properties to sha256, but with a shorter total output (and + # thus less secure). However the differences don't make a lot of + # difference for our use case here. + hashed = _hash_dict(key_parts) + + # We want to nest the directories some to prevent having a ton of top + # level directories where we might run out of sub directories on some + # FS. + parts = [hashed[:2], hashed[2:4], hashed[4:6], hashed[6:]] + + return parts + + def _get_candidates(self, link: Link, canonical_package_name: str) -> List[Any]: + can_not_cache = not self.cache_dir or not canonical_package_name or not link + if can_not_cache: + return [] + + path = self.get_path_for_link(link) + if os.path.isdir(path): + return [(candidate, path) for candidate in os.listdir(path)] + return [] + + def get_path_for_link(self, link: Link) -> str: + """Return a directory to store cached items in for link.""" + raise NotImplementedError() + + def get( + self, + link: Link, + package_name: Optional[str], + supported_tags: List[Tag], + ) -> Link: + """Returns a link to a cached item if it exists, otherwise returns the + passed link. + """ + raise NotImplementedError() + + +class SimpleWheelCache(Cache): + """A cache of wheels for future installs.""" + + def __init__(self, cache_dir: str) -> None: + super().__init__(cache_dir) + + def get_path_for_link(self, link: Link) -> str: + """Return a directory to store cached wheels for link + + Because there are M wheels for any one sdist, we provide a directory + to cache them in, and then consult that directory when looking up + cache hits. + + We only insert things into the cache if they have plausible version + numbers, so that we don't contaminate the cache with things that were + not unique. E.g. ./package might have dozens of installs done for it + and build a version of 0.0...and if we built and cached a wheel, we'd + end up using the same wheel even if the source has been edited. + + :param link: The link of the sdist for which this will cache wheels. + """ + parts = self._get_cache_path_parts(link) + assert self.cache_dir + # Store wheels within the root cache_dir + return os.path.join(self.cache_dir, "wheels", *parts) + + def get( + self, + link: Link, + package_name: Optional[str], + supported_tags: List[Tag], + ) -> Link: + candidates = [] + + if not package_name: + return link + + canonical_package_name = canonicalize_name(package_name) + for wheel_name, wheel_dir in self._get_candidates(link, canonical_package_name): + try: + wheel = Wheel(wheel_name) + except InvalidWheelFilename: + continue + if canonicalize_name(wheel.name) != canonical_package_name: + logger.debug( + "Ignoring cached wheel %s for %s as it " + "does not match the expected distribution name %s.", + wheel_name, + link, + package_name, + ) + continue + if not wheel.supported(supported_tags): + # Built for a different python/arch/etc + continue + candidates.append( + ( + wheel.support_index_min(supported_tags), + wheel_name, + wheel_dir, + ) + ) + + if not candidates: + return link + + _, wheel_name, wheel_dir = min(candidates) + return Link(path_to_url(os.path.join(wheel_dir, wheel_name))) + + +class EphemWheelCache(SimpleWheelCache): + """A SimpleWheelCache that creates it's own temporary cache directory""" + + def __init__(self) -> None: + self._temp_dir = TempDirectory( + kind=tempdir_kinds.EPHEM_WHEEL_CACHE, + globally_managed=True, + ) + + super().__init__(self._temp_dir.path) + + +class CacheEntry: + def __init__( + self, + link: Link, + persistent: bool, + ): + self.link = link + self.persistent = persistent + self.origin: Optional[DirectUrl] = None + origin_direct_url_path = Path(self.link.file_path).parent / ORIGIN_JSON_NAME + if origin_direct_url_path.exists(): + try: + self.origin = DirectUrl.from_json( + origin_direct_url_path.read_text(encoding="utf-8") + ) + except Exception as e: + logger.warning( + "Ignoring invalid cache entry origin file %s for %s (%s)", + origin_direct_url_path, + link.filename, + e, + ) + + +class WheelCache(Cache): + """Wraps EphemWheelCache and SimpleWheelCache into a single Cache + + This Cache allows for gracefully degradation, using the ephem wheel cache + when a certain link is not found in the simple wheel cache first. + """ + + def __init__(self, cache_dir: str) -> None: + super().__init__(cache_dir) + self._wheel_cache = SimpleWheelCache(cache_dir) + self._ephem_cache = EphemWheelCache() + + def get_path_for_link(self, link: Link) -> str: + return self._wheel_cache.get_path_for_link(link) + + def get_ephem_path_for_link(self, link: Link) -> str: + return self._ephem_cache.get_path_for_link(link) + + def get( + self, + link: Link, + package_name: Optional[str], + supported_tags: List[Tag], + ) -> Link: + cache_entry = self.get_cache_entry(link, package_name, supported_tags) + if cache_entry is None: + return link + return cache_entry.link + + def get_cache_entry( + self, + link: Link, + package_name: Optional[str], + supported_tags: List[Tag], + ) -> Optional[CacheEntry]: + """Returns a CacheEntry with a link to a cached item if it exists or + None. The cache entry indicates if the item was found in the persistent + or ephemeral cache. + """ + retval = self._wheel_cache.get( + link=link, + package_name=package_name, + supported_tags=supported_tags, + ) + if retval is not link: + return CacheEntry(retval, persistent=True) + + retval = self._ephem_cache.get( + link=link, + package_name=package_name, + supported_tags=supported_tags, + ) + if retval is not link: + return CacheEntry(retval, persistent=False) + + return None + + @staticmethod + def record_download_origin(cache_dir: str, download_info: DirectUrl) -> None: + origin_path = Path(cache_dir) / ORIGIN_JSON_NAME + if origin_path.exists(): + try: + origin = DirectUrl.from_json(origin_path.read_text(encoding="utf-8")) + except Exception as e: + logger.warning( + "Could not read origin file %s in cache entry (%s). " + "Will attempt to overwrite it.", + origin_path, + e, + ) + else: + # TODO: use DirectUrl.equivalent when + # https://github.com/pypa/pip/pull/10564 is merged. + if origin.url != download_info.url: + logger.warning( + "Origin URL %s in cache entry %s does not match download URL " + "%s. This is likely a pip bug or a cache corruption issue. " + "Will overwrite it with the new value.", + origin.url, + cache_dir, + download_info.url, + ) + origin_path.write_text(download_info.to_json(), encoding="utf-8") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..858a41014169b8f0eb1b905fa3bb69c753a1bda5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/__init__.py @@ -0,0 +1,132 @@ +""" +Package containing all pip commands +""" + +import importlib +from collections import namedtuple +from typing import Any, Dict, Optional + +from pip._internal.cli.base_command import Command + +CommandInfo = namedtuple("CommandInfo", "module_path, class_name, summary") + +# This dictionary does a bunch of heavy lifting for help output: +# - Enables avoiding additional (costly) imports for presenting `--help`. +# - The ordering matters for help display. +# +# Even though the module path starts with the same "pip._internal.commands" +# prefix, the full path makes testing easier (specifically when modifying +# `commands_dict` in test setup / teardown). +commands_dict: Dict[str, CommandInfo] = { + "install": CommandInfo( + "pip._internal.commands.install", + "InstallCommand", + "Install packages.", + ), + "download": CommandInfo( + "pip._internal.commands.download", + "DownloadCommand", + "Download packages.", + ), + "uninstall": CommandInfo( + "pip._internal.commands.uninstall", + "UninstallCommand", + "Uninstall packages.", + ), + "freeze": CommandInfo( + "pip._internal.commands.freeze", + "FreezeCommand", + "Output installed packages in requirements format.", + ), + "inspect": CommandInfo( + "pip._internal.commands.inspect", + "InspectCommand", + "Inspect the python environment.", + ), + "list": CommandInfo( + "pip._internal.commands.list", + "ListCommand", + "List installed packages.", + ), + "show": CommandInfo( + "pip._internal.commands.show", + "ShowCommand", + "Show information about installed packages.", + ), + "check": CommandInfo( + "pip._internal.commands.check", + "CheckCommand", + "Verify installed packages have compatible dependencies.", + ), + "config": CommandInfo( + "pip._internal.commands.configuration", + "ConfigurationCommand", + "Manage local and global configuration.", + ), + "search": CommandInfo( + "pip._internal.commands.search", + "SearchCommand", + "Search PyPI for packages.", + ), + "cache": CommandInfo( + "pip._internal.commands.cache", + "CacheCommand", + "Inspect and manage pip's wheel cache.", + ), + "index": CommandInfo( + "pip._internal.commands.index", + "IndexCommand", + "Inspect information available from package indexes.", + ), + "wheel": CommandInfo( + "pip._internal.commands.wheel", + "WheelCommand", + "Build wheels from your requirements.", + ), + "hash": CommandInfo( + "pip._internal.commands.hash", + "HashCommand", + "Compute hashes of package archives.", + ), + "completion": CommandInfo( + "pip._internal.commands.completion", + "CompletionCommand", + "A helper command used for command completion.", + ), + "debug": CommandInfo( + "pip._internal.commands.debug", + "DebugCommand", + "Show information useful for debugging.", + ), + "help": CommandInfo( + "pip._internal.commands.help", + "HelpCommand", + "Show help for commands.", + ), +} + + +def create_command(name: str, **kwargs: Any) -> Command: + """ + Create an instance of the Command class with the given name. + """ + module_path, class_name, summary = commands_dict[name] + module = importlib.import_module(module_path) + command_class = getattr(module, class_name) + command = command_class(name=name, summary=summary, **kwargs) + + return command + + +def get_similar_commands(name: str) -> Optional[str]: + """Command name auto-correct.""" + from difflib import get_close_matches + + name = name.lower() + + close_commands = get_close_matches(name, commands_dict.keys()) + + if close_commands: + return close_commands[0] + else: + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..ad65641edb262fd85c5065c76d6a81aae1168192 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/cache.py @@ -0,0 +1,228 @@ +import os +import textwrap +from optparse import Values +from typing import Any, List + +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.exceptions import CommandError, PipError +from pip._internal.utils import filesystem +from pip._internal.utils.logging import getLogger +from pip._internal.utils.misc import format_size + +logger = getLogger(__name__) + + +class CacheCommand(Command): + """ + Inspect and manage pip's wheel cache. + + Subcommands: + + - dir: Show the cache directory. + - info: Show information about the cache. + - list: List filenames of packages stored in the cache. + - remove: Remove one or more package from the cache. + - purge: Remove all items from the cache. + + ```` can be a glob expression or a package name. + """ + + ignore_require_venv = True + usage = """ + %prog dir + %prog info + %prog list [] [--format=[human, abspath]] + %prog remove + %prog purge + """ + + def add_options(self) -> None: + self.cmd_opts.add_option( + "--format", + action="store", + dest="list_format", + default="human", + choices=("human", "abspath"), + help="Select the output format among: human (default) or abspath", + ) + + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + handlers = { + "dir": self.get_cache_dir, + "info": self.get_cache_info, + "list": self.list_cache_items, + "remove": self.remove_cache_items, + "purge": self.purge_cache, + } + + if not options.cache_dir: + logger.error("pip cache commands can not function since cache is disabled.") + return ERROR + + # Determine action + if not args or args[0] not in handlers: + logger.error( + "Need an action (%s) to perform.", + ", ".join(sorted(handlers)), + ) + return ERROR + + action = args[0] + + # Error handling happens here, not in the action-handlers. + try: + handlers[action](options, args[1:]) + except PipError as e: + logger.error(e.args[0]) + return ERROR + + return SUCCESS + + def get_cache_dir(self, options: Values, args: List[Any]) -> None: + if args: + raise CommandError("Too many arguments") + + logger.info(options.cache_dir) + + def get_cache_info(self, options: Values, args: List[Any]) -> None: + if args: + raise CommandError("Too many arguments") + + num_http_files = len(self._find_http_files(options)) + num_packages = len(self._find_wheels(options, "*")) + + http_cache_location = self._cache_dir(options, "http-v2") + old_http_cache_location = self._cache_dir(options, "http") + wheels_cache_location = self._cache_dir(options, "wheels") + http_cache_size = filesystem.format_size( + filesystem.directory_size(http_cache_location) + + filesystem.directory_size(old_http_cache_location) + ) + wheels_cache_size = filesystem.format_directory_size(wheels_cache_location) + + message = ( + textwrap.dedent( + """ + Package index page cache location (pip v23.3+): {http_cache_location} + Package index page cache location (older pips): {old_http_cache_location} + Package index page cache size: {http_cache_size} + Number of HTTP files: {num_http_files} + Locally built wheels location: {wheels_cache_location} + Locally built wheels size: {wheels_cache_size} + Number of locally built wheels: {package_count} + """ # noqa: E501 + ) + .format( + http_cache_location=http_cache_location, + old_http_cache_location=old_http_cache_location, + http_cache_size=http_cache_size, + num_http_files=num_http_files, + wheels_cache_location=wheels_cache_location, + package_count=num_packages, + wheels_cache_size=wheels_cache_size, + ) + .strip() + ) + + logger.info(message) + + def list_cache_items(self, options: Values, args: List[Any]) -> None: + if len(args) > 1: + raise CommandError("Too many arguments") + + if args: + pattern = args[0] + else: + pattern = "*" + + files = self._find_wheels(options, pattern) + if options.list_format == "human": + self.format_for_human(files) + else: + self.format_for_abspath(files) + + def format_for_human(self, files: List[str]) -> None: + if not files: + logger.info("No locally built wheels cached.") + return + + results = [] + for filename in files: + wheel = os.path.basename(filename) + size = filesystem.format_file_size(filename) + results.append(f" - {wheel} ({size})") + logger.info("Cache contents:\n") + logger.info("\n".join(sorted(results))) + + def format_for_abspath(self, files: List[str]) -> None: + if files: + logger.info("\n".join(sorted(files))) + + def remove_cache_items(self, options: Values, args: List[Any]) -> None: + if len(args) > 1: + raise CommandError("Too many arguments") + + if not args: + raise CommandError("Please provide a pattern") + + files = self._find_wheels(options, args[0]) + + no_matching_msg = "No matching packages" + if args[0] == "*": + # Only fetch http files if no specific pattern given + files += self._find_http_files(options) + else: + # Add the pattern to the log message + no_matching_msg += f' for pattern "{args[0]}"' + + if not files: + logger.warning(no_matching_msg) + + bytes_removed = 0 + for filename in files: + bytes_removed += os.stat(filename).st_size + os.unlink(filename) + logger.verbose("Removed %s", filename) + logger.info("Files removed: %s (%s)", len(files), format_size(bytes_removed)) + + def purge_cache(self, options: Values, args: List[Any]) -> None: + if args: + raise CommandError("Too many arguments") + + return self.remove_cache_items(options, ["*"]) + + def _cache_dir(self, options: Values, subdir: str) -> str: + return os.path.join(options.cache_dir, subdir) + + def _find_http_files(self, options: Values) -> List[str]: + old_http_dir = self._cache_dir(options, "http") + new_http_dir = self._cache_dir(options, "http-v2") + return filesystem.find_files(old_http_dir, "*") + filesystem.find_files( + new_http_dir, "*" + ) + + def _find_wheels(self, options: Values, pattern: str) -> List[str]: + wheel_dir = self._cache_dir(options, "wheels") + + # The wheel filename format, as specified in PEP 427, is: + # {distribution}-{version}(-{build})?-{python}-{abi}-{platform}.whl + # + # Additionally, non-alphanumeric values in the distribution are + # normalized to underscores (_), meaning hyphens can never occur + # before `-{version}`. + # + # Given that information: + # - If the pattern we're given contains a hyphen (-), the user is + # providing at least the version. Thus, we can just append `*.whl` + # to match the rest of it. + # - If the pattern we're given doesn't contain a hyphen (-), the + # user is only providing the name. Thus, we append `-*.whl` to + # match the hyphen before the version, followed by anything else. + # + # PEP 427: https://www.python.org/dev/peps/pep-0427/ + pattern = pattern + ("*.whl" if "-" in pattern else "-*.whl") + + return filesystem.find_files(wheel_dir, pattern) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/check.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/check.py new file mode 100644 index 0000000000000000000000000000000000000000..f54a16dc0a15892d935a5e79679bab3510c8b0d5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/check.py @@ -0,0 +1,67 @@ +import logging +from optparse import Values +from typing import List + +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.metadata import get_default_environment +from pip._internal.operations.check import ( + check_package_set, + check_unsupported, + create_package_set_from_installed, +) +from pip._internal.utils.compatibility_tags import get_supported +from pip._internal.utils.misc import write_output + +logger = logging.getLogger(__name__) + + +class CheckCommand(Command): + """Verify installed packages have compatible dependencies.""" + + ignore_require_venv = True + usage = """ + %prog [options]""" + + def run(self, options: Values, args: List[str]) -> int: + package_set, parsing_probs = create_package_set_from_installed() + missing, conflicting = check_package_set(package_set) + unsupported = list( + check_unsupported( + get_default_environment().iter_installed_distributions(), + get_supported(), + ) + ) + + for project_name in missing: + version = package_set[project_name].version + for dependency in missing[project_name]: + write_output( + "%s %s requires %s, which is not installed.", + project_name, + version, + dependency[0], + ) + + for project_name in conflicting: + version = package_set[project_name].version + for dep_name, dep_version, req in conflicting[project_name]: + write_output( + "%s %s has requirement %s, but you have %s %s.", + project_name, + version, + req, + dep_name, + dep_version, + ) + for package in unsupported: + write_output( + "%s %s is not supported on this platform", + package.raw_name, + package.version, + ) + if missing or conflicting or parsing_probs or unsupported: + return ERROR + else: + write_output("No broken requirements found.") + return SUCCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/completion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/completion.py new file mode 100644 index 0000000000000000000000000000000000000000..9e89e27988368821f6936cd1e94ac9395ca0312d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/completion.py @@ -0,0 +1,130 @@ +import sys +import textwrap +from optparse import Values +from typing import List + +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import SUCCESS +from pip._internal.utils.misc import get_prog + +BASE_COMPLETION = """ +# pip {shell} completion start{script}# pip {shell} completion end +""" + +COMPLETION_SCRIPTS = { + "bash": """ + _pip_completion() + {{ + COMPREPLY=( $( COMP_WORDS="${{COMP_WORDS[*]}}" \\ + COMP_CWORD=$COMP_CWORD \\ + PIP_AUTO_COMPLETE=1 $1 2>/dev/null ) ) + }} + complete -o default -F _pip_completion {prog} + """, + "zsh": """ + #compdef -P pip[0-9.]# + __pip() {{ + compadd $( COMP_WORDS="$words[*]" \\ + COMP_CWORD=$((CURRENT-1)) \\ + PIP_AUTO_COMPLETE=1 $words[1] 2>/dev/null ) + }} + if [[ $zsh_eval_context[-1] == loadautofunc ]]; then + # autoload from fpath, call function directly + __pip "$@" + else + # eval/source/. command, register function for later + compdef __pip -P 'pip[0-9.]#' + fi + """, + "fish": """ + function __fish_complete_pip + set -lx COMP_WORDS (commandline -o) "" + set -lx COMP_CWORD ( \\ + math (contains -i -- (commandline -t) $COMP_WORDS)-1 \\ + ) + set -lx PIP_AUTO_COMPLETE 1 + string split \\ -- (eval $COMP_WORDS[1]) + end + complete -fa "(__fish_complete_pip)" -c {prog} + """, + "powershell": """ + if ((Test-Path Function:\\TabExpansion) -and -not ` + (Test-Path Function:\\_pip_completeBackup)) {{ + Rename-Item Function:\\TabExpansion _pip_completeBackup + }} + function TabExpansion($line, $lastWord) {{ + $lastBlock = [regex]::Split($line, '[|;]')[-1].TrimStart() + if ($lastBlock.StartsWith("{prog} ")) {{ + $Env:COMP_WORDS=$lastBlock + $Env:COMP_CWORD=$lastBlock.Split().Length - 1 + $Env:PIP_AUTO_COMPLETE=1 + (& {prog}).Split() + Remove-Item Env:COMP_WORDS + Remove-Item Env:COMP_CWORD + Remove-Item Env:PIP_AUTO_COMPLETE + }} + elseif (Test-Path Function:\\_pip_completeBackup) {{ + # Fall back on existing tab expansion + _pip_completeBackup $line $lastWord + }} + }} + """, +} + + +class CompletionCommand(Command): + """A helper command to be used for command completion.""" + + ignore_require_venv = True + + def add_options(self) -> None: + self.cmd_opts.add_option( + "--bash", + "-b", + action="store_const", + const="bash", + dest="shell", + help="Emit completion code for bash", + ) + self.cmd_opts.add_option( + "--zsh", + "-z", + action="store_const", + const="zsh", + dest="shell", + help="Emit completion code for zsh", + ) + self.cmd_opts.add_option( + "--fish", + "-f", + action="store_const", + const="fish", + dest="shell", + help="Emit completion code for fish", + ) + self.cmd_opts.add_option( + "--powershell", + "-p", + action="store_const", + const="powershell", + dest="shell", + help="Emit completion code for powershell", + ) + + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + """Prints the completion code of the given shell""" + shells = COMPLETION_SCRIPTS.keys() + shell_options = ["--" + shell for shell in sorted(shells)] + if options.shell in shells: + script = textwrap.dedent( + COMPLETION_SCRIPTS.get(options.shell, "").format(prog=get_prog()) + ) + print(BASE_COMPLETION.format(script=script, shell=options.shell)) + return SUCCESS + else: + sys.stderr.write( + "ERROR: You must pass {}\n".format(" or ".join(shell_options)) + ) + return SUCCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/debug.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..567ca967e5b64478d17455288b79dd80301b4888 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/debug.py @@ -0,0 +1,201 @@ +import locale +import logging +import os +import sys +from optparse import Values +from types import ModuleType +from typing import Any, Dict, List, Optional + +import pip._vendor +from pip._vendor.certifi import where +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.cli import cmdoptions +from pip._internal.cli.base_command import Command +from pip._internal.cli.cmdoptions import make_target_python +from pip._internal.cli.status_codes import SUCCESS +from pip._internal.configuration import Configuration +from pip._internal.metadata import get_environment +from pip._internal.utils.compat import open_text_resource +from pip._internal.utils.logging import indent_log +from pip._internal.utils.misc import get_pip_version + +logger = logging.getLogger(__name__) + + +def show_value(name: str, value: Any) -> None: + logger.info("%s: %s", name, value) + + +def show_sys_implementation() -> None: + logger.info("sys.implementation:") + implementation_name = sys.implementation.name + with indent_log(): + show_value("name", implementation_name) + + +def create_vendor_txt_map() -> Dict[str, str]: + with open_text_resource("pip._vendor", "vendor.txt") as f: + # Purge non version specifying lines. + # Also, remove any space prefix or suffixes (including comments). + lines = [ + line.strip().split(" ", 1)[0] for line in f.readlines() if "==" in line + ] + + # Transform into "module" -> version dict. + return dict(line.split("==", 1) for line in lines) + + +def get_module_from_module_name(module_name: str) -> Optional[ModuleType]: + # Module name can be uppercase in vendor.txt for some reason... + module_name = module_name.lower().replace("-", "_") + # PATCH: setuptools is actually only pkg_resources. + if module_name == "setuptools": + module_name = "pkg_resources" + + try: + __import__(f"pip._vendor.{module_name}", globals(), locals(), level=0) + return getattr(pip._vendor, module_name) + except ImportError: + # We allow 'truststore' to fail to import due + # to being unavailable on Python 3.9 and earlier. + if module_name == "truststore" and sys.version_info < (3, 10): + return None + raise + + +def get_vendor_version_from_module(module_name: str) -> Optional[str]: + module = get_module_from_module_name(module_name) + version = getattr(module, "__version__", None) + + if module and not version: + # Try to find version in debundled module info. + assert module.__file__ is not None + env = get_environment([os.path.dirname(module.__file__)]) + dist = env.get_distribution(module_name) + if dist: + version = str(dist.version) + + return version + + +def show_actual_vendor_versions(vendor_txt_versions: Dict[str, str]) -> None: + """Log the actual version and print extra info if there is + a conflict or if the actual version could not be imported. + """ + for module_name, expected_version in vendor_txt_versions.items(): + extra_message = "" + actual_version = get_vendor_version_from_module(module_name) + if not actual_version: + extra_message = ( + " (Unable to locate actual module version, using" + " vendor.txt specified version)" + ) + actual_version = expected_version + elif parse_version(actual_version) != parse_version(expected_version): + extra_message = ( + " (CONFLICT: vendor.txt suggests version should" + f" be {expected_version})" + ) + logger.info("%s==%s%s", module_name, actual_version, extra_message) + + +def show_vendor_versions() -> None: + logger.info("vendored library versions:") + + vendor_txt_versions = create_vendor_txt_map() + with indent_log(): + show_actual_vendor_versions(vendor_txt_versions) + + +def show_tags(options: Values) -> None: + tag_limit = 10 + + target_python = make_target_python(options) + tags = target_python.get_sorted_tags() + + # Display the target options that were explicitly provided. + formatted_target = target_python.format_given() + suffix = "" + if formatted_target: + suffix = f" (target: {formatted_target})" + + msg = f"Compatible tags: {len(tags)}{suffix}" + logger.info(msg) + + if options.verbose < 1 and len(tags) > tag_limit: + tags_limited = True + tags = tags[:tag_limit] + else: + tags_limited = False + + with indent_log(): + for tag in tags: + logger.info(str(tag)) + + if tags_limited: + msg = f"...\n[First {tag_limit} tags shown. Pass --verbose to show all.]" + logger.info(msg) + + +def ca_bundle_info(config: Configuration) -> str: + levels = {key.split(".", 1)[0] for key, _ in config.items()} + if not levels: + return "Not specified" + + levels_that_override_global = ["install", "wheel", "download"] + global_overriding_level = [ + level for level in levels if level in levels_that_override_global + ] + if not global_overriding_level: + return "global" + + if "global" in levels: + levels.remove("global") + return ", ".join(levels) + + +class DebugCommand(Command): + """ + Display debug information. + """ + + usage = """ + %prog """ + ignore_require_venv = True + + def add_options(self) -> None: + cmdoptions.add_target_python_options(self.cmd_opts) + self.parser.insert_option_group(0, self.cmd_opts) + self.parser.config.load() + + def run(self, options: Values, args: List[str]) -> int: + logger.warning( + "This command is only meant for debugging. " + "Do not use this with automation for parsing and getting these " + "details, since the output and options of this command may " + "change without notice." + ) + show_value("pip version", get_pip_version()) + show_value("sys.version", sys.version) + show_value("sys.executable", sys.executable) + show_value("sys.getdefaultencoding", sys.getdefaultencoding()) + show_value("sys.getfilesystemencoding", sys.getfilesystemencoding()) + show_value( + "locale.getpreferredencoding", + locale.getpreferredencoding(), + ) + show_value("sys.platform", sys.platform) + show_sys_implementation() + + show_value("'cert' config value", ca_bundle_info(self.parser.config)) + show_value("REQUESTS_CA_BUNDLE", os.environ.get("REQUESTS_CA_BUNDLE")) + show_value("CURL_CA_BUNDLE", os.environ.get("CURL_CA_BUNDLE")) + show_value("pip._vendor.certifi.where()", where()) + show_value("pip._vendor.DEBUNDLED", pip._vendor.DEBUNDLED) + + show_vendor_versions() + + show_tags(options) + + return SUCCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/freeze.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/freeze.py new file mode 100644 index 0000000000000000000000000000000000000000..885fdfeb83b837b27db08e4abc3df2e7b580dc2b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/freeze.py @@ -0,0 +1,109 @@ +import sys +from optparse import Values +from typing import AbstractSet, List + +from pip._internal.cli import cmdoptions +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import SUCCESS +from pip._internal.operations.freeze import freeze +from pip._internal.utils.compat import stdlib_pkgs + + +def _should_suppress_build_backends() -> bool: + return sys.version_info < (3, 12) + + +def _dev_pkgs() -> AbstractSet[str]: + pkgs = {"pip"} + + if _should_suppress_build_backends(): + pkgs |= {"setuptools", "distribute", "wheel"} + + return pkgs + + +class FreezeCommand(Command): + """ + Output installed packages in requirements format. + + packages are listed in a case-insensitive sorted order. + """ + + ignore_require_venv = True + usage = """ + %prog [options]""" + log_streams = ("ext://sys.stderr", "ext://sys.stderr") + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-r", + "--requirement", + dest="requirements", + action="append", + default=[], + metavar="file", + help=( + "Use the order in the given requirements file and its " + "comments when generating output. This option can be " + "used multiple times." + ), + ) + self.cmd_opts.add_option( + "-l", + "--local", + dest="local", + action="store_true", + default=False, + help=( + "If in a virtualenv that has global access, do not output " + "globally-installed packages." + ), + ) + self.cmd_opts.add_option( + "--user", + dest="user", + action="store_true", + default=False, + help="Only output packages installed in user-site.", + ) + self.cmd_opts.add_option(cmdoptions.list_path()) + self.cmd_opts.add_option( + "--all", + dest="freeze_all", + action="store_true", + help=( + "Do not skip these packages in the output:" + " {}".format(", ".join(_dev_pkgs())) + ), + ) + self.cmd_opts.add_option( + "--exclude-editable", + dest="exclude_editable", + action="store_true", + help="Exclude editable package from output.", + ) + self.cmd_opts.add_option(cmdoptions.list_exclude()) + + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + skip = set(stdlib_pkgs) + if not options.freeze_all: + skip.update(_dev_pkgs()) + + if options.excludes: + skip.update(options.excludes) + + cmdoptions.check_list_path_option(options) + + for line in freeze( + requirement=options.requirements, + local_only=options.local, + user_only=options.user, + paths=options.path, + isolated=options.isolated_mode, + skip=skip, + exclude_editable=options.exclude_editable, + ): + sys.stdout.write(line + "\n") + return SUCCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/hash.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/hash.py new file mode 100644 index 0000000000000000000000000000000000000000..042dac813e74b8187c3754cb9a937c7f7183e331 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/hash.py @@ -0,0 +1,59 @@ +import hashlib +import logging +import sys +from optparse import Values +from typing import List + +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.utils.hashes import FAVORITE_HASH, STRONG_HASHES +from pip._internal.utils.misc import read_chunks, write_output + +logger = logging.getLogger(__name__) + + +class HashCommand(Command): + """ + Compute a hash of a local package archive. + + These can be used with --hash in a requirements file to do repeatable + installs. + """ + + usage = "%prog [options] ..." + ignore_require_venv = True + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-a", + "--algorithm", + dest="algorithm", + choices=STRONG_HASHES, + action="store", + default=FAVORITE_HASH, + help="The hash algorithm to use: one of {}".format( + ", ".join(STRONG_HASHES) + ), + ) + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + if not args: + self.parser.print_usage(sys.stderr) + return ERROR + + algorithm = options.algorithm + for path in args: + write_output( + "%s:\n--hash=%s:%s", path, algorithm, _hash_of_file(path, algorithm) + ) + return SUCCESS + + +def _hash_of_file(path: str, algorithm: str) -> str: + """Return the hash digest of a file.""" + with open(path, "rb") as archive: + hash = hashlib.new(algorithm) + for chunk in read_chunks(archive): + hash.update(chunk) + return hash.hexdigest() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/index.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/index.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2661bba710fb2cc255ba6ee01ea743fdbf540e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/index.py @@ -0,0 +1,139 @@ +import logging +from optparse import Values +from typing import Any, Iterable, List, Optional + +from pip._vendor.packaging.version import Version + +from pip._internal.cli import cmdoptions +from pip._internal.cli.req_command import IndexGroupCommand +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.commands.search import print_dist_installation_info +from pip._internal.exceptions import CommandError, DistributionNotFound, PipError +from pip._internal.index.collector import LinkCollector +from pip._internal.index.package_finder import PackageFinder +from pip._internal.models.selection_prefs import SelectionPreferences +from pip._internal.models.target_python import TargetPython +from pip._internal.network.session import PipSession +from pip._internal.utils.misc import write_output + +logger = logging.getLogger(__name__) + + +class IndexCommand(IndexGroupCommand): + """ + Inspect information available from package indexes. + """ + + ignore_require_venv = True + usage = """ + %prog versions + """ + + def add_options(self) -> None: + cmdoptions.add_target_python_options(self.cmd_opts) + + self.cmd_opts.add_option(cmdoptions.ignore_requires_python()) + self.cmd_opts.add_option(cmdoptions.pre()) + self.cmd_opts.add_option(cmdoptions.no_binary()) + self.cmd_opts.add_option(cmdoptions.only_binary()) + + index_opts = cmdoptions.make_option_group( + cmdoptions.index_group, + self.parser, + ) + + self.parser.insert_option_group(0, index_opts) + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + handlers = { + "versions": self.get_available_package_versions, + } + + logger.warning( + "pip index is currently an experimental command. " + "It may be removed/changed in a future release " + "without prior warning." + ) + + # Determine action + if not args or args[0] not in handlers: + logger.error( + "Need an action (%s) to perform.", + ", ".join(sorted(handlers)), + ) + return ERROR + + action = args[0] + + # Error handling happens here, not in the action-handlers. + try: + handlers[action](options, args[1:]) + except PipError as e: + logger.error(e.args[0]) + return ERROR + + return SUCCESS + + def _build_package_finder( + self, + options: Values, + session: PipSession, + target_python: Optional[TargetPython] = None, + ignore_requires_python: Optional[bool] = None, + ) -> PackageFinder: + """ + Create a package finder appropriate to the index command. + """ + link_collector = LinkCollector.create(session, options=options) + + # Pass allow_yanked=False to ignore yanked versions. + selection_prefs = SelectionPreferences( + allow_yanked=False, + allow_all_prereleases=options.pre, + ignore_requires_python=ignore_requires_python, + ) + + return PackageFinder.create( + link_collector=link_collector, + selection_prefs=selection_prefs, + target_python=target_python, + ) + + def get_available_package_versions(self, options: Values, args: List[Any]) -> None: + if len(args) != 1: + raise CommandError("You need to specify exactly one argument") + + target_python = cmdoptions.make_target_python(options) + query = args[0] + + with self._build_session(options) as session: + finder = self._build_package_finder( + options=options, + session=session, + target_python=target_python, + ignore_requires_python=options.ignore_requires_python, + ) + + versions: Iterable[Version] = ( + candidate.version for candidate in finder.find_all_candidates(query) + ) + + if not options.pre: + # Remove prereleases + versions = ( + version for version in versions if not version.is_prerelease + ) + versions = set(versions) + + if not versions: + raise DistributionNotFound( + f"No matching distribution found for {query}" + ) + + formatted_versions = [str(ver) for ver in sorted(versions, reverse=True)] + latest = formatted_versions[0] + + write_output(f"{query} ({latest})") + write_output("Available versions: {}".format(", ".join(formatted_versions))) + print_dist_installation_info(query, latest) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/install.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/install.py new file mode 100644 index 0000000000000000000000000000000000000000..232a34a6d3e5eac0ee8e80ff1b7477b987b425f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/install.py @@ -0,0 +1,784 @@ +import errno +import json +import operator +import os +import shutil +import site +from optparse import SUPPRESS_HELP, Values +from typing import List, Optional + +from pip._vendor.packaging.utils import canonicalize_name +from pip._vendor.rich import print_json + +# Eagerly import self_outdated_check to avoid crashes. Otherwise, +# this module would be imported *after* pip was replaced, resulting +# in crashes if the new self_outdated_check module was incompatible +# with the rest of pip that's already imported, or allowing a +# wheel to execute arbitrary code on install by replacing +# self_outdated_check. +import pip._internal.self_outdated_check # noqa: F401 +from pip._internal.cache import WheelCache +from pip._internal.cli import cmdoptions +from pip._internal.cli.cmdoptions import make_target_python +from pip._internal.cli.req_command import ( + RequirementCommand, + with_cleanup, +) +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.exceptions import CommandError, InstallationError +from pip._internal.locations import get_scheme +from pip._internal.metadata import get_environment +from pip._internal.models.installation_report import InstallationReport +from pip._internal.operations.build.build_tracker import get_build_tracker +from pip._internal.operations.check import ConflictDetails, check_install_conflicts +from pip._internal.req import install_given_reqs +from pip._internal.req.req_install import ( + InstallRequirement, + check_legacy_setup_py_options, +) +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.filesystem import test_writable_dir +from pip._internal.utils.logging import getLogger +from pip._internal.utils.misc import ( + check_externally_managed, + ensure_dir, + get_pip_version, + protect_pip_from_modification_on_windows, + warn_if_run_as_root, + write_output, +) +from pip._internal.utils.temp_dir import TempDirectory +from pip._internal.utils.virtualenv import ( + running_under_virtualenv, + virtualenv_no_global, +) +from pip._internal.wheel_builder import build, should_build_for_install_command + +logger = getLogger(__name__) + + +class InstallCommand(RequirementCommand): + """ + Install packages from: + + - PyPI (and other indexes) using requirement specifiers. + - VCS project urls. + - Local project directories. + - Local or remote source archives. + + pip also supports installing from "requirements files", which provide + an easy way to specify a whole environment to be installed. + """ + + usage = """ + %prog [options] [package-index-options] ... + %prog [options] -r [package-index-options] ... + %prog [options] [-e] ... + %prog [options] [-e] ... + %prog [options] ...""" + + def add_options(self) -> None: + self.cmd_opts.add_option(cmdoptions.requirements()) + self.cmd_opts.add_option(cmdoptions.constraints()) + self.cmd_opts.add_option(cmdoptions.no_deps()) + self.cmd_opts.add_option(cmdoptions.pre()) + + self.cmd_opts.add_option(cmdoptions.editable()) + self.cmd_opts.add_option( + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help=( + "Don't actually install anything, just print what would be. " + "Can be used in combination with --ignore-installed " + "to 'resolve' the requirements." + ), + ) + self.cmd_opts.add_option( + "-t", + "--target", + dest="target_dir", + metavar="dir", + default=None, + help=( + "Install packages into . " + "By default this will not replace existing files/folders in " + ". Use --upgrade to replace existing packages in " + "with new versions." + ), + ) + cmdoptions.add_target_python_options(self.cmd_opts) + + self.cmd_opts.add_option( + "--user", + dest="use_user_site", + action="store_true", + help=( + "Install to the Python user install directory for your " + "platform. Typically ~/.local/, or %APPDATA%\\Python on " + "Windows. (See the Python documentation for site.USER_BASE " + "for full details.)" + ), + ) + self.cmd_opts.add_option( + "--no-user", + dest="use_user_site", + action="store_false", + help=SUPPRESS_HELP, + ) + self.cmd_opts.add_option( + "--root", + dest="root_path", + metavar="dir", + default=None, + help="Install everything relative to this alternate root directory.", + ) + self.cmd_opts.add_option( + "--prefix", + dest="prefix_path", + metavar="dir", + default=None, + help=( + "Installation prefix where lib, bin and other top-level " + "folders are placed. Note that the resulting installation may " + "contain scripts and other resources which reference the " + "Python interpreter of pip, and not that of ``--prefix``. " + "See also the ``--python`` option if the intention is to " + "install packages into another (possibly pip-free) " + "environment." + ), + ) + + self.cmd_opts.add_option(cmdoptions.src()) + + self.cmd_opts.add_option( + "-U", + "--upgrade", + dest="upgrade", + action="store_true", + help=( + "Upgrade all specified packages to the newest available " + "version. The handling of dependencies depends on the " + "upgrade-strategy used." + ), + ) + + self.cmd_opts.add_option( + "--upgrade-strategy", + dest="upgrade_strategy", + default="only-if-needed", + choices=["only-if-needed", "eager"], + help=( + "Determines how dependency upgrading should be handled " + "[default: %default]. " + '"eager" - dependencies are upgraded regardless of ' + "whether the currently installed version satisfies the " + "requirements of the upgraded package(s). " + '"only-if-needed" - are upgraded only when they do not ' + "satisfy the requirements of the upgraded package(s)." + ), + ) + + self.cmd_opts.add_option( + "--force-reinstall", + dest="force_reinstall", + action="store_true", + help="Reinstall all packages even if they are already up-to-date.", + ) + + self.cmd_opts.add_option( + "-I", + "--ignore-installed", + dest="ignore_installed", + action="store_true", + help=( + "Ignore the installed packages, overwriting them. " + "This can break your system if the existing package " + "is of a different version or was installed " + "with a different package manager!" + ), + ) + + self.cmd_opts.add_option(cmdoptions.ignore_requires_python()) + self.cmd_opts.add_option(cmdoptions.no_build_isolation()) + self.cmd_opts.add_option(cmdoptions.use_pep517()) + self.cmd_opts.add_option(cmdoptions.no_use_pep517()) + self.cmd_opts.add_option(cmdoptions.check_build_deps()) + self.cmd_opts.add_option(cmdoptions.override_externally_managed()) + + self.cmd_opts.add_option(cmdoptions.config_settings()) + self.cmd_opts.add_option(cmdoptions.global_options()) + + self.cmd_opts.add_option( + "--compile", + action="store_true", + dest="compile", + default=True, + help="Compile Python source files to bytecode", + ) + + self.cmd_opts.add_option( + "--no-compile", + action="store_false", + dest="compile", + help="Do not compile Python source files to bytecode", + ) + + self.cmd_opts.add_option( + "--no-warn-script-location", + action="store_false", + dest="warn_script_location", + default=True, + help="Do not warn when installing scripts outside PATH", + ) + self.cmd_opts.add_option( + "--no-warn-conflicts", + action="store_false", + dest="warn_about_conflicts", + default=True, + help="Do not warn about broken dependencies", + ) + self.cmd_opts.add_option(cmdoptions.no_binary()) + self.cmd_opts.add_option(cmdoptions.only_binary()) + self.cmd_opts.add_option(cmdoptions.prefer_binary()) + self.cmd_opts.add_option(cmdoptions.require_hashes()) + self.cmd_opts.add_option(cmdoptions.progress_bar()) + self.cmd_opts.add_option(cmdoptions.root_user_action()) + + index_opts = cmdoptions.make_option_group( + cmdoptions.index_group, + self.parser, + ) + + self.parser.insert_option_group(0, index_opts) + self.parser.insert_option_group(0, self.cmd_opts) + + self.cmd_opts.add_option( + "--report", + dest="json_report_file", + metavar="file", + default=None, + help=( + "Generate a JSON file describing what pip did to install " + "the provided requirements. " + "Can be used in combination with --dry-run and --ignore-installed " + "to 'resolve' the requirements. " + "When - is used as file name it writes to stdout. " + "When writing to stdout, please combine with the --quiet option " + "to avoid mixing pip logging output with JSON output." + ), + ) + + @with_cleanup + def run(self, options: Values, args: List[str]) -> int: + if options.use_user_site and options.target_dir is not None: + raise CommandError("Can not combine '--user' and '--target'") + + # Check whether the environment we're installing into is externally + # managed, as specified in PEP 668. Specifying --root, --target, or + # --prefix disables the check, since there's no reliable way to locate + # the EXTERNALLY-MANAGED file for those cases. An exception is also + # made specifically for "--dry-run --report" for convenience. + installing_into_current_environment = ( + not (options.dry_run and options.json_report_file) + and options.root_path is None + and options.target_dir is None + and options.prefix_path is None + ) + if ( + installing_into_current_environment + and not options.override_externally_managed + ): + check_externally_managed() + + upgrade_strategy = "to-satisfy-only" + if options.upgrade: + upgrade_strategy = options.upgrade_strategy + + cmdoptions.check_dist_restriction(options, check_target=True) + + logger.verbose("Using %s", get_pip_version()) + options.use_user_site = decide_user_install( + options.use_user_site, + prefix_path=options.prefix_path, + target_dir=options.target_dir, + root_path=options.root_path, + isolated_mode=options.isolated_mode, + ) + + target_temp_dir: Optional[TempDirectory] = None + target_temp_dir_path: Optional[str] = None + if options.target_dir: + options.ignore_installed = True + options.target_dir = os.path.abspath(options.target_dir) + if ( + # fmt: off + os.path.exists(options.target_dir) and + not os.path.isdir(options.target_dir) + # fmt: on + ): + raise CommandError( + "Target path exists but is not a directory, will not continue." + ) + + # Create a target directory for using with the target option + target_temp_dir = TempDirectory(kind="target") + target_temp_dir_path = target_temp_dir.path + self.enter_context(target_temp_dir) + + global_options = options.global_options or [] + + session = self.get_default_session(options) + + target_python = make_target_python(options) + finder = self._build_package_finder( + options=options, + session=session, + target_python=target_python, + ignore_requires_python=options.ignore_requires_python, + ) + build_tracker = self.enter_context(get_build_tracker()) + + directory = TempDirectory( + delete=not options.no_clean, + kind="install", + globally_managed=True, + ) + + try: + reqs = self.get_requirements(args, options, finder, session) + check_legacy_setup_py_options(options, reqs) + + wheel_cache = WheelCache(options.cache_dir) + + # Only when installing is it permitted to use PEP 660. + # In other circumstances (pip wheel, pip download) we generate + # regular (i.e. non editable) metadata and wheels. + for req in reqs: + req.permit_editable_wheels = True + + preparer = self.make_requirement_preparer( + temp_build_dir=directory, + options=options, + build_tracker=build_tracker, + session=session, + finder=finder, + use_user_site=options.use_user_site, + verbosity=self.verbosity, + ) + resolver = self.make_resolver( + preparer=preparer, + finder=finder, + options=options, + wheel_cache=wheel_cache, + use_user_site=options.use_user_site, + ignore_installed=options.ignore_installed, + ignore_requires_python=options.ignore_requires_python, + force_reinstall=options.force_reinstall, + upgrade_strategy=upgrade_strategy, + use_pep517=options.use_pep517, + py_version_info=options.python_version, + ) + + self.trace_basic_info(finder) + + requirement_set = resolver.resolve( + reqs, check_supported_wheels=not options.target_dir + ) + + if options.json_report_file: + report = InstallationReport(requirement_set.requirements_to_install) + if options.json_report_file == "-": + print_json(data=report.to_dict()) + else: + with open(options.json_report_file, "w", encoding="utf-8") as f: + json.dump(report.to_dict(), f, indent=2, ensure_ascii=False) + + if options.dry_run: + would_install_items = sorted( + (r.metadata["name"], r.metadata["version"]) + for r in requirement_set.requirements_to_install + ) + if would_install_items: + write_output( + "Would install %s", + " ".join("-".join(item) for item in would_install_items), + ) + return SUCCESS + + try: + pip_req = requirement_set.get_requirement("pip") + except KeyError: + modifying_pip = False + else: + # If we're not replacing an already installed pip, + # we're not modifying it. + modifying_pip = pip_req.satisfied_by is None + protect_pip_from_modification_on_windows(modifying_pip=modifying_pip) + + reqs_to_build = [ + r + for r in requirement_set.requirements.values() + if should_build_for_install_command(r) + ] + + _, build_failures = build( + reqs_to_build, + wheel_cache=wheel_cache, + verify=True, + build_options=[], + global_options=global_options, + ) + + if build_failures: + raise InstallationError( + "Failed to build installable wheels for some " + "pyproject.toml based projects ({})".format( + ", ".join(r.name for r in build_failures) # type: ignore + ) + ) + + to_install = resolver.get_installation_order(requirement_set) + + # Check for conflicts in the package set we're installing. + conflicts: Optional[ConflictDetails] = None + should_warn_about_conflicts = ( + not options.ignore_dependencies and options.warn_about_conflicts + ) + if should_warn_about_conflicts: + conflicts = self._determine_conflicts(to_install) + + # Don't warn about script install locations if + # --target or --prefix has been specified + warn_script_location = options.warn_script_location + if options.target_dir or options.prefix_path: + warn_script_location = False + + installed = install_given_reqs( + to_install, + global_options, + root=options.root_path, + home=target_temp_dir_path, + prefix=options.prefix_path, + warn_script_location=warn_script_location, + use_user_site=options.use_user_site, + pycompile=options.compile, + ) + + lib_locations = get_lib_location_guesses( + user=options.use_user_site, + home=target_temp_dir_path, + root=options.root_path, + prefix=options.prefix_path, + isolated=options.isolated_mode, + ) + env = get_environment(lib_locations) + + # Display a summary of installed packages, with extra care to + # display a package name as it was requested by the user. + installed.sort(key=operator.attrgetter("name")) + summary = [] + installed_versions = {} + for distribution in env.iter_all_distributions(): + installed_versions[distribution.canonical_name] = distribution.version + for package in installed: + display_name = package.name + version = installed_versions.get(canonicalize_name(display_name), None) + if version: + text = f"{display_name}-{version}" + else: + text = display_name + summary.append(text) + + if conflicts is not None: + self._warn_about_conflicts( + conflicts, + resolver_variant=self.determine_resolver_variant(options), + ) + + installed_desc = " ".join(summary) + if installed_desc: + write_output( + "Successfully installed %s", + installed_desc, + ) + except OSError as error: + show_traceback = self.verbosity >= 1 + + message = create_os_error_message( + error, + show_traceback, + options.use_user_site, + ) + logger.error(message, exc_info=show_traceback) + + return ERROR + + if options.target_dir: + assert target_temp_dir + self._handle_target_dir( + options.target_dir, target_temp_dir, options.upgrade + ) + if options.root_user_action == "warn": + warn_if_run_as_root() + return SUCCESS + + def _handle_target_dir( + self, target_dir: str, target_temp_dir: TempDirectory, upgrade: bool + ) -> None: + ensure_dir(target_dir) + + # Checking both purelib and platlib directories for installed + # packages to be moved to target directory + lib_dir_list = [] + + # Checking both purelib and platlib directories for installed + # packages to be moved to target directory + scheme = get_scheme("", home=target_temp_dir.path) + purelib_dir = scheme.purelib + platlib_dir = scheme.platlib + data_dir = scheme.data + + if os.path.exists(purelib_dir): + lib_dir_list.append(purelib_dir) + if os.path.exists(platlib_dir) and platlib_dir != purelib_dir: + lib_dir_list.append(platlib_dir) + if os.path.exists(data_dir): + lib_dir_list.append(data_dir) + + for lib_dir in lib_dir_list: + for item in os.listdir(lib_dir): + if lib_dir == data_dir: + ddir = os.path.join(data_dir, item) + if any(s.startswith(ddir) for s in lib_dir_list[:-1]): + continue + target_item_dir = os.path.join(target_dir, item) + if os.path.exists(target_item_dir): + if not upgrade: + logger.warning( + "Target directory %s already exists. Specify " + "--upgrade to force replacement.", + target_item_dir, + ) + continue + if os.path.islink(target_item_dir): + logger.warning( + "Target directory %s already exists and is " + "a link. pip will not automatically replace " + "links, please remove if replacement is " + "desired.", + target_item_dir, + ) + continue + if os.path.isdir(target_item_dir): + shutil.rmtree(target_item_dir) + else: + os.remove(target_item_dir) + + shutil.move(os.path.join(lib_dir, item), target_item_dir) + + def _determine_conflicts( + self, to_install: List[InstallRequirement] + ) -> Optional[ConflictDetails]: + try: + return check_install_conflicts(to_install) + except Exception: + logger.exception( + "Error while checking for conflicts. Please file an issue on " + "pip's issue tracker: https://github.com/pypa/pip/issues/new" + ) + return None + + def _warn_about_conflicts( + self, conflict_details: ConflictDetails, resolver_variant: str + ) -> None: + package_set, (missing, conflicting) = conflict_details + if not missing and not conflicting: + return + + parts: List[str] = [] + if resolver_variant == "legacy": + parts.append( + "pip's legacy dependency resolver does not consider dependency " + "conflicts when selecting packages. This behaviour is the " + "source of the following dependency conflicts." + ) + else: + assert resolver_variant == "resolvelib" + parts.append( + "pip's dependency resolver does not currently take into account " + "all the packages that are installed. This behaviour is the " + "source of the following dependency conflicts." + ) + + # NOTE: There is some duplication here, with commands/check.py + for project_name in missing: + version = package_set[project_name][0] + for dependency in missing[project_name]: + message = ( + f"{project_name} {version} requires {dependency[1]}, " + "which is not installed." + ) + parts.append(message) + + for project_name in conflicting: + version = package_set[project_name][0] + for dep_name, dep_version, req in conflicting[project_name]: + message = ( + "{name} {version} requires {requirement}, but {you} have " + "{dep_name} {dep_version} which is incompatible." + ).format( + name=project_name, + version=version, + requirement=req, + dep_name=dep_name, + dep_version=dep_version, + you=("you" if resolver_variant == "resolvelib" else "you'll"), + ) + parts.append(message) + + logger.critical("\n".join(parts)) + + +def get_lib_location_guesses( + user: bool = False, + home: Optional[str] = None, + root: Optional[str] = None, + isolated: bool = False, + prefix: Optional[str] = None, +) -> List[str]: + scheme = get_scheme( + "", + user=user, + home=home, + root=root, + isolated=isolated, + prefix=prefix, + ) + return [scheme.purelib, scheme.platlib] + + +def site_packages_writable(root: Optional[str], isolated: bool) -> bool: + return all( + test_writable_dir(d) + for d in set(get_lib_location_guesses(root=root, isolated=isolated)) + ) + + +def decide_user_install( + use_user_site: Optional[bool], + prefix_path: Optional[str] = None, + target_dir: Optional[str] = None, + root_path: Optional[str] = None, + isolated_mode: bool = False, +) -> bool: + """Determine whether to do a user install based on the input options. + + If use_user_site is False, no additional checks are done. + If use_user_site is True, it is checked for compatibility with other + options. + If use_user_site is None, the default behaviour depends on the environment, + which is provided by the other arguments. + """ + # In some cases (config from tox), use_user_site can be set to an integer + # rather than a bool, which 'use_user_site is False' wouldn't catch. + if (use_user_site is not None) and (not use_user_site): + logger.debug("Non-user install by explicit request") + return False + + if use_user_site: + if prefix_path: + raise CommandError( + "Can not combine '--user' and '--prefix' as they imply " + "different installation locations" + ) + if virtualenv_no_global(): + raise InstallationError( + "Can not perform a '--user' install. User site-packages " + "are not visible in this virtualenv." + ) + logger.debug("User install by explicit request") + return True + + # If we are here, user installs have not been explicitly requested/avoided + assert use_user_site is None + + # user install incompatible with --prefix/--target + if prefix_path or target_dir: + logger.debug("Non-user install due to --prefix or --target option") + return False + + # If user installs are not enabled, choose a non-user install + if not site.ENABLE_USER_SITE: + logger.debug("Non-user install because user site-packages disabled") + return False + + # If we have permission for a non-user install, do that, + # otherwise do a user install. + if site_packages_writable(root=root_path, isolated=isolated_mode): + logger.debug("Non-user install because site-packages writeable") + return False + + logger.info( + "Defaulting to user installation because normal site-packages " + "is not writeable" + ) + return True + + +def create_os_error_message( + error: OSError, show_traceback: bool, using_user_site: bool +) -> str: + """Format an error message for an OSError + + It may occur anytime during the execution of the install command. + """ + parts = [] + + # Mention the error if we are not going to show a traceback + parts.append("Could not install packages due to an OSError") + if not show_traceback: + parts.append(": ") + parts.append(str(error)) + else: + parts.append(".") + + # Spilt the error indication from a helper message (if any) + parts[-1] += "\n" + + # Suggest useful actions to the user: + # (1) using user site-packages or (2) verifying the permissions + if error.errno == errno.EACCES: + user_option_part = "Consider using the `--user` option" + permissions_part = "Check the permissions" + + if not running_under_virtualenv() and not using_user_site: + parts.extend( + [ + user_option_part, + " or ", + permissions_part.lower(), + ] + ) + else: + parts.append(permissions_part) + parts.append(".\n") + + # Suggest the user to enable Long Paths if path length is + # more than 260 + if ( + WINDOWS + and error.errno == errno.ENOENT + and error.filename + and len(error.filename) > 260 + ): + parts.append( + "HINT: This error might have occurred since " + "this system does not have Windows Long Path " + "support enabled. You can find information on " + "how to enable this at " + "https://pip.pypa.io/warnings/enable-long-paths\n" + ) + + return "".join(parts).strip() + "\n" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/list.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/list.py new file mode 100644 index 0000000000000000000000000000000000000000..8494370241064698d5aa891a10fce6bf3b8c1722 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/list.py @@ -0,0 +1,375 @@ +import json +import logging +from optparse import Values +from typing import TYPE_CHECKING, Generator, List, Optional, Sequence, Tuple, cast + +from pip._vendor.packaging.utils import canonicalize_name +from pip._vendor.packaging.version import Version + +from pip._internal.cli import cmdoptions +from pip._internal.cli.index_command import IndexGroupCommand +from pip._internal.cli.status_codes import SUCCESS +from pip._internal.exceptions import CommandError +from pip._internal.metadata import BaseDistribution, get_environment +from pip._internal.models.selection_prefs import SelectionPreferences +from pip._internal.utils.compat import stdlib_pkgs +from pip._internal.utils.misc import tabulate, write_output + +if TYPE_CHECKING: + from pip._internal.index.package_finder import PackageFinder + from pip._internal.network.session import PipSession + + class _DistWithLatestInfo(BaseDistribution): + """Give the distribution object a couple of extra fields. + + These will be populated during ``get_outdated()``. This is dirty but + makes the rest of the code much cleaner. + """ + + latest_version: Version + latest_filetype: str + + _ProcessedDists = Sequence[_DistWithLatestInfo] + + +logger = logging.getLogger(__name__) + + +class ListCommand(IndexGroupCommand): + """ + List installed packages, including editables. + + Packages are listed in a case-insensitive sorted order. + """ + + ignore_require_venv = True + usage = """ + %prog [options]""" + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-o", + "--outdated", + action="store_true", + default=False, + help="List outdated packages", + ) + self.cmd_opts.add_option( + "-u", + "--uptodate", + action="store_true", + default=False, + help="List uptodate packages", + ) + self.cmd_opts.add_option( + "-e", + "--editable", + action="store_true", + default=False, + help="List editable projects.", + ) + self.cmd_opts.add_option( + "-l", + "--local", + action="store_true", + default=False, + help=( + "If in a virtualenv that has global access, do not list " + "globally-installed packages." + ), + ) + self.cmd_opts.add_option( + "--user", + dest="user", + action="store_true", + default=False, + help="Only output packages installed in user-site.", + ) + self.cmd_opts.add_option(cmdoptions.list_path()) + self.cmd_opts.add_option( + "--pre", + action="store_true", + default=False, + help=( + "Include pre-release and development versions. By default, " + "pip only finds stable versions." + ), + ) + + self.cmd_opts.add_option( + "--format", + action="store", + dest="list_format", + default="columns", + choices=("columns", "freeze", "json"), + help=( + "Select the output format among: columns (default), freeze, or json. " + "The 'freeze' format cannot be used with the --outdated option." + ), + ) + + self.cmd_opts.add_option( + "--not-required", + action="store_true", + dest="not_required", + help="List packages that are not dependencies of installed packages.", + ) + + self.cmd_opts.add_option( + "--exclude-editable", + action="store_false", + dest="include_editable", + help="Exclude editable package from output.", + ) + self.cmd_opts.add_option( + "--include-editable", + action="store_true", + dest="include_editable", + help="Include editable package from output.", + default=True, + ) + self.cmd_opts.add_option(cmdoptions.list_exclude()) + index_opts = cmdoptions.make_option_group(cmdoptions.index_group, self.parser) + + self.parser.insert_option_group(0, index_opts) + self.parser.insert_option_group(0, self.cmd_opts) + + def handle_pip_version_check(self, options: Values) -> None: + if options.outdated or options.uptodate: + super().handle_pip_version_check(options) + + def _build_package_finder( + self, options: Values, session: "PipSession" + ) -> "PackageFinder": + """ + Create a package finder appropriate to this list command. + """ + # Lazy import the heavy index modules as most list invocations won't need 'em. + from pip._internal.index.collector import LinkCollector + from pip._internal.index.package_finder import PackageFinder + + link_collector = LinkCollector.create(session, options=options) + + # Pass allow_yanked=False to ignore yanked versions. + selection_prefs = SelectionPreferences( + allow_yanked=False, + allow_all_prereleases=options.pre, + ) + + return PackageFinder.create( + link_collector=link_collector, + selection_prefs=selection_prefs, + ) + + def run(self, options: Values, args: List[str]) -> int: + if options.outdated and options.uptodate: + raise CommandError("Options --outdated and --uptodate cannot be combined.") + + if options.outdated and options.list_format == "freeze": + raise CommandError( + "List format 'freeze' cannot be used with the --outdated option." + ) + + cmdoptions.check_list_path_option(options) + + skip = set(stdlib_pkgs) + if options.excludes: + skip.update(canonicalize_name(n) for n in options.excludes) + + packages: _ProcessedDists = [ + cast("_DistWithLatestInfo", d) + for d in get_environment(options.path).iter_installed_distributions( + local_only=options.local, + user_only=options.user, + editables_only=options.editable, + include_editables=options.include_editable, + skip=skip, + ) + ] + + # get_not_required must be called firstly in order to find and + # filter out all dependencies correctly. Otherwise a package + # can't be identified as requirement because some parent packages + # could be filtered out before. + if options.not_required: + packages = self.get_not_required(packages, options) + + if options.outdated: + packages = self.get_outdated(packages, options) + elif options.uptodate: + packages = self.get_uptodate(packages, options) + + self.output_package_listing(packages, options) + return SUCCESS + + def get_outdated( + self, packages: "_ProcessedDists", options: Values + ) -> "_ProcessedDists": + return [ + dist + for dist in self.iter_packages_latest_infos(packages, options) + if dist.latest_version > dist.version + ] + + def get_uptodate( + self, packages: "_ProcessedDists", options: Values + ) -> "_ProcessedDists": + return [ + dist + for dist in self.iter_packages_latest_infos(packages, options) + if dist.latest_version == dist.version + ] + + def get_not_required( + self, packages: "_ProcessedDists", options: Values + ) -> "_ProcessedDists": + dep_keys = { + canonicalize_name(dep.name) + for dist in packages + for dep in (dist.iter_dependencies() or ()) + } + + # Create a set to remove duplicate packages, and cast it to a list + # to keep the return type consistent with get_outdated and + # get_uptodate + return list({pkg for pkg in packages if pkg.canonical_name not in dep_keys}) + + def iter_packages_latest_infos( + self, packages: "_ProcessedDists", options: Values + ) -> Generator["_DistWithLatestInfo", None, None]: + with self._build_session(options) as session: + finder = self._build_package_finder(options, session) + + def latest_info( + dist: "_DistWithLatestInfo", + ) -> Optional["_DistWithLatestInfo"]: + all_candidates = finder.find_all_candidates(dist.canonical_name) + if not options.pre: + # Remove prereleases + all_candidates = [ + candidate + for candidate in all_candidates + if not candidate.version.is_prerelease + ] + + evaluator = finder.make_candidate_evaluator( + project_name=dist.canonical_name, + ) + best_candidate = evaluator.sort_best_candidate(all_candidates) + if best_candidate is None: + return None + + remote_version = best_candidate.version + if best_candidate.link.is_wheel: + typ = "wheel" + else: + typ = "sdist" + dist.latest_version = remote_version + dist.latest_filetype = typ + return dist + + for dist in map(latest_info, packages): + if dist is not None: + yield dist + + def output_package_listing( + self, packages: "_ProcessedDists", options: Values + ) -> None: + packages = sorted( + packages, + key=lambda dist: dist.canonical_name, + ) + if options.list_format == "columns" and packages: + data, header = format_for_columns(packages, options) + self.output_package_listing_columns(data, header) + elif options.list_format == "freeze": + for dist in packages: + if options.verbose >= 1: + write_output( + "%s==%s (%s)", dist.raw_name, dist.version, dist.location + ) + else: + write_output("%s==%s", dist.raw_name, dist.version) + elif options.list_format == "json": + write_output(format_for_json(packages, options)) + + def output_package_listing_columns( + self, data: List[List[str]], header: List[str] + ) -> None: + # insert the header first: we need to know the size of column names + if len(data) > 0: + data.insert(0, header) + + pkg_strings, sizes = tabulate(data) + + # Create and add a separator. + if len(data) > 0: + pkg_strings.insert(1, " ".join("-" * x for x in sizes)) + + for val in pkg_strings: + write_output(val) + + +def format_for_columns( + pkgs: "_ProcessedDists", options: Values +) -> Tuple[List[List[str]], List[str]]: + """ + Convert the package data into something usable + by output_package_listing_columns. + """ + header = ["Package", "Version"] + + running_outdated = options.outdated + if running_outdated: + header.extend(["Latest", "Type"]) + + has_editables = any(x.editable for x in pkgs) + if has_editables: + header.append("Editable project location") + + if options.verbose >= 1: + header.append("Location") + if options.verbose >= 1: + header.append("Installer") + + data = [] + for proj in pkgs: + # if we're working on the 'outdated' list, separate out the + # latest_version and type + row = [proj.raw_name, proj.raw_version] + + if running_outdated: + row.append(str(proj.latest_version)) + row.append(proj.latest_filetype) + + if has_editables: + row.append(proj.editable_project_location or "") + + if options.verbose >= 1: + row.append(proj.location or "") + if options.verbose >= 1: + row.append(proj.installer) + + data.append(row) + + return data, header + + +def format_for_json(packages: "_ProcessedDists", options: Values) -> str: + data = [] + for dist in packages: + info = { + "name": dist.raw_name, + "version": str(dist.version), + } + if options.verbose >= 1: + info["location"] = dist.location or "" + info["installer"] = dist.installer + if options.outdated: + info["latest_version"] = str(dist.latest_version) + info["latest_filetype"] = dist.latest_filetype + editable_project_location = dist.editable_project_location + if editable_project_location: + info["editable_project_location"] = editable_project_location + data.append(info) + return json.dumps(data) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/search.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/search.py new file mode 100644 index 0000000000000000000000000000000000000000..74b8d656b4749d15cc47aebccd9ca22f74357f25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/search.py @@ -0,0 +1,172 @@ +import logging +import shutil +import sys +import textwrap +import xmlrpc.client +from collections import OrderedDict +from optparse import Values +from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict + +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.cli.base_command import Command +from pip._internal.cli.req_command import SessionCommandMixin +from pip._internal.cli.status_codes import NO_MATCHES_FOUND, SUCCESS +from pip._internal.exceptions import CommandError +from pip._internal.metadata import get_default_environment +from pip._internal.models.index import PyPI +from pip._internal.network.xmlrpc import PipXmlrpcTransport +from pip._internal.utils.logging import indent_log +from pip._internal.utils.misc import write_output + +if TYPE_CHECKING: + + class TransformedHit(TypedDict): + name: str + summary: str + versions: List[str] + + +logger = logging.getLogger(__name__) + + +class SearchCommand(Command, SessionCommandMixin): + """Search for PyPI packages whose name or summary contains .""" + + usage = """ + %prog [options] """ + ignore_require_venv = True + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-i", + "--index", + dest="index", + metavar="URL", + default=PyPI.pypi_url, + help="Base URL of Python Package Index (default %default)", + ) + + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + if not args: + raise CommandError("Missing required argument (search query).") + query = args + pypi_hits = self.search(query, options) + hits = transform_hits(pypi_hits) + + terminal_width = None + if sys.stdout.isatty(): + terminal_width = shutil.get_terminal_size()[0] + + print_results(hits, terminal_width=terminal_width) + if pypi_hits: + return SUCCESS + return NO_MATCHES_FOUND + + def search(self, query: List[str], options: Values) -> List[Dict[str, str]]: + index_url = options.index + + session = self.get_default_session(options) + + transport = PipXmlrpcTransport(index_url, session) + pypi = xmlrpc.client.ServerProxy(index_url, transport) + try: + hits = pypi.search({"name": query, "summary": query}, "or") + except xmlrpc.client.Fault as fault: + message = ( + f"XMLRPC request failed [code: {fault.faultCode}]\n{fault.faultString}" + ) + raise CommandError(message) + assert isinstance(hits, list) + return hits + + +def transform_hits(hits: List[Dict[str, str]]) -> List["TransformedHit"]: + """ + The list from pypi is really a list of versions. We want a list of + packages with the list of versions stored inline. This converts the + list from pypi into one we can use. + """ + packages: Dict[str, TransformedHit] = OrderedDict() + for hit in hits: + name = hit["name"] + summary = hit["summary"] + version = hit["version"] + + if name not in packages.keys(): + packages[name] = { + "name": name, + "summary": summary, + "versions": [version], + } + else: + packages[name]["versions"].append(version) + + # if this is the highest version, replace summary and score + if version == highest_version(packages[name]["versions"]): + packages[name]["summary"] = summary + + return list(packages.values()) + + +def print_dist_installation_info(name: str, latest: str) -> None: + env = get_default_environment() + dist = env.get_distribution(name) + if dist is not None: + with indent_log(): + if dist.version == latest: + write_output("INSTALLED: %s (latest)", dist.version) + else: + write_output("INSTALLED: %s", dist.version) + if parse_version(latest).pre: + write_output( + "LATEST: %s (pre-release; install" + " with `pip install --pre`)", + latest, + ) + else: + write_output("LATEST: %s", latest) + + +def print_results( + hits: List["TransformedHit"], + name_column_width: Optional[int] = None, + terminal_width: Optional[int] = None, +) -> None: + if not hits: + return + if name_column_width is None: + name_column_width = ( + max( + [ + len(hit["name"]) + len(highest_version(hit.get("versions", ["-"]))) + for hit in hits + ] + ) + + 4 + ) + + for hit in hits: + name = hit["name"] + summary = hit["summary"] or "" + latest = highest_version(hit.get("versions", ["-"])) + if terminal_width is not None: + target_width = terminal_width - name_column_width - 5 + if target_width > 10: + # wrap and indent summary to fit terminal + summary_lines = textwrap.wrap(summary, target_width) + summary = ("\n" + " " * (name_column_width + 3)).join(summary_lines) + + name_latest = f"{name} ({latest})" + line = f"{name_latest:{name_column_width}} - {summary}" + try: + write_output(line) + print_dist_installation_info(name, latest) + except UnicodeEncodeError: + pass + + +def highest_version(versions: List[str]) -> str: + return max(versions, key=parse_version) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/show.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/show.py new file mode 100644 index 0000000000000000000000000000000000000000..b47500cf8b47598a001fd9a20eb2e5e1f3788cda --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/show.py @@ -0,0 +1,224 @@ +import logging +from optparse import Values +from typing import Generator, Iterable, Iterator, List, NamedTuple, Optional + +from pip._vendor.packaging.requirements import InvalidRequirement +from pip._vendor.packaging.utils import canonicalize_name + +from pip._internal.cli.base_command import Command +from pip._internal.cli.status_codes import ERROR, SUCCESS +from pip._internal.metadata import BaseDistribution, get_default_environment +from pip._internal.utils.misc import write_output + +logger = logging.getLogger(__name__) + + +class ShowCommand(Command): + """ + Show information about one or more installed packages. + + The output is in RFC-compliant mail header format. + """ + + usage = """ + %prog [options] ...""" + ignore_require_venv = True + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-f", + "--files", + dest="files", + action="store_true", + default=False, + help="Show the full list of installed files for each package.", + ) + + self.parser.insert_option_group(0, self.cmd_opts) + + def run(self, options: Values, args: List[str]) -> int: + if not args: + logger.warning("ERROR: Please provide a package name or names.") + return ERROR + query = args + + results = search_packages_info(query) + if not print_results( + results, list_files=options.files, verbose=options.verbose + ): + return ERROR + return SUCCESS + + +class _PackageInfo(NamedTuple): + name: str + version: str + location: str + editable_project_location: Optional[str] + requires: List[str] + required_by: List[str] + installer: str + metadata_version: str + classifiers: List[str] + summary: str + homepage: str + project_urls: List[str] + author: str + author_email: str + license: str + license_expression: str + entry_points: List[str] + files: Optional[List[str]] + + +def search_packages_info(query: List[str]) -> Generator[_PackageInfo, None, None]: + """ + Gather details from installed distributions. Print distribution name, + version, location, and installed files. Installed files requires a + pip generated 'installed-files.txt' in the distributions '.egg-info' + directory. + """ + env = get_default_environment() + + installed = {dist.canonical_name: dist for dist in env.iter_all_distributions()} + query_names = [canonicalize_name(name) for name in query] + missing = sorted( + [name for name, pkg in zip(query, query_names) if pkg not in installed] + ) + if missing: + logger.warning("Package(s) not found: %s", ", ".join(missing)) + + def _get_requiring_packages(current_dist: BaseDistribution) -> Iterator[str]: + return ( + dist.metadata["Name"] or "UNKNOWN" + for dist in installed.values() + if current_dist.canonical_name + in {canonicalize_name(d.name) for d in dist.iter_dependencies()} + ) + + for query_name in query_names: + try: + dist = installed[query_name] + except KeyError: + continue + + try: + requires = sorted( + # Avoid duplicates in requirements (e.g. due to environment markers). + {req.name for req in dist.iter_dependencies()}, + key=str.lower, + ) + except InvalidRequirement: + requires = sorted(dist.iter_raw_dependencies(), key=str.lower) + + try: + required_by = sorted(_get_requiring_packages(dist), key=str.lower) + except InvalidRequirement: + required_by = ["#N/A"] + + try: + entry_points_text = dist.read_text("entry_points.txt") + entry_points = entry_points_text.splitlines(keepends=False) + except FileNotFoundError: + entry_points = [] + + files_iter = dist.iter_declared_entries() + if files_iter is None: + files: Optional[List[str]] = None + else: + files = sorted(files_iter) + + metadata = dist.metadata + + project_urls = metadata.get_all("Project-URL", []) + homepage = metadata.get("Home-page", "") + if not homepage: + # It's common that there is a "homepage" Project-URL, but Home-page + # remains unset (especially as PEP 621 doesn't surface the field). + # + # This logic was taken from PyPI's codebase. + for url in project_urls: + url_label, url = url.split(",", maxsplit=1) + normalized_label = ( + url_label.casefold().replace("-", "").replace("_", "").strip() + ) + if normalized_label == "homepage": + homepage = url.strip() + break + + yield _PackageInfo( + name=dist.raw_name, + version=dist.raw_version, + location=dist.location or "", + editable_project_location=dist.editable_project_location, + requires=requires, + required_by=required_by, + installer=dist.installer, + metadata_version=dist.metadata_version or "", + classifiers=metadata.get_all("Classifier", []), + summary=metadata.get("Summary", ""), + homepage=homepage, + project_urls=project_urls, + author=metadata.get("Author", ""), + author_email=metadata.get("Author-email", ""), + license=metadata.get("License", ""), + license_expression=metadata.get("License-Expression", ""), + entry_points=entry_points, + files=files, + ) + + +def print_results( + distributions: Iterable[_PackageInfo], + list_files: bool, + verbose: bool, +) -> bool: + """ + Print the information from installed distributions found. + """ + results_printed = False + for i, dist in enumerate(distributions): + results_printed = True + if i > 0: + write_output("---") + + metadata_version_tuple = tuple(map(int, dist.metadata_version.split("."))) + + write_output("Name: %s", dist.name) + write_output("Version: %s", dist.version) + write_output("Summary: %s", dist.summary) + write_output("Home-page: %s", dist.homepage) + write_output("Author: %s", dist.author) + write_output("Author-email: %s", dist.author_email) + if metadata_version_tuple >= (2, 4) and dist.license_expression: + write_output("License-Expression: %s", dist.license_expression) + else: + write_output("License: %s", dist.license) + write_output("Location: %s", dist.location) + if dist.editable_project_location is not None: + write_output( + "Editable project location: %s", dist.editable_project_location + ) + write_output("Requires: %s", ", ".join(dist.requires)) + write_output("Required-by: %s", ", ".join(dist.required_by)) + + if verbose: + write_output("Metadata-Version: %s", dist.metadata_version) + write_output("Installer: %s", dist.installer) + write_output("Classifiers:") + for classifier in dist.classifiers: + write_output(" %s", classifier) + write_output("Entry-points:") + for entry in dist.entry_points: + write_output(" %s", entry.strip()) + write_output("Project-URLs:") + for project_url in dist.project_urls: + write_output(" %s", project_url) + if list_files: + write_output("Files:") + if dist.files is None: + write_output("Cannot locate RECORD or installed-files.txt") + else: + for line in dist.files: + write_output(" %s", line.strip()) + return results_printed diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/wheel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..278719f4e0c6643e5f01bccdd440b1504bf5f7c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/commands/wheel.py @@ -0,0 +1,182 @@ +import logging +import os +import shutil +from optparse import Values +from typing import List + +from pip._internal.cache import WheelCache +from pip._internal.cli import cmdoptions +from pip._internal.cli.req_command import RequirementCommand, with_cleanup +from pip._internal.cli.status_codes import SUCCESS +from pip._internal.exceptions import CommandError +from pip._internal.operations.build.build_tracker import get_build_tracker +from pip._internal.req.req_install import ( + InstallRequirement, + check_legacy_setup_py_options, +) +from pip._internal.utils.misc import ensure_dir, normalize_path +from pip._internal.utils.temp_dir import TempDirectory +from pip._internal.wheel_builder import build, should_build_for_wheel_command + +logger = logging.getLogger(__name__) + + +class WheelCommand(RequirementCommand): + """ + Build Wheel archives for your requirements and dependencies. + + Wheel is a built-package format, and offers the advantage of not + recompiling your software during every install. For more details, see the + wheel docs: https://wheel.readthedocs.io/en/latest/ + + 'pip wheel' uses the build system interface as described here: + https://pip.pypa.io/en/stable/reference/build-system/ + + """ + + usage = """ + %prog [options] ... + %prog [options] -r ... + %prog [options] [-e] ... + %prog [options] [-e] ... + %prog [options] ...""" + + def add_options(self) -> None: + self.cmd_opts.add_option( + "-w", + "--wheel-dir", + dest="wheel_dir", + metavar="dir", + default=os.curdir, + help=( + "Build wheels into , where the default is the " + "current working directory." + ), + ) + self.cmd_opts.add_option(cmdoptions.no_binary()) + self.cmd_opts.add_option(cmdoptions.only_binary()) + self.cmd_opts.add_option(cmdoptions.prefer_binary()) + self.cmd_opts.add_option(cmdoptions.no_build_isolation()) + self.cmd_opts.add_option(cmdoptions.use_pep517()) + self.cmd_opts.add_option(cmdoptions.no_use_pep517()) + self.cmd_opts.add_option(cmdoptions.check_build_deps()) + self.cmd_opts.add_option(cmdoptions.constraints()) + self.cmd_opts.add_option(cmdoptions.editable()) + self.cmd_opts.add_option(cmdoptions.requirements()) + self.cmd_opts.add_option(cmdoptions.src()) + self.cmd_opts.add_option(cmdoptions.ignore_requires_python()) + self.cmd_opts.add_option(cmdoptions.no_deps()) + self.cmd_opts.add_option(cmdoptions.progress_bar()) + + self.cmd_opts.add_option( + "--no-verify", + dest="no_verify", + action="store_true", + default=False, + help="Don't verify if built wheel is valid.", + ) + + self.cmd_opts.add_option(cmdoptions.config_settings()) + self.cmd_opts.add_option(cmdoptions.build_options()) + self.cmd_opts.add_option(cmdoptions.global_options()) + + self.cmd_opts.add_option( + "--pre", + action="store_true", + default=False, + help=( + "Include pre-release and development versions. By default, " + "pip only finds stable versions." + ), + ) + + self.cmd_opts.add_option(cmdoptions.require_hashes()) + + index_opts = cmdoptions.make_option_group( + cmdoptions.index_group, + self.parser, + ) + + self.parser.insert_option_group(0, index_opts) + self.parser.insert_option_group(0, self.cmd_opts) + + @with_cleanup + def run(self, options: Values, args: List[str]) -> int: + session = self.get_default_session(options) + + finder = self._build_package_finder(options, session) + + options.wheel_dir = normalize_path(options.wheel_dir) + ensure_dir(options.wheel_dir) + + build_tracker = self.enter_context(get_build_tracker()) + + directory = TempDirectory( + delete=not options.no_clean, + kind="wheel", + globally_managed=True, + ) + + reqs = self.get_requirements(args, options, finder, session) + check_legacy_setup_py_options(options, reqs) + + wheel_cache = WheelCache(options.cache_dir) + + preparer = self.make_requirement_preparer( + temp_build_dir=directory, + options=options, + build_tracker=build_tracker, + session=session, + finder=finder, + download_dir=options.wheel_dir, + use_user_site=False, + verbosity=self.verbosity, + ) + + resolver = self.make_resolver( + preparer=preparer, + finder=finder, + options=options, + wheel_cache=wheel_cache, + ignore_requires_python=options.ignore_requires_python, + use_pep517=options.use_pep517, + ) + + self.trace_basic_info(finder) + + requirement_set = resolver.resolve(reqs, check_supported_wheels=True) + + reqs_to_build: List[InstallRequirement] = [] + for req in requirement_set.requirements.values(): + if req.is_wheel: + preparer.save_linked_requirement(req) + elif should_build_for_wheel_command(req): + reqs_to_build.append(req) + + preparer.prepare_linked_requirements_more(requirement_set.requirements.values()) + + # build wheels + build_successes, build_failures = build( + reqs_to_build, + wheel_cache=wheel_cache, + verify=(not options.no_verify), + build_options=options.build_options or [], + global_options=options.global_options or [], + ) + for req in build_successes: + assert req.link and req.link.is_wheel + assert req.local_file_path + # copy from cache to target directory + try: + shutil.copy(req.local_file_path, options.wheel_dir) + except OSError as e: + logger.warning( + "Building wheel for %s failed: %s", + req.name, + e, + ) + build_failures.append(req) + if len(build_failures) != 0: + raise CommandError("Failed to build one or more wheels") + + return SUCCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/configuration.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..ffeda1d47a1bd2ec27174ed46b34ce9c7676bbca --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/configuration.py @@ -0,0 +1,383 @@ +"""Configuration management setup + +Some terminology: +- name + As written in config files. +- value + Value associated with a name +- key + Name combined with it's section (section.name) +- variant + A single word describing where the configuration key-value pair came from +""" + +import configparser +import locale +import os +import sys +from typing import Any, Dict, Iterable, List, NewType, Optional, Tuple + +from pip._internal.exceptions import ( + ConfigurationError, + ConfigurationFileCouldNotBeLoaded, +) +from pip._internal.utils import appdirs +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.logging import getLogger +from pip._internal.utils.misc import ensure_dir, enum + +RawConfigParser = configparser.RawConfigParser # Shorthand +Kind = NewType("Kind", str) + +CONFIG_BASENAME = "pip.ini" if WINDOWS else "pip.conf" +ENV_NAMES_IGNORED = "version", "help" + +# The kinds of configurations there are. +kinds = enum( + USER="user", # User Specific + GLOBAL="global", # System Wide + SITE="site", # [Virtual] Environment Specific + ENV="env", # from PIP_CONFIG_FILE + ENV_VAR="env-var", # from Environment Variables +) +OVERRIDE_ORDER = kinds.GLOBAL, kinds.USER, kinds.SITE, kinds.ENV, kinds.ENV_VAR +VALID_LOAD_ONLY = kinds.USER, kinds.GLOBAL, kinds.SITE + +logger = getLogger(__name__) + + +# NOTE: Maybe use the optionx attribute to normalize keynames. +def _normalize_name(name: str) -> str: + """Make a name consistent regardless of source (environment or file)""" + name = name.lower().replace("_", "-") + if name.startswith("--"): + name = name[2:] # only prefer long opts + return name + + +def _disassemble_key(name: str) -> List[str]: + if "." not in name: + error_message = ( + "Key does not contain dot separated section and key. " + f"Perhaps you wanted to use 'global.{name}' instead?" + ) + raise ConfigurationError(error_message) + return name.split(".", 1) + + +def get_configuration_files() -> Dict[Kind, List[str]]: + global_config_files = [ + os.path.join(path, CONFIG_BASENAME) for path in appdirs.site_config_dirs("pip") + ] + + site_config_file = os.path.join(sys.prefix, CONFIG_BASENAME) + legacy_config_file = os.path.join( + os.path.expanduser("~"), + "pip" if WINDOWS else ".pip", + CONFIG_BASENAME, + ) + new_config_file = os.path.join(appdirs.user_config_dir("pip"), CONFIG_BASENAME) + return { + kinds.GLOBAL: global_config_files, + kinds.SITE: [site_config_file], + kinds.USER: [legacy_config_file, new_config_file], + } + + +class Configuration: + """Handles management of configuration. + + Provides an interface to accessing and managing configuration files. + + This class converts provides an API that takes "section.key-name" style + keys and stores the value associated with it as "key-name" under the + section "section". + + This allows for a clean interface wherein the both the section and the + key-name are preserved in an easy to manage form in the configuration files + and the data stored is also nice. + """ + + def __init__(self, isolated: bool, load_only: Optional[Kind] = None) -> None: + super().__init__() + + if load_only is not None and load_only not in VALID_LOAD_ONLY: + raise ConfigurationError( + "Got invalid value for load_only - should be one of {}".format( + ", ".join(map(repr, VALID_LOAD_ONLY)) + ) + ) + self.isolated = isolated + self.load_only = load_only + + # Because we keep track of where we got the data from + self._parsers: Dict[Kind, List[Tuple[str, RawConfigParser]]] = { + variant: [] for variant in OVERRIDE_ORDER + } + self._config: Dict[Kind, Dict[str, Any]] = { + variant: {} for variant in OVERRIDE_ORDER + } + self._modified_parsers: List[Tuple[str, RawConfigParser]] = [] + + def load(self) -> None: + """Loads configuration from configuration files and environment""" + self._load_config_files() + if not self.isolated: + self._load_environment_vars() + + def get_file_to_edit(self) -> Optional[str]: + """Returns the file with highest priority in configuration""" + assert self.load_only is not None, "Need to be specified a file to be editing" + + try: + return self._get_parser_to_modify()[0] + except IndexError: + return None + + def items(self) -> Iterable[Tuple[str, Any]]: + """Returns key-value pairs like dict.items() representing the loaded + configuration + """ + return self._dictionary.items() + + def get_value(self, key: str) -> Any: + """Get a value from the configuration.""" + orig_key = key + key = _normalize_name(key) + try: + return self._dictionary[key] + except KeyError: + # disassembling triggers a more useful error message than simply + # "No such key" in the case that the key isn't in the form command.option + _disassemble_key(key) + raise ConfigurationError(f"No such key - {orig_key}") + + def set_value(self, key: str, value: Any) -> None: + """Modify a value in the configuration.""" + key = _normalize_name(key) + self._ensure_have_load_only() + + assert self.load_only + fname, parser = self._get_parser_to_modify() + + if parser is not None: + section, name = _disassemble_key(key) + + # Modify the parser and the configuration + if not parser.has_section(section): + parser.add_section(section) + parser.set(section, name, value) + + self._config[self.load_only][key] = value + self._mark_as_modified(fname, parser) + + def unset_value(self, key: str) -> None: + """Unset a value in the configuration.""" + orig_key = key + key = _normalize_name(key) + self._ensure_have_load_only() + + assert self.load_only + if key not in self._config[self.load_only]: + raise ConfigurationError(f"No such key - {orig_key}") + + fname, parser = self._get_parser_to_modify() + + if parser is not None: + section, name = _disassemble_key(key) + if not ( + parser.has_section(section) and parser.remove_option(section, name) + ): + # The option was not removed. + raise ConfigurationError( + "Fatal Internal error [id=1]. Please report as a bug." + ) + + # The section may be empty after the option was removed. + if not parser.items(section): + parser.remove_section(section) + self._mark_as_modified(fname, parser) + + del self._config[self.load_only][key] + + def save(self) -> None: + """Save the current in-memory state.""" + self._ensure_have_load_only() + + for fname, parser in self._modified_parsers: + logger.info("Writing to %s", fname) + + # Ensure directory exists. + ensure_dir(os.path.dirname(fname)) + + # Ensure directory's permission(need to be writeable) + try: + with open(fname, "w") as f: + parser.write(f) + except OSError as error: + raise ConfigurationError( + f"An error occurred while writing to the configuration file " + f"{fname}: {error}" + ) + + # + # Private routines + # + + def _ensure_have_load_only(self) -> None: + if self.load_only is None: + raise ConfigurationError("Needed a specific file to be modifying.") + logger.debug("Will be working with %s variant only", self.load_only) + + @property + def _dictionary(self) -> Dict[str, Any]: + """A dictionary representing the loaded configuration.""" + # NOTE: Dictionaries are not populated if not loaded. So, conditionals + # are not needed here. + retval = {} + + for variant in OVERRIDE_ORDER: + retval.update(self._config[variant]) + + return retval + + def _load_config_files(self) -> None: + """Loads configuration from configuration files""" + config_files = dict(self.iter_config_files()) + if config_files[kinds.ENV][0:1] == [os.devnull]: + logger.debug( + "Skipping loading configuration files due to " + "environment's PIP_CONFIG_FILE being os.devnull" + ) + return + + for variant, files in config_files.items(): + for fname in files: + # If there's specific variant set in `load_only`, load only + # that variant, not the others. + if self.load_only is not None and variant != self.load_only: + logger.debug("Skipping file '%s' (variant: %s)", fname, variant) + continue + + parser = self._load_file(variant, fname) + + # Keeping track of the parsers used + self._parsers[variant].append((fname, parser)) + + def _load_file(self, variant: Kind, fname: str) -> RawConfigParser: + logger.verbose("For variant '%s', will try loading '%s'", variant, fname) + parser = self._construct_parser(fname) + + for section in parser.sections(): + items = parser.items(section) + self._config[variant].update(self._normalized_keys(section, items)) + + return parser + + def _construct_parser(self, fname: str) -> RawConfigParser: + parser = configparser.RawConfigParser() + # If there is no such file, don't bother reading it but create the + # parser anyway, to hold the data. + # Doing this is useful when modifying and saving files, where we don't + # need to construct a parser. + if os.path.exists(fname): + locale_encoding = locale.getpreferredencoding(False) + try: + parser.read(fname, encoding=locale_encoding) + except UnicodeDecodeError: + # See https://github.com/pypa/pip/issues/4963 + raise ConfigurationFileCouldNotBeLoaded( + reason=f"contains invalid {locale_encoding} characters", + fname=fname, + ) + except configparser.Error as error: + # See https://github.com/pypa/pip/issues/4893 + raise ConfigurationFileCouldNotBeLoaded(error=error) + return parser + + def _load_environment_vars(self) -> None: + """Loads configuration from environment variables""" + self._config[kinds.ENV_VAR].update( + self._normalized_keys(":env:", self.get_environ_vars()) + ) + + def _normalized_keys( + self, section: str, items: Iterable[Tuple[str, Any]] + ) -> Dict[str, Any]: + """Normalizes items to construct a dictionary with normalized keys. + + This routine is where the names become keys and are made the same + regardless of source - configuration files or environment. + """ + normalized = {} + for name, val in items: + key = section + "." + _normalize_name(name) + normalized[key] = val + return normalized + + def get_environ_vars(self) -> Iterable[Tuple[str, str]]: + """Returns a generator with all environmental vars with prefix PIP_""" + for key, val in os.environ.items(): + if key.startswith("PIP_"): + name = key[4:].lower() + if name not in ENV_NAMES_IGNORED: + yield name, val + + # XXX: This is patched in the tests. + def iter_config_files(self) -> Iterable[Tuple[Kind, List[str]]]: + """Yields variant and configuration files associated with it. + + This should be treated like items of a dictionary. The order + here doesn't affect what gets overridden. That is controlled + by OVERRIDE_ORDER. However this does control the order they are + displayed to the user. It's probably most ergonomic to display + things in the same order as OVERRIDE_ORDER + """ + # SMELL: Move the conditions out of this function + + env_config_file = os.environ.get("PIP_CONFIG_FILE", None) + config_files = get_configuration_files() + + yield kinds.GLOBAL, config_files[kinds.GLOBAL] + + # per-user config is not loaded when env_config_file exists + should_load_user_config = not self.isolated and not ( + env_config_file and os.path.exists(env_config_file) + ) + if should_load_user_config: + # The legacy config file is overridden by the new config file + yield kinds.USER, config_files[kinds.USER] + + # virtualenv config + yield kinds.SITE, config_files[kinds.SITE] + + if env_config_file is not None: + yield kinds.ENV, [env_config_file] + else: + yield kinds.ENV, [] + + def get_values_in_config(self, variant: Kind) -> Dict[str, Any]: + """Get values present in a config file""" + return self._config[variant] + + def _get_parser_to_modify(self) -> Tuple[str, RawConfigParser]: + # Determine which parser to modify + assert self.load_only + parsers = self._parsers[self.load_only] + if not parsers: + # This should not happen if everything works correctly. + raise ConfigurationError( + "Fatal Internal error [id=2]. Please report as a bug." + ) + + # Use the highest priority parser. + return parsers[-1] + + # XXX: This is patched in the tests. + def _mark_as_modified(self, fname: str, parser: RawConfigParser) -> None: + file_parser_tuple = (fname, parser) + if file_parser_tuple not in self._modified_parsers: + self._modified_parsers.append(file_parser_tuple) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._dictionary!r})" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/exceptions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..45a876a850dfec2295ea982e1eadd5ab26364cd4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/exceptions.py @@ -0,0 +1,809 @@ +"""Exceptions used throughout package. + +This module MUST NOT try to import from anything within `pip._internal` to +operate. This is expected to be importable from any/all files within the +subpackage and, thus, should not depend on them. +""" + +import configparser +import contextlib +import locale +import logging +import pathlib +import re +import sys +from itertools import chain, groupby, repeat +from typing import TYPE_CHECKING, Dict, Iterator, List, Literal, Optional, Union + +from pip._vendor.packaging.requirements import InvalidRequirement +from pip._vendor.packaging.version import InvalidVersion +from pip._vendor.rich.console import Console, ConsoleOptions, RenderResult +from pip._vendor.rich.markup import escape +from pip._vendor.rich.text import Text + +if TYPE_CHECKING: + from hashlib import _Hash + + from pip._vendor.requests.models import Request, Response + + from pip._internal.metadata import BaseDistribution + from pip._internal.req.req_install import InstallRequirement + +logger = logging.getLogger(__name__) + + +# +# Scaffolding +# +def _is_kebab_case(s: str) -> bool: + return re.match(r"^[a-z]+(-[a-z]+)*$", s) is not None + + +def _prefix_with_indent( + s: Union[Text, str], + console: Console, + *, + prefix: str, + indent: str, +) -> Text: + if isinstance(s, Text): + text = s + else: + text = console.render_str(s) + + return console.render_str(prefix, overflow="ignore") + console.render_str( + f"\n{indent}", overflow="ignore" + ).join(text.split(allow_blank=True)) + + +class PipError(Exception): + """The base pip error.""" + + +class DiagnosticPipError(PipError): + """An error, that presents diagnostic information to the user. + + This contains a bunch of logic, to enable pretty presentation of our error + messages. Each error gets a unique reference. Each error can also include + additional context, a hint and/or a note -- which are presented with the + main error message in a consistent style. + + This is adapted from the error output styling in `sphinx-theme-builder`. + """ + + reference: str + + def __init__( + self, + *, + kind: 'Literal["error", "warning"]' = "error", + reference: Optional[str] = None, + message: Union[str, Text], + context: Optional[Union[str, Text]], + hint_stmt: Optional[Union[str, Text]], + note_stmt: Optional[Union[str, Text]] = None, + link: Optional[str] = None, + ) -> None: + # Ensure a proper reference is provided. + if reference is None: + assert hasattr(self, "reference"), "error reference not provided!" + reference = self.reference + assert _is_kebab_case(reference), "error reference must be kebab-case!" + + self.kind = kind + self.reference = reference + + self.message = message + self.context = context + + self.note_stmt = note_stmt + self.hint_stmt = hint_stmt + + self.link = link + + super().__init__(f"<{self.__class__.__name__}: {self.reference}>") + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__}(" + f"reference={self.reference!r}, " + f"message={self.message!r}, " + f"context={self.context!r}, " + f"note_stmt={self.note_stmt!r}, " + f"hint_stmt={self.hint_stmt!r}" + ")>" + ) + + def __rich_console__( + self, + console: Console, + options: ConsoleOptions, + ) -> RenderResult: + colour = "red" if self.kind == "error" else "yellow" + + yield f"[{colour} bold]{self.kind}[/]: [bold]{self.reference}[/]" + yield "" + + if not options.ascii_only: + # Present the main message, with relevant context indented. + if self.context is not None: + yield _prefix_with_indent( + self.message, + console, + prefix=f"[{colour}]×[/] ", + indent=f"[{colour}]│[/] ", + ) + yield _prefix_with_indent( + self.context, + console, + prefix=f"[{colour}]╰─>[/] ", + indent=f"[{colour}] [/] ", + ) + else: + yield _prefix_with_indent( + self.message, + console, + prefix="[red]×[/] ", + indent=" ", + ) + else: + yield self.message + if self.context is not None: + yield "" + yield self.context + + if self.note_stmt is not None or self.hint_stmt is not None: + yield "" + + if self.note_stmt is not None: + yield _prefix_with_indent( + self.note_stmt, + console, + prefix="[magenta bold]note[/]: ", + indent=" ", + ) + if self.hint_stmt is not None: + yield _prefix_with_indent( + self.hint_stmt, + console, + prefix="[cyan bold]hint[/]: ", + indent=" ", + ) + + if self.link is not None: + yield "" + yield f"Link: {self.link}" + + +# +# Actual Errors +# +class ConfigurationError(PipError): + """General exception in configuration""" + + +class InstallationError(PipError): + """General exception during installation""" + + +class MissingPyProjectBuildRequires(DiagnosticPipError): + """Raised when pyproject.toml has `build-system`, but no `build-system.requires`.""" + + reference = "missing-pyproject-build-system-requires" + + def __init__(self, *, package: str) -> None: + super().__init__( + message=f"Can not process {escape(package)}", + context=Text( + "This package has an invalid pyproject.toml file.\n" + "The [build-system] table is missing the mandatory `requires` key." + ), + note_stmt="This is an issue with the package mentioned above, not pip.", + hint_stmt=Text("See PEP 518 for the detailed specification."), + ) + + +class InvalidPyProjectBuildRequires(DiagnosticPipError): + """Raised when pyproject.toml an invalid `build-system.requires`.""" + + reference = "invalid-pyproject-build-system-requires" + + def __init__(self, *, package: str, reason: str) -> None: + super().__init__( + message=f"Can not process {escape(package)}", + context=Text( + "This package has an invalid `build-system.requires` key in " + f"pyproject.toml.\n{reason}" + ), + note_stmt="This is an issue with the package mentioned above, not pip.", + hint_stmt=Text("See PEP 518 for the detailed specification."), + ) + + +class NoneMetadataError(PipError): + """Raised when accessing a Distribution's "METADATA" or "PKG-INFO". + + This signifies an inconsistency, when the Distribution claims to have + the metadata file (if not, raise ``FileNotFoundError`` instead), but is + not actually able to produce its content. This may be due to permission + errors. + """ + + def __init__( + self, + dist: "BaseDistribution", + metadata_name: str, + ) -> None: + """ + :param dist: A Distribution object. + :param metadata_name: The name of the metadata being accessed + (can be "METADATA" or "PKG-INFO"). + """ + self.dist = dist + self.metadata_name = metadata_name + + def __str__(self) -> str: + # Use `dist` in the error message because its stringification + # includes more information, like the version and location. + return f"None {self.metadata_name} metadata found for distribution: {self.dist}" + + +class UserInstallationInvalid(InstallationError): + """A --user install is requested on an environment without user site.""" + + def __str__(self) -> str: + return "User base directory is not specified" + + +class InvalidSchemeCombination(InstallationError): + def __str__(self) -> str: + before = ", ".join(str(a) for a in self.args[:-1]) + return f"Cannot set {before} and {self.args[-1]} together" + + +class DistributionNotFound(InstallationError): + """Raised when a distribution cannot be found to satisfy a requirement""" + + +class RequirementsFileParseError(InstallationError): + """Raised when a general error occurs parsing a requirements file line.""" + + +class BestVersionAlreadyInstalled(PipError): + """Raised when the most up-to-date version of a package is already + installed.""" + + +class BadCommand(PipError): + """Raised when virtualenv or a command is not found""" + + +class CommandError(PipError): + """Raised when there is an error in command-line arguments""" + + +class PreviousBuildDirError(PipError): + """Raised when there's a previous conflicting build directory""" + + +class NetworkConnectionError(PipError): + """HTTP connection error""" + + def __init__( + self, + error_msg: str, + response: Optional["Response"] = None, + request: Optional["Request"] = None, + ) -> None: + """ + Initialize NetworkConnectionError with `request` and `response` + objects. + """ + self.response = response + self.request = request + self.error_msg = error_msg + if ( + self.response is not None + and not self.request + and hasattr(response, "request") + ): + self.request = self.response.request + super().__init__(error_msg, response, request) + + def __str__(self) -> str: + return str(self.error_msg) + + +class InvalidWheelFilename(InstallationError): + """Invalid wheel filename.""" + + +class UnsupportedWheel(InstallationError): + """Unsupported wheel.""" + + +class InvalidWheel(InstallationError): + """Invalid (e.g. corrupt) wheel.""" + + def __init__(self, location: str, name: str): + self.location = location + self.name = name + + def __str__(self) -> str: + return f"Wheel '{self.name}' located at {self.location} is invalid." + + +class MetadataInconsistent(InstallationError): + """Built metadata contains inconsistent information. + + This is raised when the metadata contains values (e.g. name and version) + that do not match the information previously obtained from sdist filename, + user-supplied ``#egg=`` value, or an install requirement name. + """ + + def __init__( + self, ireq: "InstallRequirement", field: str, f_val: str, m_val: str + ) -> None: + self.ireq = ireq + self.field = field + self.f_val = f_val + self.m_val = m_val + + def __str__(self) -> str: + return ( + f"Requested {self.ireq} has inconsistent {self.field}: " + f"expected {self.f_val!r}, but metadata has {self.m_val!r}" + ) + + +class MetadataInvalid(InstallationError): + """Metadata is invalid.""" + + def __init__(self, ireq: "InstallRequirement", error: str) -> None: + self.ireq = ireq + self.error = error + + def __str__(self) -> str: + return f"Requested {self.ireq} has invalid metadata: {self.error}" + + +class InstallationSubprocessError(DiagnosticPipError, InstallationError): + """A subprocess call failed.""" + + reference = "subprocess-exited-with-error" + + def __init__( + self, + *, + command_description: str, + exit_code: int, + output_lines: Optional[List[str]], + ) -> None: + if output_lines is None: + output_prompt = Text("See above for output.") + else: + output_prompt = ( + Text.from_markup(f"[red][{len(output_lines)} lines of output][/]\n") + + Text("".join(output_lines)) + + Text.from_markup(R"[red]\[end of output][/]") + ) + + super().__init__( + message=( + f"[green]{escape(command_description)}[/] did not run successfully.\n" + f"exit code: {exit_code}" + ), + context=output_prompt, + hint_stmt=None, + note_stmt=( + "This error originates from a subprocess, and is likely not a " + "problem with pip." + ), + ) + + self.command_description = command_description + self.exit_code = exit_code + + def __str__(self) -> str: + return f"{self.command_description} exited with {self.exit_code}" + + +class MetadataGenerationFailed(InstallationSubprocessError, InstallationError): + reference = "metadata-generation-failed" + + def __init__( + self, + *, + package_details: str, + ) -> None: + super(InstallationSubprocessError, self).__init__( + message="Encountered error while generating package metadata.", + context=escape(package_details), + hint_stmt="See above for details.", + note_stmt="This is an issue with the package mentioned above, not pip.", + ) + + def __str__(self) -> str: + return "metadata generation failed" + + +class HashErrors(InstallationError): + """Multiple HashError instances rolled into one for reporting""" + + def __init__(self) -> None: + self.errors: List[HashError] = [] + + def append(self, error: "HashError") -> None: + self.errors.append(error) + + def __str__(self) -> str: + lines = [] + self.errors.sort(key=lambda e: e.order) + for cls, errors_of_cls in groupby(self.errors, lambda e: e.__class__): + lines.append(cls.head) + lines.extend(e.body() for e in errors_of_cls) + if lines: + return "\n".join(lines) + return "" + + def __bool__(self) -> bool: + return bool(self.errors) + + +class HashError(InstallationError): + """ + A failure to verify a package against known-good hashes + + :cvar order: An int sorting hash exception classes by difficulty of + recovery (lower being harder), so the user doesn't bother fretting + about unpinned packages when he has deeper issues, like VCS + dependencies, to deal with. Also keeps error reports in a + deterministic order. + :cvar head: A section heading for display above potentially many + exceptions of this kind + :ivar req: The InstallRequirement that triggered this error. This is + pasted on after the exception is instantiated, because it's not + typically available earlier. + + """ + + req: Optional["InstallRequirement"] = None + head = "" + order: int = -1 + + def body(self) -> str: + """Return a summary of me for display under the heading. + + This default implementation simply prints a description of the + triggering requirement. + + :param req: The InstallRequirement that provoked this error, with + its link already populated by the resolver's _populate_link(). + + """ + return f" {self._requirement_name()}" + + def __str__(self) -> str: + return f"{self.head}\n{self.body()}" + + def _requirement_name(self) -> str: + """Return a description of the requirement that triggered me. + + This default implementation returns long description of the req, with + line numbers + + """ + return str(self.req) if self.req else "unknown package" + + +class VcsHashUnsupported(HashError): + """A hash was provided for a version-control-system-based requirement, but + we don't have a method for hashing those.""" + + order = 0 + head = ( + "Can't verify hashes for these requirements because we don't " + "have a way to hash version control repositories:" + ) + + +class DirectoryUrlHashUnsupported(HashError): + """A hash was provided for a version-control-system-based requirement, but + we don't have a method for hashing those.""" + + order = 1 + head = ( + "Can't verify hashes for these file:// requirements because they " + "point to directories:" + ) + + +class HashMissing(HashError): + """A hash was needed for a requirement but is absent.""" + + order = 2 + head = ( + "Hashes are required in --require-hashes mode, but they are " + "missing from some requirements. Here is a list of those " + "requirements along with the hashes their downloaded archives " + "actually had. Add lines like these to your requirements files to " + "prevent tampering. (If you did not enable --require-hashes " + "manually, note that it turns on automatically when any package " + "has a hash.)" + ) + + def __init__(self, gotten_hash: str) -> None: + """ + :param gotten_hash: The hash of the (possibly malicious) archive we + just downloaded + """ + self.gotten_hash = gotten_hash + + def body(self) -> str: + # Dodge circular import. + from pip._internal.utils.hashes import FAVORITE_HASH + + package = None + if self.req: + # In the case of URL-based requirements, display the original URL + # seen in the requirements file rather than the package name, + # so the output can be directly copied into the requirements file. + package = ( + self.req.original_link + if self.req.is_direct + # In case someone feeds something downright stupid + # to InstallRequirement's constructor. + else getattr(self.req, "req", None) + ) + return " {} --hash={}:{}".format( + package or "unknown package", FAVORITE_HASH, self.gotten_hash + ) + + +class HashUnpinned(HashError): + """A requirement had a hash specified but was not pinned to a specific + version.""" + + order = 3 + head = ( + "In --require-hashes mode, all requirements must have their " + "versions pinned with ==. These do not:" + ) + + +class HashMismatch(HashError): + """ + Distribution file hash values don't match. + + :ivar package_name: The name of the package that triggered the hash + mismatch. Feel free to write to this after the exception is raise to + improve its error message. + + """ + + order = 4 + head = ( + "THESE PACKAGES DO NOT MATCH THE HASHES FROM THE REQUIREMENTS " + "FILE. If you have updated the package versions, please update " + "the hashes. Otherwise, examine the package contents carefully; " + "someone may have tampered with them." + ) + + def __init__(self, allowed: Dict[str, List[str]], gots: Dict[str, "_Hash"]) -> None: + """ + :param allowed: A dict of algorithm names pointing to lists of allowed + hex digests + :param gots: A dict of algorithm names pointing to hashes we + actually got from the files under suspicion + """ + self.allowed = allowed + self.gots = gots + + def body(self) -> str: + return f" {self._requirement_name()}:\n{self._hash_comparison()}" + + def _hash_comparison(self) -> str: + """ + Return a comparison of actual and expected hash values. + + Example:: + + Expected sha256 abcdeabcdeabcdeabcdeabcdeabcdeabcdeabcdeabcde + or 123451234512345123451234512345123451234512345 + Got bcdefbcdefbcdefbcdefbcdefbcdefbcdefbcdefbcdef + + """ + + def hash_then_or(hash_name: str) -> "chain[str]": + # For now, all the decent hashes have 6-char names, so we can get + # away with hard-coding space literals. + return chain([hash_name], repeat(" or")) + + lines: List[str] = [] + for hash_name, expecteds in self.allowed.items(): + prefix = hash_then_or(hash_name) + lines.extend((f" Expected {next(prefix)} {e}") for e in expecteds) + lines.append( + f" Got {self.gots[hash_name].hexdigest()}\n" + ) + return "\n".join(lines) + + +class UnsupportedPythonVersion(InstallationError): + """Unsupported python version according to Requires-Python package + metadata.""" + + +class ConfigurationFileCouldNotBeLoaded(ConfigurationError): + """When there are errors while loading a configuration file""" + + def __init__( + self, + reason: str = "could not be loaded", + fname: Optional[str] = None, + error: Optional[configparser.Error] = None, + ) -> None: + super().__init__(error) + self.reason = reason + self.fname = fname + self.error = error + + def __str__(self) -> str: + if self.fname is not None: + message_part = f" in {self.fname}." + else: + assert self.error is not None + message_part = f".\n{self.error}\n" + return f"Configuration file {self.reason}{message_part}" + + +_DEFAULT_EXTERNALLY_MANAGED_ERROR = f"""\ +The Python environment under {sys.prefix} is managed externally, and may not be +manipulated by the user. Please use specific tooling from the distributor of +the Python installation to interact with this environment instead. +""" + + +class ExternallyManagedEnvironment(DiagnosticPipError): + """The current environment is externally managed. + + This is raised when the current environment is externally managed, as + defined by `PEP 668`_. The ``EXTERNALLY-MANAGED`` configuration is checked + and displayed when the error is bubbled up to the user. + + :param error: The error message read from ``EXTERNALLY-MANAGED``. + """ + + reference = "externally-managed-environment" + + def __init__(self, error: Optional[str]) -> None: + if error is None: + context = Text(_DEFAULT_EXTERNALLY_MANAGED_ERROR) + else: + context = Text(error) + super().__init__( + message="This environment is externally managed", + context=context, + note_stmt=( + "If you believe this is a mistake, please contact your " + "Python installation or OS distribution provider. " + "You can override this, at the risk of breaking your Python " + "installation or OS, by passing --break-system-packages." + ), + hint_stmt=Text("See PEP 668 for the detailed specification."), + ) + + @staticmethod + def _iter_externally_managed_error_keys() -> Iterator[str]: + # LC_MESSAGES is in POSIX, but not the C standard. The most common + # platform that does not implement this category is Windows, where + # using other categories for console message localization is equally + # unreliable, so we fall back to the locale-less vendor message. This + # can always be re-evaluated when a vendor proposes a new alternative. + try: + category = locale.LC_MESSAGES + except AttributeError: + lang: Optional[str] = None + else: + lang, _ = locale.getlocale(category) + if lang is not None: + yield f"Error-{lang}" + for sep in ("-", "_"): + before, found, _ = lang.partition(sep) + if not found: + continue + yield f"Error-{before}" + yield "Error" + + @classmethod + def from_config( + cls, + config: Union[pathlib.Path, str], + ) -> "ExternallyManagedEnvironment": + parser = configparser.ConfigParser(interpolation=None) + try: + parser.read(config, encoding="utf-8") + section = parser["externally-managed"] + for key in cls._iter_externally_managed_error_keys(): + with contextlib.suppress(KeyError): + return cls(section[key]) + except KeyError: + pass + except (OSError, UnicodeDecodeError, configparser.ParsingError): + from pip._internal.utils._log import VERBOSE + + exc_info = logger.isEnabledFor(VERBOSE) + logger.warning("Failed to read %s", config, exc_info=exc_info) + return cls(None) + + +class UninstallMissingRecord(DiagnosticPipError): + reference = "uninstall-no-record-file" + + def __init__(self, *, distribution: "BaseDistribution") -> None: + installer = distribution.installer + if not installer or installer == "pip": + dep = f"{distribution.raw_name}=={distribution.version}" + hint = Text.assemble( + "You might be able to recover from this via: ", + (f"pip install --force-reinstall --no-deps {dep}", "green"), + ) + else: + hint = Text( + f"The package was installed by {installer}. " + "You should check if it can uninstall the package." + ) + + super().__init__( + message=Text(f"Cannot uninstall {distribution}"), + context=( + "The package's contents are unknown: " + f"no RECORD file was found for {distribution.raw_name}." + ), + hint_stmt=hint, + ) + + +class LegacyDistutilsInstall(DiagnosticPipError): + reference = "uninstall-distutils-installed-package" + + def __init__(self, *, distribution: "BaseDistribution") -> None: + super().__init__( + message=Text(f"Cannot uninstall {distribution}"), + context=( + "It is a distutils installed project and thus we cannot accurately " + "determine which files belong to it which would lead to only a partial " + "uninstall." + ), + hint_stmt=None, + ) + + +class InvalidInstalledPackage(DiagnosticPipError): + reference = "invalid-installed-package" + + def __init__( + self, + *, + dist: "BaseDistribution", + invalid_exc: Union[InvalidRequirement, InvalidVersion], + ) -> None: + installed_location = dist.installed_location + + if isinstance(invalid_exc, InvalidRequirement): + invalid_type = "requirement" + else: + invalid_type = "version" + + super().__init__( + message=Text( + f"Cannot process installed package {dist} " + + (f"in {installed_location!r} " if installed_location else "") + + f"because it has an invalid {invalid_type}:\n{invalid_exc.args[0]}" + ), + context=( + "Starting with pip 24.1, packages with invalid " + f"{invalid_type}s can not be processed." + ), + hint_stmt="To proceed this package must be uninstalled.", + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a17b7b3b6ad49157ee41f3da304fec3d32342d3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/__init__.py @@ -0,0 +1,2 @@ +"""Index interaction code +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/collector.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/collector.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8fdee3d46271652d498cbfc865a25c50f2cab0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/collector.py @@ -0,0 +1,494 @@ +""" +The main purpose of this module is to expose LinkCollector.collect_sources(). +""" + +import collections +import email.message +import functools +import itertools +import json +import logging +import os +import urllib.parse +import urllib.request +from dataclasses import dataclass +from html.parser import HTMLParser +from optparse import Values +from typing import ( + Callable, + Dict, + Iterable, + List, + MutableMapping, + NamedTuple, + Optional, + Protocol, + Sequence, + Tuple, + Union, +) + +from pip._vendor import requests +from pip._vendor.requests import Response +from pip._vendor.requests.exceptions import RetryError, SSLError + +from pip._internal.exceptions import NetworkConnectionError +from pip._internal.models.link import Link +from pip._internal.models.search_scope import SearchScope +from pip._internal.network.session import PipSession +from pip._internal.network.utils import raise_for_status +from pip._internal.utils.filetypes import is_archive_file +from pip._internal.utils.misc import redact_auth_from_url +from pip._internal.vcs import vcs + +from .sources import CandidatesFromPage, LinkSource, build_source + +logger = logging.getLogger(__name__) + +ResponseHeaders = MutableMapping[str, str] + + +def _match_vcs_scheme(url: str) -> Optional[str]: + """Look for VCS schemes in the URL. + + Returns the matched VCS scheme, or None if there's no match. + """ + for scheme in vcs.schemes: + if url.lower().startswith(scheme) and url[len(scheme)] in "+:": + return scheme + return None + + +class _NotAPIContent(Exception): + def __init__(self, content_type: str, request_desc: str) -> None: + super().__init__(content_type, request_desc) + self.content_type = content_type + self.request_desc = request_desc + + +def _ensure_api_header(response: Response) -> None: + """ + Check the Content-Type header to ensure the response contains a Simple + API Response. + + Raises `_NotAPIContent` if the content type is not a valid content-type. + """ + content_type = response.headers.get("Content-Type", "Unknown") + + content_type_l = content_type.lower() + if content_type_l.startswith( + ( + "text/html", + "application/vnd.pypi.simple.v1+html", + "application/vnd.pypi.simple.v1+json", + ) + ): + return + + raise _NotAPIContent(content_type, response.request.method) + + +class _NotHTTP(Exception): + pass + + +def _ensure_api_response(url: str, session: PipSession) -> None: + """ + Send a HEAD request to the URL, and ensure the response contains a simple + API Response. + + Raises `_NotHTTP` if the URL is not available for a HEAD request, or + `_NotAPIContent` if the content type is not a valid content type. + """ + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if scheme not in {"http", "https"}: + raise _NotHTTP() + + resp = session.head(url, allow_redirects=True) + raise_for_status(resp) + + _ensure_api_header(resp) + + +def _get_simple_response(url: str, session: PipSession) -> Response: + """Access an Simple API response with GET, and return the response. + + This consists of three parts: + + 1. If the URL looks suspiciously like an archive, send a HEAD first to + check the Content-Type is HTML or Simple API, to avoid downloading a + large file. Raise `_NotHTTP` if the content type cannot be determined, or + `_NotAPIContent` if it is not HTML or a Simple API. + 2. Actually perform the request. Raise HTTP exceptions on network failures. + 3. Check the Content-Type header to make sure we got a Simple API response, + and raise `_NotAPIContent` otherwise. + """ + if is_archive_file(Link(url).filename): + _ensure_api_response(url, session=session) + + logger.debug("Getting page %s", redact_auth_from_url(url)) + + resp = session.get( + url, + headers={ + "Accept": ", ".join( + [ + "application/vnd.pypi.simple.v1+json", + "application/vnd.pypi.simple.v1+html; q=0.1", + "text/html; q=0.01", + ] + ), + # We don't want to blindly returned cached data for + # /simple/, because authors generally expecting that + # twine upload && pip install will function, but if + # they've done a pip install in the last ~10 minutes + # it won't. Thus by setting this to zero we will not + # blindly use any cached data, however the benefit of + # using max-age=0 instead of no-cache, is that we will + # still support conditional requests, so we will still + # minimize traffic sent in cases where the page hasn't + # changed at all, we will just always incur the round + # trip for the conditional GET now instead of only + # once per 10 minutes. + # For more information, please see pypa/pip#5670. + "Cache-Control": "max-age=0", + }, + ) + raise_for_status(resp) + + # The check for archives above only works if the url ends with + # something that looks like an archive. However that is not a + # requirement of an url. Unless we issue a HEAD request on every + # url we cannot know ahead of time for sure if something is a + # Simple API response or not. However we can check after we've + # downloaded it. + _ensure_api_header(resp) + + logger.debug( + "Fetched page %s as %s", + redact_auth_from_url(url), + resp.headers.get("Content-Type", "Unknown"), + ) + + return resp + + +def _get_encoding_from_headers(headers: ResponseHeaders) -> Optional[str]: + """Determine if we have any encoding information in our headers.""" + if headers and "Content-Type" in headers: + m = email.message.Message() + m["content-type"] = headers["Content-Type"] + charset = m.get_param("charset") + if charset: + return str(charset) + return None + + +class CacheablePageContent: + def __init__(self, page: "IndexContent") -> None: + assert page.cache_link_parsing + self.page = page + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.page.url == other.page.url + + def __hash__(self) -> int: + return hash(self.page.url) + + +class ParseLinks(Protocol): + def __call__(self, page: "IndexContent") -> Iterable[Link]: ... + + +def with_cached_index_content(fn: ParseLinks) -> ParseLinks: + """ + Given a function that parses an Iterable[Link] from an IndexContent, cache the + function's result (keyed by CacheablePageContent), unless the IndexContent + `page` has `page.cache_link_parsing == False`. + """ + + @functools.lru_cache(maxsize=None) + def wrapper(cacheable_page: CacheablePageContent) -> List[Link]: + return list(fn(cacheable_page.page)) + + @functools.wraps(fn) + def wrapper_wrapper(page: "IndexContent") -> List[Link]: + if page.cache_link_parsing: + return wrapper(CacheablePageContent(page)) + return list(fn(page)) + + return wrapper_wrapper + + +@with_cached_index_content +def parse_links(page: "IndexContent") -> Iterable[Link]: + """ + Parse a Simple API's Index Content, and yield its anchor elements as Link objects. + """ + + content_type_l = page.content_type.lower() + if content_type_l.startswith("application/vnd.pypi.simple.v1+json"): + data = json.loads(page.content) + for file in data.get("files", []): + link = Link.from_json(file, page.url) + if link is None: + continue + yield link + return + + parser = HTMLLinkParser(page.url) + encoding = page.encoding or "utf-8" + parser.feed(page.content.decode(encoding)) + + url = page.url + base_url = parser.base_url or url + for anchor in parser.anchors: + link = Link.from_element(anchor, page_url=url, base_url=base_url) + if link is None: + continue + yield link + + +@dataclass(frozen=True) +class IndexContent: + """Represents one response (or page), along with its URL. + + :param encoding: the encoding to decode the given content. + :param url: the URL from which the HTML was downloaded. + :param cache_link_parsing: whether links parsed from this page's url + should be cached. PyPI index urls should + have this set to False, for example. + """ + + content: bytes + content_type: str + encoding: Optional[str] + url: str + cache_link_parsing: bool = True + + def __str__(self) -> str: + return redact_auth_from_url(self.url) + + +class HTMLLinkParser(HTMLParser): + """ + HTMLParser that keeps the first base HREF and a list of all anchor + elements' attributes. + """ + + def __init__(self, url: str) -> None: + super().__init__(convert_charrefs=True) + + self.url: str = url + self.base_url: Optional[str] = None + self.anchors: List[Dict[str, Optional[str]]] = [] + + def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: + if tag == "base" and self.base_url is None: + href = self.get_href(attrs) + if href is not None: + self.base_url = href + elif tag == "a": + self.anchors.append(dict(attrs)) + + def get_href(self, attrs: List[Tuple[str, Optional[str]]]) -> Optional[str]: + for name, value in attrs: + if name == "href": + return value + return None + + +def _handle_get_simple_fail( + link: Link, + reason: Union[str, Exception], + meth: Optional[Callable[..., None]] = None, +) -> None: + if meth is None: + meth = logger.debug + meth("Could not fetch URL %s: %s - skipping", link, reason) + + +def _make_index_content( + response: Response, cache_link_parsing: bool = True +) -> IndexContent: + encoding = _get_encoding_from_headers(response.headers) + return IndexContent( + response.content, + response.headers["Content-Type"], + encoding=encoding, + url=response.url, + cache_link_parsing=cache_link_parsing, + ) + + +def _get_index_content(link: Link, *, session: PipSession) -> Optional["IndexContent"]: + url = link.url.split("#", 1)[0] + + # Check for VCS schemes that do not support lookup as web pages. + vcs_scheme = _match_vcs_scheme(url) + if vcs_scheme: + logger.warning( + "Cannot look at %s URL %s because it does not support lookup as web pages.", + vcs_scheme, + link, + ) + return None + + # Tack index.html onto file:// URLs that point to directories + scheme, _, path, _, _, _ = urllib.parse.urlparse(url) + if scheme == "file" and os.path.isdir(urllib.request.url2pathname(path)): + # add trailing slash if not present so urljoin doesn't trim + # final segment + if not url.endswith("/"): + url += "/" + # TODO: In the future, it would be nice if pip supported PEP 691 + # style responses in the file:// URLs, however there's no + # standard file extension for application/vnd.pypi.simple.v1+json + # so we'll need to come up with something on our own. + url = urllib.parse.urljoin(url, "index.html") + logger.debug(" file: URL is directory, getting %s", url) + + try: + resp = _get_simple_response(url, session=session) + except _NotHTTP: + logger.warning( + "Skipping page %s because it looks like an archive, and cannot " + "be checked by a HTTP HEAD request.", + link, + ) + except _NotAPIContent as exc: + logger.warning( + "Skipping page %s because the %s request got Content-Type: %s. " + "The only supported Content-Types are application/vnd.pypi.simple.v1+json, " + "application/vnd.pypi.simple.v1+html, and text/html", + link, + exc.request_desc, + exc.content_type, + ) + except NetworkConnectionError as exc: + _handle_get_simple_fail(link, exc) + except RetryError as exc: + _handle_get_simple_fail(link, exc) + except SSLError as exc: + reason = "There was a problem confirming the ssl certificate: " + reason += str(exc) + _handle_get_simple_fail(link, reason, meth=logger.info) + except requests.ConnectionError as exc: + _handle_get_simple_fail(link, f"connection error: {exc}") + except requests.Timeout: + _handle_get_simple_fail(link, "timed out") + else: + return _make_index_content(resp, cache_link_parsing=link.cache_link_parsing) + return None + + +class CollectedSources(NamedTuple): + find_links: Sequence[Optional[LinkSource]] + index_urls: Sequence[Optional[LinkSource]] + + +class LinkCollector: + """ + Responsible for collecting Link objects from all configured locations, + making network requests as needed. + + The class's main method is its collect_sources() method. + """ + + def __init__( + self, + session: PipSession, + search_scope: SearchScope, + ) -> None: + self.search_scope = search_scope + self.session = session + + @classmethod + def create( + cls, + session: PipSession, + options: Values, + suppress_no_index: bool = False, + ) -> "LinkCollector": + """ + :param session: The Session to use to make requests. + :param suppress_no_index: Whether to ignore the --no-index option + when constructing the SearchScope object. + """ + index_urls = [options.index_url] + options.extra_index_urls + if options.no_index and not suppress_no_index: + logger.debug( + "Ignoring indexes: %s", + ",".join(redact_auth_from_url(url) for url in index_urls), + ) + index_urls = [] + + # Make sure find_links is a list before passing to create(). + find_links = options.find_links or [] + + search_scope = SearchScope.create( + find_links=find_links, + index_urls=index_urls, + no_index=options.no_index, + ) + link_collector = LinkCollector( + session=session, + search_scope=search_scope, + ) + return link_collector + + @property + def find_links(self) -> List[str]: + return self.search_scope.find_links + + def fetch_response(self, location: Link) -> Optional[IndexContent]: + """ + Fetch an HTML page containing package links. + """ + return _get_index_content(location, session=self.session) + + def collect_sources( + self, + project_name: str, + candidates_from_page: CandidatesFromPage, + ) -> CollectedSources: + # The OrderedDict calls deduplicate sources by URL. + index_url_sources = collections.OrderedDict( + build_source( + loc, + candidates_from_page=candidates_from_page, + page_validator=self.session.is_secure_origin, + expand_dir=False, + cache_link_parsing=False, + project_name=project_name, + ) + for loc in self.search_scope.get_index_urls_locations(project_name) + ).values() + find_links_sources = collections.OrderedDict( + build_source( + loc, + candidates_from_page=candidates_from_page, + page_validator=self.session.is_secure_origin, + expand_dir=True, + cache_link_parsing=True, + project_name=project_name, + ) + for loc in self.find_links + ).values() + + if logger.isEnabledFor(logging.DEBUG): + lines = [ + f"* {s.link}" + for s in itertools.chain(find_links_sources, index_url_sources) + if s is not None and s.link is not None + ] + lines = [ + f"{len(lines)} location(s) to search " + f"for versions of {project_name}:" + ] + lines + logger.debug("\n".join(lines)) + + return CollectedSources( + find_links=list(find_links_sources), + index_urls=list(index_url_sources), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/package_finder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/package_finder.py new file mode 100644 index 0000000000000000000000000000000000000000..85628ee5d7a791b4fc9e3a59cd0d8b7659c8ef89 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/package_finder.py @@ -0,0 +1,1029 @@ +"""Routines related to PyPI, indexes""" + +import enum +import functools +import itertools +import logging +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, FrozenSet, Iterable, List, Optional, Set, Tuple, Union + +from pip._vendor.packaging import specifiers +from pip._vendor.packaging.tags import Tag +from pip._vendor.packaging.utils import canonicalize_name +from pip._vendor.packaging.version import InvalidVersion, _BaseVersion +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.exceptions import ( + BestVersionAlreadyInstalled, + DistributionNotFound, + InvalidWheelFilename, + UnsupportedWheel, +) +from pip._internal.index.collector import LinkCollector, parse_links +from pip._internal.models.candidate import InstallationCandidate +from pip._internal.models.format_control import FormatControl +from pip._internal.models.link import Link +from pip._internal.models.search_scope import SearchScope +from pip._internal.models.selection_prefs import SelectionPreferences +from pip._internal.models.target_python import TargetPython +from pip._internal.models.wheel import Wheel +from pip._internal.req import InstallRequirement +from pip._internal.utils._log import getLogger +from pip._internal.utils.filetypes import WHEEL_EXTENSION +from pip._internal.utils.hashes import Hashes +from pip._internal.utils.logging import indent_log +from pip._internal.utils.misc import build_netloc +from pip._internal.utils.packaging import check_requires_python +from pip._internal.utils.unpacking import SUPPORTED_EXTENSIONS + +if TYPE_CHECKING: + from pip._vendor.typing_extensions import TypeGuard + +__all__ = ["FormatControl", "BestCandidateResult", "PackageFinder"] + + +logger = getLogger(__name__) + +BuildTag = Union[Tuple[()], Tuple[int, str]] +CandidateSortingKey = Tuple[int, int, int, _BaseVersion, Optional[int], BuildTag] + + +def _check_link_requires_python( + link: Link, + version_info: Tuple[int, int, int], + ignore_requires_python: bool = False, +) -> bool: + """ + Return whether the given Python version is compatible with a link's + "Requires-Python" value. + + :param version_info: A 3-tuple of ints representing the Python + major-minor-micro version to check. + :param ignore_requires_python: Whether to ignore the "Requires-Python" + value if the given Python version isn't compatible. + """ + try: + is_compatible = check_requires_python( + link.requires_python, + version_info=version_info, + ) + except specifiers.InvalidSpecifier: + logger.debug( + "Ignoring invalid Requires-Python (%r) for link: %s", + link.requires_python, + link, + ) + else: + if not is_compatible: + version = ".".join(map(str, version_info)) + if not ignore_requires_python: + logger.verbose( + "Link requires a different Python (%s not in: %r): %s", + version, + link.requires_python, + link, + ) + return False + + logger.debug( + "Ignoring failed Requires-Python check (%s not in: %r) for link: %s", + version, + link.requires_python, + link, + ) + + return True + + +class LinkType(enum.Enum): + candidate = enum.auto() + different_project = enum.auto() + yanked = enum.auto() + format_unsupported = enum.auto() + format_invalid = enum.auto() + platform_mismatch = enum.auto() + requires_python_mismatch = enum.auto() + + +class LinkEvaluator: + """ + Responsible for evaluating links for a particular project. + """ + + _py_version_re = re.compile(r"-py([123]\.?[0-9]?)$") + + # Don't include an allow_yanked default value to make sure each call + # site considers whether yanked releases are allowed. This also causes + # that decision to be made explicit in the calling code, which helps + # people when reading the code. + def __init__( + self, + project_name: str, + canonical_name: str, + formats: FrozenSet[str], + target_python: TargetPython, + allow_yanked: bool, + ignore_requires_python: Optional[bool] = None, + ) -> None: + """ + :param project_name: The user supplied package name. + :param canonical_name: The canonical package name. + :param formats: The formats allowed for this package. Should be a set + with 'binary' or 'source' or both in it. + :param target_python: The target Python interpreter to use when + evaluating link compatibility. This is used, for example, to + check wheel compatibility, as well as when checking the Python + version, e.g. the Python version embedded in a link filename + (or egg fragment) and against an HTML link's optional PEP 503 + "data-requires-python" attribute. + :param allow_yanked: Whether files marked as yanked (in the sense + of PEP 592) are permitted to be candidates for install. + :param ignore_requires_python: Whether to ignore incompatible + PEP 503 "data-requires-python" values in HTML links. Defaults + to False. + """ + if ignore_requires_python is None: + ignore_requires_python = False + + self._allow_yanked = allow_yanked + self._canonical_name = canonical_name + self._ignore_requires_python = ignore_requires_python + self._formats = formats + self._target_python = target_python + + self.project_name = project_name + + def evaluate_link(self, link: Link) -> Tuple[LinkType, str]: + """ + Determine whether a link is a candidate for installation. + + :return: A tuple (result, detail), where *result* is an enum + representing whether the evaluation found a candidate, or the reason + why one is not found. If a candidate is found, *detail* will be the + candidate's version string; if one is not found, it contains the + reason the link fails to qualify. + """ + version = None + if link.is_yanked and not self._allow_yanked: + reason = link.yanked_reason or "" + return (LinkType.yanked, f"yanked for reason: {reason}") + + if link.egg_fragment: + egg_info = link.egg_fragment + ext = link.ext + else: + egg_info, ext = link.splitext() + if not ext: + return (LinkType.format_unsupported, "not a file") + if ext not in SUPPORTED_EXTENSIONS: + return ( + LinkType.format_unsupported, + f"unsupported archive format: {ext}", + ) + if "binary" not in self._formats and ext == WHEEL_EXTENSION: + reason = f"No binaries permitted for {self.project_name}" + return (LinkType.format_unsupported, reason) + if "macosx10" in link.path and ext == ".zip": + return (LinkType.format_unsupported, "macosx10 one") + if ext == WHEEL_EXTENSION: + try: + wheel = Wheel(link.filename) + except InvalidWheelFilename: + return ( + LinkType.format_invalid, + "invalid wheel filename", + ) + if canonicalize_name(wheel.name) != self._canonical_name: + reason = f"wrong project name (not {self.project_name})" + return (LinkType.different_project, reason) + + supported_tags = self._target_python.get_unsorted_tags() + if not wheel.supported(supported_tags): + # Include the wheel's tags in the reason string to + # simplify troubleshooting compatibility issues. + file_tags = ", ".join(wheel.get_formatted_file_tags()) + reason = ( + f"none of the wheel's tags ({file_tags}) are compatible " + f"(run pip debug --verbose to show compatible tags)" + ) + return (LinkType.platform_mismatch, reason) + + version = wheel.version + + # This should be up by the self.ok_binary check, but see issue 2700. + if "source" not in self._formats and ext != WHEEL_EXTENSION: + reason = f"No sources permitted for {self.project_name}" + return (LinkType.format_unsupported, reason) + + if not version: + version = _extract_version_from_fragment( + egg_info, + self._canonical_name, + ) + if not version: + reason = f"Missing project version for {self.project_name}" + return (LinkType.format_invalid, reason) + + match = self._py_version_re.search(version) + if match: + version = version[: match.start()] + py_version = match.group(1) + if py_version != self._target_python.py_version: + return ( + LinkType.platform_mismatch, + "Python version is incorrect", + ) + + supports_python = _check_link_requires_python( + link, + version_info=self._target_python.py_version_info, + ignore_requires_python=self._ignore_requires_python, + ) + if not supports_python: + reason = f"{version} Requires-Python {link.requires_python}" + return (LinkType.requires_python_mismatch, reason) + + logger.debug("Found link %s, version: %s", link, version) + + return (LinkType.candidate, version) + + +def filter_unallowed_hashes( + candidates: List[InstallationCandidate], + hashes: Optional[Hashes], + project_name: str, +) -> List[InstallationCandidate]: + """ + Filter out candidates whose hashes aren't allowed, and return a new + list of candidates. + + If at least one candidate has an allowed hash, then all candidates with + either an allowed hash or no hash specified are returned. Otherwise, + the given candidates are returned. + + Including the candidates with no hash specified when there is a match + allows a warning to be logged if there is a more preferred candidate + with no hash specified. Returning all candidates in the case of no + matches lets pip report the hash of the candidate that would otherwise + have been installed (e.g. permitting the user to more easily update + their requirements file with the desired hash). + """ + if not hashes: + logger.debug( + "Given no hashes to check %s links for project %r: " + "discarding no candidates", + len(candidates), + project_name, + ) + # Make sure we're not returning back the given value. + return list(candidates) + + matches_or_no_digest = [] + # Collect the non-matches for logging purposes. + non_matches = [] + match_count = 0 + for candidate in candidates: + link = candidate.link + if not link.has_hash: + pass + elif link.is_hash_allowed(hashes=hashes): + match_count += 1 + else: + non_matches.append(candidate) + continue + + matches_or_no_digest.append(candidate) + + if match_count: + filtered = matches_or_no_digest + else: + # Make sure we're not returning back the given value. + filtered = list(candidates) + + if len(filtered) == len(candidates): + discard_message = "discarding no candidates" + else: + discard_message = "discarding {} non-matches:\n {}".format( + len(non_matches), + "\n ".join(str(candidate.link) for candidate in non_matches), + ) + + logger.debug( + "Checked %s links for project %r against %s hashes " + "(%s matches, %s no digest): %s", + len(candidates), + project_name, + hashes.digest_count, + match_count, + len(matches_or_no_digest) - match_count, + discard_message, + ) + + return filtered + + +@dataclass +class CandidatePreferences: + """ + Encapsulates some of the preferences for filtering and sorting + InstallationCandidate objects. + """ + + prefer_binary: bool = False + allow_all_prereleases: bool = False + + +@dataclass(frozen=True) +class BestCandidateResult: + """A collection of candidates, returned by `PackageFinder.find_best_candidate`. + + This class is only intended to be instantiated by CandidateEvaluator's + `compute_best_candidate()` method. + + :param all_candidates: A sequence of all available candidates found. + :param applicable_candidates: The applicable candidates. + :param best_candidate: The most preferred candidate found, or None + if no applicable candidates were found. + """ + + all_candidates: List[InstallationCandidate] + applicable_candidates: List[InstallationCandidate] + best_candidate: Optional[InstallationCandidate] + + def __post_init__(self) -> None: + assert set(self.applicable_candidates) <= set(self.all_candidates) + + if self.best_candidate is None: + assert not self.applicable_candidates + else: + assert self.best_candidate in self.applicable_candidates + + +class CandidateEvaluator: + """ + Responsible for filtering and sorting candidates for installation based + on what tags are valid. + """ + + @classmethod + def create( + cls, + project_name: str, + target_python: Optional[TargetPython] = None, + prefer_binary: bool = False, + allow_all_prereleases: bool = False, + specifier: Optional[specifiers.BaseSpecifier] = None, + hashes: Optional[Hashes] = None, + ) -> "CandidateEvaluator": + """Create a CandidateEvaluator object. + + :param target_python: The target Python interpreter to use when + checking compatibility. If None (the default), a TargetPython + object will be constructed from the running Python. + :param specifier: An optional object implementing `filter` + (e.g. `packaging.specifiers.SpecifierSet`) to filter applicable + versions. + :param hashes: An optional collection of allowed hashes. + """ + if target_python is None: + target_python = TargetPython() + if specifier is None: + specifier = specifiers.SpecifierSet() + + supported_tags = target_python.get_sorted_tags() + + return cls( + project_name=project_name, + supported_tags=supported_tags, + specifier=specifier, + prefer_binary=prefer_binary, + allow_all_prereleases=allow_all_prereleases, + hashes=hashes, + ) + + def __init__( + self, + project_name: str, + supported_tags: List[Tag], + specifier: specifiers.BaseSpecifier, + prefer_binary: bool = False, + allow_all_prereleases: bool = False, + hashes: Optional[Hashes] = None, + ) -> None: + """ + :param supported_tags: The PEP 425 tags supported by the target + Python in order of preference (most preferred first). + """ + self._allow_all_prereleases = allow_all_prereleases + self._hashes = hashes + self._prefer_binary = prefer_binary + self._project_name = project_name + self._specifier = specifier + self._supported_tags = supported_tags + # Since the index of the tag in the _supported_tags list is used + # as a priority, precompute a map from tag to index/priority to be + # used in wheel.find_most_preferred_tag. + self._wheel_tag_preferences = { + tag: idx for idx, tag in enumerate(supported_tags) + } + + def get_applicable_candidates( + self, + candidates: List[InstallationCandidate], + ) -> List[InstallationCandidate]: + """ + Return the applicable candidates from a list of candidates. + """ + # Using None infers from the specifier instead. + allow_prereleases = self._allow_all_prereleases or None + specifier = self._specifier + + # We turn the version object into a str here because otherwise + # when we're debundled but setuptools isn't, Python will see + # packaging.version.Version and + # pkg_resources._vendor.packaging.version.Version as different + # types. This way we'll use a str as a common data interchange + # format. If we stop using the pkg_resources provided specifier + # and start using our own, we can drop the cast to str(). + candidates_and_versions = [(c, str(c.version)) for c in candidates] + versions = set( + specifier.filter( + (v for _, v in candidates_and_versions), + prereleases=allow_prereleases, + ) + ) + + applicable_candidates = [c for c, v in candidates_and_versions if v in versions] + filtered_applicable_candidates = filter_unallowed_hashes( + candidates=applicable_candidates, + hashes=self._hashes, + project_name=self._project_name, + ) + + return sorted(filtered_applicable_candidates, key=self._sort_key) + + def _sort_key(self, candidate: InstallationCandidate) -> CandidateSortingKey: + """ + Function to pass as the `key` argument to a call to sorted() to sort + InstallationCandidates by preference. + + Returns a tuple such that tuples sorting as greater using Python's + default comparison operator are more preferred. + + The preference is as follows: + + First and foremost, candidates with allowed (matching) hashes are + always preferred over candidates without matching hashes. This is + because e.g. if the only candidate with an allowed hash is yanked, + we still want to use that candidate. + + Second, excepting hash considerations, candidates that have been + yanked (in the sense of PEP 592) are always less preferred than + candidates that haven't been yanked. Then: + + If not finding wheels, they are sorted by version only. + If finding wheels, then the sort order is by version, then: + 1. existing installs + 2. wheels ordered via Wheel.support_index_min(self._supported_tags) + 3. source archives + If prefer_binary was set, then all wheels are sorted above sources. + + Note: it was considered to embed this logic into the Link + comparison operators, but then different sdist links + with the same version, would have to be considered equal + """ + valid_tags = self._supported_tags + support_num = len(valid_tags) + build_tag: BuildTag = () + binary_preference = 0 + link = candidate.link + if link.is_wheel: + # can raise InvalidWheelFilename + wheel = Wheel(link.filename) + try: + pri = -( + wheel.find_most_preferred_tag( + valid_tags, self._wheel_tag_preferences + ) + ) + except ValueError: + raise UnsupportedWheel( + f"{wheel.filename} is not a supported wheel for this platform. It " + "can't be sorted." + ) + if self._prefer_binary: + binary_preference = 1 + if wheel.build_tag is not None: + match = re.match(r"^(\d+)(.*)$", wheel.build_tag) + assert match is not None, "guaranteed by filename validation" + build_tag_groups = match.groups() + build_tag = (int(build_tag_groups[0]), build_tag_groups[1]) + else: # sdist + pri = -(support_num) + has_allowed_hash = int(link.is_hash_allowed(self._hashes)) + yank_value = -1 * int(link.is_yanked) # -1 for yanked. + return ( + has_allowed_hash, + yank_value, + binary_preference, + candidate.version, + pri, + build_tag, + ) + + def sort_best_candidate( + self, + candidates: List[InstallationCandidate], + ) -> Optional[InstallationCandidate]: + """ + Return the best candidate per the instance's sort order, or None if + no candidate is acceptable. + """ + if not candidates: + return None + best_candidate = max(candidates, key=self._sort_key) + return best_candidate + + def compute_best_candidate( + self, + candidates: List[InstallationCandidate], + ) -> BestCandidateResult: + """ + Compute and return a `BestCandidateResult` instance. + """ + applicable_candidates = self.get_applicable_candidates(candidates) + + best_candidate = self.sort_best_candidate(applicable_candidates) + + return BestCandidateResult( + candidates, + applicable_candidates=applicable_candidates, + best_candidate=best_candidate, + ) + + +class PackageFinder: + """This finds packages. + + This is meant to match easy_install's technique for looking for + packages, by reading pages and looking for appropriate links. + """ + + def __init__( + self, + link_collector: LinkCollector, + target_python: TargetPython, + allow_yanked: bool, + format_control: Optional[FormatControl] = None, + candidate_prefs: Optional[CandidatePreferences] = None, + ignore_requires_python: Optional[bool] = None, + ) -> None: + """ + This constructor is primarily meant to be used by the create() class + method and from tests. + + :param format_control: A FormatControl object, used to control + the selection of source packages / binary packages when consulting + the index and links. + :param candidate_prefs: Options to use when creating a + CandidateEvaluator object. + """ + if candidate_prefs is None: + candidate_prefs = CandidatePreferences() + + format_control = format_control or FormatControl(set(), set()) + + self._allow_yanked = allow_yanked + self._candidate_prefs = candidate_prefs + self._ignore_requires_python = ignore_requires_python + self._link_collector = link_collector + self._target_python = target_python + + self.format_control = format_control + + # These are boring links that have already been logged somehow. + self._logged_links: Set[Tuple[Link, LinkType, str]] = set() + + # Don't include an allow_yanked default value to make sure each call + # site considers whether yanked releases are allowed. This also causes + # that decision to be made explicit in the calling code, which helps + # people when reading the code. + @classmethod + def create( + cls, + link_collector: LinkCollector, + selection_prefs: SelectionPreferences, + target_python: Optional[TargetPython] = None, + ) -> "PackageFinder": + """Create a PackageFinder. + + :param selection_prefs: The candidate selection preferences, as a + SelectionPreferences object. + :param target_python: The target Python interpreter to use when + checking compatibility. If None (the default), a TargetPython + object will be constructed from the running Python. + """ + if target_python is None: + target_python = TargetPython() + + candidate_prefs = CandidatePreferences( + prefer_binary=selection_prefs.prefer_binary, + allow_all_prereleases=selection_prefs.allow_all_prereleases, + ) + + return cls( + candidate_prefs=candidate_prefs, + link_collector=link_collector, + target_python=target_python, + allow_yanked=selection_prefs.allow_yanked, + format_control=selection_prefs.format_control, + ignore_requires_python=selection_prefs.ignore_requires_python, + ) + + @property + def target_python(self) -> TargetPython: + return self._target_python + + @property + def search_scope(self) -> SearchScope: + return self._link_collector.search_scope + + @search_scope.setter + def search_scope(self, search_scope: SearchScope) -> None: + self._link_collector.search_scope = search_scope + + @property + def find_links(self) -> List[str]: + return self._link_collector.find_links + + @property + def index_urls(self) -> List[str]: + return self.search_scope.index_urls + + @property + def proxy(self) -> Optional[str]: + return self._link_collector.session.pip_proxy + + @property + def trusted_hosts(self) -> Iterable[str]: + for host_port in self._link_collector.session.pip_trusted_origins: + yield build_netloc(*host_port) + + @property + def custom_cert(self) -> Optional[str]: + # session.verify is either a boolean (use default bundle/no SSL + # verification) or a string path to a custom CA bundle to use. We only + # care about the latter. + verify = self._link_collector.session.verify + return verify if isinstance(verify, str) else None + + @property + def client_cert(self) -> Optional[str]: + cert = self._link_collector.session.cert + assert not isinstance(cert, tuple), "pip only supports PEM client certs" + return cert + + @property + def allow_all_prereleases(self) -> bool: + return self._candidate_prefs.allow_all_prereleases + + def set_allow_all_prereleases(self) -> None: + self._candidate_prefs.allow_all_prereleases = True + + @property + def prefer_binary(self) -> bool: + return self._candidate_prefs.prefer_binary + + def set_prefer_binary(self) -> None: + self._candidate_prefs.prefer_binary = True + + def requires_python_skipped_reasons(self) -> List[str]: + reasons = { + detail + for _, result, detail in self._logged_links + if result == LinkType.requires_python_mismatch + } + return sorted(reasons) + + def make_link_evaluator(self, project_name: str) -> LinkEvaluator: + canonical_name = canonicalize_name(project_name) + formats = self.format_control.get_allowed_formats(canonical_name) + + return LinkEvaluator( + project_name=project_name, + canonical_name=canonical_name, + formats=formats, + target_python=self._target_python, + allow_yanked=self._allow_yanked, + ignore_requires_python=self._ignore_requires_python, + ) + + def _sort_links(self, links: Iterable[Link]) -> List[Link]: + """ + Returns elements of links in order, non-egg links first, egg links + second, while eliminating duplicates + """ + eggs, no_eggs = [], [] + seen: Set[Link] = set() + for link in links: + if link not in seen: + seen.add(link) + if link.egg_fragment: + eggs.append(link) + else: + no_eggs.append(link) + return no_eggs + eggs + + def _log_skipped_link(self, link: Link, result: LinkType, detail: str) -> None: + # This is a hot method so don't waste time hashing links unless we're + # actually going to log 'em. + if not logger.isEnabledFor(logging.DEBUG): + return + + entry = (link, result, detail) + if entry not in self._logged_links: + # Put the link at the end so the reason is more visible and because + # the link string is usually very long. + logger.debug("Skipping link: %s: %s", detail, link) + self._logged_links.add(entry) + + def get_install_candidate( + self, link_evaluator: LinkEvaluator, link: Link + ) -> Optional[InstallationCandidate]: + """ + If the link is a candidate for install, convert it to an + InstallationCandidate and return it. Otherwise, return None. + """ + result, detail = link_evaluator.evaluate_link(link) + if result != LinkType.candidate: + self._log_skipped_link(link, result, detail) + return None + + try: + return InstallationCandidate( + name=link_evaluator.project_name, + link=link, + version=detail, + ) + except InvalidVersion: + return None + + def evaluate_links( + self, link_evaluator: LinkEvaluator, links: Iterable[Link] + ) -> List[InstallationCandidate]: + """ + Convert links that are candidates to InstallationCandidate objects. + """ + candidates = [] + for link in self._sort_links(links): + candidate = self.get_install_candidate(link_evaluator, link) + if candidate is not None: + candidates.append(candidate) + + return candidates + + def process_project_url( + self, project_url: Link, link_evaluator: LinkEvaluator + ) -> List[InstallationCandidate]: + logger.debug( + "Fetching project page and analyzing links: %s", + project_url, + ) + index_response = self._link_collector.fetch_response(project_url) + if index_response is None: + return [] + + page_links = list(parse_links(index_response)) + + with indent_log(): + package_links = self.evaluate_links( + link_evaluator, + links=page_links, + ) + + return package_links + + @functools.lru_cache(maxsize=None) + def find_all_candidates(self, project_name: str) -> List[InstallationCandidate]: + """Find all available InstallationCandidate for project_name + + This checks index_urls and find_links. + All versions found are returned as an InstallationCandidate list. + + See LinkEvaluator.evaluate_link() for details on which files + are accepted. + """ + link_evaluator = self.make_link_evaluator(project_name) + + collected_sources = self._link_collector.collect_sources( + project_name=project_name, + candidates_from_page=functools.partial( + self.process_project_url, + link_evaluator=link_evaluator, + ), + ) + + page_candidates_it = itertools.chain.from_iterable( + source.page_candidates() + for sources in collected_sources + for source in sources + if source is not None + ) + page_candidates = list(page_candidates_it) + + file_links_it = itertools.chain.from_iterable( + source.file_links() + for sources in collected_sources + for source in sources + if source is not None + ) + file_candidates = self.evaluate_links( + link_evaluator, + sorted(file_links_it, reverse=True), + ) + + if logger.isEnabledFor(logging.DEBUG) and file_candidates: + paths = [] + for candidate in file_candidates: + assert candidate.link.url # we need to have a URL + try: + paths.append(candidate.link.file_path) + except Exception: + paths.append(candidate.link.url) # it's not a local file + + logger.debug("Local files found: %s", ", ".join(paths)) + + # This is an intentional priority ordering + return file_candidates + page_candidates + + def make_candidate_evaluator( + self, + project_name: str, + specifier: Optional[specifiers.BaseSpecifier] = None, + hashes: Optional[Hashes] = None, + ) -> CandidateEvaluator: + """Create a CandidateEvaluator object to use.""" + candidate_prefs = self._candidate_prefs + return CandidateEvaluator.create( + project_name=project_name, + target_python=self._target_python, + prefer_binary=candidate_prefs.prefer_binary, + allow_all_prereleases=candidate_prefs.allow_all_prereleases, + specifier=specifier, + hashes=hashes, + ) + + @functools.lru_cache(maxsize=None) + def find_best_candidate( + self, + project_name: str, + specifier: Optional[specifiers.BaseSpecifier] = None, + hashes: Optional[Hashes] = None, + ) -> BestCandidateResult: + """Find matches for the given project and specifier. + + :param specifier: An optional object implementing `filter` + (e.g. `packaging.specifiers.SpecifierSet`) to filter applicable + versions. + + :return: A `BestCandidateResult` instance. + """ + candidates = self.find_all_candidates(project_name) + candidate_evaluator = self.make_candidate_evaluator( + project_name=project_name, + specifier=specifier, + hashes=hashes, + ) + return candidate_evaluator.compute_best_candidate(candidates) + + def find_requirement( + self, req: InstallRequirement, upgrade: bool + ) -> Optional[InstallationCandidate]: + """Try to find a Link matching req + + Expects req, an InstallRequirement and upgrade, a boolean + Returns a InstallationCandidate if found, + Raises DistributionNotFound or BestVersionAlreadyInstalled otherwise + """ + hashes = req.hashes(trust_internet=False) + best_candidate_result = self.find_best_candidate( + req.name, + specifier=req.specifier, + hashes=hashes, + ) + best_candidate = best_candidate_result.best_candidate + + installed_version: Optional[_BaseVersion] = None + if req.satisfied_by is not None: + installed_version = req.satisfied_by.version + + def _format_versions(cand_iter: Iterable[InstallationCandidate]) -> str: + # This repeated parse_version and str() conversion is needed to + # handle different vendoring sources from pip and pkg_resources. + # If we stop using the pkg_resources provided specifier and start + # using our own, we can drop the cast to str(). + return ( + ", ".join( + sorted( + {str(c.version) for c in cand_iter}, + key=parse_version, + ) + ) + or "none" + ) + + if installed_version is None and best_candidate is None: + logger.critical( + "Could not find a version that satisfies the requirement %s " + "(from versions: %s)", + req, + _format_versions(best_candidate_result.all_candidates), + ) + + raise DistributionNotFound(f"No matching distribution found for {req}") + + def _should_install_candidate( + candidate: Optional[InstallationCandidate], + ) -> "TypeGuard[InstallationCandidate]": + if installed_version is None: + return True + if best_candidate is None: + return False + return best_candidate.version > installed_version + + if not upgrade and installed_version is not None: + if _should_install_candidate(best_candidate): + logger.debug( + "Existing installed version (%s) satisfies requirement " + "(most up-to-date version is %s)", + installed_version, + best_candidate.version, + ) + else: + logger.debug( + "Existing installed version (%s) is most up-to-date and " + "satisfies requirement", + installed_version, + ) + return None + + if _should_install_candidate(best_candidate): + logger.debug( + "Using version %s (newest of versions: %s)", + best_candidate.version, + _format_versions(best_candidate_result.applicable_candidates), + ) + return best_candidate + + # We have an existing version, and its the best version + logger.debug( + "Installed version (%s) is most up-to-date (past versions: %s)", + installed_version, + _format_versions(best_candidate_result.applicable_candidates), + ) + raise BestVersionAlreadyInstalled + + +def _find_name_version_sep(fragment: str, canonical_name: str) -> int: + """Find the separator's index based on the package's canonical name. + + :param fragment: A + filename "fragment" (stem) or + egg fragment. + :param canonical_name: The package's canonical name. + + This function is needed since the canonicalized name does not necessarily + have the same length as the egg info's name part. An example:: + + >>> fragment = 'foo__bar-1.0' + >>> canonical_name = 'foo-bar' + >>> _find_name_version_sep(fragment, canonical_name) + 8 + """ + # Project name and version must be separated by one single dash. Find all + # occurrences of dashes; if the string in front of it matches the canonical + # name, this is the one separating the name and version parts. + for i, c in enumerate(fragment): + if c != "-": + continue + if canonicalize_name(fragment[:i]) == canonical_name: + return i + raise ValueError(f"{fragment} does not match {canonical_name}") + + +def _extract_version_from_fragment(fragment: str, canonical_name: str) -> Optional[str]: + """Parse the version string from a + filename + "fragment" (stem) or egg fragment. + + :param fragment: The string to parse. E.g. foo-2.1 + :param canonical_name: The canonicalized name of the package this + belongs to. + """ + try: + version_start = _find_name_version_sep(fragment, canonical_name) + 1 + except ValueError: + return None + version = fragment[version_start:] + if not version: + return None + return version diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/sources.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/sources.py new file mode 100644 index 0000000000000000000000000000000000000000..3dafb30e6eb843ac56315dd5d0ab223bf4f740b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/index/sources.py @@ -0,0 +1,284 @@ +import logging +import mimetypes +import os +from collections import defaultdict +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +from pip._vendor.packaging.utils import ( + InvalidSdistFilename, + InvalidWheelFilename, + canonicalize_name, + parse_sdist_filename, + parse_wheel_filename, +) + +from pip._internal.models.candidate import InstallationCandidate +from pip._internal.models.link import Link +from pip._internal.utils.urls import path_to_url, url_to_path +from pip._internal.vcs import is_url + +logger = logging.getLogger(__name__) + +FoundCandidates = Iterable[InstallationCandidate] +FoundLinks = Iterable[Link] +CandidatesFromPage = Callable[[Link], Iterable[InstallationCandidate]] +PageValidator = Callable[[Link], bool] + + +class LinkSource: + @property + def link(self) -> Optional[Link]: + """Returns the underlying link, if there's one.""" + raise NotImplementedError() + + def page_candidates(self) -> FoundCandidates: + """Candidates found by parsing an archive listing HTML file.""" + raise NotImplementedError() + + def file_links(self) -> FoundLinks: + """Links found by specifying archives directly.""" + raise NotImplementedError() + + +def _is_html_file(file_url: str) -> bool: + return mimetypes.guess_type(file_url, strict=False)[0] == "text/html" + + +class _FlatDirectoryToUrls: + """Scans directory and caches results""" + + def __init__(self, path: str) -> None: + self._path = path + self._page_candidates: List[str] = [] + self._project_name_to_urls: Dict[str, List[str]] = defaultdict(list) + self._scanned_directory = False + + def _scan_directory(self) -> None: + """Scans directory once and populates both page_candidates + and project_name_to_urls at the same time + """ + for entry in os.scandir(self._path): + url = path_to_url(entry.path) + if _is_html_file(url): + self._page_candidates.append(url) + continue + + # File must have a valid wheel or sdist name, + # otherwise not worth considering as a package + try: + project_filename = parse_wheel_filename(entry.name)[0] + except InvalidWheelFilename: + try: + project_filename = parse_sdist_filename(entry.name)[0] + except InvalidSdistFilename: + continue + + self._project_name_to_urls[project_filename].append(url) + self._scanned_directory = True + + @property + def page_candidates(self) -> List[str]: + if not self._scanned_directory: + self._scan_directory() + + return self._page_candidates + + @property + def project_name_to_urls(self) -> Dict[str, List[str]]: + if not self._scanned_directory: + self._scan_directory() + + return self._project_name_to_urls + + +class _FlatDirectorySource(LinkSource): + """Link source specified by ``--find-links=``. + + This looks the content of the directory, and returns: + + * ``page_candidates``: Links listed on each HTML file in the directory. + * ``file_candidates``: Archives in the directory. + """ + + _paths_to_urls: Dict[str, _FlatDirectoryToUrls] = {} + + def __init__( + self, + candidates_from_page: CandidatesFromPage, + path: str, + project_name: str, + ) -> None: + self._candidates_from_page = candidates_from_page + self._project_name = canonicalize_name(project_name) + + # Get existing instance of _FlatDirectoryToUrls if it exists + if path in self._paths_to_urls: + self._path_to_urls = self._paths_to_urls[path] + else: + self._path_to_urls = _FlatDirectoryToUrls(path=path) + self._paths_to_urls[path] = self._path_to_urls + + @property + def link(self) -> Optional[Link]: + return None + + def page_candidates(self) -> FoundCandidates: + for url in self._path_to_urls.page_candidates: + yield from self._candidates_from_page(Link(url)) + + def file_links(self) -> FoundLinks: + for url in self._path_to_urls.project_name_to_urls[self._project_name]: + yield Link(url) + + +class _LocalFileSource(LinkSource): + """``--find-links=`` or ``--[extra-]index-url=``. + + If a URL is supplied, it must be a ``file:`` URL. If a path is supplied to + the option, it is converted to a URL first. This returns: + + * ``page_candidates``: Links listed on an HTML file. + * ``file_candidates``: The non-HTML file. + """ + + def __init__( + self, + candidates_from_page: CandidatesFromPage, + link: Link, + ) -> None: + self._candidates_from_page = candidates_from_page + self._link = link + + @property + def link(self) -> Optional[Link]: + return self._link + + def page_candidates(self) -> FoundCandidates: + if not _is_html_file(self._link.url): + return + yield from self._candidates_from_page(self._link) + + def file_links(self) -> FoundLinks: + if _is_html_file(self._link.url): + return + yield self._link + + +class _RemoteFileSource(LinkSource): + """``--find-links=`` or ``--[extra-]index-url=``. + + This returns: + + * ``page_candidates``: Links listed on an HTML file. + * ``file_candidates``: The non-HTML file. + """ + + def __init__( + self, + candidates_from_page: CandidatesFromPage, + page_validator: PageValidator, + link: Link, + ) -> None: + self._candidates_from_page = candidates_from_page + self._page_validator = page_validator + self._link = link + + @property + def link(self) -> Optional[Link]: + return self._link + + def page_candidates(self) -> FoundCandidates: + if not self._page_validator(self._link): + return + yield from self._candidates_from_page(self._link) + + def file_links(self) -> FoundLinks: + yield self._link + + +class _IndexDirectorySource(LinkSource): + """``--[extra-]index-url=``. + + This is treated like a remote URL; ``candidates_from_page`` contains logic + for this by appending ``index.html`` to the link. + """ + + def __init__( + self, + candidates_from_page: CandidatesFromPage, + link: Link, + ) -> None: + self._candidates_from_page = candidates_from_page + self._link = link + + @property + def link(self) -> Optional[Link]: + return self._link + + def page_candidates(self) -> FoundCandidates: + yield from self._candidates_from_page(self._link) + + def file_links(self) -> FoundLinks: + return () + + +def build_source( + location: str, + *, + candidates_from_page: CandidatesFromPage, + page_validator: PageValidator, + expand_dir: bool, + cache_link_parsing: bool, + project_name: str, +) -> Tuple[Optional[str], Optional[LinkSource]]: + path: Optional[str] = None + url: Optional[str] = None + if os.path.exists(location): # Is a local path. + url = path_to_url(location) + path = location + elif location.startswith("file:"): # A file: URL. + url = location + path = url_to_path(location) + elif is_url(location): + url = location + + if url is None: + msg = ( + "Location '%s' is ignored: " + "it is either a non-existing path or lacks a specific scheme." + ) + logger.warning(msg, location) + return (None, None) + + if path is None: + source: LinkSource = _RemoteFileSource( + candidates_from_page=candidates_from_page, + page_validator=page_validator, + link=Link(url, cache_link_parsing=cache_link_parsing), + ) + return (url, source) + + if os.path.isdir(path): + if expand_dir: + source = _FlatDirectorySource( + candidates_from_page=candidates_from_page, + path=path, + project_name=project_name, + ) + else: + source = _IndexDirectorySource( + candidates_from_page=candidates_from_page, + link=Link(url, cache_link_parsing=cache_link_parsing), + ) + return (url, source) + elif os.path.isfile(path): + source = _LocalFileSource( + candidates_from_page=candidates_from_page, + link=Link(url, cache_link_parsing=cache_link_parsing), + ) + return (url, source) + logger.warning( + "Location '%s' is ignored: it is neither a file nor a directory.", + location, + ) + return (url, None) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32382be7fe5f6781047e9774679aa8bcf5bbce8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/__init__.py @@ -0,0 +1,456 @@ +import functools +import logging +import os +import pathlib +import sys +import sysconfig +from typing import Any, Dict, Generator, Optional, Tuple + +from pip._internal.models.scheme import SCHEME_KEYS, Scheme +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.deprecation import deprecated +from pip._internal.utils.virtualenv import running_under_virtualenv + +from . import _sysconfig +from .base import ( + USER_CACHE_DIR, + get_major_minor_version, + get_src_prefix, + is_osx_framework, + site_packages, + user_site, +) + +__all__ = [ + "USER_CACHE_DIR", + "get_bin_prefix", + "get_bin_user", + "get_major_minor_version", + "get_platlib", + "get_purelib", + "get_scheme", + "get_src_prefix", + "site_packages", + "user_site", +] + + +logger = logging.getLogger(__name__) + + +_PLATLIBDIR: str = getattr(sys, "platlibdir", "lib") + +_USE_SYSCONFIG_DEFAULT = sys.version_info >= (3, 10) + + +def _should_use_sysconfig() -> bool: + """This function determines the value of _USE_SYSCONFIG. + + By default, pip uses sysconfig on Python 3.10+. + But Python distributors can override this decision by setting: + sysconfig._PIP_USE_SYSCONFIG = True / False + Rationale in https://github.com/pypa/pip/issues/10647 + + This is a function for testability, but should be constant during any one + run. + """ + return bool(getattr(sysconfig, "_PIP_USE_SYSCONFIG", _USE_SYSCONFIG_DEFAULT)) + + +_USE_SYSCONFIG = _should_use_sysconfig() + +if not _USE_SYSCONFIG: + # Import distutils lazily to avoid deprecation warnings, + # but import it soon enough that it is in memory and available during + # a pip reinstall. + from . import _distutils + +# Be noisy about incompatibilities if this platforms "should" be using +# sysconfig, but is explicitly opting out and using distutils instead. +if _USE_SYSCONFIG_DEFAULT and not _USE_SYSCONFIG: + _MISMATCH_LEVEL = logging.WARNING +else: + _MISMATCH_LEVEL = logging.DEBUG + + +def _looks_like_bpo_44860() -> bool: + """The resolution to bpo-44860 will change this incorrect platlib. + + See . + """ + from distutils.command.install import INSTALL_SCHEMES + + try: + unix_user_platlib = INSTALL_SCHEMES["unix_user"]["platlib"] + except KeyError: + return False + return unix_user_platlib == "$usersite" + + +def _looks_like_red_hat_patched_platlib_purelib(scheme: Dict[str, str]) -> bool: + platlib = scheme["platlib"] + if "/$platlibdir/" in platlib: + platlib = platlib.replace("/$platlibdir/", f"/{_PLATLIBDIR}/") + if "/lib64/" not in platlib: + return False + unpatched = platlib.replace("/lib64/", "/lib/") + return unpatched.replace("$platbase/", "$base/") == scheme["purelib"] + + +@functools.lru_cache(maxsize=None) +def _looks_like_red_hat_lib() -> bool: + """Red Hat patches platlib in unix_prefix and unix_home, but not purelib. + + This is the only way I can see to tell a Red Hat-patched Python. + """ + from distutils.command.install import INSTALL_SCHEMES + + return all( + k in INSTALL_SCHEMES + and _looks_like_red_hat_patched_platlib_purelib(INSTALL_SCHEMES[k]) + for k in ("unix_prefix", "unix_home") + ) + + +@functools.lru_cache(maxsize=None) +def _looks_like_debian_scheme() -> bool: + """Debian adds two additional schemes.""" + from distutils.command.install import INSTALL_SCHEMES + + return "deb_system" in INSTALL_SCHEMES and "unix_local" in INSTALL_SCHEMES + + +@functools.lru_cache(maxsize=None) +def _looks_like_red_hat_scheme() -> bool: + """Red Hat patches ``sys.prefix`` and ``sys.exec_prefix``. + + Red Hat's ``00251-change-user-install-location.patch`` changes the install + command's ``prefix`` and ``exec_prefix`` to append ``"/local"``. This is + (fortunately?) done quite unconditionally, so we create a default command + object without any configuration to detect this. + """ + from distutils.command.install import install + from distutils.dist import Distribution + + cmd: Any = install(Distribution()) + cmd.finalize_options() + return ( + cmd.exec_prefix == f"{os.path.normpath(sys.exec_prefix)}/local" + and cmd.prefix == f"{os.path.normpath(sys.prefix)}/local" + ) + + +@functools.lru_cache(maxsize=None) +def _looks_like_slackware_scheme() -> bool: + """Slackware patches sysconfig but fails to patch distutils and site. + + Slackware changes sysconfig's user scheme to use ``"lib64"`` for the lib + path, but does not do the same to the site module. + """ + if user_site is None: # User-site not available. + return False + try: + paths = sysconfig.get_paths(scheme="posix_user", expand=False) + except KeyError: # User-site not available. + return False + return "/lib64/" in paths["purelib"] and "/lib64/" not in user_site + + +@functools.lru_cache(maxsize=None) +def _looks_like_msys2_mingw_scheme() -> bool: + """MSYS2 patches distutils and sysconfig to use a UNIX-like scheme. + + However, MSYS2 incorrectly patches sysconfig ``nt`` scheme. The fix is + likely going to be included in their 3.10 release, so we ignore the warning. + See msys2/MINGW-packages#9319. + + MSYS2 MINGW's patch uses lowercase ``"lib"`` instead of the usual uppercase, + and is missing the final ``"site-packages"``. + """ + paths = sysconfig.get_paths("nt", expand=False) + return all( + "Lib" not in p and "lib" in p and not p.endswith("site-packages") + for p in (paths[key] for key in ("platlib", "purelib")) + ) + + +def _fix_abiflags(parts: Tuple[str]) -> Generator[str, None, None]: + ldversion = sysconfig.get_config_var("LDVERSION") + abiflags = getattr(sys, "abiflags", None) + + # LDVERSION does not end with sys.abiflags. Just return the path unchanged. + if not ldversion or not abiflags or not ldversion.endswith(abiflags): + yield from parts + return + + # Strip sys.abiflags from LDVERSION-based path components. + for part in parts: + if part.endswith(ldversion): + part = part[: (0 - len(abiflags))] + yield part + + +@functools.lru_cache(maxsize=None) +def _warn_mismatched(old: pathlib.Path, new: pathlib.Path, *, key: str) -> None: + issue_url = "https://github.com/pypa/pip/issues/10151" + message = ( + "Value for %s does not match. Please report this to <%s>" + "\ndistutils: %s" + "\nsysconfig: %s" + ) + logger.log(_MISMATCH_LEVEL, message, key, issue_url, old, new) + + +def _warn_if_mismatch(old: pathlib.Path, new: pathlib.Path, *, key: str) -> bool: + if old == new: + return False + _warn_mismatched(old, new, key=key) + return True + + +@functools.lru_cache(maxsize=None) +def _log_context( + *, + user: bool = False, + home: Optional[str] = None, + root: Optional[str] = None, + prefix: Optional[str] = None, +) -> None: + parts = [ + "Additional context:", + "user = %r", + "home = %r", + "root = %r", + "prefix = %r", + ] + + logger.log(_MISMATCH_LEVEL, "\n".join(parts), user, home, root, prefix) + + +def get_scheme( + dist_name: str, + user: bool = False, + home: Optional[str] = None, + root: Optional[str] = None, + isolated: bool = False, + prefix: Optional[str] = None, +) -> Scheme: + new = _sysconfig.get_scheme( + dist_name, + user=user, + home=home, + root=root, + isolated=isolated, + prefix=prefix, + ) + if _USE_SYSCONFIG: + return new + + old = _distutils.get_scheme( + dist_name, + user=user, + home=home, + root=root, + isolated=isolated, + prefix=prefix, + ) + + warning_contexts = [] + for k in SCHEME_KEYS: + old_v = pathlib.Path(getattr(old, k)) + new_v = pathlib.Path(getattr(new, k)) + + if old_v == new_v: + continue + + # distutils incorrectly put PyPy packages under ``site-packages/python`` + # in the ``posix_home`` scheme, but PyPy devs said they expect the + # directory name to be ``pypy`` instead. So we treat this as a bug fix + # and not warn about it. See bpo-43307 and python/cpython#24628. + skip_pypy_special_case = ( + sys.implementation.name == "pypy" + and home is not None + and k in ("platlib", "purelib") + and old_v.parent == new_v.parent + and old_v.name.startswith("python") + and new_v.name.startswith("pypy") + ) + if skip_pypy_special_case: + continue + + # sysconfig's ``osx_framework_user`` does not include ``pythonX.Y`` in + # the ``include`` value, but distutils's ``headers`` does. We'll let + # CPython decide whether this is a bug or feature. See bpo-43948. + skip_osx_framework_user_special_case = ( + user + and is_osx_framework() + and k == "headers" + and old_v.parent.parent == new_v.parent + and old_v.parent.name.startswith("python") + ) + if skip_osx_framework_user_special_case: + continue + + # On Red Hat and derived Linux distributions, distutils is patched to + # use "lib64" instead of "lib" for platlib. + if k == "platlib" and _looks_like_red_hat_lib(): + continue + + # On Python 3.9+, sysconfig's posix_user scheme sets platlib against + # sys.platlibdir, but distutils's unix_user incorrectly coninutes + # using the same $usersite for both platlib and purelib. This creates a + # mismatch when sys.platlibdir is not "lib". + skip_bpo_44860 = ( + user + and k == "platlib" + and not WINDOWS + and sys.version_info >= (3, 9) + and _PLATLIBDIR != "lib" + and _looks_like_bpo_44860() + ) + if skip_bpo_44860: + continue + + # Slackware incorrectly patches posix_user to use lib64 instead of lib, + # but not usersite to match the location. + skip_slackware_user_scheme = ( + user + and k in ("platlib", "purelib") + and not WINDOWS + and _looks_like_slackware_scheme() + ) + if skip_slackware_user_scheme: + continue + + # Both Debian and Red Hat patch Python to place the system site under + # /usr/local instead of /usr. Debian also places lib in dist-packages + # instead of site-packages, but the /usr/local check should cover it. + skip_linux_system_special_case = ( + not (user or home or prefix or running_under_virtualenv()) + and old_v.parts[1:3] == ("usr", "local") + and len(new_v.parts) > 1 + and new_v.parts[1] == "usr" + and (len(new_v.parts) < 3 or new_v.parts[2] != "local") + and (_looks_like_red_hat_scheme() or _looks_like_debian_scheme()) + ) + if skip_linux_system_special_case: + continue + + # MSYS2 MINGW's sysconfig patch does not include the "site-packages" + # part of the path. This is incorrect and will be fixed in MSYS. + skip_msys2_mingw_bug = ( + WINDOWS and k in ("platlib", "purelib") and _looks_like_msys2_mingw_scheme() + ) + if skip_msys2_mingw_bug: + continue + + # CPython's POSIX install script invokes pip (via ensurepip) against the + # interpreter located in the source tree, not the install site. This + # triggers special logic in sysconfig that's not present in distutils. + # https://github.com/python/cpython/blob/8c21941ddaf/Lib/sysconfig.py#L178-L194 + skip_cpython_build = ( + sysconfig.is_python_build(check_home=True) + and not WINDOWS + and k in ("headers", "include", "platinclude") + ) + if skip_cpython_build: + continue + + warning_contexts.append((old_v, new_v, f"scheme.{k}")) + + if not warning_contexts: + return old + + # Check if this path mismatch is caused by distutils config files. Those + # files will no longer work once we switch to sysconfig, so this raises a + # deprecation message for them. + default_old = _distutils.distutils_scheme( + dist_name, + user, + home, + root, + isolated, + prefix, + ignore_config_files=True, + ) + if any(default_old[k] != getattr(old, k) for k in SCHEME_KEYS): + deprecated( + reason=( + "Configuring installation scheme with distutils config files " + "is deprecated and will no longer work in the near future. If you " + "are using a Homebrew or Linuxbrew Python, please see discussion " + "at https://github.com/Homebrew/homebrew-core/issues/76621" + ), + replacement=None, + gone_in=None, + ) + return old + + # Post warnings about this mismatch so user can report them back. + for old_v, new_v, key in warning_contexts: + _warn_mismatched(old_v, new_v, key=key) + _log_context(user=user, home=home, root=root, prefix=prefix) + + return old + + +def get_bin_prefix() -> str: + new = _sysconfig.get_bin_prefix() + if _USE_SYSCONFIG: + return new + + old = _distutils.get_bin_prefix() + if _warn_if_mismatch(pathlib.Path(old), pathlib.Path(new), key="bin_prefix"): + _log_context() + return old + + +def get_bin_user() -> str: + return _sysconfig.get_scheme("", user=True).scripts + + +def _looks_like_deb_system_dist_packages(value: str) -> bool: + """Check if the value is Debian's APT-controlled dist-packages. + + Debian's ``distutils.sysconfig.get_python_lib()`` implementation returns the + default package path controlled by APT, but does not patch ``sysconfig`` to + do the same. This is similar to the bug worked around in ``get_scheme()``, + but here the default is ``deb_system`` instead of ``unix_local``. Ultimately + we can't do anything about this Debian bug, and this detection allows us to + skip the warning when needed. + """ + if not _looks_like_debian_scheme(): + return False + if value == "/usr/lib/python3/dist-packages": + return True + return False + + +def get_purelib() -> str: + """Return the default pure-Python lib location.""" + new = _sysconfig.get_purelib() + if _USE_SYSCONFIG: + return new + + old = _distutils.get_purelib() + if _looks_like_deb_system_dist_packages(old): + return old + if _warn_if_mismatch(pathlib.Path(old), pathlib.Path(new), key="purelib"): + _log_context() + return old + + +def get_platlib() -> str: + """Return the default platform-shared lib location.""" + new = _sysconfig.get_platlib() + if _USE_SYSCONFIG: + return new + + from . import _distutils + + old = _distutils.get_platlib() + if _looks_like_deb_system_dist_packages(old): + return old + if _warn_if_mismatch(pathlib.Path(old), pathlib.Path(new), key="platlib"): + _log_context() + return old diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_distutils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_distutils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d856256986f68b1bc38d012cfc96f8075268493 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_distutils.py @@ -0,0 +1,172 @@ +"""Locations where we look for configs, install stuff, etc""" + +# The following comment should be removed at some point in the future. +# mypy: strict-optional=False + +# If pip's going to use distutils, it should not be using the copy that setuptools +# might have injected into the environment. This is done by removing the injected +# shim, if it's injected. +# +# See https://github.com/pypa/pip/issues/8761 for the original discussion and +# rationale for why this is done within pip. +try: + __import__("_distutils_hack").remove_shim() +except (ImportError, AttributeError): + pass + +import logging +import os +import sys +from distutils.cmd import Command as DistutilsCommand +from distutils.command.install import SCHEME_KEYS +from distutils.command.install import install as distutils_install_command +from distutils.sysconfig import get_python_lib +from typing import Dict, List, Optional, Union + +from pip._internal.models.scheme import Scheme +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.virtualenv import running_under_virtualenv + +from .base import get_major_minor_version + +logger = logging.getLogger(__name__) + + +def distutils_scheme( + dist_name: str, + user: bool = False, + home: Optional[str] = None, + root: Optional[str] = None, + isolated: bool = False, + prefix: Optional[str] = None, + *, + ignore_config_files: bool = False, +) -> Dict[str, str]: + """ + Return a distutils install scheme + """ + from distutils.dist import Distribution + + dist_args: Dict[str, Union[str, List[str]]] = {"name": dist_name} + if isolated: + dist_args["script_args"] = ["--no-user-cfg"] + + d = Distribution(dist_args) + if not ignore_config_files: + try: + d.parse_config_files() + except UnicodeDecodeError: + paths = d.find_config_files() + logger.warning( + "Ignore distutils configs in %s due to encoding errors.", + ", ".join(os.path.basename(p) for p in paths), + ) + obj: Optional[DistutilsCommand] = None + obj = d.get_command_obj("install", create=True) + assert obj is not None + i: distutils_install_command = obj + # NOTE: setting user or home has the side-effect of creating the home dir + # or user base for installations during finalize_options() + # ideally, we'd prefer a scheme class that has no side-effects. + assert not (user and prefix), f"user={user} prefix={prefix}" + assert not (home and prefix), f"home={home} prefix={prefix}" + i.user = user or i.user + if user or home: + i.prefix = "" + i.prefix = prefix or i.prefix + i.home = home or i.home + i.root = root or i.root + i.finalize_options() + + scheme: Dict[str, str] = {} + for key in SCHEME_KEYS: + scheme[key] = getattr(i, "install_" + key) + + # install_lib specified in setup.cfg should install *everything* + # into there (i.e. it takes precedence over both purelib and + # platlib). Note, i.install_lib is *always* set after + # finalize_options(); we only want to override here if the user + # has explicitly requested it hence going back to the config + if "install_lib" in d.get_option_dict("install"): + scheme.update({"purelib": i.install_lib, "platlib": i.install_lib}) + + if running_under_virtualenv(): + if home: + prefix = home + elif user: + prefix = i.install_userbase + else: + prefix = i.prefix + scheme["headers"] = os.path.join( + prefix, + "include", + "site", + f"python{get_major_minor_version()}", + dist_name, + ) + + if root is not None: + path_no_drive = os.path.splitdrive(os.path.abspath(scheme["headers"]))[1] + scheme["headers"] = os.path.join(root, path_no_drive[1:]) + + return scheme + + +def get_scheme( + dist_name: str, + user: bool = False, + home: Optional[str] = None, + root: Optional[str] = None, + isolated: bool = False, + prefix: Optional[str] = None, +) -> Scheme: + """ + Get the "scheme" corresponding to the input parameters. The distutils + documentation provides the context for the available schemes: + https://docs.python.org/3/install/index.html#alternate-installation + + :param dist_name: the name of the package to retrieve the scheme for, used + in the headers scheme path + :param user: indicates to use the "user" scheme + :param home: indicates to use the "home" scheme and provides the base + directory for the same + :param root: root under which other directories are re-based + :param isolated: equivalent to --no-user-cfg, i.e. do not consider + ~/.pydistutils.cfg (posix) or ~/pydistutils.cfg (non-posix) for + scheme paths + :param prefix: indicates to use the "prefix" scheme and provides the + base directory for the same + """ + scheme = distutils_scheme(dist_name, user, home, root, isolated, prefix) + return Scheme( + platlib=scheme["platlib"], + purelib=scheme["purelib"], + headers=scheme["headers"], + scripts=scheme["scripts"], + data=scheme["data"], + ) + + +def get_bin_prefix() -> str: + # XXX: In old virtualenv versions, sys.prefix can contain '..' components, + # so we need to call normpath to eliminate them. + prefix = os.path.normpath(sys.prefix) + if WINDOWS: + bin_py = os.path.join(prefix, "Scripts") + # buildout uses 'bin' on Windows too? + if not os.path.exists(bin_py): + bin_py = os.path.join(prefix, "bin") + return bin_py + # Forcing to use /usr/local/bin for standard macOS framework installs + # Also log to ~/Library/Logs/ for use with the Console.app log viewer + if sys.platform[:6] == "darwin" and prefix[:16] == "/System/Library/": + return "/usr/local/bin" + return os.path.join(prefix, "bin") + + +def get_purelib() -> str: + return get_python_lib(plat_specific=False) + + +def get_platlib() -> str: + return get_python_lib(plat_specific=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_sysconfig.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_sysconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..ca860ea562c2c0c30982ba6cff654355e9f21c8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/_sysconfig.py @@ -0,0 +1,214 @@ +import logging +import os +import sys +import sysconfig +import typing + +from pip._internal.exceptions import InvalidSchemeCombination, UserInstallationInvalid +from pip._internal.models.scheme import SCHEME_KEYS, Scheme +from pip._internal.utils.virtualenv import running_under_virtualenv + +from .base import change_root, get_major_minor_version, is_osx_framework + +logger = logging.getLogger(__name__) + + +# Notes on _infer_* functions. +# Unfortunately ``get_default_scheme()`` didn't exist before 3.10, so there's no +# way to ask things like "what is the '_prefix' scheme on this platform". These +# functions try to answer that with some heuristics while accounting for ad-hoc +# platforms not covered by CPython's default sysconfig implementation. If the +# ad-hoc implementation does not fully implement sysconfig, we'll fall back to +# a POSIX scheme. + +_AVAILABLE_SCHEMES = set(sysconfig.get_scheme_names()) + +_PREFERRED_SCHEME_API = getattr(sysconfig, "get_preferred_scheme", None) + + +def _should_use_osx_framework_prefix() -> bool: + """Check for Apple's ``osx_framework_library`` scheme. + + Python distributed by Apple's Command Line Tools has this special scheme + that's used when: + + * This is a framework build. + * We are installing into the system prefix. + + This does not account for ``pip install --prefix`` (also means we're not + installing to the system prefix), which should use ``posix_prefix``, but + logic here means ``_infer_prefix()`` outputs ``osx_framework_library``. But + since ``prefix`` is not available for ``sysconfig.get_default_scheme()``, + which is the stdlib replacement for ``_infer_prefix()``, presumably Apple + wouldn't be able to magically switch between ``osx_framework_library`` and + ``posix_prefix``. ``_infer_prefix()`` returning ``osx_framework_library`` + means its behavior is consistent whether we use the stdlib implementation + or our own, and we deal with this special case in ``get_scheme()`` instead. + """ + return ( + "osx_framework_library" in _AVAILABLE_SCHEMES + and not running_under_virtualenv() + and is_osx_framework() + ) + + +def _infer_prefix() -> str: + """Try to find a prefix scheme for the current platform. + + This tries: + + * A special ``osx_framework_library`` for Python distributed by Apple's + Command Line Tools, when not running in a virtual environment. + * Implementation + OS, used by PyPy on Windows (``pypy_nt``). + * Implementation without OS, used by PyPy on POSIX (``pypy``). + * OS + "prefix", used by CPython on POSIX (``posix_prefix``). + * Just the OS name, used by CPython on Windows (``nt``). + + If none of the above works, fall back to ``posix_prefix``. + """ + if _PREFERRED_SCHEME_API: + return _PREFERRED_SCHEME_API("prefix") + if _should_use_osx_framework_prefix(): + return "osx_framework_library" + implementation_suffixed = f"{sys.implementation.name}_{os.name}" + if implementation_suffixed in _AVAILABLE_SCHEMES: + return implementation_suffixed + if sys.implementation.name in _AVAILABLE_SCHEMES: + return sys.implementation.name + suffixed = f"{os.name}_prefix" + if suffixed in _AVAILABLE_SCHEMES: + return suffixed + if os.name in _AVAILABLE_SCHEMES: # On Windows, prefx is just called "nt". + return os.name + return "posix_prefix" + + +def _infer_user() -> str: + """Try to find a user scheme for the current platform.""" + if _PREFERRED_SCHEME_API: + return _PREFERRED_SCHEME_API("user") + if is_osx_framework() and not running_under_virtualenv(): + suffixed = "osx_framework_user" + else: + suffixed = f"{os.name}_user" + if suffixed in _AVAILABLE_SCHEMES: + return suffixed + if "posix_user" not in _AVAILABLE_SCHEMES: # User scheme unavailable. + raise UserInstallationInvalid() + return "posix_user" + + +def _infer_home() -> str: + """Try to find a home for the current platform.""" + if _PREFERRED_SCHEME_API: + return _PREFERRED_SCHEME_API("home") + suffixed = f"{os.name}_home" + if suffixed in _AVAILABLE_SCHEMES: + return suffixed + return "posix_home" + + +# Update these keys if the user sets a custom home. +_HOME_KEYS = [ + "installed_base", + "base", + "installed_platbase", + "platbase", + "prefix", + "exec_prefix", +] +if sysconfig.get_config_var("userbase") is not None: + _HOME_KEYS.append("userbase") + + +def get_scheme( + dist_name: str, + user: bool = False, + home: typing.Optional[str] = None, + root: typing.Optional[str] = None, + isolated: bool = False, + prefix: typing.Optional[str] = None, +) -> Scheme: + """ + Get the "scheme" corresponding to the input parameters. + + :param dist_name: the name of the package to retrieve the scheme for, used + in the headers scheme path + :param user: indicates to use the "user" scheme + :param home: indicates to use the "home" scheme + :param root: root under which other directories are re-based + :param isolated: ignored, but kept for distutils compatibility (where + this controls whether the user-site pydistutils.cfg is honored) + :param prefix: indicates to use the "prefix" scheme and provides the + base directory for the same + """ + if user and prefix: + raise InvalidSchemeCombination("--user", "--prefix") + if home and prefix: + raise InvalidSchemeCombination("--home", "--prefix") + + if home is not None: + scheme_name = _infer_home() + elif user: + scheme_name = _infer_user() + else: + scheme_name = _infer_prefix() + + # Special case: When installing into a custom prefix, use posix_prefix + # instead of osx_framework_library. See _should_use_osx_framework_prefix() + # docstring for details. + if prefix is not None and scheme_name == "osx_framework_library": + scheme_name = "posix_prefix" + + if home is not None: + variables = {k: home for k in _HOME_KEYS} + elif prefix is not None: + variables = {k: prefix for k in _HOME_KEYS} + else: + variables = {} + + paths = sysconfig.get_paths(scheme=scheme_name, vars=variables) + + # Logic here is very arbitrary, we're doing it for compatibility, don't ask. + # 1. Pip historically uses a special header path in virtual environments. + # 2. If the distribution name is not known, distutils uses 'UNKNOWN'. We + # only do the same when not running in a virtual environment because + # pip's historical header path logic (see point 1) did not do this. + if running_under_virtualenv(): + if user: + base = variables.get("userbase", sys.prefix) + else: + base = variables.get("base", sys.prefix) + python_xy = f"python{get_major_minor_version()}" + paths["include"] = os.path.join(base, "include", "site", python_xy) + elif not dist_name: + dist_name = "UNKNOWN" + + scheme = Scheme( + platlib=paths["platlib"], + purelib=paths["purelib"], + headers=os.path.join(paths["include"], dist_name), + scripts=paths["scripts"], + data=paths["data"], + ) + if root is not None: + converted_keys = {} + for key in SCHEME_KEYS: + converted_keys[key] = change_root(root, getattr(scheme, key)) + scheme = Scheme(**converted_keys) + return scheme + + +def get_bin_prefix() -> str: + # Forcing to use /usr/local/bin for standard macOS framework installs. + if sys.platform[:6] == "darwin" and sys.prefix[:16] == "/System/Library/": + return "/usr/local/bin" + return sysconfig.get_paths()["scripts"] + + +def get_purelib() -> str: + return sysconfig.get_paths()["purelib"] + + +def get_platlib() -> str: + return sysconfig.get_paths()["platlib"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9f896e632e929a63e9724ab80ecdfc9761b795 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/locations/base.py @@ -0,0 +1,81 @@ +import functools +import os +import site +import sys +import sysconfig +import typing + +from pip._internal.exceptions import InstallationError +from pip._internal.utils import appdirs +from pip._internal.utils.virtualenv import running_under_virtualenv + +# Application Directories +USER_CACHE_DIR = appdirs.user_cache_dir("pip") + +# FIXME doesn't account for venv linked to global site-packages +site_packages: str = sysconfig.get_path("purelib") + + +def get_major_minor_version() -> str: + """ + Return the major-minor version of the current Python as a string, e.g. + "3.7" or "3.10". + """ + return "{}.{}".format(*sys.version_info) + + +def change_root(new_root: str, pathname: str) -> str: + """Return 'pathname' with 'new_root' prepended. + + If 'pathname' is relative, this is equivalent to os.path.join(new_root, pathname). + Otherwise, it requires making 'pathname' relative and then joining the + two, which is tricky on DOS/Windows and Mac OS. + + This is borrowed from Python's standard library's distutils module. + """ + if os.name == "posix": + if not os.path.isabs(pathname): + return os.path.join(new_root, pathname) + else: + return os.path.join(new_root, pathname[1:]) + + elif os.name == "nt": + (drive, path) = os.path.splitdrive(pathname) + if path[0] == "\\": + path = path[1:] + return os.path.join(new_root, path) + + else: + raise InstallationError( + f"Unknown platform: {os.name}\n" + "Can not change root path prefix on unknown platform." + ) + + +def get_src_prefix() -> str: + if running_under_virtualenv(): + src_prefix = os.path.join(sys.prefix, "src") + else: + # FIXME: keep src in cwd for now (it is not a temporary folder) + try: + src_prefix = os.path.join(os.getcwd(), "src") + except OSError: + # In case the current working directory has been renamed or deleted + sys.exit("The folder you are executing pip from can no longer be found.") + + # under macOS + virtualenv sys.prefix is not properly resolved + # it is something like /path/to/python/bin/.. + return os.path.abspath(src_prefix) + + +try: + # Use getusersitepackages if this is present, as it ensures that the + # value is initialised properly. + user_site: typing.Optional[str] = site.getusersitepackages() +except AttributeError: + user_site = site.USER_SITE + + +@functools.lru_cache(maxsize=None) +def is_osx_framework() -> bool: + return bool(sysconfig.get_config_var("PYTHONFRAMEWORK")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/main.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/main.py new file mode 100644 index 0000000000000000000000000000000000000000..33c6d24cd85b55a9fb1b1e6ab784f471e2b135f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/main.py @@ -0,0 +1,12 @@ +from typing import List, Optional + + +def main(args: Optional[List[str]] = None) -> int: + """This is preserved for old console scripts that may still be referencing + it. + + For additional details, see https://github.com/pypa/pip/issues/7498. + """ + from pip._internal.utils.entrypoints import _wrapper + + return _wrapper(args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea1e7fd2e5c4908bec25d08b4ad32be05834985 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/__init__.py @@ -0,0 +1,128 @@ +import contextlib +import functools +import os +import sys +from typing import TYPE_CHECKING, List, Optional, Type, cast + +from pip._internal.utils.misc import strtobool + +from .base import BaseDistribution, BaseEnvironment, FilesystemWheel, MemoryWheel, Wheel + +if TYPE_CHECKING: + from typing import Literal, Protocol +else: + Protocol = object + +__all__ = [ + "BaseDistribution", + "BaseEnvironment", + "FilesystemWheel", + "MemoryWheel", + "Wheel", + "get_default_environment", + "get_environment", + "get_wheel_distribution", + "select_backend", +] + + +def _should_use_importlib_metadata() -> bool: + """Whether to use the ``importlib.metadata`` or ``pkg_resources`` backend. + + By default, pip uses ``importlib.metadata`` on Python 3.11+, and + ``pkg_resources`` otherwise. This can be overridden by a couple of ways: + + * If environment variable ``_PIP_USE_IMPORTLIB_METADATA`` is set, it + dictates whether ``importlib.metadata`` is used, regardless of Python + version. + * On Python 3.11+, Python distributors can patch ``importlib.metadata`` + to add a global constant ``_PIP_USE_IMPORTLIB_METADATA = False``. This + makes pip use ``pkg_resources`` (unless the user set the aforementioned + environment variable to *True*). + """ + with contextlib.suppress(KeyError, ValueError): + return bool(strtobool(os.environ["_PIP_USE_IMPORTLIB_METADATA"])) + if sys.version_info < (3, 11): + return False + import importlib.metadata + + return bool(getattr(importlib.metadata, "_PIP_USE_IMPORTLIB_METADATA", True)) + + +class Backend(Protocol): + NAME: 'Literal["importlib", "pkg_resources"]' + Distribution: Type[BaseDistribution] + Environment: Type[BaseEnvironment] + + +@functools.lru_cache(maxsize=None) +def select_backend() -> Backend: + if _should_use_importlib_metadata(): + from . import importlib + + return cast(Backend, importlib) + from . import pkg_resources + + return cast(Backend, pkg_resources) + + +def get_default_environment() -> BaseEnvironment: + """Get the default representation for the current environment. + + This returns an Environment instance from the chosen backend. The default + Environment instance should be built from ``sys.path`` and may use caching + to share instance state across calls. + """ + return select_backend().Environment.default() + + +def get_environment(paths: Optional[List[str]]) -> BaseEnvironment: + """Get a representation of the environment specified by ``paths``. + + This returns an Environment instance from the chosen backend based on the + given import paths. The backend must build a fresh instance representing + the state of installed distributions when this function is called. + """ + return select_backend().Environment.from_paths(paths) + + +def get_directory_distribution(directory: str) -> BaseDistribution: + """Get the distribution metadata representation in the specified directory. + + This returns a Distribution instance from the chosen backend based on + the given on-disk ``.dist-info`` directory. + """ + return select_backend().Distribution.from_directory(directory) + + +def get_wheel_distribution(wheel: Wheel, canonical_name: str) -> BaseDistribution: + """Get the representation of the specified wheel's distribution metadata. + + This returns a Distribution instance from the chosen backend based on + the given wheel's ``.dist-info`` directory. + + :param canonical_name: Normalized project name of the given wheel. + """ + return select_backend().Distribution.from_wheel(wheel, canonical_name) + + +def get_metadata_distribution( + metadata_contents: bytes, + filename: str, + canonical_name: str, +) -> BaseDistribution: + """Get the dist representation of the specified METADATA file contents. + + This returns a Distribution instance from the chosen backend sourced from the data + in `metadata_contents`. + + :param metadata_contents: Contents of a METADATA file within a dist, or one served + via PEP 658. + :param filename: Filename for the dist this metadata represents. + :param canonical_name: Normalized project name of the given dist. + """ + return select_backend().Distribution.from_metadata_file_contents( + metadata_contents, + filename, + canonical_name, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/_json.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/_json.py new file mode 100644 index 0000000000000000000000000000000000000000..f3aeab3225ffc8f6a63f5f40d351e04be4e3c647 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/_json.py @@ -0,0 +1,86 @@ +# Extracted from https://github.com/pfmoore/pkg_metadata + +from email.header import Header, decode_header, make_header +from email.message import Message +from typing import Any, Dict, List, Union, cast + +METADATA_FIELDS = [ + # Name, Multiple-Use + ("Metadata-Version", False), + ("Name", False), + ("Version", False), + ("Dynamic", True), + ("Platform", True), + ("Supported-Platform", True), + ("Summary", False), + ("Description", False), + ("Description-Content-Type", False), + ("Keywords", False), + ("Home-page", False), + ("Download-URL", False), + ("Author", False), + ("Author-email", False), + ("Maintainer", False), + ("Maintainer-email", False), + ("License", False), + ("License-Expression", False), + ("License-File", True), + ("Classifier", True), + ("Requires-Dist", True), + ("Requires-Python", False), + ("Requires-External", True), + ("Project-URL", True), + ("Provides-Extra", True), + ("Provides-Dist", True), + ("Obsoletes-Dist", True), +] + + +def json_name(field: str) -> str: + return field.lower().replace("-", "_") + + +def msg_to_json(msg: Message) -> Dict[str, Any]: + """Convert a Message object into a JSON-compatible dictionary.""" + + def sanitise_header(h: Union[Header, str]) -> str: + if isinstance(h, Header): + chunks = [] + for bytes, encoding in decode_header(h): + if encoding == "unknown-8bit": + try: + # See if UTF-8 works + bytes.decode("utf-8") + encoding = "utf-8" + except UnicodeDecodeError: + # If not, latin1 at least won't fail + encoding = "latin1" + chunks.append((bytes, encoding)) + return str(make_header(chunks)) + return str(h) + + result = {} + for field, multi in METADATA_FIELDS: + if field not in msg: + continue + key = json_name(field) + if multi: + value: Union[str, List[str]] = [ + sanitise_header(v) for v in msg.get_all(field) # type: ignore + ] + else: + value = sanitise_header(msg.get(field)) # type: ignore + if key == "keywords": + # Accept both comma-separated and space-separated + # forms, for better compatibility with old data. + if "," in value: + value = [v.strip() for v in value.split(",")] + else: + value = value.split() + result[key] = value + + payload = cast(str, msg.get_payload()) + if payload: + result["description"] = payload + + return result diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9eabcdb278bd53959f489a8a0fb5ee13a8512f5b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/base.py @@ -0,0 +1,688 @@ +import csv +import email.message +import functools +import json +import logging +import pathlib +import re +import zipfile +from typing import ( + IO, + Any, + Collection, + Container, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Protocol, + Tuple, + Union, +) + +from pip._vendor.packaging.requirements import Requirement +from pip._vendor.packaging.specifiers import InvalidSpecifier, SpecifierSet +from pip._vendor.packaging.utils import NormalizedName, canonicalize_name +from pip._vendor.packaging.version import Version + +from pip._internal.exceptions import NoneMetadataError +from pip._internal.locations import site_packages, user_site +from pip._internal.models.direct_url import ( + DIRECT_URL_METADATA_NAME, + DirectUrl, + DirectUrlValidationError, +) +from pip._internal.utils.compat import stdlib_pkgs # TODO: Move definition here. +from pip._internal.utils.egg_link import egg_link_path_from_sys_path +from pip._internal.utils.misc import is_local, normalize_path +from pip._internal.utils.urls import url_to_path + +from ._json import msg_to_json + +InfoPath = Union[str, pathlib.PurePath] + +logger = logging.getLogger(__name__) + + +class BaseEntryPoint(Protocol): + @property + def name(self) -> str: + raise NotImplementedError() + + @property + def value(self) -> str: + raise NotImplementedError() + + @property + def group(self) -> str: + raise NotImplementedError() + + +def _convert_installed_files_path( + entry: Tuple[str, ...], + info: Tuple[str, ...], +) -> str: + """Convert a legacy installed-files.txt path into modern RECORD path. + + The legacy format stores paths relative to the info directory, while the + modern format stores paths relative to the package root, e.g. the + site-packages directory. + + :param entry: Path parts of the installed-files.txt entry. + :param info: Path parts of the egg-info directory relative to package root. + :returns: The converted entry. + + For best compatibility with symlinks, this does not use ``abspath()`` or + ``Path.resolve()``, but tries to work with path parts: + + 1. While ``entry`` starts with ``..``, remove the equal amounts of parts + from ``info``; if ``info`` is empty, start appending ``..`` instead. + 2. Join the two directly. + """ + while entry and entry[0] == "..": + if not info or info[-1] == "..": + info += ("..",) + else: + info = info[:-1] + entry = entry[1:] + return str(pathlib.Path(*info, *entry)) + + +class RequiresEntry(NamedTuple): + requirement: str + extra: str + marker: str + + +class BaseDistribution(Protocol): + @classmethod + def from_directory(cls, directory: str) -> "BaseDistribution": + """Load the distribution from a metadata directory. + + :param directory: Path to a metadata directory, e.g. ``.dist-info``. + """ + raise NotImplementedError() + + @classmethod + def from_metadata_file_contents( + cls, + metadata_contents: bytes, + filename: str, + project_name: str, + ) -> "BaseDistribution": + """Load the distribution from the contents of a METADATA file. + + This is used to implement PEP 658 by generating a "shallow" dist object that can + be used for resolution without downloading or building the actual dist yet. + + :param metadata_contents: The contents of a METADATA file. + :param filename: File name for the dist with this metadata. + :param project_name: Name of the project this dist represents. + """ + raise NotImplementedError() + + @classmethod + def from_wheel(cls, wheel: "Wheel", name: str) -> "BaseDistribution": + """Load the distribution from a given wheel. + + :param wheel: A concrete wheel definition. + :param name: File name of the wheel. + + :raises InvalidWheel: Whenever loading of the wheel causes a + :py:exc:`zipfile.BadZipFile` exception to be thrown. + :raises UnsupportedWheel: If the wheel is a valid zip, but malformed + internally. + """ + raise NotImplementedError() + + def __repr__(self) -> str: + return f"{self.raw_name} {self.raw_version} ({self.location})" + + def __str__(self) -> str: + return f"{self.raw_name} {self.raw_version}" + + @property + def location(self) -> Optional[str]: + """Where the distribution is loaded from. + + A string value is not necessarily a filesystem path, since distributions + can be loaded from other sources, e.g. arbitrary zip archives. ``None`` + means the distribution is created in-memory. + + Do not canonicalize this value with e.g. ``pathlib.Path.resolve()``. If + this is a symbolic link, we want to preserve the relative path between + it and files in the distribution. + """ + raise NotImplementedError() + + @property + def editable_project_location(self) -> Optional[str]: + """The project location for editable distributions. + + This is the directory where pyproject.toml or setup.py is located. + None if the distribution is not installed in editable mode. + """ + # TODO: this property is relatively costly to compute, memoize it ? + direct_url = self.direct_url + if direct_url: + if direct_url.is_local_editable(): + return url_to_path(direct_url.url) + else: + # Search for an .egg-link file by walking sys.path, as it was + # done before by dist_is_editable(). + egg_link_path = egg_link_path_from_sys_path(self.raw_name) + if egg_link_path: + # TODO: get project location from second line of egg_link file + # (https://github.com/pypa/pip/issues/10243) + return self.location + return None + + @property + def installed_location(self) -> Optional[str]: + """The distribution's "installed" location. + + This should generally be a ``site-packages`` directory. This is + usually ``dist.location``, except for legacy develop-installed packages, + where ``dist.location`` is the source code location, and this is where + the ``.egg-link`` file is. + + The returned location is normalized (in particular, with symlinks removed). + """ + raise NotImplementedError() + + @property + def info_location(self) -> Optional[str]: + """Location of the .[egg|dist]-info directory or file. + + Similarly to ``location``, a string value is not necessarily a + filesystem path. ``None`` means the distribution is created in-memory. + + For a modern .dist-info installation on disk, this should be something + like ``{location}/{raw_name}-{version}.dist-info``. + + Do not canonicalize this value with e.g. ``pathlib.Path.resolve()``. If + this is a symbolic link, we want to preserve the relative path between + it and other files in the distribution. + """ + raise NotImplementedError() + + @property + def installed_by_distutils(self) -> bool: + """Whether this distribution is installed with legacy distutils format. + + A distribution installed with "raw" distutils not patched by setuptools + uses one single file at ``info_location`` to store metadata. We need to + treat this specially on uninstallation. + """ + info_location = self.info_location + if not info_location: + return False + return pathlib.Path(info_location).is_file() + + @property + def installed_as_egg(self) -> bool: + """Whether this distribution is installed as an egg. + + This usually indicates the distribution was installed by (older versions + of) easy_install. + """ + location = self.location + if not location: + return False + return location.endswith(".egg") + + @property + def installed_with_setuptools_egg_info(self) -> bool: + """Whether this distribution is installed with the ``.egg-info`` format. + + This usually indicates the distribution was installed with setuptools + with an old pip version or with ``single-version-externally-managed``. + + Note that this ensure the metadata store is a directory. distutils can + also installs an ``.egg-info``, but as a file, not a directory. This + property is *False* for that case. Also see ``installed_by_distutils``. + """ + info_location = self.info_location + if not info_location: + return False + if not info_location.endswith(".egg-info"): + return False + return pathlib.Path(info_location).is_dir() + + @property + def installed_with_dist_info(self) -> bool: + """Whether this distribution is installed with the "modern format". + + This indicates a "modern" installation, e.g. storing metadata in the + ``.dist-info`` directory. This applies to installations made by + setuptools (but through pip, not directly), or anything using the + standardized build backend interface (PEP 517). + """ + info_location = self.info_location + if not info_location: + return False + if not info_location.endswith(".dist-info"): + return False + return pathlib.Path(info_location).is_dir() + + @property + def canonical_name(self) -> NormalizedName: + raise NotImplementedError() + + @property + def version(self) -> Version: + raise NotImplementedError() + + @property + def raw_version(self) -> str: + raise NotImplementedError() + + @property + def setuptools_filename(self) -> str: + """Convert a project name to its setuptools-compatible filename. + + This is a copy of ``pkg_resources.to_filename()`` for compatibility. + """ + return self.raw_name.replace("-", "_") + + @property + def direct_url(self) -> Optional[DirectUrl]: + """Obtain a DirectUrl from this distribution. + + Returns None if the distribution has no `direct_url.json` metadata, + or if `direct_url.json` is invalid. + """ + try: + content = self.read_text(DIRECT_URL_METADATA_NAME) + except FileNotFoundError: + return None + try: + return DirectUrl.from_json(content) + except ( + UnicodeDecodeError, + json.JSONDecodeError, + DirectUrlValidationError, + ) as e: + logger.warning( + "Error parsing %s for %s: %s", + DIRECT_URL_METADATA_NAME, + self.canonical_name, + e, + ) + return None + + @property + def installer(self) -> str: + try: + installer_text = self.read_text("INSTALLER") + except (OSError, ValueError, NoneMetadataError): + return "" # Fail silently if the installer file cannot be read. + for line in installer_text.splitlines(): + cleaned_line = line.strip() + if cleaned_line: + return cleaned_line + return "" + + @property + def requested(self) -> bool: + return self.is_file("REQUESTED") + + @property + def editable(self) -> bool: + return bool(self.editable_project_location) + + @property + def local(self) -> bool: + """If distribution is installed in the current virtual environment. + + Always True if we're not in a virtualenv. + """ + if self.installed_location is None: + return False + return is_local(self.installed_location) + + @property + def in_usersite(self) -> bool: + if self.installed_location is None or user_site is None: + return False + return self.installed_location.startswith(normalize_path(user_site)) + + @property + def in_site_packages(self) -> bool: + if self.installed_location is None or site_packages is None: + return False + return self.installed_location.startswith(normalize_path(site_packages)) + + def is_file(self, path: InfoPath) -> bool: + """Check whether an entry in the info directory is a file.""" + raise NotImplementedError() + + def iter_distutils_script_names(self) -> Iterator[str]: + """Find distutils 'scripts' entries metadata. + + If 'scripts' is supplied in ``setup.py``, distutils records those in the + installed distribution's ``scripts`` directory, a file for each script. + """ + raise NotImplementedError() + + def read_text(self, path: InfoPath) -> str: + """Read a file in the info directory. + + :raise FileNotFoundError: If ``path`` does not exist in the directory. + :raise NoneMetadataError: If ``path`` exists in the info directory, but + cannot be read. + """ + raise NotImplementedError() + + def iter_entry_points(self) -> Iterable[BaseEntryPoint]: + raise NotImplementedError() + + def _metadata_impl(self) -> email.message.Message: + raise NotImplementedError() + + @functools.cached_property + def metadata(self) -> email.message.Message: + """Metadata of distribution parsed from e.g. METADATA or PKG-INFO. + + This should return an empty message if the metadata file is unavailable. + + :raises NoneMetadataError: If the metadata file is available, but does + not contain valid metadata. + """ + metadata = self._metadata_impl() + self._add_egg_info_requires(metadata) + return metadata + + @property + def metadata_dict(self) -> Dict[str, Any]: + """PEP 566 compliant JSON-serializable representation of METADATA or PKG-INFO. + + This should return an empty dict if the metadata file is unavailable. + + :raises NoneMetadataError: If the metadata file is available, but does + not contain valid metadata. + """ + return msg_to_json(self.metadata) + + @property + def metadata_version(self) -> Optional[str]: + """Value of "Metadata-Version:" in distribution metadata, if available.""" + return self.metadata.get("Metadata-Version") + + @property + def raw_name(self) -> str: + """Value of "Name:" in distribution metadata.""" + # The metadata should NEVER be missing the Name: key, but if it somehow + # does, fall back to the known canonical name. + return self.metadata.get("Name", self.canonical_name) + + @property + def requires_python(self) -> SpecifierSet: + """Value of "Requires-Python:" in distribution metadata. + + If the key does not exist or contains an invalid value, an empty + SpecifierSet should be returned. + """ + value = self.metadata.get("Requires-Python") + if value is None: + return SpecifierSet() + try: + # Convert to str to satisfy the type checker; this can be a Header object. + spec = SpecifierSet(str(value)) + except InvalidSpecifier as e: + message = "Package %r has an invalid Requires-Python: %s" + logger.warning(message, self.raw_name, e) + return SpecifierSet() + return spec + + def iter_dependencies(self, extras: Collection[str] = ()) -> Iterable[Requirement]: + """Dependencies of this distribution. + + For modern .dist-info distributions, this is the collection of + "Requires-Dist:" entries in distribution metadata. + """ + raise NotImplementedError() + + def iter_raw_dependencies(self) -> Iterable[str]: + """Raw Requires-Dist metadata.""" + return self.metadata.get_all("Requires-Dist", []) + + def iter_provided_extras(self) -> Iterable[NormalizedName]: + """Extras provided by this distribution. + + For modern .dist-info distributions, this is the collection of + "Provides-Extra:" entries in distribution metadata. + + The return value of this function is expected to be normalised names, + per PEP 685, with the returned value being handled appropriately by + `iter_dependencies`. + """ + raise NotImplementedError() + + def _iter_declared_entries_from_record(self) -> Optional[Iterator[str]]: + try: + text = self.read_text("RECORD") + except FileNotFoundError: + return None + # This extra Path-str cast normalizes entries. + return (str(pathlib.Path(row[0])) for row in csv.reader(text.splitlines())) + + def _iter_declared_entries_from_legacy(self) -> Optional[Iterator[str]]: + try: + text = self.read_text("installed-files.txt") + except FileNotFoundError: + return None + paths = (p for p in text.splitlines(keepends=False) if p) + root = self.location + info = self.info_location + if root is None or info is None: + return paths + try: + info_rel = pathlib.Path(info).relative_to(root) + except ValueError: # info is not relative to root. + return paths + if not info_rel.parts: # info *is* root. + return paths + return ( + _convert_installed_files_path(pathlib.Path(p).parts, info_rel.parts) + for p in paths + ) + + def iter_declared_entries(self) -> Optional[Iterator[str]]: + """Iterate through file entries declared in this distribution. + + For modern .dist-info distributions, this is the files listed in the + ``RECORD`` metadata file. For legacy setuptools distributions, this + comes from ``installed-files.txt``, with entries normalized to be + compatible with the format used by ``RECORD``. + + :return: An iterator for listed entries, or None if the distribution + contains neither ``RECORD`` nor ``installed-files.txt``. + """ + return ( + self._iter_declared_entries_from_record() + or self._iter_declared_entries_from_legacy() + ) + + def _iter_requires_txt_entries(self) -> Iterator[RequiresEntry]: + """Parse a ``requires.txt`` in an egg-info directory. + + This is an INI-ish format where an egg-info stores dependencies. A + section name describes extra other environment markers, while each entry + is an arbitrary string (not a key-value pair) representing a dependency + as a requirement string (no markers). + + There is a construct in ``importlib.metadata`` called ``Sectioned`` that + does mostly the same, but the format is currently considered private. + """ + try: + content = self.read_text("requires.txt") + except FileNotFoundError: + return + extra = marker = "" # Section-less entries don't have markers. + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): # Comment; ignored. + continue + if line.startswith("[") and line.endswith("]"): # A section header. + extra, _, marker = line.strip("[]").partition(":") + continue + yield RequiresEntry(requirement=line, extra=extra, marker=marker) + + def _iter_egg_info_extras(self) -> Iterable[str]: + """Get extras from the egg-info directory.""" + known_extras = {""} + for entry in self._iter_requires_txt_entries(): + extra = canonicalize_name(entry.extra) + if extra in known_extras: + continue + known_extras.add(extra) + yield extra + + def _iter_egg_info_dependencies(self) -> Iterable[str]: + """Get distribution dependencies from the egg-info directory. + + To ease parsing, this converts a legacy dependency entry into a PEP 508 + requirement string. Like ``_iter_requires_txt_entries()``, there is code + in ``importlib.metadata`` that does mostly the same, but not do exactly + what we need. + + Namely, ``importlib.metadata`` does not normalize the extra name before + putting it into the requirement string, which causes marker comparison + to fail because the dist-info format do normalize. This is consistent in + all currently available PEP 517 backends, although not standardized. + """ + for entry in self._iter_requires_txt_entries(): + extra = canonicalize_name(entry.extra) + if extra and entry.marker: + marker = f'({entry.marker}) and extra == "{extra}"' + elif extra: + marker = f'extra == "{extra}"' + elif entry.marker: + marker = entry.marker + else: + marker = "" + if marker: + yield f"{entry.requirement} ; {marker}" + else: + yield entry.requirement + + def _add_egg_info_requires(self, metadata: email.message.Message) -> None: + """Add egg-info requires.txt information to the metadata.""" + if not metadata.get_all("Requires-Dist"): + for dep in self._iter_egg_info_dependencies(): + metadata["Requires-Dist"] = dep + if not metadata.get_all("Provides-Extra"): + for extra in self._iter_egg_info_extras(): + metadata["Provides-Extra"] = extra + + +class BaseEnvironment: + """An environment containing distributions to introspect.""" + + @classmethod + def default(cls) -> "BaseEnvironment": + raise NotImplementedError() + + @classmethod + def from_paths(cls, paths: Optional[List[str]]) -> "BaseEnvironment": + raise NotImplementedError() + + def get_distribution(self, name: str) -> Optional["BaseDistribution"]: + """Given a requirement name, return the installed distributions. + + The name may not be normalized. The implementation must canonicalize + it for lookup. + """ + raise NotImplementedError() + + def _iter_distributions(self) -> Iterator["BaseDistribution"]: + """Iterate through installed distributions. + + This function should be implemented by subclass, but never called + directly. Use the public ``iter_distribution()`` instead, which + implements additional logic to make sure the distributions are valid. + """ + raise NotImplementedError() + + def iter_all_distributions(self) -> Iterator[BaseDistribution]: + """Iterate through all installed distributions without any filtering.""" + for dist in self._iter_distributions(): + # Make sure the distribution actually comes from a valid Python + # packaging distribution. Pip's AdjacentTempDirectory leaves folders + # e.g. ``~atplotlib.dist-info`` if cleanup was interrupted. The + # valid project name pattern is taken from PEP 508. + project_name_valid = re.match( + r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", + dist.canonical_name, + flags=re.IGNORECASE, + ) + if not project_name_valid: + logger.warning( + "Ignoring invalid distribution %s (%s)", + dist.canonical_name, + dist.location, + ) + continue + yield dist + + def iter_installed_distributions( + self, + local_only: bool = True, + skip: Container[str] = stdlib_pkgs, + include_editables: bool = True, + editables_only: bool = False, + user_only: bool = False, + ) -> Iterator[BaseDistribution]: + """Return a list of installed distributions. + + This is based on ``iter_all_distributions()`` with additional filtering + options. Note that ``iter_installed_distributions()`` without arguments + is *not* equal to ``iter_all_distributions()``, since some of the + configurations exclude packages by default. + + :param local_only: If True (default), only return installations + local to the current virtualenv, if in a virtualenv. + :param skip: An iterable of canonicalized project names to ignore; + defaults to ``stdlib_pkgs``. + :param include_editables: If False, don't report editables. + :param editables_only: If True, only report editables. + :param user_only: If True, only report installations in the user + site directory. + """ + it = self.iter_all_distributions() + if local_only: + it = (d for d in it if d.local) + if not include_editables: + it = (d for d in it if not d.editable) + if editables_only: + it = (d for d in it if d.editable) + if user_only: + it = (d for d in it if d.in_usersite) + return (d for d in it if d.canonical_name not in skip) + + +class Wheel(Protocol): + location: str + + def as_zipfile(self) -> zipfile.ZipFile: + raise NotImplementedError() + + +class FilesystemWheel(Wheel): + def __init__(self, location: str) -> None: + self.location = location + + def as_zipfile(self) -> zipfile.ZipFile: + return zipfile.ZipFile(self.location, allowZip64=True) + + +class MemoryWheel(Wheel): + def __init__(self, location: str, stream: IO[bytes]) -> None: + self.location = location + self.stream = stream + + def as_zipfile(self) -> zipfile.ZipFile: + return zipfile.ZipFile(self.stream, allowZip64=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/pkg_resources.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/pkg_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea84f93a6fb8f2d04230d70491eac7809672031 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/metadata/pkg_resources.py @@ -0,0 +1,301 @@ +import email.message +import email.parser +import logging +import os +import zipfile +from typing import ( + Collection, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, +) + +from pip._vendor import pkg_resources +from pip._vendor.packaging.requirements import Requirement +from pip._vendor.packaging.utils import NormalizedName, canonicalize_name +from pip._vendor.packaging.version import Version +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.exceptions import InvalidWheel, NoneMetadataError, UnsupportedWheel +from pip._internal.utils.egg_link import egg_link_path_from_location +from pip._internal.utils.misc import display_path, normalize_path +from pip._internal.utils.wheel import parse_wheel, read_wheel_metadata_file + +from .base import ( + BaseDistribution, + BaseEntryPoint, + BaseEnvironment, + InfoPath, + Wheel, +) + +__all__ = ["NAME", "Distribution", "Environment"] + +logger = logging.getLogger(__name__) + +NAME = "pkg_resources" + + +class EntryPoint(NamedTuple): + name: str + value: str + group: str + + +class InMemoryMetadata: + """IMetadataProvider that reads metadata files from a dictionary. + + This also maps metadata decoding exceptions to our internal exception type. + """ + + def __init__(self, metadata: Mapping[str, bytes], wheel_name: str) -> None: + self._metadata = metadata + self._wheel_name = wheel_name + + def has_metadata(self, name: str) -> bool: + return name in self._metadata + + def get_metadata(self, name: str) -> str: + try: + return self._metadata[name].decode() + except UnicodeDecodeError as e: + # Augment the default error with the origin of the file. + raise UnsupportedWheel( + f"Error decoding metadata for {self._wheel_name}: {e} in {name} file" + ) + + def get_metadata_lines(self, name: str) -> Iterable[str]: + return pkg_resources.yield_lines(self.get_metadata(name)) + + def metadata_isdir(self, name: str) -> bool: + return False + + def metadata_listdir(self, name: str) -> List[str]: + return [] + + def run_script(self, script_name: str, namespace: str) -> None: + pass + + +class Distribution(BaseDistribution): + def __init__(self, dist: pkg_resources.Distribution) -> None: + self._dist = dist + # This is populated lazily, to avoid loading metadata for all possible + # distributions eagerly. + self.__extra_mapping: Optional[Mapping[NormalizedName, str]] = None + + @property + def _extra_mapping(self) -> Mapping[NormalizedName, str]: + if self.__extra_mapping is None: + self.__extra_mapping = { + canonicalize_name(extra): extra for extra in self._dist.extras + } + + return self.__extra_mapping + + @classmethod + def from_directory(cls, directory: str) -> BaseDistribution: + dist_dir = directory.rstrip(os.sep) + + # Build a PathMetadata object, from path to metadata. :wink: + base_dir, dist_dir_name = os.path.split(dist_dir) + metadata = pkg_resources.PathMetadata(base_dir, dist_dir) + + # Determine the correct Distribution object type. + if dist_dir.endswith(".egg-info"): + dist_cls = pkg_resources.Distribution + dist_name = os.path.splitext(dist_dir_name)[0] + else: + assert dist_dir.endswith(".dist-info") + dist_cls = pkg_resources.DistInfoDistribution + dist_name = os.path.splitext(dist_dir_name)[0].split("-")[0] + + dist = dist_cls(base_dir, project_name=dist_name, metadata=metadata) + return cls(dist) + + @classmethod + def from_metadata_file_contents( + cls, + metadata_contents: bytes, + filename: str, + project_name: str, + ) -> BaseDistribution: + metadata_dict = { + "METADATA": metadata_contents, + } + dist = pkg_resources.DistInfoDistribution( + location=filename, + metadata=InMemoryMetadata(metadata_dict, filename), + project_name=project_name, + ) + return cls(dist) + + @classmethod + def from_wheel(cls, wheel: Wheel, name: str) -> BaseDistribution: + try: + with wheel.as_zipfile() as zf: + info_dir, _ = parse_wheel(zf, name) + metadata_dict = { + path.split("/", 1)[-1]: read_wheel_metadata_file(zf, path) + for path in zf.namelist() + if path.startswith(f"{info_dir}/") + } + except zipfile.BadZipFile as e: + raise InvalidWheel(wheel.location, name) from e + except UnsupportedWheel as e: + raise UnsupportedWheel(f"{name} has an invalid wheel, {e}") + dist = pkg_resources.DistInfoDistribution( + location=wheel.location, + metadata=InMemoryMetadata(metadata_dict, wheel.location), + project_name=name, + ) + return cls(dist) + + @property + def location(self) -> Optional[str]: + return self._dist.location + + @property + def installed_location(self) -> Optional[str]: + egg_link = egg_link_path_from_location(self.raw_name) + if egg_link: + location = egg_link + elif self.location: + location = self.location + else: + return None + return normalize_path(location) + + @property + def info_location(self) -> Optional[str]: + return self._dist.egg_info + + @property + def installed_by_distutils(self) -> bool: + # A distutils-installed distribution is provided by FileMetadata. This + # provider has a "path" attribute not present anywhere else. Not the + # best introspection logic, but pip has been doing this for a long time. + try: + return bool(self._dist._provider.path) + except AttributeError: + return False + + @property + def canonical_name(self) -> NormalizedName: + return canonicalize_name(self._dist.project_name) + + @property + def version(self) -> Version: + return parse_version(self._dist.version) + + @property + def raw_version(self) -> str: + return self._dist.version + + def is_file(self, path: InfoPath) -> bool: + return self._dist.has_metadata(str(path)) + + def iter_distutils_script_names(self) -> Iterator[str]: + yield from self._dist.metadata_listdir("scripts") + + def read_text(self, path: InfoPath) -> str: + name = str(path) + if not self._dist.has_metadata(name): + raise FileNotFoundError(name) + content = self._dist.get_metadata(name) + if content is None: + raise NoneMetadataError(self, name) + return content + + def iter_entry_points(self) -> Iterable[BaseEntryPoint]: + for group, entries in self._dist.get_entry_map().items(): + for name, entry_point in entries.items(): + name, _, value = str(entry_point).partition("=") + yield EntryPoint(name=name.strip(), value=value.strip(), group=group) + + def _metadata_impl(self) -> email.message.Message: + """ + :raises NoneMetadataError: if the distribution reports `has_metadata()` + True but `get_metadata()` returns None. + """ + if isinstance(self._dist, pkg_resources.DistInfoDistribution): + metadata_name = "METADATA" + else: + metadata_name = "PKG-INFO" + try: + metadata = self.read_text(metadata_name) + except FileNotFoundError: + if self.location: + displaying_path = display_path(self.location) + else: + displaying_path = repr(self.location) + logger.warning("No metadata found in %s", displaying_path) + metadata = "" + feed_parser = email.parser.FeedParser() + feed_parser.feed(metadata) + return feed_parser.close() + + def iter_dependencies(self, extras: Collection[str] = ()) -> Iterable[Requirement]: + if extras: + relevant_extras = set(self._extra_mapping) & set( + map(canonicalize_name, extras) + ) + extras = [self._extra_mapping[extra] for extra in relevant_extras] + return self._dist.requires(extras) + + def iter_provided_extras(self) -> Iterable[NormalizedName]: + return self._extra_mapping.keys() + + +class Environment(BaseEnvironment): + def __init__(self, ws: pkg_resources.WorkingSet) -> None: + self._ws = ws + + @classmethod + def default(cls) -> BaseEnvironment: + return cls(pkg_resources.working_set) + + @classmethod + def from_paths(cls, paths: Optional[List[str]]) -> BaseEnvironment: + return cls(pkg_resources.WorkingSet(paths)) + + def _iter_distributions(self) -> Iterator[BaseDistribution]: + for dist in self._ws: + yield Distribution(dist) + + def _search_distribution(self, name: str) -> Optional[BaseDistribution]: + """Find a distribution matching the ``name`` in the environment. + + This searches from *all* distributions available in the environment, to + match the behavior of ``pkg_resources.get_distribution()``. + """ + canonical_name = canonicalize_name(name) + for dist in self.iter_all_distributions(): + if dist.canonical_name == canonical_name: + return dist + return None + + def get_distribution(self, name: str) -> Optional[BaseDistribution]: + # Search the distribution by looking through the working set. + dist = self._search_distribution(name) + if dist: + return dist + + # If distribution could not be found, call working_set.require to + # update the working set, and try to find the distribution again. + # This might happen for e.g. when you install a package twice, once + # using setup.py develop and again using setup.py install. Now when + # running pip uninstall twice, the package gets removed from the + # working set in the first uninstall, so we have to populate the + # working set again so that pip knows about it and the packages gets + # picked up and is successfully uninstalled the second time too. + try: + # We didn't pass in any version specifiers, so this can never + # raise pkg_resources.VersionConflict. + self._ws.require(name) + except pkg_resources.DistributionNotFound: + return None + return self._search_distribution(name) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7855226e4b500142deef8fb247cd33a9a991d122 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/__init__.py @@ -0,0 +1,2 @@ +"""A package that contains models that represent entities. +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/candidate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/candidate.py new file mode 100644 index 0000000000000000000000000000000000000000..f27f283154ac5aa55d52ccac754138b36341ff6b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/candidate.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + +from pip._vendor.packaging.version import Version +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.models.link import Link + + +@dataclass(frozen=True) +class InstallationCandidate: + """Represents a potential "candidate" for installation.""" + + __slots__ = ["name", "version", "link"] + + name: str + version: Version + link: Link + + def __init__(self, name: str, version: str, link: Link) -> None: + object.__setattr__(self, "name", name) + object.__setattr__(self, "version", parse_version(version)) + object.__setattr__(self, "link", link) + + def __str__(self) -> str: + return f"{self.name!r} candidate (version {self.version} at {self.link})" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/direct_url.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/direct_url.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5ec8d4aa9b02b7264f7a5a0222e7e1fe215ad0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/direct_url.py @@ -0,0 +1,224 @@ +""" PEP 610 """ + +import json +import re +import urllib.parse +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, Iterable, Optional, Type, TypeVar, Union + +__all__ = [ + "DirectUrl", + "DirectUrlValidationError", + "DirInfo", + "ArchiveInfo", + "VcsInfo", +] + +T = TypeVar("T") + +DIRECT_URL_METADATA_NAME = "direct_url.json" +ENV_VAR_RE = re.compile(r"^\$\{[A-Za-z0-9-_]+\}(:\$\{[A-Za-z0-9-_]+\})?$") + + +class DirectUrlValidationError(Exception): + pass + + +def _get( + d: Dict[str, Any], expected_type: Type[T], key: str, default: Optional[T] = None +) -> Optional[T]: + """Get value from dictionary and verify expected type.""" + if key not in d: + return default + value = d[key] + if not isinstance(value, expected_type): + raise DirectUrlValidationError( + f"{value!r} has unexpected type for {key} (expected {expected_type})" + ) + return value + + +def _get_required( + d: Dict[str, Any], expected_type: Type[T], key: str, default: Optional[T] = None +) -> T: + value = _get(d, expected_type, key, default) + if value is None: + raise DirectUrlValidationError(f"{key} must have a value") + return value + + +def _exactly_one_of(infos: Iterable[Optional["InfoType"]]) -> "InfoType": + infos = [info for info in infos if info is not None] + if not infos: + raise DirectUrlValidationError( + "missing one of archive_info, dir_info, vcs_info" + ) + if len(infos) > 1: + raise DirectUrlValidationError( + "more than one of archive_info, dir_info, vcs_info" + ) + assert infos[0] is not None + return infos[0] + + +def _filter_none(**kwargs: Any) -> Dict[str, Any]: + """Make dict excluding None values.""" + return {k: v for k, v in kwargs.items() if v is not None} + + +@dataclass +class VcsInfo: + name: ClassVar = "vcs_info" + + vcs: str + commit_id: str + requested_revision: Optional[str] = None + + @classmethod + def _from_dict(cls, d: Optional[Dict[str, Any]]) -> Optional["VcsInfo"]: + if d is None: + return None + return cls( + vcs=_get_required(d, str, "vcs"), + commit_id=_get_required(d, str, "commit_id"), + requested_revision=_get(d, str, "requested_revision"), + ) + + def _to_dict(self) -> Dict[str, Any]: + return _filter_none( + vcs=self.vcs, + requested_revision=self.requested_revision, + commit_id=self.commit_id, + ) + + +class ArchiveInfo: + name = "archive_info" + + def __init__( + self, + hash: Optional[str] = None, + hashes: Optional[Dict[str, str]] = None, + ) -> None: + # set hashes before hash, since the hash setter will further populate hashes + self.hashes = hashes + self.hash = hash + + @property + def hash(self) -> Optional[str]: + return self._hash + + @hash.setter + def hash(self, value: Optional[str]) -> None: + if value is not None: + # Auto-populate the hashes key to upgrade to the new format automatically. + # We don't back-populate the legacy hash key from hashes. + try: + hash_name, hash_value = value.split("=", 1) + except ValueError: + raise DirectUrlValidationError( + f"invalid archive_info.hash format: {value!r}" + ) + if self.hashes is None: + self.hashes = {hash_name: hash_value} + elif hash_name not in self.hashes: + self.hashes = self.hashes.copy() + self.hashes[hash_name] = hash_value + self._hash = value + + @classmethod + def _from_dict(cls, d: Optional[Dict[str, Any]]) -> Optional["ArchiveInfo"]: + if d is None: + return None + return cls(hash=_get(d, str, "hash"), hashes=_get(d, dict, "hashes")) + + def _to_dict(self) -> Dict[str, Any]: + return _filter_none(hash=self.hash, hashes=self.hashes) + + +@dataclass +class DirInfo: + name: ClassVar = "dir_info" + + editable: bool = False + + @classmethod + def _from_dict(cls, d: Optional[Dict[str, Any]]) -> Optional["DirInfo"]: + if d is None: + return None + return cls(editable=_get_required(d, bool, "editable", default=False)) + + def _to_dict(self) -> Dict[str, Any]: + return _filter_none(editable=self.editable or None) + + +InfoType = Union[ArchiveInfo, DirInfo, VcsInfo] + + +@dataclass +class DirectUrl: + url: str + info: InfoType + subdirectory: Optional[str] = None + + def _remove_auth_from_netloc(self, netloc: str) -> str: + if "@" not in netloc: + return netloc + user_pass, netloc_no_user_pass = netloc.split("@", 1) + if ( + isinstance(self.info, VcsInfo) + and self.info.vcs == "git" + and user_pass == "git" + ): + return netloc + if ENV_VAR_RE.match(user_pass): + return netloc + return netloc_no_user_pass + + @property + def redacted_url(self) -> str: + """url with user:password part removed unless it is formed with + environment variables as specified in PEP 610, or it is ``git`` + in the case of a git URL. + """ + purl = urllib.parse.urlsplit(self.url) + netloc = self._remove_auth_from_netloc(purl.netloc) + surl = urllib.parse.urlunsplit( + (purl.scheme, netloc, purl.path, purl.query, purl.fragment) + ) + return surl + + def validate(self) -> None: + self.from_dict(self.to_dict()) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "DirectUrl": + return DirectUrl( + url=_get_required(d, str, "url"), + subdirectory=_get(d, str, "subdirectory"), + info=_exactly_one_of( + [ + ArchiveInfo._from_dict(_get(d, dict, "archive_info")), + DirInfo._from_dict(_get(d, dict, "dir_info")), + VcsInfo._from_dict(_get(d, dict, "vcs_info")), + ] + ), + ) + + def to_dict(self) -> Dict[str, Any]: + res = _filter_none( + url=self.redacted_url, + subdirectory=self.subdirectory, + ) + res[self.info.name] = self.info._to_dict() + return res + + @classmethod + def from_json(cls, s: str) -> "DirectUrl": + return cls.from_dict(json.loads(s)) + + def to_json(self) -> str: + return json.dumps(self.to_dict(), sort_keys=True) + + def is_local_editable(self) -> bool: + return isinstance(self.info, DirInfo) and self.info.editable diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/format_control.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/format_control.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd11272c030c2d067e1bb6d90fc744c7379a923 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/format_control.py @@ -0,0 +1,78 @@ +from typing import FrozenSet, Optional, Set + +from pip._vendor.packaging.utils import canonicalize_name + +from pip._internal.exceptions import CommandError + + +class FormatControl: + """Helper for managing formats from which a package can be installed.""" + + __slots__ = ["no_binary", "only_binary"] + + def __init__( + self, + no_binary: Optional[Set[str]] = None, + only_binary: Optional[Set[str]] = None, + ) -> None: + if no_binary is None: + no_binary = set() + if only_binary is None: + only_binary = set() + + self.no_binary = no_binary + self.only_binary = only_binary + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + + if self.__slots__ != other.__slots__: + return False + + return all(getattr(self, k) == getattr(other, k) for k in self.__slots__) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.no_binary}, {self.only_binary})" + + @staticmethod + def handle_mutual_excludes(value: str, target: Set[str], other: Set[str]) -> None: + if value.startswith("-"): + raise CommandError( + "--no-binary / --only-binary option requires 1 argument." + ) + new = value.split(",") + while ":all:" in new: + other.clear() + target.clear() + target.add(":all:") + del new[: new.index(":all:") + 1] + # Without a none, we want to discard everything as :all: covers it + if ":none:" not in new: + return + for name in new: + if name == ":none:": + target.clear() + continue + name = canonicalize_name(name) + other.discard(name) + target.add(name) + + def get_allowed_formats(self, canonical_name: str) -> FrozenSet[str]: + result = {"binary", "source"} + if canonical_name in self.only_binary: + result.discard("source") + elif canonical_name in self.no_binary: + result.discard("binary") + elif ":all:" in self.only_binary: + result.discard("source") + elif ":all:" in self.no_binary: + result.discard("binary") + return frozenset(result) + + def disallow_binaries(self) -> None: + self.handle_mutual_excludes( + ":all:", + self.no_binary, + self.only_binary, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/index.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/index.py new file mode 100644 index 0000000000000000000000000000000000000000..b94c32511f0cda2363bfc4f29c9c8bfcc7101f9b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/index.py @@ -0,0 +1,28 @@ +import urllib.parse + + +class PackageIndex: + """Represents a Package Index and provides easier access to endpoints""" + + __slots__ = ["url", "netloc", "simple_url", "pypi_url", "file_storage_domain"] + + def __init__(self, url: str, file_storage_domain: str) -> None: + super().__init__() + self.url = url + self.netloc = urllib.parse.urlsplit(url).netloc + self.simple_url = self._url_for_path("simple") + self.pypi_url = self._url_for_path("pypi") + + # This is part of a temporary hack used to block installs of PyPI + # packages which depend on external urls only necessary until PyPI can + # block such packages themselves + self.file_storage_domain = file_storage_domain + + def _url_for_path(self, path: str) -> str: + return urllib.parse.urljoin(self.url, path) + + +PyPI = PackageIndex("https://pypi.org/", file_storage_domain="files.pythonhosted.org") +TestPyPI = PackageIndex( + "https://test.pypi.org/", file_storage_domain="test-files.pythonhosted.org" +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/installation_report.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/installation_report.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c6330df32bd2b57c885156cb7f8c0c8c3e3741 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/installation_report.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, Sequence + +from pip._vendor.packaging.markers import default_environment + +from pip import __version__ +from pip._internal.req.req_install import InstallRequirement + + +class InstallationReport: + def __init__(self, install_requirements: Sequence[InstallRequirement]): + self._install_requirements = install_requirements + + @classmethod + def _install_req_to_dict(cls, ireq: InstallRequirement) -> Dict[str, Any]: + assert ireq.download_info, f"No download_info for {ireq}" + res = { + # PEP 610 json for the download URL. download_info.archive_info.hashes may + # be absent when the requirement was installed from the wheel cache + # and the cache entry was populated by an older pip version that did not + # record origin.json. + "download_info": ireq.download_info.to_dict(), + # is_direct is true if the requirement was a direct URL reference (which + # includes editable requirements), and false if the requirement was + # downloaded from a PEP 503 index or --find-links. + "is_direct": ireq.is_direct, + # is_yanked is true if the requirement was yanked from the index, but + # was still selected by pip to conform to PEP 592. + "is_yanked": ireq.link.is_yanked if ireq.link else False, + # requested is true if the requirement was specified by the user (aka + # top level requirement), and false if it was installed as a dependency of a + # requirement. https://peps.python.org/pep-0376/#requested + "requested": ireq.user_supplied, + # PEP 566 json encoding for metadata + # https://www.python.org/dev/peps/pep-0566/#json-compatible-metadata + "metadata": ireq.get_dist().metadata_dict, + } + if ireq.user_supplied and ireq.extras: + # For top level requirements, the list of requested extras, if any. + res["requested_extras"] = sorted(ireq.extras) + return res + + def to_dict(self) -> Dict[str, Any]: + return { + "version": "1", + "pip_version": __version__, + "install": [ + self._install_req_to_dict(ireq) for ireq in self._install_requirements + ], + # https://peps.python.org/pep-0508/#environment-markers + # TODO: currently, the resolver uses the default environment to evaluate + # environment markers, so that is what we report here. In the future, it + # should also take into account options such as --python-version or + # --platform, perhaps under the form of an environment_override field? + # https://github.com/pypa/pip/issues/11198 + "environment": default_environment(), + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/link.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/link.py new file mode 100644 index 0000000000000000000000000000000000000000..27ad016090c565af4375d9a236d363c2be62532c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/link.py @@ -0,0 +1,604 @@ +import functools +import itertools +import logging +import os +import posixpath +import re +import urllib.parse +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + NamedTuple, + Optional, + Tuple, + Union, +) + +from pip._internal.utils.deprecation import deprecated +from pip._internal.utils.filetypes import WHEEL_EXTENSION +from pip._internal.utils.hashes import Hashes +from pip._internal.utils.misc import ( + pairwise, + redact_auth_from_url, + split_auth_from_netloc, + splitext, +) +from pip._internal.utils.urls import path_to_url, url_to_path + +if TYPE_CHECKING: + from pip._internal.index.collector import IndexContent + +logger = logging.getLogger(__name__) + + +# Order matters, earlier hashes have a precedence over later hashes for what +# we will pick to use. +_SUPPORTED_HASHES = ("sha512", "sha384", "sha256", "sha224", "sha1", "md5") + + +@dataclass(frozen=True) +class LinkHash: + """Links to content may have embedded hash values. This class parses those. + + `name` must be any member of `_SUPPORTED_HASHES`. + + This class can be converted to and from `ArchiveInfo`. While ArchiveInfo intends to + be JSON-serializable to conform to PEP 610, this class contains the logic for + parsing a hash name and value for correctness, and then checking whether that hash + conforms to a schema with `.is_hash_allowed()`.""" + + name: str + value: str + + _hash_url_fragment_re = re.compile( + # NB: we do not validate that the second group (.*) is a valid hex + # digest. Instead, we simply keep that string in this class, and then check it + # against Hashes when hash-checking is needed. This is easier to debug than + # proactively discarding an invalid hex digest, as we handle incorrect hashes + # and malformed hashes in the same place. + r"[#&]({choices})=([^&]*)".format( + choices="|".join(re.escape(hash_name) for hash_name in _SUPPORTED_HASHES) + ), + ) + + def __post_init__(self) -> None: + assert self.name in _SUPPORTED_HASHES + + @classmethod + @functools.lru_cache(maxsize=None) + def find_hash_url_fragment(cls, url: str) -> Optional["LinkHash"]: + """Search a string for a checksum algorithm name and encoded output value.""" + match = cls._hash_url_fragment_re.search(url) + if match is None: + return None + name, value = match.groups() + return cls(name=name, value=value) + + def as_dict(self) -> Dict[str, str]: + return {self.name: self.value} + + def as_hashes(self) -> Hashes: + """Return a Hashes instance which checks only for the current hash.""" + return Hashes({self.name: [self.value]}) + + def is_hash_allowed(self, hashes: Optional[Hashes]) -> bool: + """ + Return True if the current hash is allowed by `hashes`. + """ + if hashes is None: + return False + return hashes.is_hash_allowed(self.name, hex_digest=self.value) + + +@dataclass(frozen=True) +class MetadataFile: + """Information about a core metadata file associated with a distribution.""" + + hashes: Optional[Dict[str, str]] + + def __post_init__(self) -> None: + if self.hashes is not None: + assert all(name in _SUPPORTED_HASHES for name in self.hashes) + + +def supported_hashes(hashes: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]: + # Remove any unsupported hash types from the mapping. If this leaves no + # supported hashes, return None + if hashes is None: + return None + hashes = {n: v for n, v in hashes.items() if n in _SUPPORTED_HASHES} + if not hashes: + return None + return hashes + + +def _clean_url_path_part(part: str) -> str: + """ + Clean a "part" of a URL path (i.e. after splitting on "@" characters). + """ + # We unquote prior to quoting to make sure nothing is double quoted. + return urllib.parse.quote(urllib.parse.unquote(part)) + + +def _clean_file_url_path(part: str) -> str: + """ + Clean the first part of a URL path that corresponds to a local + filesystem path (i.e. the first part after splitting on "@" characters). + """ + # We unquote prior to quoting to make sure nothing is double quoted. + # Also, on Windows the path part might contain a drive letter which + # should not be quoted. On Linux where drive letters do not + # exist, the colon should be quoted. We rely on urllib.request + # to do the right thing here. + return urllib.request.pathname2url(urllib.request.url2pathname(part)) + + +# percent-encoded: / +_reserved_chars_re = re.compile("(@|%2F)", re.IGNORECASE) + + +def _clean_url_path(path: str, is_local_path: bool) -> str: + """ + Clean the path portion of a URL. + """ + if is_local_path: + clean_func = _clean_file_url_path + else: + clean_func = _clean_url_path_part + + # Split on the reserved characters prior to cleaning so that + # revision strings in VCS URLs are properly preserved. + parts = _reserved_chars_re.split(path) + + cleaned_parts = [] + for to_clean, reserved in pairwise(itertools.chain(parts, [""])): + cleaned_parts.append(clean_func(to_clean)) + # Normalize %xx escapes (e.g. %2f -> %2F) + cleaned_parts.append(reserved.upper()) + + return "".join(cleaned_parts) + + +def _ensure_quoted_url(url: str) -> str: + """ + Make sure a link is fully quoted. + For example, if ' ' occurs in the URL, it will be replaced with "%20", + and without double-quoting other characters. + """ + # Split the URL into parts according to the general structure + # `scheme://netloc/path?query#fragment`. + result = urllib.parse.urlsplit(url) + # If the netloc is empty, then the URL refers to a local filesystem path. + is_local_path = not result.netloc + path = _clean_url_path(result.path, is_local_path=is_local_path) + return urllib.parse.urlunsplit(result._replace(path=path)) + + +def _absolute_link_url(base_url: str, url: str) -> str: + """ + A faster implementation of urllib.parse.urljoin with a shortcut + for absolute http/https URLs. + """ + if url.startswith(("https://", "http://")): + return url + else: + return urllib.parse.urljoin(base_url, url) + + +@functools.total_ordering +class Link: + """Represents a parsed link from a Package Index's simple URL""" + + __slots__ = [ + "_parsed_url", + "_url", + "_path", + "_hashes", + "comes_from", + "requires_python", + "yanked_reason", + "metadata_file_data", + "cache_link_parsing", + "egg_fragment", + ] + + def __init__( + self, + url: str, + comes_from: Optional[Union[str, "IndexContent"]] = None, + requires_python: Optional[str] = None, + yanked_reason: Optional[str] = None, + metadata_file_data: Optional[MetadataFile] = None, + cache_link_parsing: bool = True, + hashes: Optional[Mapping[str, str]] = None, + ) -> None: + """ + :param url: url of the resource pointed to (href of the link) + :param comes_from: instance of IndexContent where the link was found, + or string. + :param requires_python: String containing the `Requires-Python` + metadata field, specified in PEP 345. This may be specified by + a data-requires-python attribute in the HTML link tag, as + described in PEP 503. + :param yanked_reason: the reason the file has been yanked, if the + file has been yanked, or None if the file hasn't been yanked. + This is the value of the "data-yanked" attribute, if present, in + a simple repository HTML link. If the file has been yanked but + no reason was provided, this should be the empty string. See + PEP 592 for more information and the specification. + :param metadata_file_data: the metadata attached to the file, or None if + no such metadata is provided. This argument, if not None, indicates + that a separate metadata file exists, and also optionally supplies + hashes for that file. + :param cache_link_parsing: A flag that is used elsewhere to determine + whether resources retrieved from this link should be cached. PyPI + URLs should generally have this set to False, for example. + :param hashes: A mapping of hash names to digests to allow us to + determine the validity of a download. + """ + + # The comes_from, requires_python, and metadata_file_data arguments are + # only used by classmethods of this class, and are not used in client + # code directly. + + # url can be a UNC windows share + if url.startswith("\\\\"): + url = path_to_url(url) + + self._parsed_url = urllib.parse.urlsplit(url) + # Store the url as a private attribute to prevent accidentally + # trying to set a new value. + self._url = url + # The .path property is hot, so calculate its value ahead of time. + self._path = urllib.parse.unquote(self._parsed_url.path) + + link_hash = LinkHash.find_hash_url_fragment(url) + hashes_from_link = {} if link_hash is None else link_hash.as_dict() + if hashes is None: + self._hashes = hashes_from_link + else: + self._hashes = {**hashes, **hashes_from_link} + + self.comes_from = comes_from + self.requires_python = requires_python if requires_python else None + self.yanked_reason = yanked_reason + self.metadata_file_data = metadata_file_data + + self.cache_link_parsing = cache_link_parsing + self.egg_fragment = self._egg_fragment() + + @classmethod + def from_json( + cls, + file_data: Dict[str, Any], + page_url: str, + ) -> Optional["Link"]: + """ + Convert an pypi json document from a simple repository page into a Link. + """ + file_url = file_data.get("url") + if file_url is None: + return None + + url = _ensure_quoted_url(_absolute_link_url(page_url, file_url)) + pyrequire = file_data.get("requires-python") + yanked_reason = file_data.get("yanked") + hashes = file_data.get("hashes", {}) + + # PEP 714: Indexes must use the name core-metadata, but + # clients should support the old name as a fallback for compatibility. + metadata_info = file_data.get("core-metadata") + if metadata_info is None: + metadata_info = file_data.get("dist-info-metadata") + + # The metadata info value may be a boolean, or a dict of hashes. + if isinstance(metadata_info, dict): + # The file exists, and hashes have been supplied + metadata_file_data = MetadataFile(supported_hashes(metadata_info)) + elif metadata_info: + # The file exists, but there are no hashes + metadata_file_data = MetadataFile(None) + else: + # False or not present: the file does not exist + metadata_file_data = None + + # The Link.yanked_reason expects an empty string instead of a boolean. + if yanked_reason and not isinstance(yanked_reason, str): + yanked_reason = "" + # The Link.yanked_reason expects None instead of False. + elif not yanked_reason: + yanked_reason = None + + return cls( + url, + comes_from=page_url, + requires_python=pyrequire, + yanked_reason=yanked_reason, + hashes=hashes, + metadata_file_data=metadata_file_data, + ) + + @classmethod + def from_element( + cls, + anchor_attribs: Dict[str, Optional[str]], + page_url: str, + base_url: str, + ) -> Optional["Link"]: + """ + Convert an anchor element's attributes in a simple repository page to a Link. + """ + href = anchor_attribs.get("href") + if not href: + return None + + url = _ensure_quoted_url(_absolute_link_url(base_url, href)) + pyrequire = anchor_attribs.get("data-requires-python") + yanked_reason = anchor_attribs.get("data-yanked") + + # PEP 714: Indexes must use the name data-core-metadata, but + # clients should support the old name as a fallback for compatibility. + metadata_info = anchor_attribs.get("data-core-metadata") + if metadata_info is None: + metadata_info = anchor_attribs.get("data-dist-info-metadata") + # The metadata info value may be the string "true", or a string of + # the form "hashname=hashval" + if metadata_info == "true": + # The file exists, but there are no hashes + metadata_file_data = MetadataFile(None) + elif metadata_info is None: + # The file does not exist + metadata_file_data = None + else: + # The file exists, and hashes have been supplied + hashname, sep, hashval = metadata_info.partition("=") + if sep == "=": + metadata_file_data = MetadataFile(supported_hashes({hashname: hashval})) + else: + # Error - data is wrong. Treat as no hashes supplied. + logger.debug( + "Index returned invalid data-dist-info-metadata value: %s", + metadata_info, + ) + metadata_file_data = MetadataFile(None) + + return cls( + url, + comes_from=page_url, + requires_python=pyrequire, + yanked_reason=yanked_reason, + metadata_file_data=metadata_file_data, + ) + + def __str__(self) -> str: + if self.requires_python: + rp = f" (requires-python:{self.requires_python})" + else: + rp = "" + if self.comes_from: + return f"{redact_auth_from_url(self._url)} (from {self.comes_from}){rp}" + else: + return redact_auth_from_url(str(self._url)) + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self.url) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Link): + return NotImplemented + return self.url == other.url + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Link): + return NotImplemented + return self.url < other.url + + @property + def url(self) -> str: + return self._url + + @property + def filename(self) -> str: + path = self.path.rstrip("/") + name = posixpath.basename(path) + if not name: + # Make sure we don't leak auth information if the netloc + # includes a username and password. + netloc, user_pass = split_auth_from_netloc(self.netloc) + return netloc + + name = urllib.parse.unquote(name) + assert name, f"URL {self._url!r} produced no filename" + return name + + @property + def file_path(self) -> str: + return url_to_path(self.url) + + @property + def scheme(self) -> str: + return self._parsed_url.scheme + + @property + def netloc(self) -> str: + """ + This can contain auth information. + """ + return self._parsed_url.netloc + + @property + def path(self) -> str: + return self._path + + def splitext(self) -> Tuple[str, str]: + return splitext(posixpath.basename(self.path.rstrip("/"))) + + @property + def ext(self) -> str: + return self.splitext()[1] + + @property + def url_without_fragment(self) -> str: + scheme, netloc, path, query, fragment = self._parsed_url + return urllib.parse.urlunsplit((scheme, netloc, path, query, "")) + + _egg_fragment_re = re.compile(r"[#&]egg=([^&]*)") + + # Per PEP 508. + _project_name_re = re.compile( + r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", re.IGNORECASE + ) + + def _egg_fragment(self) -> Optional[str]: + match = self._egg_fragment_re.search(self._url) + if not match: + return None + + # An egg fragment looks like a PEP 508 project name, along with + # an optional extras specifier. Anything else is invalid. + project_name = match.group(1) + if not self._project_name_re.match(project_name): + deprecated( + reason=f"{self} contains an egg fragment with a non-PEP 508 name.", + replacement="to use the req @ url syntax, and remove the egg fragment", + gone_in="25.1", + issue=13157, + ) + + return project_name + + _subdirectory_fragment_re = re.compile(r"[#&]subdirectory=([^&]*)") + + @property + def subdirectory_fragment(self) -> Optional[str]: + match = self._subdirectory_fragment_re.search(self._url) + if not match: + return None + return match.group(1) + + def metadata_link(self) -> Optional["Link"]: + """Return a link to the associated core metadata file (if any).""" + if self.metadata_file_data is None: + return None + metadata_url = f"{self.url_without_fragment}.metadata" + if self.metadata_file_data.hashes is None: + return Link(metadata_url) + return Link(metadata_url, hashes=self.metadata_file_data.hashes) + + def as_hashes(self) -> Hashes: + return Hashes({k: [v] for k, v in self._hashes.items()}) + + @property + def hash(self) -> Optional[str]: + return next(iter(self._hashes.values()), None) + + @property + def hash_name(self) -> Optional[str]: + return next(iter(self._hashes), None) + + @property + def show_url(self) -> str: + return posixpath.basename(self._url.split("#", 1)[0].split("?", 1)[0]) + + @property + def is_file(self) -> bool: + return self.scheme == "file" + + def is_existing_dir(self) -> bool: + return self.is_file and os.path.isdir(self.file_path) + + @property + def is_wheel(self) -> bool: + return self.ext == WHEEL_EXTENSION + + @property + def is_vcs(self) -> bool: + from pip._internal.vcs import vcs + + return self.scheme in vcs.all_schemes + + @property + def is_yanked(self) -> bool: + return self.yanked_reason is not None + + @property + def has_hash(self) -> bool: + return bool(self._hashes) + + def is_hash_allowed(self, hashes: Optional[Hashes]) -> bool: + """ + Return True if the link has a hash and it is allowed by `hashes`. + """ + if hashes is None: + return False + return any(hashes.is_hash_allowed(k, v) for k, v in self._hashes.items()) + + +class _CleanResult(NamedTuple): + """Convert link for equivalency check. + + This is used in the resolver to check whether two URL-specified requirements + likely point to the same distribution and can be considered equivalent. This + equivalency logic avoids comparing URLs literally, which can be too strict + (e.g. "a=1&b=2" vs "b=2&a=1") and produce conflicts unexpecting to users. + + Currently this does three things: + + 1. Drop the basic auth part. This is technically wrong since a server can + serve different content based on auth, but if it does that, it is even + impossible to guarantee two URLs without auth are equivalent, since + the user can input different auth information when prompted. So the + practical solution is to assume the auth doesn't affect the response. + 2. Parse the query to avoid the ordering issue. Note that ordering under the + same key in the query are NOT cleaned; i.e. "a=1&a=2" and "a=2&a=1" are + still considered different. + 3. Explicitly drop most of the fragment part, except ``subdirectory=`` and + hash values, since it should have no impact the downloaded content. Note + that this drops the "egg=" part historically used to denote the requested + project (and extras), which is wrong in the strictest sense, but too many + people are supplying it inconsistently to cause superfluous resolution + conflicts, so we choose to also ignore them. + """ + + parsed: urllib.parse.SplitResult + query: Dict[str, List[str]] + subdirectory: str + hashes: Dict[str, str] + + +def _clean_link(link: Link) -> _CleanResult: + parsed = link._parsed_url + netloc = parsed.netloc.rsplit("@", 1)[-1] + # According to RFC 8089, an empty host in file: means localhost. + if parsed.scheme == "file" and not netloc: + netloc = "localhost" + fragment = urllib.parse.parse_qs(parsed.fragment) + if "egg" in fragment: + logger.debug("Ignoring egg= fragment in %s", link) + try: + # If there are multiple subdirectory values, use the first one. + # This matches the behavior of Link.subdirectory_fragment. + subdirectory = fragment["subdirectory"][0] + except (IndexError, KeyError): + subdirectory = "" + # If there are multiple hash values under the same algorithm, use the + # first one. This matches the behavior of Link.hash_value. + hashes = {k: fragment[k][0] for k in _SUPPORTED_HASHES if k in fragment} + return _CleanResult( + parsed=parsed._replace(netloc=netloc, query="", fragment=""), + query=urllib.parse.parse_qs(parsed.query), + subdirectory=subdirectory, + hashes=hashes, + ) + + +@functools.lru_cache(maxsize=None) +def links_equivalent(link1: Link, link2: Link) -> bool: + return _clean_link(link1) == _clean_link(link2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/scheme.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..06a9a550e34389c27ad3ee0bcef73d581cd4b448 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/scheme.py @@ -0,0 +1,25 @@ +""" +For types associated with installation schemes. + +For a general overview of available schemes and their context, see +https://docs.python.org/3/install/index.html#alternate-installation. +""" + +from dataclasses import dataclass + +SCHEME_KEYS = ["platlib", "purelib", "headers", "scripts", "data"] + + +@dataclass(frozen=True) +class Scheme: + """A Scheme holds paths which are used as the base directories for + artifacts associated with a Python package. + """ + + __slots__ = SCHEME_KEYS + + platlib: str + purelib: str + headers: str + scripts: str + data: str diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/search_scope.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/search_scope.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7bc86229acda0378707431e5b4e9f054305d85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/search_scope.py @@ -0,0 +1,127 @@ +import itertools +import logging +import os +import posixpath +import urllib.parse +from dataclasses import dataclass +from typing import List + +from pip._vendor.packaging.utils import canonicalize_name + +from pip._internal.models.index import PyPI +from pip._internal.utils.compat import has_tls +from pip._internal.utils.misc import normalize_path, redact_auth_from_url + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SearchScope: + """ + Encapsulates the locations that pip is configured to search. + """ + + __slots__ = ["find_links", "index_urls", "no_index"] + + find_links: List[str] + index_urls: List[str] + no_index: bool + + @classmethod + def create( + cls, + find_links: List[str], + index_urls: List[str], + no_index: bool, + ) -> "SearchScope": + """ + Create a SearchScope object after normalizing the `find_links`. + """ + # Build find_links. If an argument starts with ~, it may be + # a local file relative to a home directory. So try normalizing + # it and if it exists, use the normalized version. + # This is deliberately conservative - it might be fine just to + # blindly normalize anything starting with a ~... + built_find_links: List[str] = [] + for link in find_links: + if link.startswith("~"): + new_link = normalize_path(link) + if os.path.exists(new_link): + link = new_link + built_find_links.append(link) + + # If we don't have TLS enabled, then WARN if anyplace we're looking + # relies on TLS. + if not has_tls(): + for link in itertools.chain(index_urls, built_find_links): + parsed = urllib.parse.urlparse(link) + if parsed.scheme == "https": + logger.warning( + "pip is configured with locations that require " + "TLS/SSL, however the ssl module in Python is not " + "available." + ) + break + + return cls( + find_links=built_find_links, + index_urls=index_urls, + no_index=no_index, + ) + + def get_formatted_locations(self) -> str: + lines = [] + redacted_index_urls = [] + if self.index_urls and self.index_urls != [PyPI.simple_url]: + for url in self.index_urls: + redacted_index_url = redact_auth_from_url(url) + + # Parse the URL + purl = urllib.parse.urlsplit(redacted_index_url) + + # URL is generally invalid if scheme and netloc is missing + # there are issues with Python and URL parsing, so this test + # is a bit crude. See bpo-20271, bpo-23505. Python doesn't + # always parse invalid URLs correctly - it should raise + # exceptions for malformed URLs + if not purl.scheme and not purl.netloc: + logger.warning( + 'The index url "%s" seems invalid, please provide a scheme.', + redacted_index_url, + ) + + redacted_index_urls.append(redacted_index_url) + + lines.append( + "Looking in indexes: {}".format(", ".join(redacted_index_urls)) + ) + + if self.find_links: + lines.append( + "Looking in links: {}".format( + ", ".join(redact_auth_from_url(url) for url in self.find_links) + ) + ) + return "\n".join(lines) + + def get_index_urls_locations(self, project_name: str) -> List[str]: + """Returns the locations found via self.index_urls + + Checks the url_name on the main (first in the list) index and + use this url_name to produce all locations + """ + + def mkurl_pypi_url(url: str) -> str: + loc = posixpath.join( + url, urllib.parse.quote(canonicalize_name(project_name)) + ) + # For maximum compatibility with easy_install, ensure the path + # ends in a trailing slash. Although this isn't in the spec + # (and PyPI can handle it without the slash) some other index + # implementations might break if they relied on easy_install's + # behavior. + if not loc.endswith("/"): + loc = loc + "/" + return loc + + return [mkurl_pypi_url(url) for url in self.index_urls] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/selection_prefs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/selection_prefs.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b50aa51756719d751ed0338aa7ca0a33d45f5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/selection_prefs.py @@ -0,0 +1,53 @@ +from typing import Optional + +from pip._internal.models.format_control import FormatControl + + +# TODO: This needs Python 3.10's improved slots support for dataclasses +# to be converted into a dataclass. +class SelectionPreferences: + """ + Encapsulates the candidate selection preferences for downloading + and installing files. + """ + + __slots__ = [ + "allow_yanked", + "allow_all_prereleases", + "format_control", + "prefer_binary", + "ignore_requires_python", + ] + + # Don't include an allow_yanked default value to make sure each call + # site considers whether yanked releases are allowed. This also causes + # that decision to be made explicit in the calling code, which helps + # people when reading the code. + def __init__( + self, + allow_yanked: bool, + allow_all_prereleases: bool = False, + format_control: Optional[FormatControl] = None, + prefer_binary: bool = False, + ignore_requires_python: Optional[bool] = None, + ) -> None: + """Create a SelectionPreferences object. + + :param allow_yanked: Whether files marked as yanked (in the sense + of PEP 592) are permitted to be candidates for install. + :param format_control: A FormatControl object or None. Used to control + the selection of source packages / binary packages when consulting + the index and links. + :param prefer_binary: Whether to prefer an old, but valid, binary + dist over a new source dist. + :param ignore_requires_python: Whether to ignore incompatible + "Requires-Python" values in links. Defaults to False. + """ + if ignore_requires_python is None: + ignore_requires_python = False + + self.allow_yanked = allow_yanked + self.allow_all_prereleases = allow_all_prereleases + self.format_control = format_control + self.prefer_binary = prefer_binary + self.ignore_requires_python = ignore_requires_python diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/target_python.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/target_python.py new file mode 100644 index 0000000000000000000000000000000000000000..88925a9fd01a440e6de970bc234c3503b7f09cc1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/target_python.py @@ -0,0 +1,121 @@ +import sys +from typing import List, Optional, Set, Tuple + +from pip._vendor.packaging.tags import Tag + +from pip._internal.utils.compatibility_tags import get_supported, version_info_to_nodot +from pip._internal.utils.misc import normalize_version_info + + +class TargetPython: + """ + Encapsulates the properties of a Python interpreter one is targeting + for a package install, download, etc. + """ + + __slots__ = [ + "_given_py_version_info", + "abis", + "implementation", + "platforms", + "py_version", + "py_version_info", + "_valid_tags", + "_valid_tags_set", + ] + + def __init__( + self, + platforms: Optional[List[str]] = None, + py_version_info: Optional[Tuple[int, ...]] = None, + abis: Optional[List[str]] = None, + implementation: Optional[str] = None, + ) -> None: + """ + :param platforms: A list of strings or None. If None, searches for + packages that are supported by the current system. Otherwise, will + find packages that can be built on the platforms passed in. These + packages will only be downloaded for distribution: they will + not be built locally. + :param py_version_info: An optional tuple of ints representing the + Python version information to use (e.g. `sys.version_info[:3]`). + This can have length 1, 2, or 3 when provided. + :param abis: A list of strings or None. This is passed to + compatibility_tags.py's get_supported() function as is. + :param implementation: A string or None. This is passed to + compatibility_tags.py's get_supported() function as is. + """ + # Store the given py_version_info for when we call get_supported(). + self._given_py_version_info = py_version_info + + if py_version_info is None: + py_version_info = sys.version_info[:3] + else: + py_version_info = normalize_version_info(py_version_info) + + py_version = ".".join(map(str, py_version_info[:2])) + + self.abis = abis + self.implementation = implementation + self.platforms = platforms + self.py_version = py_version + self.py_version_info = py_version_info + + # This is used to cache the return value of get_(un)sorted_tags. + self._valid_tags: Optional[List[Tag]] = None + self._valid_tags_set: Optional[Set[Tag]] = None + + def format_given(self) -> str: + """ + Format the given, non-None attributes for display. + """ + display_version = None + if self._given_py_version_info is not None: + display_version = ".".join( + str(part) for part in self._given_py_version_info + ) + + key_values = [ + ("platforms", self.platforms), + ("version_info", display_version), + ("abis", self.abis), + ("implementation", self.implementation), + ] + return " ".join( + f"{key}={value!r}" for key, value in key_values if value is not None + ) + + def get_sorted_tags(self) -> List[Tag]: + """ + Return the supported PEP 425 tags to check wheel candidates against. + + The tags are returned in order of preference (most preferred first). + """ + if self._valid_tags is None: + # Pass versions=None if no py_version_info was given since + # versions=None uses special default logic. + py_version_info = self._given_py_version_info + if py_version_info is None: + version = None + else: + version = version_info_to_nodot(py_version_info) + + tags = get_supported( + version=version, + platforms=self.platforms, + abis=self.abis, + impl=self.implementation, + ) + self._valid_tags = tags + + return self._valid_tags + + def get_unsorted_tags(self) -> Set[Tag]: + """Exactly the same as get_sorted_tags, but returns a set. + + This is important for performance. + """ + if self._valid_tags_set is None: + self._valid_tags_set = set(self.get_sorted_tags()) + + return self._valid_tags_set diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/wheel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8560089d3df41689f41fe2639aa9f61dd1eace --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/models/wheel.py @@ -0,0 +1,118 @@ +"""Represents a wheel file and provides access to the various parts of the +name that have meaning. +""" + +import re +from typing import Dict, Iterable, List + +from pip._vendor.packaging.tags import Tag +from pip._vendor.packaging.utils import ( + InvalidWheelFilename as PackagingInvalidWheelName, +) +from pip._vendor.packaging.utils import parse_wheel_filename + +from pip._internal.exceptions import InvalidWheelFilename +from pip._internal.utils.deprecation import deprecated + + +class Wheel: + """A wheel file""" + + wheel_file_re = re.compile( + r"""^(?P(?P[^\s-]+?)-(?P[^\s-]*?)) + ((-(?P\d[^-]*?))?-(?P[^\s-]+?)-(?P[^\s-]+?)-(?P[^\s-]+?) + \.whl|\.dist-info)$""", + re.VERBOSE, + ) + + def __init__(self, filename: str) -> None: + """ + :raises InvalidWheelFilename: when the filename is invalid for a wheel + """ + wheel_info = self.wheel_file_re.match(filename) + if not wheel_info: + raise InvalidWheelFilename(f"{filename} is not a valid wheel filename.") + self.filename = filename + self.name = wheel_info.group("name").replace("_", "-") + _version = wheel_info.group("ver") + if "_" in _version: + try: + parse_wheel_filename(filename) + except PackagingInvalidWheelName as e: + deprecated( + reason=( + f"Wheel filename {filename!r} is not correctly normalised. " + "Future versions of pip will raise the following error:\n" + f"{e.args[0]}\n\n" + ), + replacement=( + "to rename the wheel to use a correctly normalised " + "name (this may require updating the version in " + "the project metadata)" + ), + gone_in="25.1", + issue=12938, + ) + + _version = _version.replace("_", "-") + + self.version = _version + self.build_tag = wheel_info.group("build") + self.pyversions = wheel_info.group("pyver").split(".") + self.abis = wheel_info.group("abi").split(".") + self.plats = wheel_info.group("plat").split(".") + + # All the tag combinations from this file + self.file_tags = { + Tag(x, y, z) for x in self.pyversions for y in self.abis for z in self.plats + } + + def get_formatted_file_tags(self) -> List[str]: + """Return the wheel's tags as a sorted list of strings.""" + return sorted(str(tag) for tag in self.file_tags) + + def support_index_min(self, tags: List[Tag]) -> int: + """Return the lowest index that one of the wheel's file_tag combinations + achieves in the given list of supported tags. + + For example, if there are 8 supported tags and one of the file tags + is first in the list, then return 0. + + :param tags: the PEP 425 tags to check the wheel against, in order + with most preferred first. + + :raises ValueError: If none of the wheel's file tags match one of + the supported tags. + """ + try: + return next(i for i, t in enumerate(tags) if t in self.file_tags) + except StopIteration: + raise ValueError() + + def find_most_preferred_tag( + self, tags: List[Tag], tag_to_priority: Dict[Tag, int] + ) -> int: + """Return the priority of the most preferred tag that one of the wheel's file + tag combinations achieves in the given list of supported tags using the given + tag_to_priority mapping, where lower priorities are more-preferred. + + This is used in place of support_index_min in some cases in order to avoid + an expensive linear scan of a large list of tags. + + :param tags: the PEP 425 tags to check the wheel against. + :param tag_to_priority: a mapping from tag to priority of that tag, where + lower is more preferred. + + :raises ValueError: If none of the wheel's file tags match one of + the supported tags. + """ + return min( + tag_to_priority[tag] for tag in self.file_tags if tag in tag_to_priority + ) + + def supported(self, tags: Iterable[Tag]) -> bool: + """Return whether the wheel is compatible with one of the given tags. + + :param tags: the PEP 425 tags to check the wheel against. + """ + return not self.file_tags.isdisjoint(tags) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/pyproject.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/pyproject.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8452f39dca0ab98888ab65c4ff34e4ac94f1c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/pyproject.py @@ -0,0 +1,185 @@ +import importlib.util +import os +import sys +from collections import namedtuple +from typing import Any, List, Optional + +if sys.version_info >= (3, 11): + import tomllib +else: + from pip._vendor import tomli as tomllib + +from pip._vendor.packaging.requirements import InvalidRequirement + +from pip._internal.exceptions import ( + InstallationError, + InvalidPyProjectBuildRequires, + MissingPyProjectBuildRequires, +) +from pip._internal.utils.packaging import get_requirement + + +def _is_list_of_str(obj: Any) -> bool: + return isinstance(obj, list) and all(isinstance(item, str) for item in obj) + + +def make_pyproject_path(unpacked_source_directory: str) -> str: + return os.path.join(unpacked_source_directory, "pyproject.toml") + + +BuildSystemDetails = namedtuple( + "BuildSystemDetails", ["requires", "backend", "check", "backend_path"] +) + + +def load_pyproject_toml( + use_pep517: Optional[bool], pyproject_toml: str, setup_py: str, req_name: str +) -> Optional[BuildSystemDetails]: + """Load the pyproject.toml file. + + Parameters: + use_pep517 - Has the user requested PEP 517 processing? None + means the user hasn't explicitly specified. + pyproject_toml - Location of the project's pyproject.toml file + setup_py - Location of the project's setup.py file + req_name - The name of the requirement we're processing (for + error reporting) + + Returns: + None if we should use the legacy code path, otherwise a tuple + ( + requirements from pyproject.toml, + name of PEP 517 backend, + requirements we should check are installed after setting + up the build environment + directory paths to import the backend from (backend-path), + relative to the project root. + ) + """ + has_pyproject = os.path.isfile(pyproject_toml) + has_setup = os.path.isfile(setup_py) + + if not has_pyproject and not has_setup: + raise InstallationError( + f"{req_name} does not appear to be a Python project: " + f"neither 'setup.py' nor 'pyproject.toml' found." + ) + + if has_pyproject: + with open(pyproject_toml, encoding="utf-8") as f: + pp_toml = tomllib.loads(f.read()) + build_system = pp_toml.get("build-system") + else: + build_system = None + + # The following cases must use PEP 517 + # We check for use_pep517 being non-None and falsy because that means + # the user explicitly requested --no-use-pep517. The value 0 as + # opposed to False can occur when the value is provided via an + # environment variable or config file option (due to the quirk of + # strtobool() returning an integer in pip's configuration code). + if has_pyproject and not has_setup: + if use_pep517 is not None and not use_pep517: + raise InstallationError( + "Disabling PEP 517 processing is invalid: " + "project does not have a setup.py" + ) + use_pep517 = True + elif build_system and "build-backend" in build_system: + if use_pep517 is not None and not use_pep517: + raise InstallationError( + "Disabling PEP 517 processing is invalid: " + "project specifies a build backend of {} " + "in pyproject.toml".format(build_system["build-backend"]) + ) + use_pep517 = True + + # If we haven't worked out whether to use PEP 517 yet, + # and the user hasn't explicitly stated a preference, + # we do so if the project has a pyproject.toml file + # or if we cannot import setuptools or wheels. + + # We fallback to PEP 517 when without setuptools or without the wheel package, + # so setuptools can be installed as a default build backend. + # For more info see: + # https://discuss.python.org/t/pip-without-setuptools-could-the-experience-be-improved/11810/9 + # https://github.com/pypa/pip/issues/8559 + elif use_pep517 is None: + use_pep517 = ( + has_pyproject + or not importlib.util.find_spec("setuptools") + or not importlib.util.find_spec("wheel") + ) + + # At this point, we know whether we're going to use PEP 517. + assert use_pep517 is not None + + # If we're using the legacy code path, there is nothing further + # for us to do here. + if not use_pep517: + return None + + if build_system is None: + # Either the user has a pyproject.toml with no build-system + # section, or the user has no pyproject.toml, but has opted in + # explicitly via --use-pep517. + # In the absence of any explicit backend specification, we + # assume the setuptools backend that most closely emulates the + # traditional direct setup.py execution, and require wheel and + # a version of setuptools that supports that backend. + + build_system = { + "requires": ["setuptools>=40.8.0"], + "build-backend": "setuptools.build_meta:__legacy__", + } + + # If we're using PEP 517, we have build system information (either + # from pyproject.toml, or defaulted by the code above). + # Note that at this point, we do not know if the user has actually + # specified a backend, though. + assert build_system is not None + + # Ensure that the build-system section in pyproject.toml conforms + # to PEP 518. + + # Specifying the build-system table but not the requires key is invalid + if "requires" not in build_system: + raise MissingPyProjectBuildRequires(package=req_name) + + # Error out if requires is not a list of strings + requires = build_system["requires"] + if not _is_list_of_str(requires): + raise InvalidPyProjectBuildRequires( + package=req_name, + reason="It is not a list of strings.", + ) + + # Each requirement must be valid as per PEP 508 + for requirement in requires: + try: + get_requirement(requirement) + except InvalidRequirement as error: + raise InvalidPyProjectBuildRequires( + package=req_name, + reason=f"It contains an invalid requirement: {requirement!r}", + ) from error + + backend = build_system.get("build-backend") + backend_path = build_system.get("backend-path", []) + check: List[str] = [] + if backend is None: + # If the user didn't specify a backend, we assume they want to use + # the setuptools backend. But we can't be sure they have included + # a version of setuptools which supplies the backend. So we + # make a note to check that this requirement is present once + # we have set up the environment. + # This is quite a lot of work to check for a very specific case. But + # the problem is, that case is potentially quite common - projects that + # adopted PEP 518 early for the ability to specify requirements to + # execute setup.py, but never considered needing to mention the build + # tools themselves. The original PEP 518 code had a similar check (but + # implemented in a different way). + backend = "setuptools.build_meta:__legacy__" + check = ["setuptools>=40.8.0"] + + return BuildSystemDetails(requires, backend, check, backend_path) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/resolution/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/resolution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/resolution/base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/resolution/base.py new file mode 100644 index 0000000000000000000000000000000000000000..42dade18c1ec2b825f756dad4aaa89f2d9e6ce21 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/resolution/base.py @@ -0,0 +1,20 @@ +from typing import Callable, List, Optional + +from pip._internal.req.req_install import InstallRequirement +from pip._internal.req.req_set import RequirementSet + +InstallRequirementProvider = Callable[ + [str, Optional[InstallRequirement]], InstallRequirement +] + + +class BaseResolver: + def resolve( + self, root_reqs: List[InstallRequirement], check_supported_wheels: bool + ) -> RequirementSet: + raise NotImplementedError() + + def get_installation_order( + self, req_set: RequirementSet + ) -> List[InstallRequirement]: + raise NotImplementedError() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/self_outdated_check.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/self_outdated_check.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0e3df3542b54701a0c7117608511ad0db42848 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/self_outdated_check.py @@ -0,0 +1,252 @@ +import datetime +import functools +import hashlib +import json +import logging +import optparse +import os.path +import sys +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from pip._vendor.packaging.version import Version +from pip._vendor.packaging.version import parse as parse_version +from pip._vendor.rich.console import Group +from pip._vendor.rich.markup import escape +from pip._vendor.rich.text import Text + +from pip._internal.index.collector import LinkCollector +from pip._internal.index.package_finder import PackageFinder +from pip._internal.metadata import get_default_environment +from pip._internal.models.selection_prefs import SelectionPreferences +from pip._internal.network.session import PipSession +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.entrypoints import ( + get_best_invocation_for_this_pip, + get_best_invocation_for_this_python, +) +from pip._internal.utils.filesystem import adjacent_tmp_file, check_path_owner, replace +from pip._internal.utils.misc import ( + ExternallyManagedEnvironment, + check_externally_managed, + ensure_dir, +) + +_WEEK = datetime.timedelta(days=7) + +logger = logging.getLogger(__name__) + + +def _get_statefile_name(key: str) -> str: + key_bytes = key.encode() + name = hashlib.sha224(key_bytes).hexdigest() + return name + + +def _convert_date(isodate: str) -> datetime.datetime: + """Convert an ISO format string to a date. + + Handles the format 2020-01-22T14:24:01Z (trailing Z) + which is not supported by older versions of fromisoformat. + """ + return datetime.datetime.fromisoformat(isodate.replace("Z", "+00:00")) + + +class SelfCheckState: + def __init__(self, cache_dir: str) -> None: + self._state: Dict[str, Any] = {} + self._statefile_path = None + + # Try to load the existing state + if cache_dir: + self._statefile_path = os.path.join( + cache_dir, "selfcheck", _get_statefile_name(self.key) + ) + try: + with open(self._statefile_path, encoding="utf-8") as statefile: + self._state = json.load(statefile) + except (OSError, ValueError, KeyError): + # Explicitly suppressing exceptions, since we don't want to + # error out if the cache file is invalid. + pass + + @property + def key(self) -> str: + return sys.prefix + + def get(self, current_time: datetime.datetime) -> Optional[str]: + """Check if we have a not-outdated version loaded already.""" + if not self._state: + return None + + if "last_check" not in self._state: + return None + + if "pypi_version" not in self._state: + return None + + # Determine if we need to refresh the state + last_check = _convert_date(self._state["last_check"]) + time_since_last_check = current_time - last_check + if time_since_last_check > _WEEK: + return None + + return self._state["pypi_version"] + + def set(self, pypi_version: str, current_time: datetime.datetime) -> None: + # If we do not have a path to cache in, don't bother saving. + if not self._statefile_path: + return + + # Check to make sure that we own the directory + if not check_path_owner(os.path.dirname(self._statefile_path)): + return + + # Now that we've ensured the directory is owned by this user, we'll go + # ahead and make sure that all our directories are created. + ensure_dir(os.path.dirname(self._statefile_path)) + + state = { + # Include the key so it's easy to tell which pip wrote the + # file. + "key": self.key, + "last_check": current_time.isoformat(), + "pypi_version": pypi_version, + } + + text = json.dumps(state, sort_keys=True, separators=(",", ":")) + + with adjacent_tmp_file(self._statefile_path) as f: + f.write(text.encode()) + + try: + # Since we have a prefix-specific state file, we can just + # overwrite whatever is there, no need to check. + replace(f.name, self._statefile_path) + except OSError: + # Best effort. + pass + + +@dataclass +class UpgradePrompt: + old: str + new: str + + def __rich__(self) -> Group: + if WINDOWS: + pip_cmd = f"{get_best_invocation_for_this_python()} -m pip" + else: + pip_cmd = get_best_invocation_for_this_pip() + + notice = "[bold][[reset][blue]notice[reset][bold]][reset]" + return Group( + Text(), + Text.from_markup( + f"{notice} A new release of pip is available: " + f"[red]{self.old}[reset] -> [green]{self.new}[reset]" + ), + Text.from_markup( + f"{notice} To update, run: " + f"[green]{escape(pip_cmd)} install --upgrade pip" + ), + ) + + +def was_installed_by_pip(pkg: str) -> bool: + """Checks whether pkg was installed by pip + + This is used not to display the upgrade message when pip is in fact + installed by system package manager, such as dnf on Fedora. + """ + dist = get_default_environment().get_distribution(pkg) + return dist is not None and "pip" == dist.installer + + +def _get_current_remote_pip_version( + session: PipSession, options: optparse.Values +) -> Optional[str]: + # Lets use PackageFinder to see what the latest pip version is + link_collector = LinkCollector.create( + session, + options=options, + suppress_no_index=True, + ) + + # Pass allow_yanked=False so we don't suggest upgrading to a + # yanked version. + selection_prefs = SelectionPreferences( + allow_yanked=False, + allow_all_prereleases=False, # Explicitly set to False + ) + + finder = PackageFinder.create( + link_collector=link_collector, + selection_prefs=selection_prefs, + ) + best_candidate = finder.find_best_candidate("pip").best_candidate + if best_candidate is None: + return None + + return str(best_candidate.version) + + +def _self_version_check_logic( + *, + state: SelfCheckState, + current_time: datetime.datetime, + local_version: Version, + get_remote_version: Callable[[], Optional[str]], +) -> Optional[UpgradePrompt]: + remote_version_str = state.get(current_time) + if remote_version_str is None: + remote_version_str = get_remote_version() + if remote_version_str is None: + logger.debug("No remote pip version found") + return None + state.set(remote_version_str, current_time) + + remote_version = parse_version(remote_version_str) + logger.debug("Remote version of pip: %s", remote_version) + logger.debug("Local version of pip: %s", local_version) + + pip_installed_by_pip = was_installed_by_pip("pip") + logger.debug("Was pip installed by pip? %s", pip_installed_by_pip) + if not pip_installed_by_pip: + return None # Only suggest upgrade if pip is installed by pip. + + local_version_is_older = ( + local_version < remote_version + and local_version.base_version != remote_version.base_version + ) + if local_version_is_older: + return UpgradePrompt(old=str(local_version), new=remote_version_str) + + return None + + +def pip_self_version_check(session: PipSession, options: optparse.Values) -> None: + """Check for an update for pip. + + Limit the frequency of checks to once per week. State is stored either in + the active virtualenv or in the user's USER_CACHE_DIR keyed off the prefix + of the pip script path. + """ + installed_dist = get_default_environment().get_distribution("pip") + if not installed_dist: + return + try: + check_externally_managed() + except ExternallyManagedEnvironment: + return + + upgrade_prompt = _self_version_check_logic( + state=SelfCheckState(cache_dir=options.cache_dir), + current_time=datetime.datetime.now(datetime.timezone.utc), + local_version=installed_dist.version, + get_remote_version=functools.partial( + _get_current_remote_pip_version, session, options + ), + ) + if upgrade_prompt is not None: + logger.warning("%s", upgrade_prompt, extra={"rich": True}) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_jaraco_text.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_jaraco_text.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccf53b7ac5d415b8526e75ccabe31cf994ac7da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_jaraco_text.py @@ -0,0 +1,109 @@ +"""Functions brought over from jaraco.text. + +These functions are not supposed to be used within `pip._internal`. These are +helper functions brought over from `jaraco.text` to enable vendoring newer +copies of `pkg_resources` without having to vendor `jaraco.text` and its entire +dependency cone; something that our vendoring setup is not currently capable of +handling. + +License reproduced from original source below: + +Copyright Jason R. Coombs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. +""" + +import functools +import itertools + + +def _nonblank(str): + return str and not str.startswith("#") + + +@functools.singledispatch +def yield_lines(iterable): + r""" + Yield valid lines of a string or iterable. + + >>> list(yield_lines('')) + [] + >>> list(yield_lines(['foo', 'bar'])) + ['foo', 'bar'] + >>> list(yield_lines('foo\nbar')) + ['foo', 'bar'] + >>> list(yield_lines('\nfoo\n#bar\nbaz #comment')) + ['foo', 'baz #comment'] + >>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n'])) + ['foo', 'bar', 'baz', 'bing'] + """ + return itertools.chain.from_iterable(map(yield_lines, iterable)) + + +@yield_lines.register(str) +def _(text): + return filter(_nonblank, map(str.strip, text.splitlines())) + + +def drop_comment(line): + """ + Drop comments. + + >>> drop_comment('foo # bar') + 'foo' + + A hash without a space may be in a URL. + + >>> drop_comment('http://example.com/foo#bar') + 'http://example.com/foo#bar' + """ + return line.partition(" #")[0] + + +def join_continuation(lines): + r""" + Join lines continued by a trailing backslash. + + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar \\', 'baz'])) + ['foobarbaz'] + + Not sure why, but... + The character preceding the backslash is also elided. + + >>> list(join_continuation(['goo\\', 'dly'])) + ['godly'] + + A terrible idea, but... + If no line is available to continue, suppress the lines. + + >>> list(join_continuation(['foo', 'bar\\', 'baz\\'])) + ['foo'] + """ + lines = iter(lines) + for item in lines: + while item.endswith("\\"): + try: + item = item[:-2].strip() + next(lines) + except StopIteration: + return + yield item diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_log.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_log.py new file mode 100644 index 0000000000000000000000000000000000000000..92c4c6a193873ce09629f6cfaa2dabc4f14ecb03 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/_log.py @@ -0,0 +1,38 @@ +"""Customize logging + +Defines custom logger class for the `logger.verbose(...)` method. + +init_logging() must be called before any other modules that call logging.getLogger. +""" + +import logging +from typing import Any, cast + +# custom log level for `--verbose` output +# between DEBUG and INFO +VERBOSE = 15 + + +class VerboseLogger(logging.Logger): + """Custom Logger, defining a verbose log-level + + VERBOSE is between INFO and DEBUG. + """ + + def verbose(self, msg: str, *args: Any, **kwargs: Any) -> None: + return self.log(VERBOSE, msg, *args, **kwargs) + + +def getLogger(name: str) -> VerboseLogger: + """logging.getLogger, but ensures our VerboseLogger class is returned""" + return cast(VerboseLogger, logging.getLogger(name)) + + +def init_logging() -> None: + """Register our VerboseLogger and VERBOSE log level. + + Should be called before any calls to getLogger(), + i.e. in pip._internal.__init__ + """ + logging.setLoggerClass(VerboseLogger) + logging.addLevelName(VERBOSE, "VERBOSE") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/appdirs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/appdirs.py new file mode 100644 index 0000000000000000000000000000000000000000..16933bf8afedcbe3e9d4fcc04e5f7246228c56fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/appdirs.py @@ -0,0 +1,52 @@ +""" +This code wraps the vendored appdirs module to so the return values are +compatible for the current pip code base. + +The intention is to rewrite current usages gradually, keeping the tests pass, +and eventually drop this after all usages are changed. +""" + +import os +import sys +from typing import List + +from pip._vendor import platformdirs as _appdirs + + +def user_cache_dir(appname: str) -> str: + return _appdirs.user_cache_dir(appname, appauthor=False) + + +def _macos_user_config_dir(appname: str, roaming: bool = True) -> str: + # Use ~/Application Support/pip, if the directory exists. + path = _appdirs.user_data_dir(appname, appauthor=False, roaming=roaming) + if os.path.isdir(path): + return path + + # Use a Linux-like ~/.config/pip, by default. + linux_like_path = "~/.config/" + if appname: + linux_like_path = os.path.join(linux_like_path, appname) + + return os.path.expanduser(linux_like_path) + + +def user_config_dir(appname: str, roaming: bool = True) -> str: + if sys.platform == "darwin": + return _macos_user_config_dir(appname, roaming) + + return _appdirs.user_config_dir(appname, appauthor=False, roaming=roaming) + + +# for the discussion regarding site_config_dir locations +# see +def site_config_dirs(appname: str) -> List[str]: + if sys.platform == "darwin": + return [_appdirs.site_data_dir(appname, appauthor=False, multipath=True)] + + dirval = _appdirs.site_config_dir(appname, appauthor=False, multipath=True) + if sys.platform == "win32": + return [dirval] + + # Unix-y system. Look in /etc as well. + return dirval.split(os.pathsep) + ["/etc"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b54e4ee51d03a7beca065971967b9c70cc3526 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compat.py @@ -0,0 +1,79 @@ +"""Stuff that differs in different Python versions and platform +distributions.""" + +import importlib.resources +import logging +import os +import sys +from typing import IO + +__all__ = ["get_path_uid", "stdlib_pkgs", "WINDOWS"] + + +logger = logging.getLogger(__name__) + + +def has_tls() -> bool: + try: + import _ssl # noqa: F401 # ignore unused + + return True + except ImportError: + pass + + from pip._vendor.urllib3.util import IS_PYOPENSSL + + return IS_PYOPENSSL + + +def get_path_uid(path: str) -> int: + """ + Return path's uid. + + Does not follow symlinks: + https://github.com/pypa/pip/pull/935#discussion_r5307003 + + Placed this function in compat due to differences on AIX and + Jython, that should eventually go away. + + :raises OSError: When path is a symlink or can't be read. + """ + if hasattr(os, "O_NOFOLLOW"): + fd = os.open(path, os.O_RDONLY | os.O_NOFOLLOW) + file_uid = os.fstat(fd).st_uid + os.close(fd) + else: # AIX and Jython + # WARNING: time of check vulnerability, but best we can do w/o NOFOLLOW + if not os.path.islink(path): + # older versions of Jython don't have `os.fstat` + file_uid = os.stat(path).st_uid + else: + # raise OSError for parity with os.O_NOFOLLOW above + raise OSError(f"{path} is a symlink; Will not return uid for symlinks") + return file_uid + + +# The importlib.resources.open_text function was deprecated in 3.11 with suggested +# replacement we use below. +if sys.version_info < (3, 11): + open_text_resource = importlib.resources.open_text +else: + + def open_text_resource( + package: str, resource: str, encoding: str = "utf-8", errors: str = "strict" + ) -> IO[str]: + return (importlib.resources.files(package) / resource).open( + "r", encoding=encoding, errors=errors + ) + + +# packages in the stdlib that may have installation metadata, but should not be +# considered 'installed'. this theoretically could be determined based on +# dist.location (py27:`sysconfig.get_paths()['stdlib']`, +# py26:sysconfig.get_config_vars('LIBDEST')), but fear platform variation may +# make this ineffective, so hard-coding +stdlib_pkgs = {"python", "wsgiref", "argparse"} + + +# windows detection, covers cpython and ironpython +WINDOWS = sys.platform.startswith("win") or (sys.platform == "cli" and os.name == "nt") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compatibility_tags.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compatibility_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7b7450dcea5b3bbcfe118f2e4cbe3fc16a7b1a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/compatibility_tags.py @@ -0,0 +1,188 @@ +"""Generate and work with PEP 425 Compatibility Tags. +""" + +import re +from typing import List, Optional, Tuple + +from pip._vendor.packaging.tags import ( + PythonVersion, + Tag, + compatible_tags, + cpython_tags, + generic_tags, + interpreter_name, + interpreter_version, + ios_platforms, + mac_platforms, +) + +_apple_arch_pat = re.compile(r"(.+)_(\d+)_(\d+)_(.+)") + + +def version_info_to_nodot(version_info: Tuple[int, ...]) -> str: + # Only use up to the first two numbers. + return "".join(map(str, version_info[:2])) + + +def _mac_platforms(arch: str) -> List[str]: + match = _apple_arch_pat.match(arch) + if match: + name, major, minor, actual_arch = match.groups() + mac_version = (int(major), int(minor)) + arches = [ + # Since we have always only checked that the platform starts + # with "macosx", for backwards-compatibility we extract the + # actual prefix provided by the user in case they provided + # something like "macosxcustom_". It may be good to remove + # this as undocumented or deprecate it in the future. + "{}_{}".format(name, arch[len("macosx_") :]) + for arch in mac_platforms(mac_version, actual_arch) + ] + else: + # arch pattern didn't match (?!) + arches = [arch] + return arches + + +def _ios_platforms(arch: str) -> List[str]: + match = _apple_arch_pat.match(arch) + if match: + name, major, minor, actual_multiarch = match.groups() + ios_version = (int(major), int(minor)) + arches = [ + # Since we have always only checked that the platform starts + # with "ios", for backwards-compatibility we extract the + # actual prefix provided by the user in case they provided + # something like "ioscustom_". It may be good to remove + # this as undocumented or deprecate it in the future. + "{}_{}".format(name, arch[len("ios_") :]) + for arch in ios_platforms(ios_version, actual_multiarch) + ] + else: + # arch pattern didn't match (?!) + arches = [arch] + return arches + + +def _custom_manylinux_platforms(arch: str) -> List[str]: + arches = [arch] + arch_prefix, arch_sep, arch_suffix = arch.partition("_") + if arch_prefix == "manylinux2014": + # manylinux1/manylinux2010 wheels run on most manylinux2014 systems + # with the exception of wheels depending on ncurses. PEP 599 states + # manylinux1/manylinux2010 wheels should be considered + # manylinux2014 wheels: + # https://www.python.org/dev/peps/pep-0599/#backwards-compatibility-with-manylinux2010-wheels + if arch_suffix in {"i686", "x86_64"}: + arches.append("manylinux2010" + arch_sep + arch_suffix) + arches.append("manylinux1" + arch_sep + arch_suffix) + elif arch_prefix == "manylinux2010": + # manylinux1 wheels run on most manylinux2010 systems with the + # exception of wheels depending on ncurses. PEP 571 states + # manylinux1 wheels should be considered manylinux2010 wheels: + # https://www.python.org/dev/peps/pep-0571/#backwards-compatibility-with-manylinux1-wheels + arches.append("manylinux1" + arch_sep + arch_suffix) + return arches + + +def _get_custom_platforms(arch: str) -> List[str]: + arch_prefix, arch_sep, arch_suffix = arch.partition("_") + if arch.startswith("macosx"): + arches = _mac_platforms(arch) + elif arch.startswith("ios"): + arches = _ios_platforms(arch) + elif arch_prefix in ["manylinux2014", "manylinux2010"]: + arches = _custom_manylinux_platforms(arch) + else: + arches = [arch] + return arches + + +def _expand_allowed_platforms(platforms: Optional[List[str]]) -> Optional[List[str]]: + if not platforms: + return None + + seen = set() + result = [] + + for p in platforms: + if p in seen: + continue + additions = [c for c in _get_custom_platforms(p) if c not in seen] + seen.update(additions) + result.extend(additions) + + return result + + +def _get_python_version(version: str) -> PythonVersion: + if len(version) > 1: + return int(version[0]), int(version[1:]) + else: + return (int(version[0]),) + + +def _get_custom_interpreter( + implementation: Optional[str] = None, version: Optional[str] = None +) -> str: + if implementation is None: + implementation = interpreter_name() + if version is None: + version = interpreter_version() + return f"{implementation}{version}" + + +def get_supported( + version: Optional[str] = None, + platforms: Optional[List[str]] = None, + impl: Optional[str] = None, + abis: Optional[List[str]] = None, +) -> List[Tag]: + """Return a list of supported tags for each version specified in + `versions`. + + :param version: a string version, of the form "33" or "32", + or None. The version will be assumed to support our ABI. + :param platform: specify a list of platforms you want valid + tags for, or None. If None, use the local system platform. + :param impl: specify the exact implementation you want valid + tags for, or None. If None, use the local interpreter impl. + :param abis: specify a list of abis you want valid + tags for, or None. If None, use the local interpreter abi. + """ + supported: List[Tag] = [] + + python_version: Optional[PythonVersion] = None + if version is not None: + python_version = _get_python_version(version) + + interpreter = _get_custom_interpreter(impl, version) + + platforms = _expand_allowed_platforms(platforms) + + is_cpython = (impl or interpreter_name()) == "cp" + if is_cpython: + supported.extend( + cpython_tags( + python_version=python_version, + abis=abis, + platforms=platforms, + ) + ) + else: + supported.extend( + generic_tags( + interpreter=interpreter, + abis=abis, + platforms=platforms, + ) + ) + supported.extend( + compatible_tags( + python_version=python_version, + interpreter=interpreter, + platforms=platforms, + ) + ) + + return supported diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/datetime.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..8668b3b0ec1deec2aeb7ff6bd94265d6705e05bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/datetime.py @@ -0,0 +1,11 @@ +"""For when pip wants to check the date or time. +""" + +import datetime + + +def today_is_later_than(year: int, month: int, day: int) -> bool: + today = datetime.date.today() + given = datetime.date(year, month, day) + + return today > given diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/deprecation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..0911147e784737f58f174dce98ecae32b615c7b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/deprecation.py @@ -0,0 +1,124 @@ +""" +A module that implements tooling to enable easy warnings about deprecations. +""" + +import logging +import warnings +from typing import Any, Optional, TextIO, Type, Union + +from pip._vendor.packaging.version import parse + +from pip import __version__ as current_version # NOTE: tests patch this name. + +DEPRECATION_MSG_PREFIX = "DEPRECATION: " + + +class PipDeprecationWarning(Warning): + pass + + +_original_showwarning: Any = None + + +# Warnings <-> Logging Integration +def _showwarning( + message: Union[Warning, str], + category: Type[Warning], + filename: str, + lineno: int, + file: Optional[TextIO] = None, + line: Optional[str] = None, +) -> None: + if file is not None: + if _original_showwarning is not None: + _original_showwarning(message, category, filename, lineno, file, line) + elif issubclass(category, PipDeprecationWarning): + # We use a specially named logger which will handle all of the + # deprecation messages for pip. + logger = logging.getLogger("pip._internal.deprecations") + logger.warning(message) + else: + _original_showwarning(message, category, filename, lineno, file, line) + + +def install_warning_logger() -> None: + # Enable our Deprecation Warnings + warnings.simplefilter("default", PipDeprecationWarning, append=True) + + global _original_showwarning + + if _original_showwarning is None: + _original_showwarning = warnings.showwarning + warnings.showwarning = _showwarning + + +def deprecated( + *, + reason: str, + replacement: Optional[str], + gone_in: Optional[str], + feature_flag: Optional[str] = None, + issue: Optional[int] = None, +) -> None: + """Helper to deprecate existing functionality. + + reason: + Textual reason shown to the user about why this functionality has + been deprecated. Should be a complete sentence. + replacement: + Textual suggestion shown to the user about what alternative + functionality they can use. + gone_in: + The version of pip does this functionality should get removed in. + Raises an error if pip's current version is greater than or equal to + this. + feature_flag: + Command-line flag of the form --use-feature={feature_flag} for testing + upcoming functionality. + issue: + Issue number on the tracker that would serve as a useful place for + users to find related discussion and provide feedback. + """ + + # Determine whether or not the feature is already gone in this version. + is_gone = gone_in is not None and parse(current_version) >= parse(gone_in) + + message_parts = [ + (reason, f"{DEPRECATION_MSG_PREFIX}{{}}"), + ( + gone_in, + ( + "pip {} will enforce this behaviour change." + if not is_gone + else "Since pip {}, this is no longer supported." + ), + ), + ( + replacement, + "A possible replacement is {}.", + ), + ( + feature_flag, + ( + "You can use the flag --use-feature={} to test the upcoming behaviour." + if not is_gone + else None + ), + ), + ( + issue, + "Discussion can be found at https://github.com/pypa/pip/issues/{}", + ), + ] + + message = " ".join( + format_str.format(value) + for value, format_str in message_parts + if format_str is not None and value is not None + ) + + # Raise as an error if this behaviour is deprecated. + if is_gone: + raise PipDeprecationWarning(message) + + warnings.warn(message, category=PipDeprecationWarning, stacklevel=2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/direct_url_helpers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/direct_url_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..66020d3964ad4d8bc55893380383b271642471f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/direct_url_helpers.py @@ -0,0 +1,87 @@ +from typing import Optional + +from pip._internal.models.direct_url import ArchiveInfo, DirectUrl, DirInfo, VcsInfo +from pip._internal.models.link import Link +from pip._internal.utils.urls import path_to_url +from pip._internal.vcs import vcs + + +def direct_url_as_pep440_direct_reference(direct_url: DirectUrl, name: str) -> str: + """Convert a DirectUrl to a pip requirement string.""" + direct_url.validate() # if invalid, this is a pip bug + requirement = name + " @ " + fragments = [] + if isinstance(direct_url.info, VcsInfo): + requirement += ( + f"{direct_url.info.vcs}+{direct_url.url}@{direct_url.info.commit_id}" + ) + elif isinstance(direct_url.info, ArchiveInfo): + requirement += direct_url.url + if direct_url.info.hash: + fragments.append(direct_url.info.hash) + else: + assert isinstance(direct_url.info, DirInfo) + requirement += direct_url.url + if direct_url.subdirectory: + fragments.append("subdirectory=" + direct_url.subdirectory) + if fragments: + requirement += "#" + "&".join(fragments) + return requirement + + +def direct_url_for_editable(source_dir: str) -> DirectUrl: + return DirectUrl( + url=path_to_url(source_dir), + info=DirInfo(editable=True), + ) + + +def direct_url_from_link( + link: Link, source_dir: Optional[str] = None, link_is_in_wheel_cache: bool = False +) -> DirectUrl: + if link.is_vcs: + vcs_backend = vcs.get_backend_for_scheme(link.scheme) + assert vcs_backend + url, requested_revision, _ = vcs_backend.get_url_rev_and_auth( + link.url_without_fragment + ) + # For VCS links, we need to find out and add commit_id. + if link_is_in_wheel_cache: + # If the requested VCS link corresponds to a cached + # wheel, it means the requested revision was an + # immutable commit hash, otherwise it would not have + # been cached. In that case we don't have a source_dir + # with the VCS checkout. + assert requested_revision + commit_id = requested_revision + else: + # If the wheel was not in cache, it means we have + # had to checkout from VCS to build and we have a source_dir + # which we can inspect to find out the commit id. + assert source_dir + commit_id = vcs_backend.get_revision(source_dir) + return DirectUrl( + url=url, + info=VcsInfo( + vcs=vcs_backend.name, + commit_id=commit_id, + requested_revision=requested_revision, + ), + subdirectory=link.subdirectory_fragment, + ) + elif link.is_existing_dir(): + return DirectUrl( + url=link.url_without_fragment, + info=DirInfo(), + subdirectory=link.subdirectory_fragment, + ) + else: + hash = None + hash_name = link.hash_name + if hash_name: + hash = f"{hash_name}={link.hash}" + return DirectUrl( + url=link.url_without_fragment, + info=ArchiveInfo(hash=hash), + subdirectory=link.subdirectory_fragment, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/egg_link.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/egg_link.py new file mode 100644 index 0000000000000000000000000000000000000000..4a384a63682ce53cafcf889551b13b9177a14e44 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/egg_link.py @@ -0,0 +1,80 @@ +import os +import re +import sys +from typing import List, Optional + +from pip._internal.locations import site_packages, user_site +from pip._internal.utils.virtualenv import ( + running_under_virtualenv, + virtualenv_no_global, +) + +__all__ = [ + "egg_link_path_from_sys_path", + "egg_link_path_from_location", +] + + +def _egg_link_names(raw_name: str) -> List[str]: + """ + Convert a Name metadata value to a .egg-link name, by applying + the same substitution as pkg_resources's safe_name function. + Note: we cannot use canonicalize_name because it has a different logic. + + We also look for the raw name (without normalization) as setuptools 69 changed + the way it names .egg-link files (https://github.com/pypa/setuptools/issues/4167). + """ + return [ + re.sub("[^A-Za-z0-9.]+", "-", raw_name) + ".egg-link", + f"{raw_name}.egg-link", + ] + + +def egg_link_path_from_sys_path(raw_name: str) -> Optional[str]: + """ + Look for a .egg-link file for project name, by walking sys.path. + """ + egg_link_names = _egg_link_names(raw_name) + for path_item in sys.path: + for egg_link_name in egg_link_names: + egg_link = os.path.join(path_item, egg_link_name) + if os.path.isfile(egg_link): + return egg_link + return None + + +def egg_link_path_from_location(raw_name: str) -> Optional[str]: + """ + Return the path for the .egg-link file if it exists, otherwise, None. + + There's 3 scenarios: + 1) not in a virtualenv + try to find in site.USER_SITE, then site_packages + 2) in a no-global virtualenv + try to find in site_packages + 3) in a yes-global virtualenv + try to find in site_packages, then site.USER_SITE + (don't look in global location) + + For #1 and #3, there could be odd cases, where there's an egg-link in 2 + locations. + + This method will just return the first one found. + """ + sites: List[str] = [] + if running_under_virtualenv(): + sites.append(site_packages) + if not virtualenv_no_global() and user_site: + sites.append(user_site) + else: + if user_site: + sites.append(user_site) + sites.append(site_packages) + + egg_link_names = _egg_link_names(raw_name) + for site in sites: + for egg_link_name in egg_link_names: + egglink = os.path.join(site, egg_link_name) + if os.path.isfile(egglink): + return egglink + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/entrypoints.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/entrypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..150136938548af6aa5ae1f716b330d0eb2d3e013 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/entrypoints.py @@ -0,0 +1,84 @@ +import itertools +import os +import shutil +import sys +from typing import List, Optional + +from pip._internal.cli.main import main +from pip._internal.utils.compat import WINDOWS + +_EXECUTABLE_NAMES = [ + "pip", + f"pip{sys.version_info.major}", + f"pip{sys.version_info.major}.{sys.version_info.minor}", +] +if WINDOWS: + _allowed_extensions = {"", ".exe"} + _EXECUTABLE_NAMES = [ + "".join(parts) + for parts in itertools.product(_EXECUTABLE_NAMES, _allowed_extensions) + ] + + +def _wrapper(args: Optional[List[str]] = None) -> int: + """Central wrapper for all old entrypoints. + + Historically pip has had several entrypoints defined. Because of issues + arising from PATH, sys.path, multiple Pythons, their interactions, and most + of them having a pip installed, users suffer every time an entrypoint gets + moved. + + To alleviate this pain, and provide a mechanism for warning users and + directing them to an appropriate place for help, we now define all of + our old entrypoints as wrappers for the current one. + """ + sys.stderr.write( + "WARNING: pip is being invoked by an old script wrapper. This will " + "fail in a future version of pip.\n" + "Please see https://github.com/pypa/pip/issues/5599 for advice on " + "fixing the underlying issue.\n" + "To avoid this problem you can invoke Python with '-m pip' instead of " + "running pip directly.\n" + ) + return main(args) + + +def get_best_invocation_for_this_pip() -> str: + """Try to figure out the best way to invoke pip in the current environment.""" + binary_directory = "Scripts" if WINDOWS else "bin" + binary_prefix = os.path.join(sys.prefix, binary_directory) + + # Try to use pip[X[.Y]] names, if those executables for this environment are + # the first on PATH with that name. + path_parts = os.path.normcase(os.environ.get("PATH", "")).split(os.pathsep) + exe_are_in_PATH = os.path.normcase(binary_prefix) in path_parts + if exe_are_in_PATH: + for exe_name in _EXECUTABLE_NAMES: + found_executable = shutil.which(exe_name) + binary_executable = os.path.join(binary_prefix, exe_name) + if ( + found_executable + and os.path.exists(binary_executable) + and os.path.samefile( + found_executable, + binary_executable, + ) + ): + return exe_name + + # Use the `-m` invocation, if there's no "nice" invocation. + return f"{get_best_invocation_for_this_python()} -m pip" + + +def get_best_invocation_for_this_python() -> str: + """Try to figure out the best way to invoke the current Python.""" + exe = sys.executable + exe_name = os.path.basename(exe) + + # Try to use the basename, if it's the first executable. + found_executable = shutil.which(exe_name) + if found_executable and os.path.samefile(found_executable, exe): + return exe_name + + # Use the full executable name, because we couldn't find something simpler. + return exe diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filesystem.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..22e356cdd75ae69c05c5488d701e978e01c9e7a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filesystem.py @@ -0,0 +1,149 @@ +import fnmatch +import os +import os.path +import random +import sys +from contextlib import contextmanager +from tempfile import NamedTemporaryFile +from typing import Any, BinaryIO, Generator, List, Union, cast + +from pip._internal.utils.compat import get_path_uid +from pip._internal.utils.misc import format_size +from pip._internal.utils.retry import retry + + +def check_path_owner(path: str) -> bool: + # If we don't have a way to check the effective uid of this process, then + # we'll just assume that we own the directory. + if sys.platform == "win32" or not hasattr(os, "geteuid"): + return True + + assert os.path.isabs(path) + + previous = None + while path != previous: + if os.path.lexists(path): + # Check if path is writable by current user. + if os.geteuid() == 0: + # Special handling for root user in order to handle properly + # cases where users use sudo without -H flag. + try: + path_uid = get_path_uid(path) + except OSError: + return False + return path_uid == 0 + else: + return os.access(path, os.W_OK) + else: + previous, path = path, os.path.dirname(path) + return False # assume we don't own the path + + +@contextmanager +def adjacent_tmp_file(path: str, **kwargs: Any) -> Generator[BinaryIO, None, None]: + """Return a file-like object pointing to a tmp file next to path. + + The file is created securely and is ensured to be written to disk + after the context reaches its end. + + kwargs will be passed to tempfile.NamedTemporaryFile to control + the way the temporary file will be opened. + """ + with NamedTemporaryFile( + delete=False, + dir=os.path.dirname(path), + prefix=os.path.basename(path), + suffix=".tmp", + **kwargs, + ) as f: + result = cast(BinaryIO, f) + try: + yield result + finally: + result.flush() + os.fsync(result.fileno()) + + +replace = retry(stop_after_delay=1, wait=0.25)(os.replace) + + +# test_writable_dir and _test_writable_dir_win are copied from Flit, +# with the author's agreement to also place them under pip's license. +def test_writable_dir(path: str) -> bool: + """Check if a directory is writable. + + Uses os.access() on POSIX, tries creating files on Windows. + """ + # If the directory doesn't exist, find the closest parent that does. + while not os.path.isdir(path): + parent = os.path.dirname(path) + if parent == path: + break # Should never get here, but infinite loops are bad + path = parent + + if os.name == "posix": + return os.access(path, os.W_OK) + + return _test_writable_dir_win(path) + + +def _test_writable_dir_win(path: str) -> bool: + # os.access doesn't work on Windows: http://bugs.python.org/issue2528 + # and we can't use tempfile: http://bugs.python.org/issue22107 + basename = "accesstest_deleteme_fishfingers_custard_" + alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" + for _ in range(10): + name = basename + "".join(random.choice(alphabet) for _ in range(6)) + file = os.path.join(path, name) + try: + fd = os.open(file, os.O_RDWR | os.O_CREAT | os.O_EXCL) + except FileExistsError: + pass + except PermissionError: + # This could be because there's a directory with the same name. + # But it's highly unlikely there's a directory called that, + # so we'll assume it's because the parent dir is not writable. + # This could as well be because the parent dir is not readable, + # due to non-privileged user access. + return False + else: + os.close(fd) + os.unlink(file) + return True + + # This should never be reached + raise OSError("Unexpected condition testing for writable directory") + + +def find_files(path: str, pattern: str) -> List[str]: + """Returns a list of absolute paths of files beneath path, recursively, + with filenames which match the UNIX-style shell glob pattern.""" + result: List[str] = [] + for root, _, files in os.walk(path): + matches = fnmatch.filter(files, pattern) + result.extend(os.path.join(root, f) for f in matches) + return result + + +def file_size(path: str) -> Union[int, float]: + # If it's a symlink, return 0. + if os.path.islink(path): + return 0 + return os.path.getsize(path) + + +def format_file_size(path: str) -> str: + return format_size(file_size(path)) + + +def directory_size(path: str) -> Union[int, float]: + size = 0.0 + for root, _dirs, files in os.walk(path): + for filename in files: + file_path = os.path.join(root, filename) + size += file_size(file_path) + return size + + +def format_directory_size(path: str) -> str: + return format_size(directory_size(path)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filetypes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filetypes.py new file mode 100644 index 0000000000000000000000000000000000000000..5948570178f3e6e79d1ff574241d09d4d8ed78de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/filetypes.py @@ -0,0 +1,27 @@ +"""Filetype information. +""" + +from typing import Tuple + +from pip._internal.utils.misc import splitext + +WHEEL_EXTENSION = ".whl" +BZ2_EXTENSIONS: Tuple[str, ...] = (".tar.bz2", ".tbz") +XZ_EXTENSIONS: Tuple[str, ...] = ( + ".tar.xz", + ".txz", + ".tlz", + ".tar.lz", + ".tar.lzma", +) +ZIP_EXTENSIONS: Tuple[str, ...] = (".zip", WHEEL_EXTENSION) +TAR_EXTENSIONS: Tuple[str, ...] = (".tar.gz", ".tgz", ".tar") +ARCHIVE_EXTENSIONS = ZIP_EXTENSIONS + BZ2_EXTENSIONS + TAR_EXTENSIONS + XZ_EXTENSIONS + + +def is_archive_file(name: str) -> bool: + """Return True if `name` is a considered as an archive file.""" + ext = splitext(name)[1].lower() + if ext in ARCHIVE_EXTENSIONS: + return True + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/glibc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/glibc.py new file mode 100644 index 0000000000000000000000000000000000000000..998868ff2a482648024c848c9650d584403cbc8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/glibc.py @@ -0,0 +1,101 @@ +import os +import sys +from typing import Optional, Tuple + + +def glibc_version_string() -> Optional[str]: + "Returns glibc version string, or None if not using glibc." + return glibc_version_string_confstr() or glibc_version_string_ctypes() + + +def glibc_version_string_confstr() -> Optional[str]: + "Primary implementation of glibc_version_string using os.confstr." + # os.confstr is quite a bit faster than ctypes.DLL. It's also less likely + # to be broken or missing. This strategy is used in the standard library + # platform module: + # https://github.com/python/cpython/blob/fcf1d003bf4f0100c9d0921ff3d70e1127ca1b71/Lib/platform.py#L175-L183 + if sys.platform == "win32": + return None + try: + gnu_libc_version = os.confstr("CS_GNU_LIBC_VERSION") + if gnu_libc_version is None: + return None + # os.confstr("CS_GNU_LIBC_VERSION") returns a string like "glibc 2.17": + _, version = gnu_libc_version.split() + except (AttributeError, OSError, ValueError): + # os.confstr() or CS_GNU_LIBC_VERSION not available (or a bad value)... + return None + return version + + +def glibc_version_string_ctypes() -> Optional[str]: + "Fallback implementation of glibc_version_string using ctypes." + + try: + import ctypes + except ImportError: + return None + + # ctypes.CDLL(None) internally calls dlopen(NULL), and as the dlopen + # manpage says, "If filename is NULL, then the returned handle is for the + # main program". This way we can let the linker do the work to figure out + # which libc our process is actually using. + # + # We must also handle the special case where the executable is not a + # dynamically linked executable. This can occur when using musl libc, + # for example. In this situation, dlopen() will error, leading to an + # OSError. Interestingly, at least in the case of musl, there is no + # errno set on the OSError. The single string argument used to construct + # OSError comes from libc itself and is therefore not portable to + # hard code here. In any case, failure to call dlopen() means we + # can't proceed, so we bail on our attempt. + try: + process_namespace = ctypes.CDLL(None) + except OSError: + return None + + try: + gnu_get_libc_version = process_namespace.gnu_get_libc_version + except AttributeError: + # Symbol doesn't exist -> therefore, we are not linked to + # glibc. + return None + + # Call gnu_get_libc_version, which returns a string like "2.5" + gnu_get_libc_version.restype = ctypes.c_char_p + version_str: str = gnu_get_libc_version() + # py2 / py3 compatibility: + if not isinstance(version_str, str): + version_str = version_str.decode("ascii") + + return version_str + + +# platform.libc_ver regularly returns completely nonsensical glibc +# versions. E.g. on my computer, platform says: +# +# ~$ python2.7 -c 'import platform; print(platform.libc_ver())' +# ('glibc', '2.7') +# ~$ python3.5 -c 'import platform; print(platform.libc_ver())' +# ('glibc', '2.9') +# +# But the truth is: +# +# ~$ ldd --version +# ldd (Debian GLIBC 2.22-11) 2.22 +# +# This is unfortunate, because it means that the linehaul data on libc +# versions that was generated by pip 8.1.2 and earlier is useless and +# misleading. Solution: instead of using platform, use our code that actually +# works. +def libc_ver() -> Tuple[str, str]: + """Try to determine the glibc version + + Returns a tuple of strings (lib, version) which default to empty strings + in case the lookup fails. + """ + glibc_version = glibc_version_string() + if glibc_version is None: + return ("", "") + else: + return ("glibc", glibc_version) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/hashes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/hashes.py new file mode 100644 index 0000000000000000000000000000000000000000..535e94fca0cc8b049673ee0d02dba259c68af76c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/hashes.py @@ -0,0 +1,147 @@ +import hashlib +from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, NoReturn, Optional + +from pip._internal.exceptions import HashMismatch, HashMissing, InstallationError +from pip._internal.utils.misc import read_chunks + +if TYPE_CHECKING: + from hashlib import _Hash + + +# The recommended hash algo of the moment. Change this whenever the state of +# the art changes; it won't hurt backward compatibility. +FAVORITE_HASH = "sha256" + + +# Names of hashlib algorithms allowed by the --hash option and ``pip hash`` +# Currently, those are the ones at least as collision-resistant as sha256. +STRONG_HASHES = ["sha256", "sha384", "sha512"] + + +class Hashes: + """A wrapper that builds multiple hashes at once and checks them against + known-good values + + """ + + def __init__(self, hashes: Optional[Dict[str, List[str]]] = None) -> None: + """ + :param hashes: A dict of algorithm names pointing to lists of allowed + hex digests + """ + allowed = {} + if hashes is not None: + for alg, keys in hashes.items(): + # Make sure values are always sorted (to ease equality checks) + allowed[alg] = [k.lower() for k in sorted(keys)] + self._allowed = allowed + + def __and__(self, other: "Hashes") -> "Hashes": + if not isinstance(other, Hashes): + return NotImplemented + + # If either of the Hashes object is entirely empty (i.e. no hash + # specified at all), all hashes from the other object are allowed. + if not other: + return self + if not self: + return other + + # Otherwise only hashes that present in both objects are allowed. + new = {} + for alg, values in other._allowed.items(): + if alg not in self._allowed: + continue + new[alg] = [v for v in values if v in self._allowed[alg]] + return Hashes(new) + + @property + def digest_count(self) -> int: + return sum(len(digests) for digests in self._allowed.values()) + + def is_hash_allowed(self, hash_name: str, hex_digest: str) -> bool: + """Return whether the given hex digest is allowed.""" + return hex_digest in self._allowed.get(hash_name, []) + + def check_against_chunks(self, chunks: Iterable[bytes]) -> None: + """Check good hashes against ones built from iterable of chunks of + data. + + Raise HashMismatch if none match. + + """ + gots = {} + for hash_name in self._allowed.keys(): + try: + gots[hash_name] = hashlib.new(hash_name) + except (ValueError, TypeError): + raise InstallationError(f"Unknown hash name: {hash_name}") + + for chunk in chunks: + for hash in gots.values(): + hash.update(chunk) + + for hash_name, got in gots.items(): + if got.hexdigest() in self._allowed[hash_name]: + return + self._raise(gots) + + def _raise(self, gots: Dict[str, "_Hash"]) -> "NoReturn": + raise HashMismatch(self._allowed, gots) + + def check_against_file(self, file: BinaryIO) -> None: + """Check good hashes against a file-like object + + Raise HashMismatch if none match. + + """ + return self.check_against_chunks(read_chunks(file)) + + def check_against_path(self, path: str) -> None: + with open(path, "rb") as file: + return self.check_against_file(file) + + def has_one_of(self, hashes: Dict[str, str]) -> bool: + """Return whether any of the given hashes are allowed.""" + for hash_name, hex_digest in hashes.items(): + if self.is_hash_allowed(hash_name, hex_digest): + return True + return False + + def __bool__(self) -> bool: + """Return whether I know any known-good hashes.""" + return bool(self._allowed) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Hashes): + return NotImplemented + return self._allowed == other._allowed + + def __hash__(self) -> int: + return hash( + ",".join( + sorted( + ":".join((alg, digest)) + for alg, digest_list in self._allowed.items() + for digest in digest_list + ) + ) + ) + + +class MissingHashes(Hashes): + """A workalike for Hashes used when we're missing a hash for a requirement + + It computes the actual hash of the requirement and raises a HashMissing + exception showing it to the user. + + """ + + def __init__(self) -> None: + """Don't offer the ``hashes`` kwarg.""" + # Pass our favorite hash in to generate a "gotten hash". With the + # empty list, it will never match, so an error will always raise. + super().__init__(hashes={FAVORITE_HASH: []}) + + def _raise(self, gots: Dict[str, "_Hash"]) -> "NoReturn": + raise HashMissing(gots[FAVORITE_HASH].hexdigest()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/logging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..62035fc40eca1311704175d80a5c7082a166924f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/logging.py @@ -0,0 +1,354 @@ +import contextlib +import errno +import logging +import logging.handlers +import os +import sys +import threading +from dataclasses import dataclass +from io import TextIOWrapper +from logging import Filter +from typing import Any, ClassVar, Generator, List, Optional, TextIO, Type + +from pip._vendor.rich.console import ( + Console, + ConsoleOptions, + ConsoleRenderable, + RenderableType, + RenderResult, + RichCast, +) +from pip._vendor.rich.highlighter import NullHighlighter +from pip._vendor.rich.logging import RichHandler +from pip._vendor.rich.segment import Segment +from pip._vendor.rich.style import Style + +from pip._internal.utils._log import VERBOSE, getLogger +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.deprecation import DEPRECATION_MSG_PREFIX +from pip._internal.utils.misc import ensure_dir + +_log_state = threading.local() +subprocess_logger = getLogger("pip.subprocessor") + + +class BrokenStdoutLoggingError(Exception): + """ + Raised if BrokenPipeError occurs for the stdout stream while logging. + """ + + +def _is_broken_pipe_error(exc_class: Type[BaseException], exc: BaseException) -> bool: + if exc_class is BrokenPipeError: + return True + + # On Windows, a broken pipe can show up as EINVAL rather than EPIPE: + # https://bugs.python.org/issue19612 + # https://bugs.python.org/issue30418 + if not WINDOWS: + return False + + return isinstance(exc, OSError) and exc.errno in (errno.EINVAL, errno.EPIPE) + + +@contextlib.contextmanager +def indent_log(num: int = 2) -> Generator[None, None, None]: + """ + A context manager which will cause the log output to be indented for any + log messages emitted inside it. + """ + # For thread-safety + _log_state.indentation = get_indentation() + _log_state.indentation += num + try: + yield + finally: + _log_state.indentation -= num + + +def get_indentation() -> int: + return getattr(_log_state, "indentation", 0) + + +class IndentingFormatter(logging.Formatter): + default_time_format = "%Y-%m-%dT%H:%M:%S" + + def __init__( + self, + *args: Any, + add_timestamp: bool = False, + **kwargs: Any, + ) -> None: + """ + A logging.Formatter that obeys the indent_log() context manager. + + :param add_timestamp: A bool indicating output lines should be prefixed + with their record's timestamp. + """ + self.add_timestamp = add_timestamp + super().__init__(*args, **kwargs) + + def get_message_start(self, formatted: str, levelno: int) -> str: + """ + Return the start of the formatted log message (not counting the + prefix to add to each line). + """ + if levelno < logging.WARNING: + return "" + if formatted.startswith(DEPRECATION_MSG_PREFIX): + # Then the message already has a prefix. We don't want it to + # look like "WARNING: DEPRECATION: ...." + return "" + if levelno < logging.ERROR: + return "WARNING: " + + return "ERROR: " + + def format(self, record: logging.LogRecord) -> str: + """ + Calls the standard formatter, but will indent all of the log message + lines by our current indentation level. + """ + formatted = super().format(record) + message_start = self.get_message_start(formatted, record.levelno) + formatted = message_start + formatted + + prefix = "" + if self.add_timestamp: + prefix = f"{self.formatTime(record)} " + prefix += " " * get_indentation() + formatted = "".join([prefix + line for line in formatted.splitlines(True)]) + return formatted + + +@dataclass +class IndentedRenderable: + renderable: RenderableType + indent: int + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + segments = console.render(self.renderable, options) + lines = Segment.split_lines(segments) + for line in lines: + yield Segment(" " * self.indent) + yield from line + yield Segment("\n") + + +class PipConsole(Console): + def on_broken_pipe(self) -> None: + # Reraise the original exception, rich 13.8.0+ exits by default + # instead, preventing our handler from firing. + raise BrokenPipeError() from None + + +class RichPipStreamHandler(RichHandler): + KEYWORDS: ClassVar[Optional[List[str]]] = [] + + def __init__(self, stream: Optional[TextIO], no_color: bool) -> None: + super().__init__( + console=PipConsole(file=stream, no_color=no_color, soft_wrap=True), + show_time=False, + show_level=False, + show_path=False, + highlighter=NullHighlighter(), + ) + + # Our custom override on Rich's logger, to make things work as we need them to. + def emit(self, record: logging.LogRecord) -> None: + style: Optional[Style] = None + + # If we are given a diagnostic error to present, present it with indentation. + if getattr(record, "rich", False): + assert isinstance(record.args, tuple) + (rich_renderable,) = record.args + assert isinstance( + rich_renderable, (ConsoleRenderable, RichCast, str) + ), f"{rich_renderable} is not rich-console-renderable" + + renderable: RenderableType = IndentedRenderable( + rich_renderable, indent=get_indentation() + ) + else: + message = self.format(record) + renderable = self.render_message(record, message) + if record.levelno is not None: + if record.levelno >= logging.ERROR: + style = Style(color="red") + elif record.levelno >= logging.WARNING: + style = Style(color="yellow") + + try: + self.console.print(renderable, overflow="ignore", crop=False, style=style) + except Exception: + self.handleError(record) + + def handleError(self, record: logging.LogRecord) -> None: + """Called when logging is unable to log some output.""" + + exc_class, exc = sys.exc_info()[:2] + # If a broken pipe occurred while calling write() or flush() on the + # stdout stream in logging's Handler.emit(), then raise our special + # exception so we can handle it in main() instead of logging the + # broken pipe error and continuing. + if ( + exc_class + and exc + and self.console.file is sys.stdout + and _is_broken_pipe_error(exc_class, exc) + ): + raise BrokenStdoutLoggingError() + + return super().handleError(record) + + +class BetterRotatingFileHandler(logging.handlers.RotatingFileHandler): + def _open(self) -> TextIOWrapper: + ensure_dir(os.path.dirname(self.baseFilename)) + return super()._open() + + +class MaxLevelFilter(Filter): + def __init__(self, level: int) -> None: + self.level = level + + def filter(self, record: logging.LogRecord) -> bool: + return record.levelno < self.level + + +class ExcludeLoggerFilter(Filter): + """ + A logging Filter that excludes records from a logger (or its children). + """ + + def filter(self, record: logging.LogRecord) -> bool: + # The base Filter class allows only records from a logger (or its + # children). + return not super().filter(record) + + +def setup_logging(verbosity: int, no_color: bool, user_log_file: Optional[str]) -> int: + """Configures and sets up all of the logging + + Returns the requested logging level, as its integer value. + """ + + # Determine the level to be logging at. + if verbosity >= 2: + level_number = logging.DEBUG + elif verbosity == 1: + level_number = VERBOSE + elif verbosity == -1: + level_number = logging.WARNING + elif verbosity == -2: + level_number = logging.ERROR + elif verbosity <= -3: + level_number = logging.CRITICAL + else: + level_number = logging.INFO + + level = logging.getLevelName(level_number) + + # The "root" logger should match the "console" level *unless* we also need + # to log to a user log file. + include_user_log = user_log_file is not None + if include_user_log: + additional_log_file = user_log_file + root_level = "DEBUG" + else: + additional_log_file = "/dev/null" + root_level = level + + # Disable any logging besides WARNING unless we have DEBUG level logging + # enabled for vendored libraries. + vendored_log_level = "WARNING" if level in ["INFO", "ERROR"] else "DEBUG" + + # Shorthands for clarity + log_streams = { + "stdout": "ext://sys.stdout", + "stderr": "ext://sys.stderr", + } + handler_classes = { + "stream": "pip._internal.utils.logging.RichPipStreamHandler", + "file": "pip._internal.utils.logging.BetterRotatingFileHandler", + } + handlers = ["console", "console_errors", "console_subprocess"] + ( + ["user_log"] if include_user_log else [] + ) + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "filters": { + "exclude_warnings": { + "()": "pip._internal.utils.logging.MaxLevelFilter", + "level": logging.WARNING, + }, + "restrict_to_subprocess": { + "()": "logging.Filter", + "name": subprocess_logger.name, + }, + "exclude_subprocess": { + "()": "pip._internal.utils.logging.ExcludeLoggerFilter", + "name": subprocess_logger.name, + }, + }, + "formatters": { + "indent": { + "()": IndentingFormatter, + "format": "%(message)s", + }, + "indent_with_timestamp": { + "()": IndentingFormatter, + "format": "%(message)s", + "add_timestamp": True, + }, + }, + "handlers": { + "console": { + "level": level, + "class": handler_classes["stream"], + "no_color": no_color, + "stream": log_streams["stdout"], + "filters": ["exclude_subprocess", "exclude_warnings"], + "formatter": "indent", + }, + "console_errors": { + "level": "WARNING", + "class": handler_classes["stream"], + "no_color": no_color, + "stream": log_streams["stderr"], + "filters": ["exclude_subprocess"], + "formatter": "indent", + }, + # A handler responsible for logging to the console messages + # from the "subprocessor" logger. + "console_subprocess": { + "level": level, + "class": handler_classes["stream"], + "stream": log_streams["stderr"], + "no_color": no_color, + "filters": ["restrict_to_subprocess"], + "formatter": "indent", + }, + "user_log": { + "level": "DEBUG", + "class": handler_classes["file"], + "filename": additional_log_file, + "encoding": "utf-8", + "delay": True, + "formatter": "indent_with_timestamp", + }, + }, + "root": { + "level": root_level, + "handlers": handlers, + }, + "loggers": {"pip._vendor": {"level": vendored_log_level}}, + } + ) + + return level_number diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/misc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..44f6a05fbdd7f7b5779141f53b25b523af7e15eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/misc.py @@ -0,0 +1,773 @@ +import errno +import getpass +import hashlib +import logging +import os +import posixpath +import shutil +import stat +import sys +import sysconfig +import urllib.parse +from dataclasses import dataclass +from functools import partial +from io import StringIO +from itertools import filterfalse, tee, zip_longest +from pathlib import Path +from types import FunctionType, TracebackType +from typing import ( + Any, + BinaryIO, + Callable, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + TextIO, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from pip._vendor.packaging.requirements import Requirement +from pip._vendor.pyproject_hooks import BuildBackendHookCaller + +from pip import __version__ +from pip._internal.exceptions import CommandError, ExternallyManagedEnvironment +from pip._internal.locations import get_major_minor_version +from pip._internal.utils.compat import WINDOWS +from pip._internal.utils.retry import retry +from pip._internal.utils.virtualenv import running_under_virtualenv + +__all__ = [ + "rmtree", + "display_path", + "backup_dir", + "ask", + "splitext", + "format_size", + "is_installable_dir", + "normalize_path", + "renames", + "get_prog", + "ensure_dir", + "remove_auth_from_url", + "check_externally_managed", + "ConfiguredBuildBackendHookCaller", +] + +logger = logging.getLogger(__name__) + +T = TypeVar("T") +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] +VersionInfo = Tuple[int, int, int] +NetlocTuple = Tuple[str, Tuple[Optional[str], Optional[str]]] +OnExc = Callable[[FunctionType, Path, BaseException], Any] +OnErr = Callable[[FunctionType, Path, ExcInfo], Any] + +FILE_CHUNK_SIZE = 1024 * 1024 + + +def get_pip_version() -> str: + pip_pkg_dir = os.path.join(os.path.dirname(__file__), "..", "..") + pip_pkg_dir = os.path.abspath(pip_pkg_dir) + + return f"pip {__version__} from {pip_pkg_dir} (python {get_major_minor_version()})" + + +def normalize_version_info(py_version_info: Tuple[int, ...]) -> Tuple[int, int, int]: + """ + Convert a tuple of ints representing a Python version to one of length + three. + + :param py_version_info: a tuple of ints representing a Python version, + or None to specify no version. The tuple can have any length. + + :return: a tuple of length three if `py_version_info` is non-None. + Otherwise, return `py_version_info` unchanged (i.e. None). + """ + if len(py_version_info) < 3: + py_version_info += (3 - len(py_version_info)) * (0,) + elif len(py_version_info) > 3: + py_version_info = py_version_info[:3] + + return cast("VersionInfo", py_version_info) + + +def ensure_dir(path: str) -> None: + """os.path.makedirs without EEXIST.""" + try: + os.makedirs(path) + except OSError as e: + # Windows can raise spurious ENOTEMPTY errors. See #6426. + if e.errno != errno.EEXIST and e.errno != errno.ENOTEMPTY: + raise + + +def get_prog() -> str: + try: + prog = os.path.basename(sys.argv[0]) + if prog in ("__main__.py", "-c"): + return f"{sys.executable} -m pip" + else: + return prog + except (AttributeError, TypeError, IndexError): + pass + return "pip" + + +# Retry every half second for up to 3 seconds +@retry(stop_after_delay=3, wait=0.5) +def rmtree( + dir: str, ignore_errors: bool = False, onexc: Optional[OnExc] = None +) -> None: + if ignore_errors: + onexc = _onerror_ignore + if onexc is None: + onexc = _onerror_reraise + handler: OnErr = partial(rmtree_errorhandler, onexc=onexc) + if sys.version_info >= (3, 12): + # See https://docs.python.org/3.12/whatsnew/3.12.html#shutil. + shutil.rmtree(dir, onexc=handler) # type: ignore + else: + shutil.rmtree(dir, onerror=handler) # type: ignore + + +def _onerror_ignore(*_args: Any) -> None: + pass + + +def _onerror_reraise(*_args: Any) -> None: + raise # noqa: PLE0704 - Bare exception used to reraise existing exception + + +def rmtree_errorhandler( + func: FunctionType, + path: Path, + exc_info: Union[ExcInfo, BaseException], + *, + onexc: OnExc = _onerror_reraise, +) -> None: + """ + `rmtree` error handler to 'force' a file remove (i.e. like `rm -f`). + + * If a file is readonly then it's write flag is set and operation is + retried. + + * `onerror` is the original callback from `rmtree(... onerror=onerror)` + that is chained at the end if the "rm -f" still fails. + """ + try: + st_mode = os.stat(path).st_mode + except OSError: + # it's equivalent to os.path.exists + return + + if not st_mode & stat.S_IWRITE: + # convert to read/write + try: + os.chmod(path, st_mode | stat.S_IWRITE) + except OSError: + pass + else: + # use the original function to repeat the operation + try: + func(path) + return + except OSError: + pass + + if not isinstance(exc_info, BaseException): + _, exc_info, _ = exc_info + onexc(func, path, exc_info) + + +def display_path(path: str) -> str: + """Gives the display value for a given path, making it relative to cwd + if possible.""" + path = os.path.normcase(os.path.abspath(path)) + if path.startswith(os.getcwd() + os.path.sep): + path = "." + path[len(os.getcwd()) :] + return path + + +def backup_dir(dir: str, ext: str = ".bak") -> str: + """Figure out the name of a directory to back up the given dir to + (adding .bak, .bak2, etc)""" + n = 1 + extension = ext + while os.path.exists(dir + extension): + n += 1 + extension = ext + str(n) + return dir + extension + + +def ask_path_exists(message: str, options: Iterable[str]) -> str: + for action in os.environ.get("PIP_EXISTS_ACTION", "").split(): + if action in options: + return action + return ask(message, options) + + +def _check_no_input(message: str) -> None: + """Raise an error if no input is allowed.""" + if os.environ.get("PIP_NO_INPUT"): + raise Exception( + f"No input was expected ($PIP_NO_INPUT set); question: {message}" + ) + + +def ask(message: str, options: Iterable[str]) -> str: + """Ask the message interactively, with the given possible responses""" + while 1: + _check_no_input(message) + response = input(message) + response = response.strip().lower() + if response not in options: + print( + "Your response ({!r}) was not one of the expected responses: " + "{}".format(response, ", ".join(options)) + ) + else: + return response + + +def ask_input(message: str) -> str: + """Ask for input interactively.""" + _check_no_input(message) + return input(message) + + +def ask_password(message: str) -> str: + """Ask for a password interactively.""" + _check_no_input(message) + return getpass.getpass(message) + + +def strtobool(val: str) -> int: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif val in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {val!r}") + + +def format_size(bytes: float) -> str: + if bytes > 1000 * 1000: + return f"{bytes / 1000.0 / 1000:.1f} MB" + elif bytes > 10 * 1000: + return f"{int(bytes / 1000)} kB" + elif bytes > 1000: + return f"{bytes / 1000.0:.1f} kB" + else: + return f"{int(bytes)} bytes" + + +def tabulate(rows: Iterable[Iterable[Any]]) -> Tuple[List[str], List[int]]: + """Return a list of formatted rows and a list of column sizes. + + For example:: + + >>> tabulate([['foobar', 2000], [0xdeadbeef]]) + (['foobar 2000', '3735928559'], [10, 4]) + """ + rows = [tuple(map(str, row)) for row in rows] + sizes = [max(map(len, col)) for col in zip_longest(*rows, fillvalue="")] + table = [" ".join(map(str.ljust, row, sizes)).rstrip() for row in rows] + return table, sizes + + +def is_installable_dir(path: str) -> bool: + """Is path is a directory containing pyproject.toml or setup.py? + + If pyproject.toml exists, this is a PEP 517 project. Otherwise we look for + a legacy setuptools layout by identifying setup.py. We don't check for the + setup.cfg because using it without setup.py is only available for PEP 517 + projects, which are already covered by the pyproject.toml check. + """ + if not os.path.isdir(path): + return False + if os.path.isfile(os.path.join(path, "pyproject.toml")): + return True + if os.path.isfile(os.path.join(path, "setup.py")): + return True + return False + + +def read_chunks( + file: BinaryIO, size: int = FILE_CHUNK_SIZE +) -> Generator[bytes, None, None]: + """Yield pieces of data from a file-like object until EOF.""" + while True: + chunk = file.read(size) + if not chunk: + break + yield chunk + + +def normalize_path(path: str, resolve_symlinks: bool = True) -> str: + """ + Convert a path to its canonical, case-normalized, absolute version. + + """ + path = os.path.expanduser(path) + if resolve_symlinks: + path = os.path.realpath(path) + else: + path = os.path.abspath(path) + return os.path.normcase(path) + + +def splitext(path: str) -> Tuple[str, str]: + """Like os.path.splitext, but take off .tar too""" + base, ext = posixpath.splitext(path) + if base.lower().endswith(".tar"): + ext = base[-4:] + ext + base = base[:-4] + return base, ext + + +def renames(old: str, new: str) -> None: + """Like os.renames(), but handles renaming across devices.""" + # Implementation borrowed from os.renames(). + head, tail = os.path.split(new) + if head and tail and not os.path.exists(head): + os.makedirs(head) + + shutil.move(old, new) + + head, tail = os.path.split(old) + if head and tail: + try: + os.removedirs(head) + except OSError: + pass + + +def is_local(path: str) -> bool: + """ + Return True if path is within sys.prefix, if we're running in a virtualenv. + + If we're not in a virtualenv, all paths are considered "local." + + Caution: this function assumes the head of path has been normalized + with normalize_path. + """ + if not running_under_virtualenv(): + return True + return path.startswith(normalize_path(sys.prefix)) + + +def write_output(msg: Any, *args: Any) -> None: + logger.info(msg, *args) + + +class StreamWrapper(StringIO): + orig_stream: TextIO + + @classmethod + def from_stream(cls, orig_stream: TextIO) -> "StreamWrapper": + ret = cls() + ret.orig_stream = orig_stream + return ret + + # compileall.compile_dir() needs stdout.encoding to print to stdout + # type ignore is because TextIOBase.encoding is writeable + @property + def encoding(self) -> str: # type: ignore + return self.orig_stream.encoding + + +# Simulates an enum +def enum(*sequential: Any, **named: Any) -> Type[Any]: + enums = dict(zip(sequential, range(len(sequential))), **named) + reverse = {value: key for key, value in enums.items()} + enums["reverse_mapping"] = reverse + return type("Enum", (), enums) + + +def build_netloc(host: str, port: Optional[int]) -> str: + """ + Build a netloc from a host-port pair + """ + if port is None: + return host + if ":" in host: + # Only wrap host with square brackets when it is IPv6 + host = f"[{host}]" + return f"{host}:{port}" + + +def build_url_from_netloc(netloc: str, scheme: str = "https") -> str: + """ + Build a full URL from a netloc. + """ + if netloc.count(":") >= 2 and "@" not in netloc and "[" not in netloc: + # It must be a bare IPv6 address, so wrap it with brackets. + netloc = f"[{netloc}]" + return f"{scheme}://{netloc}" + + +def parse_netloc(netloc: str) -> Tuple[Optional[str], Optional[int]]: + """ + Return the host-port pair from a netloc. + """ + url = build_url_from_netloc(netloc) + parsed = urllib.parse.urlparse(url) + return parsed.hostname, parsed.port + + +def split_auth_from_netloc(netloc: str) -> NetlocTuple: + """ + Parse out and remove the auth information from a netloc. + + Returns: (netloc, (username, password)). + """ + if "@" not in netloc: + return netloc, (None, None) + + # Split from the right because that's how urllib.parse.urlsplit() + # behaves if more than one @ is present (which can be checked using + # the password attribute of urlsplit()'s return value). + auth, netloc = netloc.rsplit("@", 1) + pw: Optional[str] = None + if ":" in auth: + # Split from the left because that's how urllib.parse.urlsplit() + # behaves if more than one : is present (which again can be checked + # using the password attribute of the return value) + user, pw = auth.split(":", 1) + else: + user, pw = auth, None + + user = urllib.parse.unquote(user) + if pw is not None: + pw = urllib.parse.unquote(pw) + + return netloc, (user, pw) + + +def redact_netloc(netloc: str) -> str: + """ + Replace the sensitive data in a netloc with "****", if it exists. + + For example: + - "user:pass@example.com" returns "user:****@example.com" + - "accesstoken@example.com" returns "****@example.com" + """ + netloc, (user, password) = split_auth_from_netloc(netloc) + if user is None: + return netloc + if password is None: + user = "****" + password = "" + else: + user = urllib.parse.quote(user) + password = ":****" + return f"{user}{password}@{netloc}" + + +def _transform_url( + url: str, transform_netloc: Callable[[str], Tuple[Any, ...]] +) -> Tuple[str, NetlocTuple]: + """Transform and replace netloc in a url. + + transform_netloc is a function taking the netloc and returning a + tuple. The first element of this tuple is the new netloc. The + entire tuple is returned. + + Returns a tuple containing the transformed url as item 0 and the + original tuple returned by transform_netloc as item 1. + """ + purl = urllib.parse.urlsplit(url) + netloc_tuple = transform_netloc(purl.netloc) + # stripped url + url_pieces = (purl.scheme, netloc_tuple[0], purl.path, purl.query, purl.fragment) + surl = urllib.parse.urlunsplit(url_pieces) + return surl, cast("NetlocTuple", netloc_tuple) + + +def _get_netloc(netloc: str) -> NetlocTuple: + return split_auth_from_netloc(netloc) + + +def _redact_netloc(netloc: str) -> Tuple[str]: + return (redact_netloc(netloc),) + + +def split_auth_netloc_from_url( + url: str, +) -> Tuple[str, str, Tuple[Optional[str], Optional[str]]]: + """ + Parse a url into separate netloc, auth, and url with no auth. + + Returns: (url_without_auth, netloc, (username, password)) + """ + url_without_auth, (netloc, auth) = _transform_url(url, _get_netloc) + return url_without_auth, netloc, auth + + +def remove_auth_from_url(url: str) -> str: + """Return a copy of url with 'username:password@' removed.""" + # username/pass params are passed to subversion through flags + # and are not recognized in the url. + return _transform_url(url, _get_netloc)[0] + + +def redact_auth_from_url(url: str) -> str: + """Replace the password in a given url with ****.""" + return _transform_url(url, _redact_netloc)[0] + + +def redact_auth_from_requirement(req: Requirement) -> str: + """Replace the password in a given requirement url with ****.""" + if not req.url: + return str(req) + return str(req).replace(req.url, redact_auth_from_url(req.url)) + + +@dataclass(frozen=True) +class HiddenText: + secret: str + redacted: str + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.redacted + + # This is useful for testing. + def __eq__(self, other: Any) -> bool: + if type(self) is not type(other): + return False + + # The string being used for redaction doesn't also have to match, + # just the raw, original string. + return self.secret == other.secret + + +def hide_value(value: str) -> HiddenText: + return HiddenText(value, redacted="****") + + +def hide_url(url: str) -> HiddenText: + redacted = redact_auth_from_url(url) + return HiddenText(url, redacted=redacted) + + +def protect_pip_from_modification_on_windows(modifying_pip: bool) -> None: + """Protection of pip.exe from modification on Windows + + On Windows, any operation modifying pip should be run as: + python -m pip ... + """ + pip_names = [ + "pip", + f"pip{sys.version_info.major}", + f"pip{sys.version_info.major}.{sys.version_info.minor}", + ] + + # See https://github.com/pypa/pip/issues/1299 for more discussion + should_show_use_python_msg = ( + modifying_pip and WINDOWS and os.path.basename(sys.argv[0]) in pip_names + ) + + if should_show_use_python_msg: + new_command = [sys.executable, "-m", "pip"] + sys.argv[1:] + raise CommandError( + "To modify pip, please run the following command:\n{}".format( + " ".join(new_command) + ) + ) + + +def check_externally_managed() -> None: + """Check whether the current environment is externally managed. + + If the ``EXTERNALLY-MANAGED`` config file is found, the current environment + is considered externally managed, and an ExternallyManagedEnvironment is + raised. + """ + if running_under_virtualenv(): + return + marker = os.path.join(sysconfig.get_path("stdlib"), "EXTERNALLY-MANAGED") + if not os.path.isfile(marker): + return + raise ExternallyManagedEnvironment.from_config(marker) + + +def is_console_interactive() -> bool: + """Is this console interactive?""" + return sys.stdin is not None and sys.stdin.isatty() + + +def hash_file(path: str, blocksize: int = 1 << 20) -> Tuple[Any, int]: + """Return (hash, length) for path using hashlib.sha256()""" + + h = hashlib.sha256() + length = 0 + with open(path, "rb") as f: + for block in read_chunks(f, size=blocksize): + length += len(block) + h.update(block) + return h, length + + +def pairwise(iterable: Iterable[Any]) -> Iterator[Tuple[Any, Any]]: + """ + Return paired elements. + + For example: + s -> (s0, s1), (s2, s3), (s4, s5), ... + """ + iterable = iter(iterable) + return zip_longest(iterable, iterable) + + +def partition( + pred: Callable[[T], bool], iterable: Iterable[T] +) -> Tuple[Iterable[T], Iterable[T]]: + """ + Use a predicate to partition entries into false entries and true entries, + like + + partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + """ + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + +class ConfiguredBuildBackendHookCaller(BuildBackendHookCaller): + def __init__( + self, + config_holder: Any, + source_dir: str, + build_backend: str, + backend_path: Optional[str] = None, + runner: Optional[Callable[..., None]] = None, + python_executable: Optional[str] = None, + ): + super().__init__( + source_dir, build_backend, backend_path, runner, python_executable + ) + self.config_holder = config_holder + + def build_wheel( + self, + wheel_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + metadata_directory: Optional[str] = None, + ) -> str: + cs = self.config_holder.config_settings + return super().build_wheel( + wheel_directory, config_settings=cs, metadata_directory=metadata_directory + ) + + def build_sdist( + self, + sdist_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + ) -> str: + cs = self.config_holder.config_settings + return super().build_sdist(sdist_directory, config_settings=cs) + + def build_editable( + self, + wheel_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + metadata_directory: Optional[str] = None, + ) -> str: + cs = self.config_holder.config_settings + return super().build_editable( + wheel_directory, config_settings=cs, metadata_directory=metadata_directory + ) + + def get_requires_for_build_wheel( + self, config_settings: Optional[Mapping[str, Any]] = None + ) -> Sequence[str]: + cs = self.config_holder.config_settings + return super().get_requires_for_build_wheel(config_settings=cs) + + def get_requires_for_build_sdist( + self, config_settings: Optional[Mapping[str, Any]] = None + ) -> Sequence[str]: + cs = self.config_holder.config_settings + return super().get_requires_for_build_sdist(config_settings=cs) + + def get_requires_for_build_editable( + self, config_settings: Optional[Mapping[str, Any]] = None + ) -> Sequence[str]: + cs = self.config_holder.config_settings + return super().get_requires_for_build_editable(config_settings=cs) + + def prepare_metadata_for_build_wheel( + self, + metadata_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + _allow_fallback: bool = True, + ) -> str: + cs = self.config_holder.config_settings + return super().prepare_metadata_for_build_wheel( + metadata_directory=metadata_directory, + config_settings=cs, + _allow_fallback=_allow_fallback, + ) + + def prepare_metadata_for_build_editable( + self, + metadata_directory: str, + config_settings: Optional[Mapping[str, Any]] = None, + _allow_fallback: bool = True, + ) -> Optional[str]: + cs = self.config_holder.config_settings + return super().prepare_metadata_for_build_editable( + metadata_directory=metadata_directory, + config_settings=cs, + _allow_fallback=_allow_fallback, + ) + + +def warn_if_run_as_root() -> None: + """Output a warning for sudo users on Unix. + + In a virtual environment, sudo pip still writes to virtualenv. + On Windows, users may run pip as Administrator without issues. + This warning only applies to Unix root users outside of virtualenv. + """ + if running_under_virtualenv(): + return + if not hasattr(os, "getuid"): + return + # On Windows, there are no "system managed" Python packages. Installing as + # Administrator via pip is the correct way of updating system environments. + # + # We choose sys.platform over utils.compat.WINDOWS here to enable Mypy platform + # checks: https://mypy.readthedocs.io/en/stable/common_issues.html + if sys.platform == "win32" or sys.platform == "cygwin": + return + + if os.getuid() != 0: + return + + logger.warning( + "Running pip as the 'root' user can result in broken permissions and " + "conflicting behaviour with the system package manager, possibly " + "rendering your system unusable. " + "It is recommended to use a virtual environment instead: " + "https://pip.pypa.io/warnings/venv. " + "Use the --root-user-action option if you know what you are doing and " + "want to suppress this warning." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/packaging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/packaging.py new file mode 100644 index 0000000000000000000000000000000000000000..caad70f7fd17593769cbb5db99035e8c21b21a58 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/packaging.py @@ -0,0 +1,58 @@ +import functools +import logging +import re +from typing import NewType, Optional, Tuple, cast + +from pip._vendor.packaging import specifiers, version +from pip._vendor.packaging.requirements import Requirement + +NormalizedExtra = NewType("NormalizedExtra", str) + +logger = logging.getLogger(__name__) + + +@functools.lru_cache(maxsize=32) +def check_requires_python( + requires_python: Optional[str], version_info: Tuple[int, ...] +) -> bool: + """ + Check if the given Python version matches a "Requires-Python" specifier. + + :param version_info: A 3-tuple of ints representing a Python + major-minor-micro version to check (e.g. `sys.version_info[:3]`). + + :return: `True` if the given Python version satisfies the requirement. + Otherwise, return `False`. + + :raises InvalidSpecifier: If `requires_python` has an invalid format. + """ + if requires_python is None: + # The package provides no information + return True + requires_python_specifier = specifiers.SpecifierSet(requires_python) + + python_version = version.parse(".".join(map(str, version_info))) + return python_version in requires_python_specifier + + +@functools.lru_cache(maxsize=2048) +def get_requirement(req_string: str) -> Requirement: + """Construct a packaging.Requirement object with caching""" + # Parsing requirement strings is expensive, and is also expected to happen + # with a low diversity of different arguments (at least relative the number + # constructed). This method adds a cache to requirement object creation to + # minimize repeated parsing of the same string to construct equivalent + # Requirement objects. + return Requirement(req_string) + + +def safe_extra(extra: str) -> NormalizedExtra: + """Convert an arbitrary string to a standard 'extra' name + + Any runs of non-alphanumeric characters are replaced with a single '_', + and the result is always lowercased. + + This function is duplicated from ``pkg_resources``. Note that this is not + the same to either ``canonicalize_name`` or ``_egg_link_name``. + """ + return cast(NormalizedExtra, re.sub("[^A-Za-z0-9.-]+", "_", extra).lower()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/retry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/retry.py new file mode 100644 index 0000000000000000000000000000000000000000..abfe07286ea747f656ea73f5a6919f1d66215847 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/retry.py @@ -0,0 +1,42 @@ +import functools +from time import perf_counter, sleep +from typing import Callable, TypeVar + +from pip._vendor.typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def retry( + wait: float, stop_after_delay: float +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Decorator to automatically retry a function on error. + + If the function raises, the function is recalled with the same arguments + until it returns or the time limit is reached. When the time limit is + surpassed, the last exception raised is reraised. + + :param wait: The time to wait after an error before retrying, in seconds. + :param stop_after_delay: The time limit after which retries will cease, + in seconds. + """ + + def wrapper(func: Callable[P, T]) -> Callable[P, T]: + + @functools.wraps(func) + def retry_wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + # The performance counter is monotonic on all platforms we care + # about and has much better resolution than time.monotonic(). + start_time = perf_counter() + while True: + try: + return func(*args, **kwargs) + except Exception: + if perf_counter() - start_time > stop_after_delay: + raise + sleep(wait) + + return retry_wrapped + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/setuptools_build.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/setuptools_build.py new file mode 100644 index 0000000000000000000000000000000000000000..96d1b2460670e20ac92a5ade7a74b7ab1cba71d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/setuptools_build.py @@ -0,0 +1,146 @@ +import sys +import textwrap +from typing import List, Optional, Sequence + +# Shim to wrap setup.py invocation with setuptools +# Note that __file__ is handled via two {!r} *and* %r, to ensure that paths on +# Windows are correctly handled (it should be "C:\\Users" not "C:\Users"). +_SETUPTOOLS_SHIM = textwrap.dedent( + """ + exec(compile(''' + # This is -- a caller that pip uses to run setup.py + # + # - It imports setuptools before invoking setup.py, to enable projects that directly + # import from `distutils.core` to work with newer packaging standards. + # - It provides a clear error message when setuptools is not installed. + # - It sets `sys.argv[0]` to the underlying `setup.py`, when invoking `setup.py` so + # setuptools doesn't think the script is `-c`. This avoids the following warning: + # manifest_maker: standard file '-c' not found". + # - It generates a shim setup.py, for handling setup.cfg-only projects. + import os, sys, tokenize + + try: + import setuptools + except ImportError as error: + print( + "ERROR: Can not execute `setup.py` since setuptools is not available in " + "the build environment.", + file=sys.stderr, + ) + sys.exit(1) + + __file__ = %r + sys.argv[0] = __file__ + + if os.path.exists(__file__): + filename = __file__ + with tokenize.open(__file__) as f: + setup_py_code = f.read() + else: + filename = "" + setup_py_code = "from setuptools import setup; setup()" + + exec(compile(setup_py_code, filename, "exec")) + ''' % ({!r},), "", "exec")) + """ +).rstrip() + + +def make_setuptools_shim_args( + setup_py_path: str, + global_options: Optional[Sequence[str]] = None, + no_user_config: bool = False, + unbuffered_output: bool = False, +) -> List[str]: + """ + Get setuptools command arguments with shim wrapped setup file invocation. + + :param setup_py_path: The path to setup.py to be wrapped. + :param global_options: Additional global options. + :param no_user_config: If True, disables personal user configuration. + :param unbuffered_output: If True, adds the unbuffered switch to the + argument list. + """ + args = [sys.executable] + if unbuffered_output: + args += ["-u"] + args += ["-c", _SETUPTOOLS_SHIM.format(setup_py_path)] + if global_options: + args += global_options + if no_user_config: + args += ["--no-user-cfg"] + return args + + +def make_setuptools_bdist_wheel_args( + setup_py_path: str, + global_options: Sequence[str], + build_options: Sequence[str], + destination_dir: str, +) -> List[str]: + # NOTE: Eventually, we'd want to also -S to the flags here, when we're + # isolating. Currently, it breaks Python in virtualenvs, because it + # relies on site.py to find parts of the standard library outside the + # virtualenv. + args = make_setuptools_shim_args( + setup_py_path, global_options=global_options, unbuffered_output=True + ) + args += ["bdist_wheel", "-d", destination_dir] + args += build_options + return args + + +def make_setuptools_clean_args( + setup_py_path: str, + global_options: Sequence[str], +) -> List[str]: + args = make_setuptools_shim_args( + setup_py_path, global_options=global_options, unbuffered_output=True + ) + args += ["clean", "--all"] + return args + + +def make_setuptools_develop_args( + setup_py_path: str, + *, + global_options: Sequence[str], + no_user_config: bool, + prefix: Optional[str], + home: Optional[str], + use_user_site: bool, +) -> List[str]: + assert not (use_user_site and prefix) + + args = make_setuptools_shim_args( + setup_py_path, + global_options=global_options, + no_user_config=no_user_config, + ) + + args += ["develop", "--no-deps"] + + if prefix: + args += ["--prefix", prefix] + if home is not None: + args += ["--install-dir", home] + + if use_user_site: + args += ["--user", "--prefix="] + + return args + + +def make_setuptools_egg_info_args( + setup_py_path: str, + egg_info_dir: Optional[str], + no_user_config: bool, +) -> List[str]: + args = make_setuptools_shim_args(setup_py_path, no_user_config=no_user_config) + + args += ["egg_info"] + + if egg_info_dir: + args += ["--egg-base", egg_info_dir] + + return args diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/subprocess.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/subprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2e23f007aca75c7e96e37df42ac0df6f2591e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/subprocess.py @@ -0,0 +1,245 @@ +import logging +import os +import shlex +import subprocess +from typing import Any, Callable, Iterable, List, Literal, Mapping, Optional, Union + +from pip._vendor.rich.markup import escape + +from pip._internal.cli.spinners import SpinnerInterface, open_spinner +from pip._internal.exceptions import InstallationSubprocessError +from pip._internal.utils.logging import VERBOSE, subprocess_logger +from pip._internal.utils.misc import HiddenText + +CommandArgs = List[Union[str, HiddenText]] + + +def make_command(*args: Union[str, HiddenText, CommandArgs]) -> CommandArgs: + """ + Create a CommandArgs object. + """ + command_args: CommandArgs = [] + for arg in args: + # Check for list instead of CommandArgs since CommandArgs is + # only known during type-checking. + if isinstance(arg, list): + command_args.extend(arg) + else: + # Otherwise, arg is str or HiddenText. + command_args.append(arg) + + return command_args + + +def format_command_args(args: Union[List[str], CommandArgs]) -> str: + """ + Format command arguments for display. + """ + # For HiddenText arguments, display the redacted form by calling str(). + # Also, we don't apply str() to arguments that aren't HiddenText since + # this can trigger a UnicodeDecodeError in Python 2 if the argument + # has type unicode and includes a non-ascii character. (The type + # checker doesn't ensure the annotations are correct in all cases.) + return " ".join( + shlex.quote(str(arg)) if isinstance(arg, HiddenText) else shlex.quote(arg) + for arg in args + ) + + +def reveal_command_args(args: Union[List[str], CommandArgs]) -> List[str]: + """ + Return the arguments in their raw, unredacted form. + """ + return [arg.secret if isinstance(arg, HiddenText) else arg for arg in args] + + +def call_subprocess( + cmd: Union[List[str], CommandArgs], + show_stdout: bool = False, + cwd: Optional[str] = None, + on_returncode: 'Literal["raise", "warn", "ignore"]' = "raise", + extra_ok_returncodes: Optional[Iterable[int]] = None, + extra_environ: Optional[Mapping[str, Any]] = None, + unset_environ: Optional[Iterable[str]] = None, + spinner: Optional[SpinnerInterface] = None, + log_failed_cmd: Optional[bool] = True, + stdout_only: Optional[bool] = False, + *, + command_desc: str, +) -> str: + """ + Args: + show_stdout: if true, use INFO to log the subprocess's stderr and + stdout streams. Otherwise, use DEBUG. Defaults to False. + extra_ok_returncodes: an iterable of integer return codes that are + acceptable, in addition to 0. Defaults to None, which means []. + unset_environ: an iterable of environment variable names to unset + prior to calling subprocess.Popen(). + log_failed_cmd: if false, failed commands are not logged, only raised. + stdout_only: if true, return only stdout, else return both. When true, + logging of both stdout and stderr occurs when the subprocess has + terminated, else logging occurs as subprocess output is produced. + """ + if extra_ok_returncodes is None: + extra_ok_returncodes = [] + if unset_environ is None: + unset_environ = [] + # Most places in pip use show_stdout=False. What this means is-- + # + # - We connect the child's output (combined stderr and stdout) to a + # single pipe, which we read. + # - We log this output to stderr at DEBUG level as it is received. + # - If DEBUG logging isn't enabled (e.g. if --verbose logging wasn't + # requested), then we show a spinner so the user can still see the + # subprocess is in progress. + # - If the subprocess exits with an error, we log the output to stderr + # at ERROR level if it hasn't already been displayed to the console + # (e.g. if --verbose logging wasn't enabled). This way we don't log + # the output to the console twice. + # + # If show_stdout=True, then the above is still done, but with DEBUG + # replaced by INFO. + if show_stdout: + # Then log the subprocess output at INFO level. + log_subprocess: Callable[..., None] = subprocess_logger.info + used_level = logging.INFO + else: + # Then log the subprocess output using VERBOSE. This also ensures + # it will be logged to the log file (aka user_log), if enabled. + log_subprocess = subprocess_logger.verbose + used_level = VERBOSE + + # Whether the subprocess will be visible in the console. + showing_subprocess = subprocess_logger.getEffectiveLevel() <= used_level + + # Only use the spinner if we're not showing the subprocess output + # and we have a spinner. + use_spinner = not showing_subprocess and spinner is not None + + log_subprocess("Running command %s", command_desc) + env = os.environ.copy() + if extra_environ: + env.update(extra_environ) + for name in unset_environ: + env.pop(name, None) + try: + proc = subprocess.Popen( + # Convert HiddenText objects to the underlying str. + reveal_command_args(cmd), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT if not stdout_only else subprocess.PIPE, + cwd=cwd, + env=env, + errors="backslashreplace", + ) + except Exception as exc: + if log_failed_cmd: + subprocess_logger.critical( + "Error %s while executing command %s", + exc, + command_desc, + ) + raise + all_output = [] + if not stdout_only: + assert proc.stdout + assert proc.stdin + proc.stdin.close() + # In this mode, stdout and stderr are in the same pipe. + while True: + line: str = proc.stdout.readline() + if not line: + break + line = line.rstrip() + all_output.append(line + "\n") + + # Show the line immediately. + log_subprocess(line) + # Update the spinner. + if use_spinner: + assert spinner + spinner.spin() + try: + proc.wait() + finally: + if proc.stdout: + proc.stdout.close() + output = "".join(all_output) + else: + # In this mode, stdout and stderr are in different pipes. + # We must use communicate() which is the only safe way to read both. + out, err = proc.communicate() + # log line by line to preserve pip log indenting + for out_line in out.splitlines(): + log_subprocess(out_line) + all_output.append(out) + for err_line in err.splitlines(): + log_subprocess(err_line) + all_output.append(err) + output = out + + proc_had_error = proc.returncode and proc.returncode not in extra_ok_returncodes + if use_spinner: + assert spinner + if proc_had_error: + spinner.finish("error") + else: + spinner.finish("done") + if proc_had_error: + if on_returncode == "raise": + error = InstallationSubprocessError( + command_description=command_desc, + exit_code=proc.returncode, + output_lines=all_output if not showing_subprocess else None, + ) + if log_failed_cmd: + subprocess_logger.error("%s", error, extra={"rich": True}) + subprocess_logger.verbose( + "[bold magenta]full command[/]: [blue]%s[/]", + escape(format_command_args(cmd)), + extra={"markup": True}, + ) + subprocess_logger.verbose( + "[bold magenta]cwd[/]: %s", + escape(cwd or "[inherit]"), + extra={"markup": True}, + ) + + raise error + elif on_returncode == "warn": + subprocess_logger.warning( + 'Command "%s" had error code %s in %s', + command_desc, + proc.returncode, + cwd, + ) + elif on_returncode == "ignore": + pass + else: + raise ValueError(f"Invalid value: on_returncode={on_returncode!r}") + return output + + +def runner_with_spinner_message(message: str) -> Callable[..., None]: + """Provide a subprocess_runner that shows a spinner message. + + Intended for use with for BuildBackendHookCaller. Thus, the runner has + an API that matches what's expected by BuildBackendHookCaller.subprocess_runner. + """ + + def runner( + cmd: List[str], + cwd: Optional[str] = None, + extra_environ: Optional[Mapping[str, Any]] = None, + ) -> None: + with open_spinner(message) as spinner: + call_subprocess( + cmd, + command_desc=message, + cwd=cwd, + extra_environ=extra_environ, + spinner=spinner, + ) + + return runner diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/temp_dir.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/temp_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..06668e8ab2dad131106cd9e4963d871cea147997 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/temp_dir.py @@ -0,0 +1,296 @@ +import errno +import itertools +import logging +import os.path +import tempfile +import traceback +from contextlib import ExitStack, contextmanager +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + TypeVar, + Union, +) + +from pip._internal.utils.misc import enum, rmtree + +logger = logging.getLogger(__name__) + +_T = TypeVar("_T", bound="TempDirectory") + + +# Kinds of temporary directories. Only needed for ones that are +# globally-managed. +tempdir_kinds = enum( + BUILD_ENV="build-env", + EPHEM_WHEEL_CACHE="ephem-wheel-cache", + REQ_BUILD="req-build", +) + + +_tempdir_manager: Optional[ExitStack] = None + + +@contextmanager +def global_tempdir_manager() -> Generator[None, None, None]: + global _tempdir_manager + with ExitStack() as stack: + old_tempdir_manager, _tempdir_manager = _tempdir_manager, stack + try: + yield + finally: + _tempdir_manager = old_tempdir_manager + + +class TempDirectoryTypeRegistry: + """Manages temp directory behavior""" + + def __init__(self) -> None: + self._should_delete: Dict[str, bool] = {} + + def set_delete(self, kind: str, value: bool) -> None: + """Indicate whether a TempDirectory of the given kind should be + auto-deleted. + """ + self._should_delete[kind] = value + + def get_delete(self, kind: str) -> bool: + """Get configured auto-delete flag for a given TempDirectory type, + default True. + """ + return self._should_delete.get(kind, True) + + +_tempdir_registry: Optional[TempDirectoryTypeRegistry] = None + + +@contextmanager +def tempdir_registry() -> Generator[TempDirectoryTypeRegistry, None, None]: + """Provides a scoped global tempdir registry that can be used to dictate + whether directories should be deleted. + """ + global _tempdir_registry + old_tempdir_registry = _tempdir_registry + _tempdir_registry = TempDirectoryTypeRegistry() + try: + yield _tempdir_registry + finally: + _tempdir_registry = old_tempdir_registry + + +class _Default: + pass + + +_default = _Default() + + +class TempDirectory: + """Helper class that owns and cleans up a temporary directory. + + This class can be used as a context manager or as an OO representation of a + temporary directory. + + Attributes: + path + Location to the created temporary directory + delete + Whether the directory should be deleted when exiting + (when used as a contextmanager) + + Methods: + cleanup() + Deletes the temporary directory + + When used as a context manager, if the delete attribute is True, on + exiting the context the temporary directory is deleted. + """ + + def __init__( + self, + path: Optional[str] = None, + delete: Union[bool, None, _Default] = _default, + kind: str = "temp", + globally_managed: bool = False, + ignore_cleanup_errors: bool = True, + ): + super().__init__() + + if delete is _default: + if path is not None: + # If we were given an explicit directory, resolve delete option + # now. + delete = False + else: + # Otherwise, we wait until cleanup and see what + # tempdir_registry says. + delete = None + + # The only time we specify path is in for editables where it + # is the value of the --src option. + if path is None: + path = self._create(kind) + + self._path = path + self._deleted = False + self.delete = delete + self.kind = kind + self.ignore_cleanup_errors = ignore_cleanup_errors + + if globally_managed: + assert _tempdir_manager is not None + _tempdir_manager.enter_context(self) + + @property + def path(self) -> str: + assert not self._deleted, f"Attempted to access deleted path: {self._path}" + return self._path + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.path!r}>" + + def __enter__(self: _T) -> _T: + return self + + def __exit__(self, exc: Any, value: Any, tb: Any) -> None: + if self.delete is not None: + delete = self.delete + elif _tempdir_registry: + delete = _tempdir_registry.get_delete(self.kind) + else: + delete = True + + if delete: + self.cleanup() + + def _create(self, kind: str) -> str: + """Create a temporary directory and store its path in self.path""" + # We realpath here because some systems have their default tmpdir + # symlinked to another directory. This tends to confuse build + # scripts, so we canonicalize the path by traversing potential + # symlinks here. + path = os.path.realpath(tempfile.mkdtemp(prefix=f"pip-{kind}-")) + logger.debug("Created temporary directory: %s", path) + return path + + def cleanup(self) -> None: + """Remove the temporary directory created and reset state""" + self._deleted = True + if not os.path.exists(self._path): + return + + errors: List[BaseException] = [] + + def onerror( + func: Callable[..., Any], + path: Path, + exc_val: BaseException, + ) -> None: + """Log a warning for a `rmtree` error and continue""" + formatted_exc = "\n".join( + traceback.format_exception_only(type(exc_val), exc_val) + ) + formatted_exc = formatted_exc.rstrip() # remove trailing new line + if func in (os.unlink, os.remove, os.rmdir): + logger.debug( + "Failed to remove a temporary file '%s' due to %s.\n", + path, + formatted_exc, + ) + else: + logger.debug("%s failed with %s.", func.__qualname__, formatted_exc) + errors.append(exc_val) + + if self.ignore_cleanup_errors: + try: + # first try with @retry; retrying to handle ephemeral errors + rmtree(self._path, ignore_errors=False) + except OSError: + # last pass ignore/log all errors + rmtree(self._path, onexc=onerror) + if errors: + logger.warning( + "Failed to remove contents in a temporary directory '%s'.\n" + "You can safely remove it manually.", + self._path, + ) + else: + rmtree(self._path) + + +class AdjacentTempDirectory(TempDirectory): + """Helper class that creates a temporary directory adjacent to a real one. + + Attributes: + original + The original directory to create a temp directory for. + path + After calling create() or entering, contains the full + path to the temporary directory. + delete + Whether the directory should be deleted when exiting + (when used as a contextmanager) + + """ + + # The characters that may be used to name the temp directory + # We always prepend a ~ and then rotate through these until + # a usable name is found. + # pkg_resources raises a different error for .dist-info folder + # with leading '-' and invalid metadata + LEADING_CHARS = "-~.=%0123456789" + + def __init__(self, original: str, delete: Optional[bool] = None) -> None: + self.original = original.rstrip("/\\") + super().__init__(delete=delete) + + @classmethod + def _generate_names(cls, name: str) -> Generator[str, None, None]: + """Generates a series of temporary names. + + The algorithm replaces the leading characters in the name + with ones that are valid filesystem characters, but are not + valid package names (for both Python and pip definitions of + package). + """ + for i in range(1, len(name)): + for candidate in itertools.combinations_with_replacement( + cls.LEADING_CHARS, i - 1 + ): + new_name = "~" + "".join(candidate) + name[i:] + if new_name != name: + yield new_name + + # If we make it this far, we will have to make a longer name + for i in range(len(cls.LEADING_CHARS)): + for candidate in itertools.combinations_with_replacement( + cls.LEADING_CHARS, i + ): + new_name = "~" + "".join(candidate) + name + if new_name != name: + yield new_name + + def _create(self, kind: str) -> str: + root, name = os.path.split(self.original) + for candidate in self._generate_names(name): + path = os.path.join(root, candidate) + try: + os.mkdir(path) + except OSError as ex: + # Continue if the name exists already + if ex.errno != errno.EEXIST: + raise + else: + path = os.path.realpath(path) + break + else: + # Final fallback on the default behavior. + path = os.path.realpath(tempfile.mkdtemp(prefix=f"pip-{kind}-")) + + logger.debug("Created temporary directory: %s", path) + return path diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/unpacking.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/unpacking.py new file mode 100644 index 0000000000000000000000000000000000000000..87a6d19ab5a9f9f305cbb45f62b8f918fc867946 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/unpacking.py @@ -0,0 +1,337 @@ +"""Utilities related archives. +""" + +import logging +import os +import shutil +import stat +import sys +import tarfile +import zipfile +from typing import Iterable, List, Optional +from zipfile import ZipInfo + +from pip._internal.exceptions import InstallationError +from pip._internal.utils.filetypes import ( + BZ2_EXTENSIONS, + TAR_EXTENSIONS, + XZ_EXTENSIONS, + ZIP_EXTENSIONS, +) +from pip._internal.utils.misc import ensure_dir + +logger = logging.getLogger(__name__) + + +SUPPORTED_EXTENSIONS = ZIP_EXTENSIONS + TAR_EXTENSIONS + +try: + import bz2 # noqa + + SUPPORTED_EXTENSIONS += BZ2_EXTENSIONS +except ImportError: + logger.debug("bz2 module is not available") + +try: + # Only for Python 3.3+ + import lzma # noqa + + SUPPORTED_EXTENSIONS += XZ_EXTENSIONS +except ImportError: + logger.debug("lzma module is not available") + + +def current_umask() -> int: + """Get the current umask which involves having to set it temporarily.""" + mask = os.umask(0) + os.umask(mask) + return mask + + +def split_leading_dir(path: str) -> List[str]: + path = path.lstrip("/").lstrip("\\") + if "/" in path and ( + ("\\" in path and path.find("/") < path.find("\\")) or "\\" not in path + ): + return path.split("/", 1) + elif "\\" in path: + return path.split("\\", 1) + else: + return [path, ""] + + +def has_leading_dir(paths: Iterable[str]) -> bool: + """Returns true if all the paths have the same leading path name + (i.e., everything is in one subdirectory in an archive)""" + common_prefix = None + for path in paths: + prefix, rest = split_leading_dir(path) + if not prefix: + return False + elif common_prefix is None: + common_prefix = prefix + elif prefix != common_prefix: + return False + return True + + +def is_within_directory(directory: str, target: str) -> bool: + """ + Return true if the absolute path of target is within the directory + """ + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + +def _get_default_mode_plus_executable() -> int: + return 0o777 & ~current_umask() | 0o111 + + +def set_extracted_file_to_default_mode_plus_executable(path: str) -> None: + """ + Make file present at path have execute for user/group/world + (chmod +x) is no-op on windows per python docs + """ + os.chmod(path, _get_default_mode_plus_executable()) + + +def zip_item_is_executable(info: ZipInfo) -> bool: + mode = info.external_attr >> 16 + # if mode and regular file and any execute permissions for + # user/group/world? + return bool(mode and stat.S_ISREG(mode) and mode & 0o111) + + +def unzip_file(filename: str, location: str, flatten: bool = True) -> None: + """ + Unzip the file (with path `filename`) to the destination `location`. All + files are written based on system defaults and umask (i.e. permissions are + not preserved), except that regular file members with any execute + permissions (user, group, or world) have "chmod +x" applied after being + written. Note that for windows, any execute changes using os.chmod are + no-ops per the python docs. + """ + ensure_dir(location) + zipfp = open(filename, "rb") + try: + zip = zipfile.ZipFile(zipfp, allowZip64=True) + leading = has_leading_dir(zip.namelist()) and flatten + for info in zip.infolist(): + name = info.filename + fn = name + if leading: + fn = split_leading_dir(name)[1] + fn = os.path.join(location, fn) + dir = os.path.dirname(fn) + if not is_within_directory(location, fn): + message = ( + "The zip file ({}) has a file ({}) trying to install " + "outside target directory ({})" + ) + raise InstallationError(message.format(filename, fn, location)) + if fn.endswith("/") or fn.endswith("\\"): + # A directory + ensure_dir(fn) + else: + ensure_dir(dir) + # Don't use read() to avoid allocating an arbitrarily large + # chunk of memory for the file's content + fp = zip.open(name) + try: + with open(fn, "wb") as destfp: + shutil.copyfileobj(fp, destfp) + finally: + fp.close() + if zip_item_is_executable(info): + set_extracted_file_to_default_mode_plus_executable(fn) + finally: + zipfp.close() + + +def untar_file(filename: str, location: str) -> None: + """ + Untar the file (with path `filename`) to the destination `location`. + All files are written based on system defaults and umask (i.e. permissions + are not preserved), except that regular file members with any execute + permissions (user, group, or world) have "chmod +x" applied on top of the + default. Note that for windows, any execute changes using os.chmod are + no-ops per the python docs. + """ + ensure_dir(location) + if filename.lower().endswith(".gz") or filename.lower().endswith(".tgz"): + mode = "r:gz" + elif filename.lower().endswith(BZ2_EXTENSIONS): + mode = "r:bz2" + elif filename.lower().endswith(XZ_EXTENSIONS): + mode = "r:xz" + elif filename.lower().endswith(".tar"): + mode = "r" + else: + logger.warning( + "Cannot determine compression type for file %s", + filename, + ) + mode = "r:*" + + tar = tarfile.open(filename, mode, encoding="utf-8") # type: ignore + try: + leading = has_leading_dir([member.name for member in tar.getmembers()]) + + # PEP 706 added `tarfile.data_filter`, and made some other changes to + # Python's tarfile module (see below). The features were backported to + # security releases. + try: + data_filter = tarfile.data_filter + except AttributeError: + _untar_without_filter(filename, location, tar, leading) + else: + default_mode_plus_executable = _get_default_mode_plus_executable() + + if leading: + # Strip the leading directory from all files in the archive, + # including hardlink targets (which are relative to the + # unpack location). + for member in tar.getmembers(): + name_lead, name_rest = split_leading_dir(member.name) + member.name = name_rest + if member.islnk(): + lnk_lead, lnk_rest = split_leading_dir(member.linkname) + if lnk_lead == name_lead: + member.linkname = lnk_rest + + def pip_filter(member: tarfile.TarInfo, path: str) -> tarfile.TarInfo: + orig_mode = member.mode + try: + try: + member = data_filter(member, location) + except tarfile.LinkOutsideDestinationError: + if sys.version_info[:3] in { + (3, 8, 17), + (3, 9, 17), + (3, 10, 12), + (3, 11, 4), + }: + # The tarfile filter in specific Python versions + # raises LinkOutsideDestinationError on valid input + # (https://github.com/python/cpython/issues/107845) + # Ignore the error there, but do use the + # more lax `tar_filter` + member = tarfile.tar_filter(member, location) + else: + raise + except tarfile.TarError as exc: + message = "Invalid member in the tar file {}: {}" + # Filter error messages mention the member name. + # No need to add it here. + raise InstallationError( + message.format( + filename, + exc, + ) + ) + if member.isfile() and orig_mode & 0o111: + member.mode = default_mode_plus_executable + else: + # See PEP 706 note above. + # The PEP changed this from `int` to `Optional[int]`, + # where None means "use the default". Mypy doesn't + # know this yet. + member.mode = None # type: ignore [assignment] + return member + + tar.extractall(location, filter=pip_filter) + + finally: + tar.close() + + +def _untar_without_filter( + filename: str, + location: str, + tar: tarfile.TarFile, + leading: bool, +) -> None: + """Fallback for Python without tarfile.data_filter""" + for member in tar.getmembers(): + fn = member.name + if leading: + fn = split_leading_dir(fn)[1] + path = os.path.join(location, fn) + if not is_within_directory(location, path): + message = ( + "The tar file ({}) has a file ({}) trying to install " + "outside target directory ({})" + ) + raise InstallationError(message.format(filename, path, location)) + if member.isdir(): + ensure_dir(path) + elif member.issym(): + try: + tar._extract_member(member, path) + except Exception as exc: + # Some corrupt tar files seem to produce this + # (specifically bad symlinks) + logger.warning( + "In the tar file %s the member %s is invalid: %s", + filename, + member.name, + exc, + ) + continue + else: + try: + fp = tar.extractfile(member) + except (KeyError, AttributeError) as exc: + # Some corrupt tar files seem to produce this + # (specifically bad symlinks) + logger.warning( + "In the tar file %s the member %s is invalid: %s", + filename, + member.name, + exc, + ) + continue + ensure_dir(os.path.dirname(path)) + assert fp is not None + with open(path, "wb") as destfp: + shutil.copyfileobj(fp, destfp) + fp.close() + # Update the timestamp (useful for cython compiled files) + tar.utime(member, path) + # member have any execute permissions for user/group/world? + if member.mode & 0o111: + set_extracted_file_to_default_mode_plus_executable(path) + + +def unpack_file( + filename: str, + location: str, + content_type: Optional[str] = None, +) -> None: + filename = os.path.realpath(filename) + if ( + content_type == "application/zip" + or filename.lower().endswith(ZIP_EXTENSIONS) + or zipfile.is_zipfile(filename) + ): + unzip_file(filename, location, flatten=not filename.endswith(".whl")) + elif ( + content_type == "application/x-gzip" + or tarfile.is_tarfile(filename) + or filename.lower().endswith(TAR_EXTENSIONS + BZ2_EXTENSIONS + XZ_EXTENSIONS) + ): + untar_file(filename, location) + else: + # FIXME: handle? + # FIXME: magic signatures? + logger.critical( + "Cannot unpack file %s (downloaded from %s, content-type: %s); " + "cannot detect archive format", + filename, + location, + content_type, + ) + raise InstallationError(f"Cannot determine archive format of {location}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/urls.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/urls.py new file mode 100644 index 0000000000000000000000000000000000000000..9f34f882a1a6b7bf8e8ec5eb42c5d28f2c4e30aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/urls.py @@ -0,0 +1,55 @@ +import os +import string +import urllib.parse +import urllib.request + +from .compat import WINDOWS + + +def path_to_url(path: str) -> str: + """ + Convert a path to a file: URL. The path will be made absolute and have + quoted path parts. + """ + path = os.path.normpath(os.path.abspath(path)) + url = urllib.parse.urljoin("file:", urllib.request.pathname2url(path)) + return url + + +def url_to_path(url: str) -> str: + """ + Convert a file: URL to a path. + """ + assert url.startswith( + "file:" + ), f"You can only turn file: urls into filenames (not {url!r})" + + _, netloc, path, _, _ = urllib.parse.urlsplit(url) + + if not netloc or netloc == "localhost": + # According to RFC 8089, same as empty authority. + netloc = "" + elif WINDOWS: + # If we have a UNC path, prepend UNC share notation. + netloc = "\\\\" + netloc + else: + raise ValueError( + f"non-local file URIs are not supported on this platform: {url!r}" + ) + + path = urllib.request.url2pathname(netloc + path) + + # On Windows, urlsplit parses the path as something like "/C:/Users/foo". + # This creates issues for path-related functions like io.open(), so we try + # to detect and strip the leading slash. + if ( + WINDOWS + and not netloc # Not UNC. + and len(path) >= 3 + and path[0] == "/" # Leading slash to strip. + and path[1] in string.ascii_letters # Drive letter. + and path[2:4] in (":", ":/") # Colon + end of string, or colon + absolute path. + ): + path = path[1:] + + return path diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/virtualenv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/virtualenv.py new file mode 100644 index 0000000000000000000000000000000000000000..882e36f5c1de19a8200000c216cf80119b37c96d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/virtualenv.py @@ -0,0 +1,104 @@ +import logging +import os +import re +import site +import sys +from typing import List, Optional + +logger = logging.getLogger(__name__) +_INCLUDE_SYSTEM_SITE_PACKAGES_REGEX = re.compile( + r"include-system-site-packages\s*=\s*(?Ptrue|false)" +) + + +def _running_under_venv() -> bool: + """Checks if sys.base_prefix and sys.prefix match. + + This handles PEP 405 compliant virtual environments. + """ + return sys.prefix != getattr(sys, "base_prefix", sys.prefix) + + +def _running_under_legacy_virtualenv() -> bool: + """Checks if sys.real_prefix is set. + + This handles virtual environments created with pypa's virtualenv. + """ + # pypa/virtualenv case + return hasattr(sys, "real_prefix") + + +def running_under_virtualenv() -> bool: + """True if we're running inside a virtual environment, False otherwise.""" + return _running_under_venv() or _running_under_legacy_virtualenv() + + +def _get_pyvenv_cfg_lines() -> Optional[List[str]]: + """Reads {sys.prefix}/pyvenv.cfg and returns its contents as list of lines + + Returns None, if it could not read/access the file. + """ + pyvenv_cfg_file = os.path.join(sys.prefix, "pyvenv.cfg") + try: + # Although PEP 405 does not specify, the built-in venv module always + # writes with UTF-8. (pypa/pip#8717) + with open(pyvenv_cfg_file, encoding="utf-8") as f: + return f.read().splitlines() # avoids trailing newlines + except OSError: + return None + + +def _no_global_under_venv() -> bool: + """Check `{sys.prefix}/pyvenv.cfg` for system site-packages inclusion + + PEP 405 specifies that when system site-packages are not supposed to be + visible from a virtual environment, `pyvenv.cfg` must contain the following + line: + + include-system-site-packages = false + + Additionally, log a warning if accessing the file fails. + """ + cfg_lines = _get_pyvenv_cfg_lines() + if cfg_lines is None: + # We're not in a "sane" venv, so assume there is no system + # site-packages access (since that's PEP 405's default state). + logger.warning( + "Could not access 'pyvenv.cfg' despite a virtual environment " + "being active. Assuming global site-packages is not accessible " + "in this environment." + ) + return True + + for line in cfg_lines: + match = _INCLUDE_SYSTEM_SITE_PACKAGES_REGEX.match(line) + if match is not None and match.group("value") == "false": + return True + return False + + +def _no_global_under_legacy_virtualenv() -> bool: + """Check if "no-global-site-packages.txt" exists beside site.py + + This mirrors logic in pypa/virtualenv for determining whether system + site-packages are visible in the virtual environment. + """ + site_mod_dir = os.path.dirname(os.path.abspath(site.__file__)) + no_global_site_packages_file = os.path.join( + site_mod_dir, + "no-global-site-packages.txt", + ) + return os.path.exists(no_global_site_packages_file) + + +def virtualenv_no_global() -> bool: + """Returns a boolean, whether running in venv with no system site-packages.""" + # PEP 405 compliance needs to be checked first since virtualenv >=20 would + # return True for both checks, but is only able to use the PEP 405 config. + if _running_under_venv(): + return _no_global_under_venv() + + if _running_under_legacy_virtualenv(): + return _no_global_under_legacy_virtualenv() + + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/wheel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..f85aee8a3f925ad831431de5251c4e9daa6877ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/utils/wheel.py @@ -0,0 +1,134 @@ +"""Support functions for working with wheel files. +""" + +import logging +from email.message import Message +from email.parser import Parser +from typing import Tuple +from zipfile import BadZipFile, ZipFile + +from pip._vendor.packaging.utils import canonicalize_name + +from pip._internal.exceptions import UnsupportedWheel + +VERSION_COMPATIBLE = (1, 0) + + +logger = logging.getLogger(__name__) + + +def parse_wheel(wheel_zip: ZipFile, name: str) -> Tuple[str, Message]: + """Extract information from the provided wheel, ensuring it meets basic + standards. + + Returns the name of the .dist-info directory and the parsed WHEEL metadata. + """ + try: + info_dir = wheel_dist_info_dir(wheel_zip, name) + metadata = wheel_metadata(wheel_zip, info_dir) + version = wheel_version(metadata) + except UnsupportedWheel as e: + raise UnsupportedWheel(f"{name} has an invalid wheel, {e}") + + check_compatibility(version, name) + + return info_dir, metadata + + +def wheel_dist_info_dir(source: ZipFile, name: str) -> str: + """Returns the name of the contained .dist-info directory. + + Raises AssertionError or UnsupportedWheel if not found, >1 found, or + it doesn't match the provided name. + """ + # Zip file path separators must be / + subdirs = {p.split("/", 1)[0] for p in source.namelist()} + + info_dirs = [s for s in subdirs if s.endswith(".dist-info")] + + if not info_dirs: + raise UnsupportedWheel(".dist-info directory not found") + + if len(info_dirs) > 1: + raise UnsupportedWheel( + "multiple .dist-info directories found: {}".format(", ".join(info_dirs)) + ) + + info_dir = info_dirs[0] + + info_dir_name = canonicalize_name(info_dir) + canonical_name = canonicalize_name(name) + if not info_dir_name.startswith(canonical_name): + raise UnsupportedWheel( + f".dist-info directory {info_dir!r} does not start with {canonical_name!r}" + ) + + return info_dir + + +def read_wheel_metadata_file(source: ZipFile, path: str) -> bytes: + try: + return source.read(path) + # BadZipFile for general corruption, KeyError for missing entry, + # and RuntimeError for password-protected files + except (BadZipFile, KeyError, RuntimeError) as e: + raise UnsupportedWheel(f"could not read {path!r} file: {e!r}") + + +def wheel_metadata(source: ZipFile, dist_info_dir: str) -> Message: + """Return the WHEEL metadata of an extracted wheel, if possible. + Otherwise, raise UnsupportedWheel. + """ + path = f"{dist_info_dir}/WHEEL" + # Zip file path separators must be / + wheel_contents = read_wheel_metadata_file(source, path) + + try: + wheel_text = wheel_contents.decode() + except UnicodeDecodeError as e: + raise UnsupportedWheel(f"error decoding {path!r}: {e!r}") + + # FeedParser (used by Parser) does not raise any exceptions. The returned + # message may have .defects populated, but for backwards-compatibility we + # currently ignore them. + return Parser().parsestr(wheel_text) + + +def wheel_version(wheel_data: Message) -> Tuple[int, ...]: + """Given WHEEL metadata, return the parsed Wheel-Version. + Otherwise, raise UnsupportedWheel. + """ + version_text = wheel_data["Wheel-Version"] + if version_text is None: + raise UnsupportedWheel("WHEEL is missing Wheel-Version") + + version = version_text.strip() + + try: + return tuple(map(int, version.split("."))) + except ValueError: + raise UnsupportedWheel(f"invalid Wheel-Version: {version!r}") + + +def check_compatibility(version: Tuple[int, ...], name: str) -> None: + """Raises errors or warns if called with an incompatible Wheel-Version. + + pip should refuse to install a Wheel-Version that's a major series + ahead of what it's compatible with (e.g 2.0 > 1.1); and warn when + installing a version only minor version ahead (e.g 1.2 > 1.1). + + version: a 2-tuple representing a Wheel-Version (Major, Minor) + name: name of wheel or package to raise exception about + + :raises UnsupportedWheel: when an incompatible Wheel-Version is given + """ + if version[0] > VERSION_COMPATIBLE[0]: + raise UnsupportedWheel( + "{}'s Wheel-Version ({}) is not compatible with this version " + "of pip".format(name, ".".join(map(str, version))) + ) + elif version > VERSION_COMPATIBLE: + logger.warning( + "Installing from a newer Wheel-Version (%s)", + ".".join(map(str, version)), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6beddbe6d24d2949dc89ed07abfebd59d8b63b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/__init__.py @@ -0,0 +1,15 @@ +# Expose a limited set of classes and functions so callers outside of +# the vcs package don't need to import deeper than `pip._internal.vcs`. +# (The test directory may still need to import from a vcs sub-package.) +# Import all vcs modules to register each VCS in the VcsSupport object. +import pip._internal.vcs.bazaar +import pip._internal.vcs.git +import pip._internal.vcs.mercurial +import pip._internal.vcs.subversion # noqa: F401 +from pip._internal.vcs.versioncontrol import ( # noqa: F401 + RemoteNotFoundError, + RemoteNotValidError, + is_url, + make_vcs_requirement_url, + vcs, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/bazaar.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/bazaar.py new file mode 100644 index 0000000000000000000000000000000000000000..c754b7cc5c0bb1c9473161f589f81e27a93286f9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/bazaar.py @@ -0,0 +1,112 @@ +import logging +from typing import List, Optional, Tuple + +from pip._internal.utils.misc import HiddenText, display_path +from pip._internal.utils.subprocess import make_command +from pip._internal.utils.urls import path_to_url +from pip._internal.vcs.versioncontrol import ( + AuthInfo, + RemoteNotFoundError, + RevOptions, + VersionControl, + vcs, +) + +logger = logging.getLogger(__name__) + + +class Bazaar(VersionControl): + name = "bzr" + dirname = ".bzr" + repo_name = "branch" + schemes = ( + "bzr+http", + "bzr+https", + "bzr+ssh", + "bzr+sftp", + "bzr+ftp", + "bzr+lp", + "bzr+file", + ) + + @staticmethod + def get_base_rev_args(rev: str) -> List[str]: + return ["-r", rev] + + def fetch_new( + self, dest: str, url: HiddenText, rev_options: RevOptions, verbosity: int + ) -> None: + rev_display = rev_options.to_display() + logger.info( + "Checking out %s%s to %s", + url, + rev_display, + display_path(dest), + ) + if verbosity <= 0: + flags = ["--quiet"] + elif verbosity == 1: + flags = [] + else: + flags = [f"-{'v'*verbosity}"] + cmd_args = make_command( + "checkout", "--lightweight", *flags, rev_options.to_args(), url, dest + ) + self.run_command(cmd_args) + + def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + self.run_command(make_command("switch", url), cwd=dest) + + def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + output = self.run_command( + make_command("info"), show_stdout=False, stdout_only=True, cwd=dest + ) + if output.startswith("Standalone "): + # Older versions of pip used to create standalone branches. + # Convert the standalone branch to a checkout by calling "bzr bind". + cmd_args = make_command("bind", "-q", url) + self.run_command(cmd_args, cwd=dest) + + cmd_args = make_command("update", "-q", rev_options.to_args()) + self.run_command(cmd_args, cwd=dest) + + @classmethod + def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]: + # hotfix the URL scheme after removing bzr+ from bzr+ssh:// re-add it + url, rev, user_pass = super().get_url_rev_and_auth(url) + if url.startswith("ssh://"): + url = "bzr+" + url + return url, rev, user_pass + + @classmethod + def get_remote_url(cls, location: str) -> str: + urls = cls.run_command( + ["info"], show_stdout=False, stdout_only=True, cwd=location + ) + for line in urls.splitlines(): + line = line.strip() + for x in ("checkout of branch: ", "parent branch: "): + if line.startswith(x): + repo = line.split(x)[1] + if cls._is_local_repository(repo): + return path_to_url(repo) + return repo + raise RemoteNotFoundError + + @classmethod + def get_revision(cls, location: str) -> str: + revision = cls.run_command( + ["revno"], + show_stdout=False, + stdout_only=True, + cwd=location, + ) + return revision.splitlines()[-1] + + @classmethod + def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool: + """Always assume the versions don't match""" + return False + + +vcs.register(Bazaar) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/git.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/git.py new file mode 100644 index 0000000000000000000000000000000000000000..0425debb3ae768378812cbce068ca9fcceaada56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/git.py @@ -0,0 +1,527 @@ +import logging +import os.path +import pathlib +import re +import urllib.parse +import urllib.request +from dataclasses import replace +from typing import List, Optional, Tuple + +from pip._internal.exceptions import BadCommand, InstallationError +from pip._internal.utils.misc import HiddenText, display_path, hide_url +from pip._internal.utils.subprocess import make_command +from pip._internal.vcs.versioncontrol import ( + AuthInfo, + RemoteNotFoundError, + RemoteNotValidError, + RevOptions, + VersionControl, + find_path_to_project_root_from_repo_root, + vcs, +) + +urlsplit = urllib.parse.urlsplit +urlunsplit = urllib.parse.urlunsplit + + +logger = logging.getLogger(__name__) + + +GIT_VERSION_REGEX = re.compile( + r"^git version " # Prefix. + r"(\d+)" # Major. + r"\.(\d+)" # Dot, minor. + r"(?:\.(\d+))?" # Optional dot, patch. + r".*$" # Suffix, including any pre- and post-release segments we don't care about. +) + +HASH_REGEX = re.compile("^[a-fA-F0-9]{40}$") + +# SCP (Secure copy protocol) shorthand. e.g. 'git@example.com:foo/bar.git' +SCP_REGEX = re.compile( + r"""^ + # Optional user, e.g. 'git@' + (\w+@)? + # Server, e.g. 'github.com'. + ([^/:]+): + # The server-side path. e.g. 'user/project.git'. Must start with an + # alphanumeric character so as not to be confusable with a Windows paths + # like 'C:/foo/bar' or 'C:\foo\bar'. + (\w[^:]*) + $""", + re.VERBOSE, +) + + +def looks_like_hash(sha: str) -> bool: + return bool(HASH_REGEX.match(sha)) + + +class Git(VersionControl): + name = "git" + dirname = ".git" + repo_name = "clone" + schemes = ( + "git+http", + "git+https", + "git+ssh", + "git+git", + "git+file", + ) + # Prevent the user's environment variables from interfering with pip: + # https://github.com/pypa/pip/issues/1130 + unset_environ = ("GIT_DIR", "GIT_WORK_TREE") + default_arg_rev = "HEAD" + + @staticmethod + def get_base_rev_args(rev: str) -> List[str]: + return [rev] + + def is_immutable_rev_checkout(self, url: str, dest: str) -> bool: + _, rev_options = self.get_url_rev_options(hide_url(url)) + if not rev_options.rev: + return False + if not self.is_commit_id_equal(dest, rev_options.rev): + # the current commit is different from rev, + # which means rev was something else than a commit hash + return False + # return False in the rare case rev is both a commit hash + # and a tag or a branch; we don't want to cache in that case + # because that branch/tag could point to something else in the future + is_tag_or_branch = bool(self.get_revision_sha(dest, rev_options.rev)[0]) + return not is_tag_or_branch + + def get_git_version(self) -> Tuple[int, ...]: + version = self.run_command( + ["version"], + command_desc="git version", + show_stdout=False, + stdout_only=True, + ) + match = GIT_VERSION_REGEX.match(version) + if not match: + logger.warning("Can't parse git version: %s", version) + return () + return (int(match.group(1)), int(match.group(2))) + + @classmethod + def get_current_branch(cls, location: str) -> Optional[str]: + """ + Return the current branch, or None if HEAD isn't at a branch + (e.g. detached HEAD). + """ + # git-symbolic-ref exits with empty stdout if "HEAD" is a detached + # HEAD rather than a symbolic ref. In addition, the -q causes the + # command to exit with status code 1 instead of 128 in this case + # and to suppress the message to stderr. + args = ["symbolic-ref", "-q", "HEAD"] + output = cls.run_command( + args, + extra_ok_returncodes=(1,), + show_stdout=False, + stdout_only=True, + cwd=location, + ) + ref = output.strip() + + if ref.startswith("refs/heads/"): + return ref[len("refs/heads/") :] + + return None + + @classmethod + def get_revision_sha(cls, dest: str, rev: str) -> Tuple[Optional[str], bool]: + """ + Return (sha_or_none, is_branch), where sha_or_none is a commit hash + if the revision names a remote branch or tag, otherwise None. + + Args: + dest: the repository directory. + rev: the revision name. + """ + # Pass rev to pre-filter the list. + output = cls.run_command( + ["show-ref", rev], + cwd=dest, + show_stdout=False, + stdout_only=True, + on_returncode="ignore", + ) + refs = {} + # NOTE: We do not use splitlines here since that would split on other + # unicode separators, which can be maliciously used to install a + # different revision. + for line in output.strip().split("\n"): + line = line.rstrip("\r") + if not line: + continue + try: + ref_sha, ref_name = line.split(" ", maxsplit=2) + except ValueError: + # Include the offending line to simplify troubleshooting if + # this error ever occurs. + raise ValueError(f"unexpected show-ref line: {line!r}") + + refs[ref_name] = ref_sha + + branch_ref = f"refs/remotes/origin/{rev}" + tag_ref = f"refs/tags/{rev}" + + sha = refs.get(branch_ref) + if sha is not None: + return (sha, True) + + sha = refs.get(tag_ref) + + return (sha, False) + + @classmethod + def _should_fetch(cls, dest: str, rev: str) -> bool: + """ + Return true if rev is a ref or is a commit that we don't have locally. + + Branches and tags are not considered in this method because they are + assumed to be always available locally (which is a normal outcome of + ``git clone`` and ``git fetch --tags``). + """ + if rev.startswith("refs/"): + # Always fetch remote refs. + return True + + if not looks_like_hash(rev): + # Git fetch would fail with abbreviated commits. + return False + + if cls.has_commit(dest, rev): + # Don't fetch if we have the commit locally. + return False + + return True + + @classmethod + def resolve_revision( + cls, dest: str, url: HiddenText, rev_options: RevOptions + ) -> RevOptions: + """ + Resolve a revision to a new RevOptions object with the SHA1 of the + branch, tag, or ref if found. + + Args: + rev_options: a RevOptions object. + """ + rev = rev_options.arg_rev + # The arg_rev property's implementation for Git ensures that the + # rev return value is always non-None. + assert rev is not None + + sha, is_branch = cls.get_revision_sha(dest, rev) + + if sha is not None: + rev_options = rev_options.make_new(sha) + rev_options = replace(rev_options, branch_name=(rev if is_branch else None)) + + return rev_options + + # Do not show a warning for the common case of something that has + # the form of a Git commit hash. + if not looks_like_hash(rev): + logger.warning( + "Did not find branch or tag '%s', assuming revision or ref.", + rev, + ) + + if not cls._should_fetch(dest, rev): + return rev_options + + # fetch the requested revision + cls.run_command( + make_command("fetch", "-q", url, rev_options.to_args()), + cwd=dest, + ) + # Change the revision to the SHA of the ref we fetched + sha = cls.get_revision(dest, rev="FETCH_HEAD") + rev_options = rev_options.make_new(sha) + + return rev_options + + @classmethod + def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool: + """ + Return whether the current commit hash equals the given name. + + Args: + dest: the repository directory. + name: a string name. + """ + if not name: + # Then avoid an unnecessary subprocess call. + return False + + return cls.get_revision(dest) == name + + def fetch_new( + self, dest: str, url: HiddenText, rev_options: RevOptions, verbosity: int + ) -> None: + rev_display = rev_options.to_display() + logger.info("Cloning %s%s to %s", url, rev_display, display_path(dest)) + if verbosity <= 0: + flags: Tuple[str, ...] = ("--quiet",) + elif verbosity == 1: + flags = () + else: + flags = ("--verbose", "--progress") + if self.get_git_version() >= (2, 17): + # Git added support for partial clone in 2.17 + # https://git-scm.com/docs/partial-clone + # Speeds up cloning by functioning without a complete copy of repository + self.run_command( + make_command( + "clone", + "--filter=blob:none", + *flags, + url, + dest, + ) + ) + else: + self.run_command(make_command("clone", *flags, url, dest)) + + if rev_options.rev: + # Then a specific revision was requested. + rev_options = self.resolve_revision(dest, url, rev_options) + branch_name = getattr(rev_options, "branch_name", None) + logger.debug("Rev options %s, branch_name %s", rev_options, branch_name) + if branch_name is None: + # Only do a checkout if the current commit id doesn't match + # the requested revision. + if not self.is_commit_id_equal(dest, rev_options.rev): + cmd_args = make_command( + "checkout", + "-q", + rev_options.to_args(), + ) + self.run_command(cmd_args, cwd=dest) + elif self.get_current_branch(dest) != branch_name: + # Then a specific branch was requested, and that branch + # is not yet checked out. + track_branch = f"origin/{branch_name}" + cmd_args = [ + "checkout", + "-b", + branch_name, + "--track", + track_branch, + ] + self.run_command(cmd_args, cwd=dest) + else: + sha = self.get_revision(dest) + rev_options = rev_options.make_new(sha) + + logger.info("Resolved %s to commit %s", url, rev_options.rev) + + #: repo may contain submodules + self.update_submodules(dest) + + def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + self.run_command( + make_command("config", "remote.origin.url", url), + cwd=dest, + ) + cmd_args = make_command("checkout", "-q", rev_options.to_args()) + self.run_command(cmd_args, cwd=dest) + + self.update_submodules(dest) + + def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + # First fetch changes from the default remote + if self.get_git_version() >= (1, 9): + # fetch tags in addition to everything else + self.run_command(["fetch", "-q", "--tags"], cwd=dest) + else: + self.run_command(["fetch", "-q"], cwd=dest) + # Then reset to wanted revision (maybe even origin/master) + rev_options = self.resolve_revision(dest, url, rev_options) + cmd_args = make_command("reset", "--hard", "-q", rev_options.to_args()) + self.run_command(cmd_args, cwd=dest) + #: update submodules + self.update_submodules(dest) + + @classmethod + def get_remote_url(cls, location: str) -> str: + """ + Return URL of the first remote encountered. + + Raises RemoteNotFoundError if the repository does not have a remote + url configured. + """ + # We need to pass 1 for extra_ok_returncodes since the command + # exits with return code 1 if there are no matching lines. + stdout = cls.run_command( + ["config", "--get-regexp", r"remote\..*\.url"], + extra_ok_returncodes=(1,), + show_stdout=False, + stdout_only=True, + cwd=location, + ) + remotes = stdout.splitlines() + try: + found_remote = remotes[0] + except IndexError: + raise RemoteNotFoundError + + for remote in remotes: + if remote.startswith("remote.origin.url "): + found_remote = remote + break + url = found_remote.split(" ")[1] + return cls._git_remote_to_pip_url(url.strip()) + + @staticmethod + def _git_remote_to_pip_url(url: str) -> str: + """ + Convert a remote url from what git uses to what pip accepts. + + There are 3 legal forms **url** may take: + + 1. A fully qualified url: ssh://git@example.com/foo/bar.git + 2. A local project.git folder: /path/to/bare/repository.git + 3. SCP shorthand for form 1: git@example.com:foo/bar.git + + Form 1 is output as-is. Form 2 must be converted to URI and form 3 must + be converted to form 1. + + See the corresponding test test_git_remote_url_to_pip() for examples of + sample inputs/outputs. + """ + if re.match(r"\w+://", url): + # This is already valid. Pass it though as-is. + return url + if os.path.exists(url): + # A local bare remote (git clone --mirror). + # Needs a file:// prefix. + return pathlib.PurePath(url).as_uri() + scp_match = SCP_REGEX.match(url) + if scp_match: + # Add an ssh:// prefix and replace the ':' with a '/'. + return scp_match.expand(r"ssh://\1\2/\3") + # Otherwise, bail out. + raise RemoteNotValidError(url) + + @classmethod + def has_commit(cls, location: str, rev: str) -> bool: + """ + Check if rev is a commit that is available in the local repository. + """ + try: + cls.run_command( + ["rev-parse", "-q", "--verify", "sha^" + rev], + cwd=location, + log_failed_cmd=False, + ) + except InstallationError: + return False + else: + return True + + @classmethod + def get_revision(cls, location: str, rev: Optional[str] = None) -> str: + if rev is None: + rev = "HEAD" + current_rev = cls.run_command( + ["rev-parse", rev], + show_stdout=False, + stdout_only=True, + cwd=location, + ) + return current_rev.strip() + + @classmethod + def get_subdirectory(cls, location: str) -> Optional[str]: + """ + Return the path to Python project root, relative to the repo root. + Return None if the project root is in the repo root. + """ + # find the repo root + git_dir = cls.run_command( + ["rev-parse", "--git-dir"], + show_stdout=False, + stdout_only=True, + cwd=location, + ).strip() + if not os.path.isabs(git_dir): + git_dir = os.path.join(location, git_dir) + repo_root = os.path.abspath(os.path.join(git_dir, "..")) + return find_path_to_project_root_from_repo_root(location, repo_root) + + @classmethod + def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]: + """ + Prefixes stub URLs like 'user@hostname:user/repo.git' with 'ssh://'. + That's required because although they use SSH they sometimes don't + work with a ssh:// scheme (e.g. GitHub). But we need a scheme for + parsing. Hence we remove it again afterwards and return it as a stub. + """ + # Works around an apparent Git bug + # (see https://article.gmane.org/gmane.comp.version-control.git/146500) + scheme, netloc, path, query, fragment = urlsplit(url) + if scheme.endswith("file"): + initial_slashes = path[: -len(path.lstrip("/"))] + newpath = initial_slashes + urllib.request.url2pathname(path).replace( + "\\", "/" + ).lstrip("/") + after_plus = scheme.find("+") + 1 + url = scheme[:after_plus] + urlunsplit( + (scheme[after_plus:], netloc, newpath, query, fragment), + ) + + if "://" not in url: + assert "file:" not in url + url = url.replace("git+", "git+ssh://") + url, rev, user_pass = super().get_url_rev_and_auth(url) + url = url.replace("ssh://", "") + else: + url, rev, user_pass = super().get_url_rev_and_auth(url) + + return url, rev, user_pass + + @classmethod + def update_submodules(cls, location: str) -> None: + if not os.path.exists(os.path.join(location, ".gitmodules")): + return + cls.run_command( + ["submodule", "update", "--init", "--recursive", "-q"], + cwd=location, + ) + + @classmethod + def get_repository_root(cls, location: str) -> Optional[str]: + loc = super().get_repository_root(location) + if loc: + return loc + try: + r = cls.run_command( + ["rev-parse", "--show-toplevel"], + cwd=location, + show_stdout=False, + stdout_only=True, + on_returncode="raise", + log_failed_cmd=False, + ) + except BadCommand: + logger.debug( + "could not determine if %s is under git control " + "because git is not available", + location, + ) + return None + except InstallationError: + return None + return os.path.normpath(r.rstrip("\r\n")) + + @staticmethod + def should_add_vcs_url_prefix(repo_url: str) -> bool: + """In either https or ssh form, requirements must be prefixed with git+.""" + return True + + +vcs.register(Git) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/mercurial.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/mercurial.py new file mode 100644 index 0000000000000000000000000000000000000000..c183d41d09cf0752d99b74650427b2597f07d3b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/mercurial.py @@ -0,0 +1,163 @@ +import configparser +import logging +import os +from typing import List, Optional, Tuple + +from pip._internal.exceptions import BadCommand, InstallationError +from pip._internal.utils.misc import HiddenText, display_path +from pip._internal.utils.subprocess import make_command +from pip._internal.utils.urls import path_to_url +from pip._internal.vcs.versioncontrol import ( + RevOptions, + VersionControl, + find_path_to_project_root_from_repo_root, + vcs, +) + +logger = logging.getLogger(__name__) + + +class Mercurial(VersionControl): + name = "hg" + dirname = ".hg" + repo_name = "clone" + schemes = ( + "hg+file", + "hg+http", + "hg+https", + "hg+ssh", + "hg+static-http", + ) + + @staticmethod + def get_base_rev_args(rev: str) -> List[str]: + return [f"--rev={rev}"] + + def fetch_new( + self, dest: str, url: HiddenText, rev_options: RevOptions, verbosity: int + ) -> None: + rev_display = rev_options.to_display() + logger.info( + "Cloning hg %s%s to %s", + url, + rev_display, + display_path(dest), + ) + if verbosity <= 0: + flags: Tuple[str, ...] = ("--quiet",) + elif verbosity == 1: + flags = () + elif verbosity == 2: + flags = ("--verbose",) + else: + flags = ("--verbose", "--debug") + self.run_command(make_command("clone", "--noupdate", *flags, url, dest)) + self.run_command( + make_command("update", *flags, rev_options.to_args()), + cwd=dest, + ) + + def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + repo_config = os.path.join(dest, self.dirname, "hgrc") + config = configparser.RawConfigParser() + try: + config.read(repo_config) + config.set("paths", "default", url.secret) + with open(repo_config, "w") as config_file: + config.write(config_file) + except (OSError, configparser.NoSectionError) as exc: + logger.warning("Could not switch Mercurial repository to %s: %s", url, exc) + else: + cmd_args = make_command("update", "-q", rev_options.to_args()) + self.run_command(cmd_args, cwd=dest) + + def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + self.run_command(["pull", "-q"], cwd=dest) + cmd_args = make_command("update", "-q", rev_options.to_args()) + self.run_command(cmd_args, cwd=dest) + + @classmethod + def get_remote_url(cls, location: str) -> str: + url = cls.run_command( + ["showconfig", "paths.default"], + show_stdout=False, + stdout_only=True, + cwd=location, + ).strip() + if cls._is_local_repository(url): + url = path_to_url(url) + return url.strip() + + @classmethod + def get_revision(cls, location: str) -> str: + """ + Return the repository-local changeset revision number, as an integer. + """ + current_revision = cls.run_command( + ["parents", "--template={rev}"], + show_stdout=False, + stdout_only=True, + cwd=location, + ).strip() + return current_revision + + @classmethod + def get_requirement_revision(cls, location: str) -> str: + """ + Return the changeset identification hash, as a 40-character + hexadecimal string + """ + current_rev_hash = cls.run_command( + ["parents", "--template={node}"], + show_stdout=False, + stdout_only=True, + cwd=location, + ).strip() + return current_rev_hash + + @classmethod + def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool: + """Always assume the versions don't match""" + return False + + @classmethod + def get_subdirectory(cls, location: str) -> Optional[str]: + """ + Return the path to Python project root, relative to the repo root. + Return None if the project root is in the repo root. + """ + # find the repo root + repo_root = cls.run_command( + ["root"], show_stdout=False, stdout_only=True, cwd=location + ).strip() + if not os.path.isabs(repo_root): + repo_root = os.path.abspath(os.path.join(location, repo_root)) + return find_path_to_project_root_from_repo_root(location, repo_root) + + @classmethod + def get_repository_root(cls, location: str) -> Optional[str]: + loc = super().get_repository_root(location) + if loc: + return loc + try: + r = cls.run_command( + ["root"], + cwd=location, + show_stdout=False, + stdout_only=True, + on_returncode="raise", + log_failed_cmd=False, + ) + except BadCommand: + logger.debug( + "could not determine if %s is under hg control " + "because hg is not available", + location, + ) + return None + except InstallationError: + return None + return os.path.normpath(r.rstrip("\r\n")) + + +vcs.register(Mercurial) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/subversion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/subversion.py new file mode 100644 index 0000000000000000000000000000000000000000..f359266d9c0879f55991377fa5f354bb01f17efe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/subversion.py @@ -0,0 +1,324 @@ +import logging +import os +import re +from typing import List, Optional, Tuple + +from pip._internal.utils.misc import ( + HiddenText, + display_path, + is_console_interactive, + is_installable_dir, + split_auth_from_netloc, +) +from pip._internal.utils.subprocess import CommandArgs, make_command +from pip._internal.vcs.versioncontrol import ( + AuthInfo, + RemoteNotFoundError, + RevOptions, + VersionControl, + vcs, +) + +logger = logging.getLogger(__name__) + +_svn_xml_url_re = re.compile('url="([^"]+)"') +_svn_rev_re = re.compile(r'committed-rev="(\d+)"') +_svn_info_xml_rev_re = re.compile(r'\s*revision="(\d+)"') +_svn_info_xml_url_re = re.compile(r"(.*)") + + +class Subversion(VersionControl): + name = "svn" + dirname = ".svn" + repo_name = "checkout" + schemes = ("svn+ssh", "svn+http", "svn+https", "svn+svn", "svn+file") + + @classmethod + def should_add_vcs_url_prefix(cls, remote_url: str) -> bool: + return True + + @staticmethod + def get_base_rev_args(rev: str) -> List[str]: + return ["-r", rev] + + @classmethod + def get_revision(cls, location: str) -> str: + """ + Return the maximum revision for all files under a given location + """ + # Note: taken from setuptools.command.egg_info + revision = 0 + + for base, dirs, _ in os.walk(location): + if cls.dirname not in dirs: + dirs[:] = [] + continue # no sense walking uncontrolled subdirs + dirs.remove(cls.dirname) + entries_fn = os.path.join(base, cls.dirname, "entries") + if not os.path.exists(entries_fn): + # FIXME: should we warn? + continue + + dirurl, localrev = cls._get_svn_url_rev(base) + + if base == location: + assert dirurl is not None + base = dirurl + "/" # save the root url + elif not dirurl or not dirurl.startswith(base): + dirs[:] = [] + continue # not part of the same svn tree, skip it + revision = max(revision, localrev) + return str(revision) + + @classmethod + def get_netloc_and_auth( + cls, netloc: str, scheme: str + ) -> Tuple[str, Tuple[Optional[str], Optional[str]]]: + """ + This override allows the auth information to be passed to svn via the + --username and --password options instead of via the URL. + """ + if scheme == "ssh": + # The --username and --password options can't be used for + # svn+ssh URLs, so keep the auth information in the URL. + return super().get_netloc_and_auth(netloc, scheme) + + return split_auth_from_netloc(netloc) + + @classmethod + def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]: + # hotfix the URL scheme after removing svn+ from svn+ssh:// re-add it + url, rev, user_pass = super().get_url_rev_and_auth(url) + if url.startswith("ssh://"): + url = "svn+" + url + return url, rev, user_pass + + @staticmethod + def make_rev_args( + username: Optional[str], password: Optional[HiddenText] + ) -> CommandArgs: + extra_args: CommandArgs = [] + if username: + extra_args += ["--username", username] + if password: + extra_args += ["--password", password] + + return extra_args + + @classmethod + def get_remote_url(cls, location: str) -> str: + # In cases where the source is in a subdirectory, we have to look up in + # the location until we find a valid project root. + orig_location = location + while not is_installable_dir(location): + last_location = location + location = os.path.dirname(location) + if location == last_location: + # We've traversed up to the root of the filesystem without + # finding a Python project. + logger.warning( + "Could not find Python project for directory %s (tried all " + "parent directories)", + orig_location, + ) + raise RemoteNotFoundError + + url, _rev = cls._get_svn_url_rev(location) + if url is None: + raise RemoteNotFoundError + + return url + + @classmethod + def _get_svn_url_rev(cls, location: str) -> Tuple[Optional[str], int]: + from pip._internal.exceptions import InstallationError + + entries_path = os.path.join(location, cls.dirname, "entries") + if os.path.exists(entries_path): + with open(entries_path) as f: + data = f.read() + else: # subversion >= 1.7 does not have the 'entries' file + data = "" + + url = None + if data.startswith("8") or data.startswith("9") or data.startswith("10"): + entries = list(map(str.splitlines, data.split("\n\x0c\n"))) + del entries[0][0] # get rid of the '8' + url = entries[0][3] + revs = [int(d[9]) for d in entries if len(d) > 9 and d[9]] + [0] + elif data.startswith("= 1.7 + # Note that using get_remote_call_options is not necessary here + # because `svn info` is being run against a local directory. + # We don't need to worry about making sure interactive mode + # is being used to prompt for passwords, because passwords + # are only potentially needed for remote server requests. + xml = cls.run_command( + ["info", "--xml", location], + show_stdout=False, + stdout_only=True, + ) + match = _svn_info_xml_url_re.search(xml) + assert match is not None + url = match.group(1) + revs = [int(m.group(1)) for m in _svn_info_xml_rev_re.finditer(xml)] + except InstallationError: + url, revs = None, [] + + if revs: + rev = max(revs) + else: + rev = 0 + + return url, rev + + @classmethod + def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool: + """Always assume the versions don't match""" + return False + + def __init__(self, use_interactive: Optional[bool] = None) -> None: + if use_interactive is None: + use_interactive = is_console_interactive() + self.use_interactive = use_interactive + + # This member is used to cache the fetched version of the current + # ``svn`` client. + # Special value definitions: + # None: Not evaluated yet. + # Empty tuple: Could not parse version. + self._vcs_version: Optional[Tuple[int, ...]] = None + + super().__init__() + + def call_vcs_version(self) -> Tuple[int, ...]: + """Query the version of the currently installed Subversion client. + + :return: A tuple containing the parts of the version information or + ``()`` if the version returned from ``svn`` could not be parsed. + :raises: BadCommand: If ``svn`` is not installed. + """ + # Example versions: + # svn, version 1.10.3 (r1842928) + # compiled Feb 25 2019, 14:20:39 on x86_64-apple-darwin17.0.0 + # svn, version 1.7.14 (r1542130) + # compiled Mar 28 2018, 08:49:13 on x86_64-pc-linux-gnu + # svn, version 1.12.0-SlikSvn (SlikSvn/1.12.0) + # compiled May 28 2019, 13:44:56 on x86_64-microsoft-windows6.2 + version_prefix = "svn, version " + version = self.run_command(["--version"], show_stdout=False, stdout_only=True) + if not version.startswith(version_prefix): + return () + + version = version[len(version_prefix) :].split()[0] + version_list = version.partition("-")[0].split(".") + try: + parsed_version = tuple(map(int, version_list)) + except ValueError: + return () + + return parsed_version + + def get_vcs_version(self) -> Tuple[int, ...]: + """Return the version of the currently installed Subversion client. + + If the version of the Subversion client has already been queried, + a cached value will be used. + + :return: A tuple containing the parts of the version information or + ``()`` if the version returned from ``svn`` could not be parsed. + :raises: BadCommand: If ``svn`` is not installed. + """ + if self._vcs_version is not None: + # Use cached version, if available. + # If parsing the version failed previously (empty tuple), + # do not attempt to parse it again. + return self._vcs_version + + vcs_version = self.call_vcs_version() + self._vcs_version = vcs_version + return vcs_version + + def get_remote_call_options(self) -> CommandArgs: + """Return options to be used on calls to Subversion that contact the server. + + These options are applicable for the following ``svn`` subcommands used + in this class. + + - checkout + - switch + - update + + :return: A list of command line arguments to pass to ``svn``. + """ + if not self.use_interactive: + # --non-interactive switch is available since Subversion 0.14.4. + # Subversion < 1.8 runs in interactive mode by default. + return ["--non-interactive"] + + svn_version = self.get_vcs_version() + # By default, Subversion >= 1.8 runs in non-interactive mode if + # stdin is not a TTY. Since that is how pip invokes SVN, in + # call_subprocess(), pip must pass --force-interactive to ensure + # the user can be prompted for a password, if required. + # SVN added the --force-interactive option in SVN 1.8. Since + # e.g. RHEL/CentOS 7, which is supported until 2024, ships with + # SVN 1.7, pip should continue to support SVN 1.7. Therefore, pip + # can't safely add the option if the SVN version is < 1.8 (or unknown). + if svn_version >= (1, 8): + return ["--force-interactive"] + + return [] + + def fetch_new( + self, dest: str, url: HiddenText, rev_options: RevOptions, verbosity: int + ) -> None: + rev_display = rev_options.to_display() + logger.info( + "Checking out %s%s to %s", + url, + rev_display, + display_path(dest), + ) + if verbosity <= 0: + flags = ["--quiet"] + else: + flags = [] + cmd_args = make_command( + "checkout", + *flags, + self.get_remote_call_options(), + rev_options.to_args(), + url, + dest, + ) + self.run_command(cmd_args) + + def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + cmd_args = make_command( + "switch", + self.get_remote_call_options(), + rev_options.to_args(), + url, + dest, + ) + self.run_command(cmd_args) + + def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + cmd_args = make_command( + "update", + self.get_remote_call_options(), + rev_options.to_args(), + dest, + ) + self.run_command(cmd_args) + + +vcs.register(Subversion) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/versioncontrol.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/versioncontrol.py new file mode 100644 index 0000000000000000000000000000000000000000..a4133165e9ae2464e2eb84175e77dcb904394545 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/vcs/versioncontrol.py @@ -0,0 +1,688 @@ +"""Handles all VCS (version control) support""" + +import logging +import os +import shutil +import sys +import urllib.parse +from dataclasses import dataclass, field +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +from pip._internal.cli.spinners import SpinnerInterface +from pip._internal.exceptions import BadCommand, InstallationError +from pip._internal.utils.misc import ( + HiddenText, + ask_path_exists, + backup_dir, + display_path, + hide_url, + hide_value, + is_installable_dir, + rmtree, +) +from pip._internal.utils.subprocess import ( + CommandArgs, + call_subprocess, + format_command_args, + make_command, +) + +__all__ = ["vcs"] + + +logger = logging.getLogger(__name__) + +AuthInfo = Tuple[Optional[str], Optional[str]] + + +def is_url(name: str) -> bool: + """ + Return true if the name looks like a URL. + """ + scheme = urllib.parse.urlsplit(name).scheme + if not scheme: + return False + return scheme in ["http", "https", "file", "ftp"] + vcs.all_schemes + + +def make_vcs_requirement_url( + repo_url: str, rev: str, project_name: str, subdir: Optional[str] = None +) -> str: + """ + Return the URL for a VCS requirement. + + Args: + repo_url: the remote VCS url, with any needed VCS prefix (e.g. "git+"). + project_name: the (unescaped) project name. + """ + egg_project_name = project_name.replace("-", "_") + req = f"{repo_url}@{rev}#egg={egg_project_name}" + if subdir: + req += f"&subdirectory={subdir}" + + return req + + +def find_path_to_project_root_from_repo_root( + location: str, repo_root: str +) -> Optional[str]: + """ + Find the the Python project's root by searching up the filesystem from + `location`. Return the path to project root relative to `repo_root`. + Return None if the project root is `repo_root`, or cannot be found. + """ + # find project root. + orig_location = location + while not is_installable_dir(location): + last_location = location + location = os.path.dirname(location) + if location == last_location: + # We've traversed up to the root of the filesystem without + # finding a Python project. + logger.warning( + "Could not find a Python project for directory %s (tried all " + "parent directories)", + orig_location, + ) + return None + + if os.path.samefile(repo_root, location): + return None + + return os.path.relpath(location, repo_root) + + +class RemoteNotFoundError(Exception): + pass + + +class RemoteNotValidError(Exception): + def __init__(self, url: str): + super().__init__(url) + self.url = url + + +@dataclass(frozen=True) +class RevOptions: + """ + Encapsulates a VCS-specific revision to install, along with any VCS + install options. + + Args: + vc_class: a VersionControl subclass. + rev: the name of the revision to install. + extra_args: a list of extra options. + """ + + vc_class: Type["VersionControl"] + rev: Optional[str] = None + extra_args: CommandArgs = field(default_factory=list) + branch_name: Optional[str] = None + + def __repr__(self) -> str: + return f"" + + @property + def arg_rev(self) -> Optional[str]: + if self.rev is None: + return self.vc_class.default_arg_rev + + return self.rev + + def to_args(self) -> CommandArgs: + """ + Return the VCS-specific command arguments. + """ + args: CommandArgs = [] + rev = self.arg_rev + if rev is not None: + args += self.vc_class.get_base_rev_args(rev) + args += self.extra_args + + return args + + def to_display(self) -> str: + if not self.rev: + return "" + + return f" (to revision {self.rev})" + + def make_new(self, rev: str) -> "RevOptions": + """ + Make a copy of the current instance, but with a new rev. + + Args: + rev: the name of the revision for the new object. + """ + return self.vc_class.make_rev_options(rev, extra_args=self.extra_args) + + +class VcsSupport: + _registry: Dict[str, "VersionControl"] = {} + schemes = ["ssh", "git", "hg", "bzr", "sftp", "svn"] + + def __init__(self) -> None: + # Register more schemes with urlparse for various version control + # systems + urllib.parse.uses_netloc.extend(self.schemes) + super().__init__() + + def __iter__(self) -> Iterator[str]: + return self._registry.__iter__() + + @property + def backends(self) -> List["VersionControl"]: + return list(self._registry.values()) + + @property + def dirnames(self) -> List[str]: + return [backend.dirname for backend in self.backends] + + @property + def all_schemes(self) -> List[str]: + schemes: List[str] = [] + for backend in self.backends: + schemes.extend(backend.schemes) + return schemes + + def register(self, cls: Type["VersionControl"]) -> None: + if not hasattr(cls, "name"): + logger.warning("Cannot register VCS %s", cls.__name__) + return + if cls.name not in self._registry: + self._registry[cls.name] = cls() + logger.debug("Registered VCS backend: %s", cls.name) + + def unregister(self, name: str) -> None: + if name in self._registry: + del self._registry[name] + + def get_backend_for_dir(self, location: str) -> Optional["VersionControl"]: + """ + Return a VersionControl object if a repository of that type is found + at the given directory. + """ + vcs_backends = {} + for vcs_backend in self._registry.values(): + repo_path = vcs_backend.get_repository_root(location) + if not repo_path: + continue + logger.debug("Determine that %s uses VCS: %s", location, vcs_backend.name) + vcs_backends[repo_path] = vcs_backend + + if not vcs_backends: + return None + + # Choose the VCS in the inner-most directory. Since all repository + # roots found here would be either `location` or one of its + # parents, the longest path should have the most path components, + # i.e. the backend representing the inner-most repository. + inner_most_repo_path = max(vcs_backends, key=len) + return vcs_backends[inner_most_repo_path] + + def get_backend_for_scheme(self, scheme: str) -> Optional["VersionControl"]: + """ + Return a VersionControl object or None. + """ + for vcs_backend in self._registry.values(): + if scheme in vcs_backend.schemes: + return vcs_backend + return None + + def get_backend(self, name: str) -> Optional["VersionControl"]: + """ + Return a VersionControl object or None. + """ + name = name.lower() + return self._registry.get(name) + + +vcs = VcsSupport() + + +class VersionControl: + name = "" + dirname = "" + repo_name = "" + # List of supported schemes for this Version Control + schemes: Tuple[str, ...] = () + # Iterable of environment variable names to pass to call_subprocess(). + unset_environ: Tuple[str, ...] = () + default_arg_rev: Optional[str] = None + + @classmethod + def should_add_vcs_url_prefix(cls, remote_url: str) -> bool: + """ + Return whether the vcs prefix (e.g. "git+") should be added to a + repository's remote url when used in a requirement. + """ + return not remote_url.lower().startswith(f"{cls.name}:") + + @classmethod + def get_subdirectory(cls, location: str) -> Optional[str]: + """ + Return the path to Python project root, relative to the repo root. + Return None if the project root is in the repo root. + """ + return None + + @classmethod + def get_requirement_revision(cls, repo_dir: str) -> str: + """ + Return the revision string that should be used in a requirement. + """ + return cls.get_revision(repo_dir) + + @classmethod + def get_src_requirement(cls, repo_dir: str, project_name: str) -> str: + """ + Return the requirement string to use to redownload the files + currently at the given repository directory. + + Args: + project_name: the (unescaped) project name. + + The return value has a form similar to the following: + + {repository_url}@{revision}#egg={project_name} + """ + repo_url = cls.get_remote_url(repo_dir) + + if cls.should_add_vcs_url_prefix(repo_url): + repo_url = f"{cls.name}+{repo_url}" + + revision = cls.get_requirement_revision(repo_dir) + subdir = cls.get_subdirectory(repo_dir) + req = make_vcs_requirement_url(repo_url, revision, project_name, subdir=subdir) + + return req + + @staticmethod + def get_base_rev_args(rev: str) -> List[str]: + """ + Return the base revision arguments for a vcs command. + + Args: + rev: the name of a revision to install. Cannot be None. + """ + raise NotImplementedError + + def is_immutable_rev_checkout(self, url: str, dest: str) -> bool: + """ + Return true if the commit hash checked out at dest matches + the revision in url. + + Always return False, if the VCS does not support immutable commit + hashes. + + This method does not check if there are local uncommitted changes + in dest after checkout, as pip currently has no use case for that. + """ + return False + + @classmethod + def make_rev_options( + cls, rev: Optional[str] = None, extra_args: Optional[CommandArgs] = None + ) -> RevOptions: + """ + Return a RevOptions object. + + Args: + rev: the name of a revision to install. + extra_args: a list of extra options. + """ + return RevOptions(cls, rev, extra_args=extra_args or []) + + @classmethod + def _is_local_repository(cls, repo: str) -> bool: + """ + posix absolute paths start with os.path.sep, + win32 ones start with drive (like c:\\folder) + """ + drive, tail = os.path.splitdrive(repo) + return repo.startswith(os.path.sep) or bool(drive) + + @classmethod + def get_netloc_and_auth( + cls, netloc: str, scheme: str + ) -> Tuple[str, Tuple[Optional[str], Optional[str]]]: + """ + Parse the repository URL's netloc, and return the new netloc to use + along with auth information. + + Args: + netloc: the original repository URL netloc. + scheme: the repository URL's scheme without the vcs prefix. + + This is mainly for the Subversion class to override, so that auth + information can be provided via the --username and --password options + instead of through the URL. For other subclasses like Git without + such an option, auth information must stay in the URL. + + Returns: (netloc, (username, password)). + """ + return netloc, (None, None) + + @classmethod + def get_url_rev_and_auth(cls, url: str) -> Tuple[str, Optional[str], AuthInfo]: + """ + Parse the repository URL to use, and return the URL, revision, + and auth info to use. + + Returns: (url, rev, (username, password)). + """ + scheme, netloc, path, query, frag = urllib.parse.urlsplit(url) + if "+" not in scheme: + raise ValueError( + f"Sorry, {url!r} is a malformed VCS url. " + "The format is +://, " + "e.g. svn+http://myrepo/svn/MyApp#egg=MyApp" + ) + # Remove the vcs prefix. + scheme = scheme.split("+", 1)[1] + netloc, user_pass = cls.get_netloc_and_auth(netloc, scheme) + rev = None + if "@" in path: + path, rev = path.rsplit("@", 1) + if not rev: + raise InstallationError( + f"The URL {url!r} has an empty revision (after @) " + "which is not supported. Include a revision after @ " + "or remove @ from the URL." + ) + url = urllib.parse.urlunsplit((scheme, netloc, path, query, "")) + return url, rev, user_pass + + @staticmethod + def make_rev_args( + username: Optional[str], password: Optional[HiddenText] + ) -> CommandArgs: + """ + Return the RevOptions "extra arguments" to use in obtain(). + """ + return [] + + def get_url_rev_options(self, url: HiddenText) -> Tuple[HiddenText, RevOptions]: + """ + Return the URL and RevOptions object to use in obtain(), + as a tuple (url, rev_options). + """ + secret_url, rev, user_pass = self.get_url_rev_and_auth(url.secret) + username, secret_password = user_pass + password: Optional[HiddenText] = None + if secret_password is not None: + password = hide_value(secret_password) + extra_args = self.make_rev_args(username, password) + rev_options = self.make_rev_options(rev, extra_args=extra_args) + + return hide_url(secret_url), rev_options + + @staticmethod + def normalize_url(url: str) -> str: + """ + Normalize a URL for comparison by unquoting it and removing any + trailing slash. + """ + return urllib.parse.unquote(url).rstrip("/") + + @classmethod + def compare_urls(cls, url1: str, url2: str) -> bool: + """ + Compare two repo URLs for identity, ignoring incidental differences. + """ + return cls.normalize_url(url1) == cls.normalize_url(url2) + + def fetch_new( + self, dest: str, url: HiddenText, rev_options: RevOptions, verbosity: int + ) -> None: + """ + Fetch a revision from a repository, in the case that this is the + first fetch from the repository. + + Args: + dest: the directory to fetch the repository to. + rev_options: a RevOptions object. + verbosity: verbosity level. + """ + raise NotImplementedError + + def switch(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + """ + Switch the repo at ``dest`` to point to ``URL``. + + Args: + rev_options: a RevOptions object. + """ + raise NotImplementedError + + def update(self, dest: str, url: HiddenText, rev_options: RevOptions) -> None: + """ + Update an already-existing repo to the given ``rev_options``. + + Args: + rev_options: a RevOptions object. + """ + raise NotImplementedError + + @classmethod + def is_commit_id_equal(cls, dest: str, name: Optional[str]) -> bool: + """ + Return whether the id of the current commit equals the given name. + + Args: + dest: the repository directory. + name: a string name. + """ + raise NotImplementedError + + def obtain(self, dest: str, url: HiddenText, verbosity: int) -> None: + """ + Install or update in editable mode the package represented by this + VersionControl object. + + :param dest: the repository directory in which to install or update. + :param url: the repository URL starting with a vcs prefix. + :param verbosity: verbosity level. + """ + url, rev_options = self.get_url_rev_options(url) + + if not os.path.exists(dest): + self.fetch_new(dest, url, rev_options, verbosity=verbosity) + return + + rev_display = rev_options.to_display() + if self.is_repository_directory(dest): + existing_url = self.get_remote_url(dest) + if self.compare_urls(existing_url, url.secret): + logger.debug( + "%s in %s exists, and has correct URL (%s)", + self.repo_name.title(), + display_path(dest), + url, + ) + if not self.is_commit_id_equal(dest, rev_options.rev): + logger.info( + "Updating %s %s%s", + display_path(dest), + self.repo_name, + rev_display, + ) + self.update(dest, url, rev_options) + else: + logger.info("Skipping because already up-to-date.") + return + + logger.warning( + "%s %s in %s exists with URL %s", + self.name, + self.repo_name, + display_path(dest), + existing_url, + ) + prompt = ("(s)witch, (i)gnore, (w)ipe, (b)ackup ", ("s", "i", "w", "b")) + else: + logger.warning( + "Directory %s already exists, and is not a %s %s.", + dest, + self.name, + self.repo_name, + ) + # https://github.com/python/mypy/issues/1174 + prompt = ("(i)gnore, (w)ipe, (b)ackup ", ("i", "w", "b")) # type: ignore + + logger.warning( + "The plan is to install the %s repository %s", + self.name, + url, + ) + response = ask_path_exists(f"What to do? {prompt[0]}", prompt[1]) + + if response == "a": + sys.exit(-1) + + if response == "w": + logger.warning("Deleting %s", display_path(dest)) + rmtree(dest) + self.fetch_new(dest, url, rev_options, verbosity=verbosity) + return + + if response == "b": + dest_dir = backup_dir(dest) + logger.warning("Backing up %s to %s", display_path(dest), dest_dir) + shutil.move(dest, dest_dir) + self.fetch_new(dest, url, rev_options, verbosity=verbosity) + return + + # Do nothing if the response is "i". + if response == "s": + logger.info( + "Switching %s %s to %s%s", + self.repo_name, + display_path(dest), + url, + rev_display, + ) + self.switch(dest, url, rev_options) + + def unpack(self, location: str, url: HiddenText, verbosity: int) -> None: + """ + Clean up current location and download the url repository + (and vcs infos) into location + + :param url: the repository URL starting with a vcs prefix. + :param verbosity: verbosity level. + """ + if os.path.exists(location): + rmtree(location) + self.obtain(location, url=url, verbosity=verbosity) + + @classmethod + def get_remote_url(cls, location: str) -> str: + """ + Return the url used at location + + Raises RemoteNotFoundError if the repository does not have a remote + url configured. + """ + raise NotImplementedError + + @classmethod + def get_revision(cls, location: str) -> str: + """ + Return the current commit id of the files at the given location. + """ + raise NotImplementedError + + @classmethod + def run_command( + cls, + cmd: Union[List[str], CommandArgs], + show_stdout: bool = True, + cwd: Optional[str] = None, + on_returncode: 'Literal["raise", "warn", "ignore"]' = "raise", + extra_ok_returncodes: Optional[Iterable[int]] = None, + command_desc: Optional[str] = None, + extra_environ: Optional[Mapping[str, Any]] = None, + spinner: Optional[SpinnerInterface] = None, + log_failed_cmd: bool = True, + stdout_only: bool = False, + ) -> str: + """ + Run a VCS subcommand + This is simply a wrapper around call_subprocess that adds the VCS + command name, and checks that the VCS is available + """ + cmd = make_command(cls.name, *cmd) + if command_desc is None: + command_desc = format_command_args(cmd) + try: + return call_subprocess( + cmd, + show_stdout, + cwd, + on_returncode=on_returncode, + extra_ok_returncodes=extra_ok_returncodes, + command_desc=command_desc, + extra_environ=extra_environ, + unset_environ=cls.unset_environ, + spinner=spinner, + log_failed_cmd=log_failed_cmd, + stdout_only=stdout_only, + ) + except NotADirectoryError: + raise BadCommand(f"Cannot find command {cls.name!r} - invalid PATH") + except FileNotFoundError: + # errno.ENOENT = no such file or directory + # In other words, the VCS executable isn't available + raise BadCommand( + f"Cannot find command {cls.name!r} - do you have " + f"{cls.name!r} installed and in your PATH?" + ) + except PermissionError: + # errno.EACCES = Permission denied + # This error occurs, for instance, when the command is installed + # only for another user. So, the current user don't have + # permission to call the other user command. + raise BadCommand( + f"No permission to execute {cls.name!r} - install it " + f"locally, globally (ask admin), or check your PATH. " + f"See possible solutions at " + f"https://pip.pypa.io/en/latest/reference/pip_freeze/" + f"#fixing-permission-denied." + ) + + @classmethod + def is_repository_directory(cls, path: str) -> bool: + """ + Return whether a directory path is a repository directory. + """ + logger.debug("Checking in %s for %s (%s)...", path, cls.dirname, cls.name) + return os.path.exists(os.path.join(path, cls.dirname)) + + @classmethod + def get_repository_root(cls, location: str) -> Optional[str]: + """ + Return the "root" (top-level) directory controlled by the vcs, + or `None` if the directory is not in any. + + It is meant to be overridden to implement smarter detection + mechanisms for specific vcs. + + This can do more than is_repository_directory() alone. For + example, the Git override checks that Git is actually available. + """ + if cls.is_repository_directory(location): + return location + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/wheel_builder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/wheel_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..93f8e1f5b2f61257c09c58ea6cecde8ecc778e07 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_internal/wheel_builder.py @@ -0,0 +1,354 @@ +"""Orchestrator for building wheels from InstallRequirements. +""" + +import logging +import os.path +import re +import shutil +from typing import Iterable, List, Optional, Tuple + +from pip._vendor.packaging.utils import canonicalize_name, canonicalize_version +from pip._vendor.packaging.version import InvalidVersion, Version + +from pip._internal.cache import WheelCache +from pip._internal.exceptions import InvalidWheelFilename, UnsupportedWheel +from pip._internal.metadata import FilesystemWheel, get_wheel_distribution +from pip._internal.models.link import Link +from pip._internal.models.wheel import Wheel +from pip._internal.operations.build.wheel import build_wheel_pep517 +from pip._internal.operations.build.wheel_editable import build_wheel_editable +from pip._internal.operations.build.wheel_legacy import build_wheel_legacy +from pip._internal.req.req_install import InstallRequirement +from pip._internal.utils.logging import indent_log +from pip._internal.utils.misc import ensure_dir, hash_file +from pip._internal.utils.setuptools_build import make_setuptools_clean_args +from pip._internal.utils.subprocess import call_subprocess +from pip._internal.utils.temp_dir import TempDirectory +from pip._internal.utils.urls import path_to_url +from pip._internal.vcs import vcs + +logger = logging.getLogger(__name__) + +_egg_info_re = re.compile(r"([a-z0-9_.]+)-([a-z0-9_.!+-]+)", re.IGNORECASE) + +BuildResult = Tuple[List[InstallRequirement], List[InstallRequirement]] + + +def _contains_egg_info(s: str) -> bool: + """Determine whether the string looks like an egg_info. + + :param s: The string to parse. E.g. foo-2.1 + """ + return bool(_egg_info_re.search(s)) + + +def _should_build( + req: InstallRequirement, + need_wheel: bool, +) -> bool: + """Return whether an InstallRequirement should be built into a wheel.""" + if req.constraint: + # never build requirements that are merely constraints + return False + if req.is_wheel: + if need_wheel: + logger.info( + "Skipping %s, due to already being wheel.", + req.name, + ) + return False + + if need_wheel: + # i.e. pip wheel, not pip install + return True + + # From this point, this concerns the pip install command only + # (need_wheel=False). + + if not req.source_dir: + return False + + if req.editable: + # we only build PEP 660 editable requirements + return req.supports_pyproject_editable + + return True + + +def should_build_for_wheel_command( + req: InstallRequirement, +) -> bool: + return _should_build(req, need_wheel=True) + + +def should_build_for_install_command( + req: InstallRequirement, +) -> bool: + return _should_build(req, need_wheel=False) + + +def _should_cache( + req: InstallRequirement, +) -> Optional[bool]: + """ + Return whether a built InstallRequirement can be stored in the persistent + wheel cache, assuming the wheel cache is available, and _should_build() + has determined a wheel needs to be built. + """ + if req.editable or not req.source_dir: + # never cache editable requirements + return False + + if req.link and req.link.is_vcs: + # VCS checkout. Do not cache + # unless it points to an immutable commit hash. + assert not req.editable + assert req.source_dir + vcs_backend = vcs.get_backend_for_scheme(req.link.scheme) + assert vcs_backend + if vcs_backend.is_immutable_rev_checkout(req.link.url, req.source_dir): + return True + return False + + assert req.link + base, ext = req.link.splitext() + if _contains_egg_info(base): + return True + + # Otherwise, do not cache. + return False + + +def _get_cache_dir( + req: InstallRequirement, + wheel_cache: WheelCache, +) -> str: + """Return the persistent or temporary cache directory where the built + wheel need to be stored. + """ + cache_available = bool(wheel_cache.cache_dir) + assert req.link + if cache_available and _should_cache(req): + cache_dir = wheel_cache.get_path_for_link(req.link) + else: + cache_dir = wheel_cache.get_ephem_path_for_link(req.link) + return cache_dir + + +def _verify_one(req: InstallRequirement, wheel_path: str) -> None: + canonical_name = canonicalize_name(req.name or "") + w = Wheel(os.path.basename(wheel_path)) + if canonicalize_name(w.name) != canonical_name: + raise InvalidWheelFilename( + f"Wheel has unexpected file name: expected {canonical_name!r}, " + f"got {w.name!r}", + ) + dist = get_wheel_distribution(FilesystemWheel(wheel_path), canonical_name) + dist_verstr = str(dist.version) + if canonicalize_version(dist_verstr) != canonicalize_version(w.version): + raise InvalidWheelFilename( + f"Wheel has unexpected file name: expected {dist_verstr!r}, " + f"got {w.version!r}", + ) + metadata_version_value = dist.metadata_version + if metadata_version_value is None: + raise UnsupportedWheel("Missing Metadata-Version") + try: + metadata_version = Version(metadata_version_value) + except InvalidVersion: + msg = f"Invalid Metadata-Version: {metadata_version_value}" + raise UnsupportedWheel(msg) + if metadata_version >= Version("1.2") and not isinstance(dist.version, Version): + raise UnsupportedWheel( + f"Metadata 1.2 mandates PEP 440 version, but {dist_verstr!r} is not" + ) + + +def _build_one( + req: InstallRequirement, + output_dir: str, + verify: bool, + build_options: List[str], + global_options: List[str], + editable: bool, +) -> Optional[str]: + """Build one wheel. + + :return: The filename of the built wheel, or None if the build failed. + """ + artifact = "editable" if editable else "wheel" + try: + ensure_dir(output_dir) + except OSError as e: + logger.warning( + "Building %s for %s failed: %s", + artifact, + req.name, + e, + ) + return None + + # Install build deps into temporary directory (PEP 518) + with req.build_env: + wheel_path = _build_one_inside_env( + req, output_dir, build_options, global_options, editable + ) + if wheel_path and verify: + try: + _verify_one(req, wheel_path) + except (InvalidWheelFilename, UnsupportedWheel) as e: + logger.warning("Built %s for %s is invalid: %s", artifact, req.name, e) + return None + return wheel_path + + +def _build_one_inside_env( + req: InstallRequirement, + output_dir: str, + build_options: List[str], + global_options: List[str], + editable: bool, +) -> Optional[str]: + with TempDirectory(kind="wheel") as temp_dir: + assert req.name + if req.use_pep517: + assert req.metadata_directory + assert req.pep517_backend + if global_options: + logger.warning( + "Ignoring --global-option when building %s using PEP 517", req.name + ) + if build_options: + logger.warning( + "Ignoring --build-option when building %s using PEP 517", req.name + ) + if editable: + wheel_path = build_wheel_editable( + name=req.name, + backend=req.pep517_backend, + metadata_directory=req.metadata_directory, + tempd=temp_dir.path, + ) + else: + wheel_path = build_wheel_pep517( + name=req.name, + backend=req.pep517_backend, + metadata_directory=req.metadata_directory, + tempd=temp_dir.path, + ) + else: + wheel_path = build_wheel_legacy( + name=req.name, + setup_py_path=req.setup_py_path, + source_dir=req.unpacked_source_directory, + global_options=global_options, + build_options=build_options, + tempd=temp_dir.path, + ) + + if wheel_path is not None: + wheel_name = os.path.basename(wheel_path) + dest_path = os.path.join(output_dir, wheel_name) + try: + wheel_hash, length = hash_file(wheel_path) + shutil.move(wheel_path, dest_path) + logger.info( + "Created wheel for %s: filename=%s size=%d sha256=%s", + req.name, + wheel_name, + length, + wheel_hash.hexdigest(), + ) + logger.info("Stored in directory: %s", output_dir) + return dest_path + except Exception as e: + logger.warning( + "Building wheel for %s failed: %s", + req.name, + e, + ) + # Ignore return, we can't do anything else useful. + if not req.use_pep517: + _clean_one_legacy(req, global_options) + return None + + +def _clean_one_legacy(req: InstallRequirement, global_options: List[str]) -> bool: + clean_args = make_setuptools_clean_args( + req.setup_py_path, + global_options=global_options, + ) + + logger.info("Running setup.py clean for %s", req.name) + try: + call_subprocess( + clean_args, command_desc="python setup.py clean", cwd=req.source_dir + ) + return True + except Exception: + logger.error("Failed cleaning build dir for %s", req.name) + return False + + +def build( + requirements: Iterable[InstallRequirement], + wheel_cache: WheelCache, + verify: bool, + build_options: List[str], + global_options: List[str], +) -> BuildResult: + """Build wheels. + + :return: The list of InstallRequirement that succeeded to build and + the list of InstallRequirement that failed to build. + """ + if not requirements: + return [], [] + + # Build the wheels. + logger.info( + "Building wheels for collected packages: %s", + ", ".join(req.name for req in requirements), # type: ignore + ) + + with indent_log(): + build_successes, build_failures = [], [] + for req in requirements: + assert req.name + cache_dir = _get_cache_dir(req, wheel_cache) + wheel_file = _build_one( + req, + cache_dir, + verify, + build_options, + global_options, + req.editable and req.permit_editable_wheels, + ) + if wheel_file: + # Record the download origin in the cache + if req.download_info is not None: + # download_info is guaranteed to be set because when we build an + # InstallRequirement it has been through the preparer before, but + # let's be cautious. + wheel_cache.record_download_origin(cache_dir, req.download_info) + # Update the link for this. + req.link = Link(path_to_url(wheel_file)) + req.local_file_path = req.link.file_path + assert req.link.is_wheel + build_successes.append(req) + else: + build_failures.append(req) + + # notify success/failure + if build_successes: + logger.info( + "Successfully built %s", + " ".join([req.name for req in build_successes]), # type: ignore + ) + if build_failures: + logger.info( + "Failed to build %s", + " ".join([req.name for req in build_failures]), # type: ignore + ) + # Return a list of requirements that failed to build + return build_successes, build_failures diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..561089ccc0c65454bbb02d20e1c94e012a561920 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/__init__.py @@ -0,0 +1,116 @@ +""" +pip._vendor is for vendoring dependencies of pip to prevent needing pip to +depend on something external. + +Files inside of pip._vendor should be considered immutable and should only be +updated to versions from upstream. +""" +from __future__ import absolute_import + +import glob +import os.path +import sys + +# Downstream redistributors which have debundled our dependencies should also +# patch this value to be true. This will trigger the additional patching +# to cause things like "six" to be available as pip. +DEBUNDLED = False + +# By default, look in this directory for a bunch of .whl files which we will +# add to the beginning of sys.path before attempting to import anything. This +# is done to support downstream re-distributors like Debian and Fedora who +# wish to create their own Wheels for our dependencies to aid in debundling. +WHEEL_DIR = os.path.abspath(os.path.dirname(__file__)) + + +# Define a small helper function to alias our vendored modules to the real ones +# if the vendored ones do not exist. This idea of this was taken from +# https://github.com/kennethreitz/requests/pull/2567. +def vendored(modulename): + vendored_name = "{0}.{1}".format(__name__, modulename) + + try: + __import__(modulename, globals(), locals(), level=0) + except ImportError: + # We can just silently allow import failures to pass here. If we + # got to this point it means that ``import pip._vendor.whatever`` + # failed and so did ``import whatever``. Since we're importing this + # upfront in an attempt to alias imports, not erroring here will + # just mean we get a regular import error whenever pip *actually* + # tries to import one of these modules to use it, which actually + # gives us a better error message than we would have otherwise + # gotten. + pass + else: + sys.modules[vendored_name] = sys.modules[modulename] + base, head = vendored_name.rsplit(".", 1) + setattr(sys.modules[base], head, sys.modules[modulename]) + + +# If we're operating in a debundled setup, then we want to go ahead and trigger +# the aliasing of our vendored libraries as well as looking for wheels to add +# to our sys.path. This will cause all of this code to be a no-op typically +# however downstream redistributors can enable it in a consistent way across +# all platforms. +if DEBUNDLED: + # Actually look inside of WHEEL_DIR to find .whl files and add them to the + # front of our sys.path. + sys.path[:] = glob.glob(os.path.join(WHEEL_DIR, "*.whl")) + sys.path + + # Actually alias all of our vendored dependencies. + vendored("cachecontrol") + vendored("certifi") + vendored("distlib") + vendored("distro") + vendored("packaging") + vendored("packaging.version") + vendored("packaging.specifiers") + vendored("pkg_resources") + vendored("platformdirs") + vendored("progress") + vendored("pyproject_hooks") + vendored("requests") + vendored("requests.exceptions") + vendored("requests.packages") + vendored("requests.packages.urllib3") + vendored("requests.packages.urllib3._collections") + vendored("requests.packages.urllib3.connection") + vendored("requests.packages.urllib3.connectionpool") + vendored("requests.packages.urllib3.contrib") + vendored("requests.packages.urllib3.contrib.ntlmpool") + vendored("requests.packages.urllib3.contrib.pyopenssl") + vendored("requests.packages.urllib3.exceptions") + vendored("requests.packages.urllib3.fields") + vendored("requests.packages.urllib3.filepost") + vendored("requests.packages.urllib3.packages") + vendored("requests.packages.urllib3.packages.ordered_dict") + vendored("requests.packages.urllib3.packages.six") + vendored("requests.packages.urllib3.packages.ssl_match_hostname") + vendored("requests.packages.urllib3.packages.ssl_match_hostname." + "_implementation") + vendored("requests.packages.urllib3.poolmanager") + vendored("requests.packages.urllib3.request") + vendored("requests.packages.urllib3.response") + vendored("requests.packages.urllib3.util") + vendored("requests.packages.urllib3.util.connection") + vendored("requests.packages.urllib3.util.request") + vendored("requests.packages.urllib3.util.response") + vendored("requests.packages.urllib3.util.retry") + vendored("requests.packages.urllib3.util.ssl_") + vendored("requests.packages.urllib3.util.timeout") + vendored("requests.packages.urllib3.util.url") + vendored("resolvelib") + vendored("rich") + vendored("rich.console") + vendored("rich.highlighter") + vendored("rich.logging") + vendored("rich.markup") + vendored("rich.progress") + vendored("rich.segment") + vendored("rich.style") + vendored("rich.text") + vendored("rich.traceback") + if sys.version_info < (3, 11): + vendored("tomli") + vendored("truststore") + vendored("urllib3") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/typing_extensions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/typing_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..e429384e76aa9a27c3168ffc998d187ebb84ab93 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/typing_extensions.py @@ -0,0 +1,3641 @@ +import abc +import collections +import collections.abc +import contextlib +import functools +import inspect +import operator +import sys +import types as _types +import typing +import warnings + +__all__ = [ + # Super-special typing primitives. + 'Any', + 'ClassVar', + 'Concatenate', + 'Final', + 'LiteralString', + 'ParamSpec', + 'ParamSpecArgs', + 'ParamSpecKwargs', + 'Self', + 'Type', + 'TypeVar', + 'TypeVarTuple', + 'Unpack', + + # ABCs (from collections.abc). + 'Awaitable', + 'AsyncIterator', + 'AsyncIterable', + 'Coroutine', + 'AsyncGenerator', + 'AsyncContextManager', + 'Buffer', + 'ChainMap', + + # Concrete collection types. + 'ContextManager', + 'Counter', + 'Deque', + 'DefaultDict', + 'NamedTuple', + 'OrderedDict', + 'TypedDict', + + # Structural checks, a.k.a. protocols. + 'SupportsAbs', + 'SupportsBytes', + 'SupportsComplex', + 'SupportsFloat', + 'SupportsIndex', + 'SupportsInt', + 'SupportsRound', + + # One-off things. + 'Annotated', + 'assert_never', + 'assert_type', + 'clear_overloads', + 'dataclass_transform', + 'deprecated', + 'Doc', + 'get_overloads', + 'final', + 'get_args', + 'get_origin', + 'get_original_bases', + 'get_protocol_members', + 'get_type_hints', + 'IntVar', + 'is_protocol', + 'is_typeddict', + 'Literal', + 'NewType', + 'overload', + 'override', + 'Protocol', + 'reveal_type', + 'runtime', + 'runtime_checkable', + 'Text', + 'TypeAlias', + 'TypeAliasType', + 'TypeGuard', + 'TypeIs', + 'TYPE_CHECKING', + 'Never', + 'NoReturn', + 'ReadOnly', + 'Required', + 'NotRequired', + + # Pure aliases, have always been in typing + 'AbstractSet', + 'AnyStr', + 'BinaryIO', + 'Callable', + 'Collection', + 'Container', + 'Dict', + 'ForwardRef', + 'FrozenSet', + 'Generator', + 'Generic', + 'Hashable', + 'IO', + 'ItemsView', + 'Iterable', + 'Iterator', + 'KeysView', + 'List', + 'Mapping', + 'MappingView', + 'Match', + 'MutableMapping', + 'MutableSequence', + 'MutableSet', + 'NoDefault', + 'Optional', + 'Pattern', + 'Reversible', + 'Sequence', + 'Set', + 'Sized', + 'TextIO', + 'Tuple', + 'Union', + 'ValuesView', + 'cast', + 'no_type_check', + 'no_type_check_decorator', +] + +# for backward compatibility +PEP_560 = True +GenericMeta = type +_PEP_696_IMPLEMENTED = sys.version_info >= (3, 13, 0, "beta") + +# The functions below are modified copies of typing internal helpers. +# They are needed by _ProtocolMeta and they provide support for PEP 646. + + +class _Sentinel: + def __repr__(self): + return "" + + +_marker = _Sentinel() + + +if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): + return isinstance( + t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) + ) +elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): + return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) +else: + def _should_collect_from_parameters(t): + return isinstance(t, typing._GenericAlias) and not t._special + + +NoReturn = typing.NoReturn + +# Some unconstrained type variables. These are used by the container types. +# (These are not for export.) +T = typing.TypeVar('T') # Any type. +KT = typing.TypeVar('KT') # Key type. +VT = typing.TypeVar('VT') # Value type. +T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. + + +if sys.version_info >= (3, 11): + from typing import Any +else: + + class _AnyMeta(type): + def __instancecheck__(self, obj): + if self is Any: + raise TypeError("typing_extensions.Any cannot be used with isinstance()") + return super().__instancecheck__(obj) + + def __repr__(self): + if self is Any: + return "typing_extensions.Any" + return super().__repr__() + + class Any(metaclass=_AnyMeta): + """Special type indicating an unconstrained type. + - Any is compatible with every type. + - Any assumed to have all methods. + - All values assumed to be instances of Any. + Note that all the above statements are true from the point of view of + static type checkers. At runtime, Any should not be used with instance + checks. + """ + def __new__(cls, *args, **kwargs): + if cls is Any: + raise TypeError("Any cannot be instantiated") + return super().__new__(cls, *args, **kwargs) + + +ClassVar = typing.ClassVar + + +class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + +Final = typing.Final + +if sys.version_info >= (3, 11): + final = typing.final +else: + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + +def IntVar(name): + return typing.TypeVar(name) + + +# A Literal bug was fixed in 3.11.0, 3.10.1 and 3.9.8 +if sys.version_info >= (3, 10, 1): + Literal = typing.Literal +else: + def _flatten_literal_params(parameters): + """An internal helper for Literal creation: flatten Literals among parameters""" + params = [] + for p in parameters: + if isinstance(p, _LiteralGenericAlias): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + def _value_and_type_iter(params): + for p in params: + yield p, type(p) + + class _LiteralGenericAlias(typing._GenericAlias, _root=True): + def __eq__(self, other): + if not isinstance(other, _LiteralGenericAlias): + return NotImplemented + these_args_deduped = set(_value_and_type_iter(self.__args__)) + other_args_deduped = set(_value_and_type_iter(other.__args__)) + return these_args_deduped == other_args_deduped + + def __hash__(self): + return hash(frozenset(_value_and_type_iter(self.__args__))) + + class _LiteralForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, doc: str): + self._name = 'Literal' + self._doc = self.__doc__ = doc + + def __getitem__(self, parameters): + if not isinstance(parameters, tuple): + parameters = (parameters,) + + parameters = _flatten_literal_params(parameters) + + val_type_pairs = list(_value_and_type_iter(parameters)) + try: + deduped_pairs = set(val_type_pairs) + except TypeError: + # unhashable parameters + pass + else: + # similar logic to typing._deduplicate on Python 3.9+ + if len(deduped_pairs) < len(val_type_pairs): + new_parameters = [] + for pair in val_type_pairs: + if pair in deduped_pairs: + new_parameters.append(pair[0]) + deduped_pairs.remove(pair) + assert not deduped_pairs, deduped_pairs + parameters = tuple(new_parameters) + + return _LiteralGenericAlias(self, parameters) + + Literal = _LiteralForm(doc="""\ + A type that can be used to indicate to type checkers + that the corresponding value has a value literally equivalent + to the provided parameter. For example: + + var: Literal[4] = 4 + + The type checker understands that 'var' is literally equal to + the value 4 and no other value. + + Literal[...] cannot be subclassed. There is no runtime + checking verifying that the parameter is actually a value + instead of a type.""") + + +_overload_dummy = typing._overload_dummy + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + +# This is not a real generic class. Don't use outside annotations. +Type = typing.Type + +# Various ABCs mimicking those in collections.abc. +# A few are simply re-exported for completeness. +Awaitable = typing.Awaitable +Coroutine = typing.Coroutine +AsyncIterable = typing.AsyncIterable +AsyncIterator = typing.AsyncIterator +Deque = typing.Deque +DefaultDict = typing.DefaultDict +OrderedDict = typing.OrderedDict +Counter = typing.Counter +ChainMap = typing.ChainMap +Text = typing.Text +TYPE_CHECKING = typing.TYPE_CHECKING + + +if sys.version_info >= (3, 13, 0, "beta"): + from typing import AsyncContextManager, AsyncGenerator, ContextManager, Generator +else: + def _is_dunder(attr): + return attr.startswith('__') and attr.endswith('__') + + # Python <3.9 doesn't have typing._SpecialGenericAlias + _special_generic_alias_base = getattr( + typing, "_SpecialGenericAlias", typing._GenericAlias + ) + + class _SpecialGenericAlias(_special_generic_alias_base, _root=True): + def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): + if _special_generic_alias_base is typing._GenericAlias: + # Python <3.9 + self.__origin__ = origin + self._nparams = nparams + super().__init__(origin, nparams, special=True, inst=inst, name=name) + else: + # Python >= 3.9 + super().__init__(origin, nparams, inst=inst, name=name) + self._defaults = defaults + + def __setattr__(self, attr, val): + allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'} + if _special_generic_alias_base is typing._GenericAlias: + # Python <3.9 + allowed_attrs.add("__origin__") + if _is_dunder(attr) or attr in allowed_attrs: + object.__setattr__(self, attr, val) + else: + setattr(self.__origin__, attr, val) + + @typing._tp_cache + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) + if ( + self._defaults + and len(params) < self._nparams + and len(params) + len(self._defaults) >= self._nparams + ): + params = (*params, *self._defaults[len(params) - self._nparams:]) + actual_len = len(params) + + if actual_len != self._nparams: + if self._defaults: + expected = f"at least {self._nparams - len(self._defaults)}" + else: + expected = str(self._nparams) + if not self._nparams: + raise TypeError(f"{self} is not a generic class") + raise TypeError( + f"Too {'many' if actual_len > self._nparams else 'few'}" + f" arguments for {self};" + f" actual {actual_len}, expected {expected}" + ) + return self.copy_with(params) + + _NoneType = type(None) + Generator = _SpecialGenericAlias( + collections.abc.Generator, 3, defaults=(_NoneType, _NoneType) + ) + AsyncGenerator = _SpecialGenericAlias( + collections.abc.AsyncGenerator, 2, defaults=(_NoneType,) + ) + ContextManager = _SpecialGenericAlias( + contextlib.AbstractContextManager, + 2, + name="ContextManager", + defaults=(typing.Optional[bool],) + ) + AsyncContextManager = _SpecialGenericAlias( + contextlib.AbstractAsyncContextManager, + 2, + name="AsyncContextManager", + defaults=(typing.Optional[bool],) + ) + + +_PROTO_ALLOWLIST = { + 'collections.abc': [ + 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Buffer', + ], + 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], + 'typing_extensions': ['Buffer'], +} + + +_EXCLUDED_ATTRS = frozenset(typing.EXCLUDED_ATTRIBUTES) | { + "__match_args__", "__protocol_attrs__", "__non_callable_proto_members__", + "__final__", +} + + +def _get_protocol_attrs(cls): + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in {'Protocol', 'Generic'}: + continue + annotations = getattr(base, '__annotations__', {}) + for attr in (*base.__dict__, *annotations): + if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): + attrs.add(attr) + return attrs + + +def _caller(depth=2): + try: + return sys._getframe(depth).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + +# `__match_args__` attribute was removed from protocol members in 3.13, +# we want to backport this change to older Python versions. +if sys.version_info >= (3, 13): + Protocol = typing.Protocol +else: + def _allow_reckless_class_checks(depth=3): + """Allow instance and class checks for special stdlib modules. + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + return _caller(depth) in {'abc', 'functools', None} + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + def _type_check_issubclass_arg_1(arg): + """Raise TypeError if `arg` is not an instance of `type` + in `issubclass(arg, )`. + + In most cases, this is verified by type.__subclasscheck__. + Checking it again unnecessarily would slow down issubclass() checks, + so, we don't perform this check unless we absolutely have to. + + For various error paths, however, + we want to ensure that *this* error message is shown to the user + where relevant, rather than a typing.py-specific error message. + """ + if not isinstance(arg, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + + # Inheriting from typing._ProtocolMeta isn't actually desirable, + # but is necessary to allow typing.Protocol and typing_extensions.Protocol + # to mix without getting TypeErrors about "metaclass conflict" + class _ProtocolMeta(type(typing.Protocol)): + # This metaclass is somewhat unfortunate, + # but is necessary for several reasons... + # + # NOTE: DO NOT call super() in any methods in this class + # That would call the methods on typing._ProtocolMeta on Python 3.8-3.11 + # and those are slow + def __new__(mcls, name, bases, namespace, **kwargs): + if name == "Protocol" and len(bases) < 2: + pass + elif {Protocol, typing.Protocol} & set(bases): + for base in bases: + if not ( + base in {object, typing.Generic, Protocol, typing.Protocol} + or base.__name__ in _PROTO_ALLOWLIST.get(base.__module__, []) + or is_protocol(base) + ): + raise TypeError( + f"Protocols can only inherit from other protocols, " + f"got {base!r}" + ) + return abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) + + def __init__(cls, *args, **kwargs): + abc.ABCMeta.__init__(cls, *args, **kwargs) + if getattr(cls, "_is_protocol", False): + cls.__protocol_attrs__ = _get_protocol_attrs(cls) + + def __subclasscheck__(cls, other): + if cls is Protocol: + return type.__subclasscheck__(cls, other) + if ( + getattr(cls, '_is_protocol', False) + and not _allow_reckless_class_checks() + ): + if not getattr(cls, '_is_runtime_protocol', False): + _type_check_issubclass_arg_1(other) + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) + if ( + # this attribute is set by @runtime_checkable: + cls.__non_callable_proto_members__ + and cls.__dict__.get("__subclasshook__") is _proto_hook + ): + _type_check_issubclass_arg_1(other) + non_method_attrs = sorted(cls.__non_callable_proto_members__) + raise TypeError( + "Protocols with non-method members don't support issubclass()." + f" Non-method members: {str(non_method_attrs)[1:-1]}." + ) + return abc.ABCMeta.__subclasscheck__(cls, other) + + def __instancecheck__(cls, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if cls is Protocol: + return type.__instancecheck__(cls, instance) + if not getattr(cls, "_is_protocol", False): + # i.e., it's a concrete subclass of a protocol + return abc.ABCMeta.__instancecheck__(cls, instance) + + if ( + not getattr(cls, '_is_runtime_protocol', False) and + not _allow_reckless_class_checks() + ): + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + + if abc.ABCMeta.__instancecheck__(cls, instance): + return True + + for attr in cls.__protocol_attrs__: + try: + val = inspect.getattr_static(instance, attr) + except AttributeError: + break + # this attribute is set by @runtime_checkable: + if val is None and attr not in cls.__non_callable_proto_members__: + break + else: + return True + + return False + + def __eq__(cls, other): + # Hack so that typing.Generic.__class_getitem__ + # treats typing_extensions.Protocol + # as equivalent to typing.Protocol + if abc.ABCMeta.__eq__(cls, other) is True: + return True + return cls is Protocol and other is typing.Protocol + + # This has to be defined, or the abc-module cache + # complains about classes with this metaclass being unhashable, + # if we define only __eq__! + def __hash__(cls) -> int: + return type.__hash__(cls) + + @classmethod + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented + + for attr in cls.__protocol_attrs__: + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if ( + isinstance(annotations, collections.abc.Mapping) + and attr in annotations + and is_protocol(other) + ): + break + else: + return NotImplemented + return True + + class Protocol(typing.Generic, metaclass=_ProtocolMeta): + __doc__ = typing.Protocol.__doc__ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', False): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # Prohibit instantiation for protocol classes + if cls._is_protocol and cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init + + +if sys.version_info >= (3, 13): + runtime_checkable = typing.runtime_checkable +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol. + + Such protocol can be used with isinstance() and issubclass(). + Raise TypeError if applied to a non-protocol class. + This allows a simple-minded structural check very similar to + one trick ponies in collections.abc such as Iterable. + + For example:: + + @runtime_checkable + class Closable(Protocol): + def close(self): ... + + assert isinstance(open('/some/file'), Closable) + + Warning: this will check only the presence of the required methods, + not their type signatures! + """ + if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False): + raise TypeError(f'@runtime_checkable can be only applied to protocol classes,' + f' got {cls!r}') + cls._is_runtime_protocol = True + + # typing.Protocol classes on <=3.11 break if we execute this block, + # because typing.Protocol classes on <=3.11 don't have a + # `__protocol_attrs__` attribute, and this block relies on the + # `__protocol_attrs__` attribute. Meanwhile, typing.Protocol classes on 3.12.2+ + # break if we *don't* execute this block, because *they* assume that all + # protocol classes have a `__non_callable_proto_members__` attribute + # (which this block sets) + if isinstance(cls, _ProtocolMeta) or sys.version_info >= (3, 12, 2): + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in cls.__protocol_attrs__: + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) + + return cls + + +# The "runtime" alias exists for backwards compatibility. +runtime = runtime_checkable + + +# Our version of runtime-checkable protocols is faster on Python 3.8-3.11 +if sys.version_info >= (3, 12): + SupportsInt = typing.SupportsInt + SupportsFloat = typing.SupportsFloat + SupportsComplex = typing.SupportsComplex + SupportsBytes = typing.SupportsBytes + SupportsIndex = typing.SupportsIndex + SupportsAbs = typing.SupportsAbs + SupportsRound = typing.SupportsRound +else: + @runtime_checkable + class SupportsInt(Protocol): + """An ABC with one abstract method __int__.""" + __slots__ = () + + @abc.abstractmethod + def __int__(self) -> int: + pass + + @runtime_checkable + class SupportsFloat(Protocol): + """An ABC with one abstract method __float__.""" + __slots__ = () + + @abc.abstractmethod + def __float__(self) -> float: + pass + + @runtime_checkable + class SupportsComplex(Protocol): + """An ABC with one abstract method __complex__.""" + __slots__ = () + + @abc.abstractmethod + def __complex__(self) -> complex: + pass + + @runtime_checkable + class SupportsBytes(Protocol): + """An ABC with one abstract method __bytes__.""" + __slots__ = () + + @abc.abstractmethod + def __bytes__(self) -> bytes: + pass + + @runtime_checkable + class SupportsIndex(Protocol): + __slots__ = () + + @abc.abstractmethod + def __index__(self) -> int: + pass + + @runtime_checkable + class SupportsAbs(Protocol[T_co]): + """ + An ABC with one abstract method __abs__ that is covariant in its return type. + """ + __slots__ = () + + @abc.abstractmethod + def __abs__(self) -> T_co: + pass + + @runtime_checkable + class SupportsRound(Protocol[T_co]): + """ + An ABC with one abstract method __round__ that is covariant in its return type. + """ + __slots__ = () + + @abc.abstractmethod + def __round__(self, ndigits: int = 0) -> T_co: + pass + + +def _ensure_subclassable(mro_entries): + def inner(func): + if sys.implementation.name == "pypy" and sys.version_info < (3, 9): + cls_dict = { + "__call__": staticmethod(func), + "__mro_entries__": staticmethod(mro_entries) + } + t = type(func.__name__, (), cls_dict) + return functools.update_wrapper(t(), func) + else: + func.__mro_entries__ = mro_entries + return func + return inner + + +# Update this to something like >=3.13.0b1 if and when +# PEP 728 is implemented in CPython +_PEP_728_IMPLEMENTED = False + +if _PEP_728_IMPLEMENTED: + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834 + # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" + # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 + # The standard library TypedDict below Python 3.11 does not store runtime + # information about optional and required keys when using Required or NotRequired. + # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. + # Aaaand on 3.12 we add __orig_bases__ to TypedDict + # to enable better runtime introspection. + # On 3.13 we deprecate some odd ways of creating TypedDicts. + # Also on 3.13, PEP 705 adds the ReadOnly[] qualifier. + # PEP 728 (still pending) makes more changes. + TypedDict = typing.TypedDict + _TypedDictMeta = typing._TypedDictMeta + is_typeddict = typing.is_typeddict +else: + # 3.10.0 and later + _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters + + def _get_typeddict_qualifiers(annotation_type): + while True: + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + else: + break + elif annotation_origin is Required: + yield Required + annotation_type, = get_args(annotation_type) + elif annotation_origin is NotRequired: + yield NotRequired + annotation_type, = get_args(annotation_type) + elif annotation_origin is ReadOnly: + yield ReadOnly + annotation_type, = get_args(annotation_type) + else: + break + + class _TypedDictMeta(type): + def __new__(cls, name, bases, ns, *, total=True, closed=False): + """Create new typed dict class object. + + This method is called when TypedDict is subclassed, + or when TypedDict is instantiated. This way + TypedDict supports all three syntax forms described in its docstring. + Subclasses and instances of TypedDict return actual dictionaries. + """ + for base in bases: + if type(base) is not _TypedDictMeta and base is not typing.Generic: + raise TypeError('cannot inherit from both a TypedDict type ' + 'and a non-TypedDict base class') + + if any(issubclass(b, typing.Generic) for b in bases): + generic_base = (typing.Generic,) + else: + generic_base = () + + # typing.py generally doesn't let you inherit from plain Generic, unless + # the name of the class happens to be "Protocol" + tp_dict = type.__new__(_TypedDictMeta, "Protocol", (*generic_base, dict), ns) + tp_dict.__name__ = name + if tp_dict.__qualname__ == "Protocol": + tp_dict.__qualname__ = name + + if not hasattr(tp_dict, '__orig_bases__'): + tp_dict.__orig_bases__ = bases + + annotations = {} + if "__annotations__" in ns: + own_annotations = ns["__annotations__"] + elif "__annotate__" in ns: + # TODO: Use inspect.VALUE here, and make the annotations lazily evaluated + own_annotations = ns["__annotate__"](1) + else: + own_annotations = {} + msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + if _TAKES_MODULE: + own_annotations = { + n: typing._type_check(tp, msg, module=tp_dict.__module__) + for n, tp in own_annotations.items() + } + else: + own_annotations = { + n: typing._type_check(tp, msg) + for n, tp in own_annotations.items() + } + required_keys = set() + optional_keys = set() + readonly_keys = set() + mutable_keys = set() + extra_items_type = None + + for base in bases: + base_dict = base.__dict__ + + annotations.update(base_dict.get('__annotations__', {})) + required_keys.update(base_dict.get('__required_keys__', ())) + optional_keys.update(base_dict.get('__optional_keys__', ())) + readonly_keys.update(base_dict.get('__readonly_keys__', ())) + mutable_keys.update(base_dict.get('__mutable_keys__', ())) + base_extra_items_type = base_dict.get('__extra_items__', None) + if base_extra_items_type is not None: + extra_items_type = base_extra_items_type + + if closed and extra_items_type is None: + extra_items_type = Never + if closed and "__extra_items__" in own_annotations: + annotation_type = own_annotations.pop("__extra_items__") + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) + if Required in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "Required" + ) + if NotRequired in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "NotRequired" + ) + extra_items_type = annotation_type + + annotations.update(own_annotations) + for annotation_key, annotation_type in own_annotations.items(): + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) + + if Required in qualifiers: + required_keys.add(annotation_key) + elif NotRequired in qualifiers: + optional_keys.add(annotation_key) + elif total: + required_keys.add(annotation_key) + else: + optional_keys.add(annotation_key) + if ReadOnly in qualifiers: + mutable_keys.discard(annotation_key) + readonly_keys.add(annotation_key) + else: + mutable_keys.add(annotation_key) + readonly_keys.discard(annotation_key) + + tp_dict.__annotations__ = annotations + tp_dict.__required_keys__ = frozenset(required_keys) + tp_dict.__optional_keys__ = frozenset(optional_keys) + tp_dict.__readonly_keys__ = frozenset(readonly_keys) + tp_dict.__mutable_keys__ = frozenset(mutable_keys) + if not hasattr(tp_dict, '__total__'): + tp_dict.__total__ = total + tp_dict.__closed__ = closed + tp_dict.__extra_items__ = extra_items_type + return tp_dict + + __call__ = dict # static method + + def __subclasscheck__(cls, other): + # Typed dicts are only for static structural subtyping. + raise TypeError('TypedDict does not support instance and class checks') + + __instancecheck__ = __subclasscheck__ + + _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) + + @_ensure_subclassable(lambda bases: (_TypedDict,)) + def TypedDict(typename, fields=_marker, /, *, total=True, closed=False, **kwargs): + """A simple typed namespace. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type such that a type checker will expect all + instances to have a certain set of keys, where each key is + associated with a value of a consistent type. This expectation + is not checked at runtime. + + Usage:: + + class Point2D(TypedDict): + x: int + y: int + label: str + + a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + + assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + + The type info can be accessed via the Point2D.__annotations__ dict, and + the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. + TypedDict supports an additional equivalent form:: + + Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) + + By default, all keys must be present in a TypedDict. It is possible + to override this by specifying totality:: + + class Point2D(TypedDict, total=False): + x: int + y: int + + This means that a Point2D TypedDict can have any of the keys omitted. A type + checker is only expected to support a literal False or True as the value of + the total argument. True is the default, and makes all items defined in the + class body be required. + + The Required and NotRequired special forms can also be used to mark + individual keys as being required or not required:: + + class Point2D(TypedDict): + x: int # the "x" key must always be present (Required is the default) + y: NotRequired[int] # the "y" key can be omitted + + See PEP 655 for more details on Required and NotRequired. + """ + if fields is _marker or fields is None: + if fields is _marker: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + + example = f"`{typename} = TypedDict({typename!r}, {{}})`" + deprecation_msg = ( + f"{deprecated_thing} is deprecated and will be disallowed in " + "Python 3.15. To create a TypedDict class with 0 fields " + "using the functional syntax, pass an empty dictionary, e.g. " + ) + example + "." + warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) + if closed is not False and closed is not True: + kwargs["closed"] = closed + closed = False + fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") + if kwargs: + if sys.version_info >= (3, 13): + raise TypeError("TypedDict takes no keyword arguments") + warnings.warn( + "The kwargs-based syntax for TypedDict definitions is deprecated " + "in Python 3.11, will be removed in Python 3.13, and may not be " + "understood by third-party type checkers.", + DeprecationWarning, + stacklevel=2, + ) + + ns = {'__annotations__': dict(fields)} + module = _caller() + if module is not None: + # Setting correct module is necessary to make typed dict classes pickleable. + ns['__module__'] = module + + td = _TypedDictMeta(typename, (), ns, total=total, closed=closed) + td.__orig_bases__ = (TypedDict,) + return td + + if hasattr(typing, "_TypedDictMeta"): + _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) + else: + _TYPEDDICT_TYPES = (_TypedDictMeta,) + + def is_typeddict(tp): + """Check if an annotation is a TypedDict class + + For example:: + class Film(TypedDict): + title: str + year: int + + is_typeddict(Film) # => True + is_typeddict(Union[list, str]) # => False + """ + # On 3.8, this would otherwise return True + if hasattr(typing, "TypedDict") and tp is typing.TypedDict: + return False + return isinstance(tp, _TYPEDDICT_TYPES) + + +if hasattr(typing, "assert_type"): + assert_type = typing.assert_type + +else: + def assert_type(val, typ, /): + """Assert (to the type checker) that the value is of the given type. + + When the type checker encounters a call to assert_type(), it + emits an error if the value is not of the specified type:: + + def greet(name: str) -> None: + assert_type(name, str) # ok + assert_type(name, int) # type checker error + + At runtime this returns the first argument unchanged and otherwise + does nothing. + """ + return val + + +if hasattr(typing, "ReadOnly"): # 3.13+ + get_type_hints = typing.get_type_hints +else: # <=3.13 + # replaces _strip_annotations() + def _strip_extras(t): + """Strips Annotated, Required and NotRequired from a given type.""" + if isinstance(t, _AnnotatedAlias): + return _strip_extras(t.__origin__) + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly): + return _strip_extras(t.__args__[0]) + if isinstance(t, typing._GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return t.copy_with(stripped_args) + if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return _types.GenericAlias(t.__origin__, stripped_args) + if hasattr(_types, "UnionType") and isinstance(t, _types.UnionType): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + + return t + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + """Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]', 'Required[T]' or 'NotRequired[T]' with 'T' + (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + if hasattr(typing, "Annotated"): # 3.9+ + hint = typing.get_type_hints( + obj, globalns=globalns, localns=localns, include_extras=True + ) + else: # 3.8 + hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if include_extras: + return hint + return {k: _strip_extras(t) for k, t in hint.items()} + + +# Python 3.9+ has PEP 593 (Annotated) +if hasattr(typing, 'Annotated'): + Annotated = typing.Annotated + # Not exported and not a public API, but needed for get_origin() and get_args() + # to work. + _AnnotatedAlias = typing._AnnotatedAlias +# 3.8 +else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): + """Runtime representation of an annotated type. + + At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' + with extra annotations. The alias behaves like a normal typing alias, + instantiating is the same as instantiating the underlying type, binding + it to types is also the same. + """ + def __init__(self, origin, metadata): + if isinstance(origin, _AnnotatedAlias): + metadata = origin.__metadata__ + metadata + origin = origin.__origin__ + super().__init__(origin, origin) + self.__metadata__ = metadata + + def copy_with(self, params): + assert len(params) == 1 + new_type = params[0] + return _AnnotatedAlias(new_type, self.__metadata__) + + def __repr__(self): + return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]") + + def __reduce__(self): + return operator.getitem, ( + Annotated, (self.__origin__, *self.__metadata__) + ) + + def __eq__(self, other): + if not isinstance(other, _AnnotatedAlias): + return NotImplemented + if self.__origin__ != other.__origin__: + return False + return self.__metadata__ == other.__metadata__ + + def __hash__(self): + return hash((self.__origin__, self.__metadata__)) + + class Annotated: + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type (and will be in + the __origin__ field), the remaining arguments are kept as a tuple in + the __extra__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize()] + Optimized[int] == Annotated[int, runtime.Optimize()] + + OptimizedList = Annotated[List[T], runtime.Optimize()] + OptimizedList[int] == Annotated[List[int], runtime.Optimize()] + """ + + __slots__ = () + + def __new__(cls, *args, **kwargs): + raise TypeError("Type Annotated cannot be instantiated.") + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + allowed_special_forms = (ClassVar, Final) + if get_origin(params[0]) in allowed_special_forms: + origin = params[0] + else: + msg = "Annotated[t, ...]: t must be a type." + origin = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return _AnnotatedAlias(origin, metadata) + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + f"Cannot subclass {cls.__module__}.Annotated" + ) + +# Python 3.8 has get_origin() and get_args() but those implementations aren't +# Annotated-aware, so we can't use those. Python 3.9's versions don't support +# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. +if sys.version_info[:2] >= (3, 10): + get_origin = typing.get_origin + get_args = typing.get_args +# 3.8-3.9 +else: + try: + # 3.9+ + from typing import _BaseGenericAlias + except ImportError: + _BaseGenericAlias = typing._GenericAlias + try: + # 3.9+ + from typing import GenericAlias as _typing_GenericAlias + except ImportError: + _typing_GenericAlias = typing._GenericAlias + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar + and Annotated. Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + get_origin(P.args) is P + """ + if isinstance(tp, _AnnotatedAlias): + return Annotated + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, + ParamSpecArgs, ParamSpecKwargs)): + return tp.__origin__ + if tp is typing.Generic: + return typing.Generic + return None + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _AnnotatedAlias): + return (tp.__origin__, *tp.__metadata__) + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias)): + if getattr(tp, "_special", False): + return () + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () + + +# 3.10+ +if hasattr(typing, 'TypeAlias'): + TypeAlias = typing.TypeAlias +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeAlias(self, parameters): + """Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example above. + """ + raise TypeError(f"{self} is not subscriptable") +# 3.8 +else: + TypeAlias = _ExtensionsSpecialForm( + 'TypeAlias', + doc="""Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example + above.""" + ) + + +if hasattr(typing, "NoDefault"): + NoDefault = typing.NoDefault +else: + class NoDefaultTypeMeta(type): + def __setattr__(cls, attr, value): + # TypeError is consistent with the behavior of NoneType + raise TypeError( + f"cannot set {attr!r} attribute of immutable type {cls.__name__!r}" + ) + + class NoDefaultType(metaclass=NoDefaultTypeMeta): + """The type of the NoDefault singleton.""" + + __slots__ = () + + def __new__(cls): + return globals().get("NoDefault") or object.__new__(cls) + + def __repr__(self): + return "typing_extensions.NoDefault" + + def __reduce__(self): + return "NoDefault" + + NoDefault = NoDefaultType() + del NoDefaultType, NoDefaultTypeMeta + + +def _set_default(type_param, default): + type_param.has_default = lambda: default is not NoDefault + type_param.__default__ = default + + +def _set_module(typevarlike): + # for pickling: + def_mod = _caller(depth=3) + if def_mod != 'typing_extensions': + typevarlike.__module__ = def_mod + + +class _DefaultMixin: + """Mixin for TypeVarLike defaults.""" + + __slots__ = () + __init__ = _set_default + + +# Classes using this metaclass must provide a _backported_typevarlike ClassVar +class _TypeVarLikeMeta(type): + def __instancecheck__(cls, __instance: Any) -> bool: + return isinstance(__instance, cls._backported_typevarlike) + + +if _PEP_696_IMPLEMENTED: + from typing import TypeVar +else: + # Add default and infer_variance parameters from PEP 696 and 695 + class TypeVar(metaclass=_TypeVarLikeMeta): + """Type variable.""" + + _backported_typevarlike = typing.TypeVar + + def __new__(cls, name, *constraints, bound=None, + covariant=False, contravariant=False, + default=NoDefault, infer_variance=False): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented (3.12+), can pass infer_variance to typing.TypeVar + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant, + infer_variance=infer_variance) + else: + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant) + if infer_variance and (covariant or contravariant): + raise ValueError("Variance cannot be specified with infer_variance.") + typevar.__infer_variance__ = infer_variance + + _set_default(typevar, default) + _set_module(typevar) + + def _tvar_prepare_subst(alias, args): + if ( + typevar.has_default() + and alias.__parameters__.index(typevar) == len(args) + ): + args += (typevar.__default__,) + return args + + typevar.__typing_prepare_subst__ = _tvar_prepare_subst + return typevar + + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.TypeVar' is not an acceptable base type") + + +# Python 3.10+ has PEP 612 +if hasattr(typing, 'ParamSpecArgs'): + ParamSpecArgs = typing.ParamSpecArgs + ParamSpecKwargs = typing.ParamSpecKwargs +# 3.8-3.9 +else: + class _Immutable: + """Mixin to indicate that object should not be copied.""" + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + class ParamSpecArgs(_Immutable): + """The args for a ParamSpec object. + + Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. + + ParamSpecArgs objects have a reference back to their ParamSpec: + + P.args.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.args" + + def __eq__(self, other): + if not isinstance(other, ParamSpecArgs): + return NotImplemented + return self.__origin__ == other.__origin__ + + class ParamSpecKwargs(_Immutable): + """The kwargs for a ParamSpec object. + + Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. + + ParamSpecKwargs objects have a reference back to their ParamSpec: + + P.kwargs.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.kwargs" + + def __eq__(self, other): + if not isinstance(other, ParamSpecKwargs): + return NotImplemented + return self.__origin__ == other.__origin__ + + +if _PEP_696_IMPLEMENTED: + from typing import ParamSpec + +# 3.10+ +elif hasattr(typing, 'ParamSpec'): + + # Add default parameter - PEP 696 + class ParamSpec(metaclass=_TypeVarLikeMeta): + """Parameter specification.""" + + _backported_typevarlike = typing.ParamSpec + + def __new__(cls, name, *, bound=None, + covariant=False, contravariant=False, + infer_variance=False, default=NoDefault): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented, can pass infer_variance to typing.TypeVar + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance) + else: + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant) + paramspec.__infer_variance__ = infer_variance + + _set_default(paramspec, default) + _set_module(paramspec) + + def _paramspec_prepare_subst(alias, args): + params = alias.__parameters__ + i = params.index(paramspec) + if i == len(args) and paramspec.has_default(): + args = [*args, paramspec.__default__] + if i >= len(args): + raise TypeError(f"Too few arguments for {alias}") + # Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612. + if len(params) == 1 and not typing._is_param_expr(args[0]): + assert i == 0 + args = (args,) + # Convert lists to tuples to help other libraries cache the results. + elif isinstance(args[i], list): + args = (*args[:i], tuple(args[i]), *args[i + 1:]) + return args + + paramspec.__typing_prepare_subst__ = _paramspec_prepare_subst + return paramspec + + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + +# 3.8-3.9 +else: + + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class ParamSpec(list, _DefaultMixin): + """Parameter specification variable. + + Usage:: + + P = ParamSpec('P') + + Parameter specification variables exist primarily for the benefit of static + type checkers. They are used to forward the parameter types of one + callable to another callable, a pattern commonly found in higher order + functions and decorators. They are only valid when used in ``Concatenate``, + or s the first argument to ``Callable``. In Python 3.10 and higher, + they are also supported in user-defined Generics at runtime. + See class Generic for more information on generic types. An + example for annotating a decorator:: + + T = TypeVar('T') + P = ParamSpec('P') + + def add_logging(f: Callable[P, T]) -> Callable[P, T]: + '''A type-safe decorator to add logging to a function.''' + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + logging.info(f'{f.__name__} was called') + return f(*args, **kwargs) + return inner + + @add_logging + def add_two(x: float, y: float) -> float: + '''Add two numbers together.''' + return x + y + + Parameter specification variables defined with covariant=True or + contravariant=True can be used to declare covariant or contravariant + generic types. These keyword arguments are valid, but their actual semantics + are yet to be decided. See PEP 612 for details. + + Parameter specification variables can be introspected. e.g.: + + P.__name__ == 'T' + P.__bound__ == None + P.__covariant__ == False + P.__contravariant__ == False + + Note that only parameter specification variables defined in global scope can + be pickled. + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + @property + def args(self): + return ParamSpecArgs(self) + + @property + def kwargs(self): + return ParamSpecKwargs(self) + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + infer_variance=False, default=NoDefault): + list.__init__(self, [self]) + self.__name__ = name + self.__covariant__ = bool(covariant) + self.__contravariant__ = bool(contravariant) + self.__infer_variance__ = bool(infer_variance) + if bound: + self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + else: + self.__bound__ = None + _DefaultMixin.__init__(self, default) + + # for pickling: + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __repr__(self): + if self.__infer_variance__: + prefix = '' + elif self.__covariant__: + prefix = '+' + elif self.__contravariant__: + prefix = '-' + else: + prefix = '~' + return prefix + self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + # Hack to get typing._type_check to pass. + def __call__(self, *args, **kwargs): + pass + + +# 3.8-3.9 +if not hasattr(typing, 'Concatenate'): + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class _ConcatenateGenericAlias(list): + + # Trick Generic into looking into this for __parameters__. + __class__ = typing._GenericAlias + + # Flag in 3.8. + _special = False + + def __init__(self, origin, args): + super().__init__(args) + self.__origin__ = origin + self.__args__ = args + + def __repr__(self): + _type_repr = typing._type_repr + return (f'{_type_repr(self.__origin__)}' + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + + def __hash__(self): + return hash((self.__origin__, self.__args__)) + + # Hack to get typing._type_check to pass in Generic. + def __call__(self, *args, **kwargs): + pass + + @property + def __parameters__(self): + return tuple( + tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + ) + + +# 3.8-3.9 +@typing._tp_cache +def _concatenate_getitem(self, parameters): + if parameters == (): + raise TypeError("Cannot take a Concatenate of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + if not isinstance(parameters[-1], ParamSpec): + raise TypeError("The last parameter to Concatenate should be a " + "ParamSpec variable.") + msg = "Concatenate[arg, ...]: each arg must be a type." + parameters = tuple(typing._type_check(p, msg) for p in parameters) + return _ConcatenateGenericAlias(self, parameters) + + +# 3.10+ +if hasattr(typing, 'Concatenate'): + Concatenate = typing.Concatenate + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def Concatenate(self, parameters): + """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """ + return _concatenate_getitem(self, parameters) +# 3.8 +else: + class _ConcatenateForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + return _concatenate_getitem(self, parameters) + + Concatenate = _ConcatenateForm( + 'Concatenate', + doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """) + +# 3.10+ +if hasattr(typing, 'TypeGuard'): + TypeGuard = typing.TypeGuard +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeGuard(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.8 +else: + class _TypeGuardForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeGuard = _TypeGuardForm( + 'TypeGuard', + doc="""Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """) + +# 3.13+ +if hasattr(typing, 'TypeIs'): + TypeIs = typing.TypeIs +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeIs(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.8 +else: + class _TypeIsForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeIs = _TypeIsForm( + 'TypeIs', + doc="""Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """) + + +# Vendored from cpython typing._SpecialFrom +class _SpecialForm(typing._Final, _root=True): + __slots__ = ('_name', '__doc__', '_getitem') + + def __init__(self, getitem): + self._getitem = getitem + self._name = getitem.__name__ + self.__doc__ = getitem.__doc__ + + def __getattr__(self, item): + if item in {'__name__', '__qualname__'}: + return self._name + + raise AttributeError(item) + + def __mro_entries__(self, bases): + raise TypeError(f"Cannot subclass {self!r}") + + def __repr__(self): + return f'typing_extensions.{self._name}' + + def __reduce__(self): + return self._name + + def __call__(self, *args, **kwds): + raise TypeError(f"Cannot instantiate {self!r}") + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance()") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass()") + + @typing._tp_cache + def __getitem__(self, parameters): + return self._getitem(self, parameters) + + +if hasattr(typing, "LiteralString"): # 3.11+ + LiteralString = typing.LiteralString +else: + @_SpecialForm + def LiteralString(self, params): + """Represents an arbitrary literal string. + + Example:: + + from pip._vendor.typing_extensions import LiteralString + + def query(sql: LiteralString) -> ...: + ... + + query("SELECT * FROM table") # ok + query(f"SELECT * FROM {input()}") # not ok + + See PEP 675 for details. + + """ + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Self"): # 3.11+ + Self = typing.Self +else: + @_SpecialForm + def Self(self, params): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class ReturnsSelf: + def parse(self, data: bytes) -> Self: + ... + return self + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Never"): # 3.11+ + Never = typing.Never +else: + @_SpecialForm + def Never(self, params): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: + + from pip._vendor.typing_extensions import Never + + def never_call_me(arg: Never) -> None: + pass + + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # ok, arg is of type Never + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, 'Required'): # 3.11+ + Required = typing.Required + NotRequired = typing.NotRequired +elif sys.version_info[:2] >= (3, 9): # 3.9-3.10 + @_ExtensionsSpecialForm + def Required(self, parameters): + """A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + @_ExtensionsSpecialForm + def NotRequired(self, parameters): + """A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: # 3.8 + class _RequiredForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Required = _RequiredForm( + 'Required', + doc="""A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """) + NotRequired = _RequiredForm( + 'NotRequired', + doc="""A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """) + + +if hasattr(typing, 'ReadOnly'): + ReadOnly = typing.ReadOnly +elif sys.version_info[:2] >= (3, 9): # 3.9-3.12 + @_ExtensionsSpecialForm + def ReadOnly(self, parameters): + """A special typing construct to mark an item of a TypedDict as read-only. + + For example: + + class Movie(TypedDict): + title: ReadOnly[str] + year: int + + def mutate_movie(m: Movie) -> None: + m["year"] = 1992 # allowed + m["title"] = "The Matrix" # typechecker error + + There is no runtime checking for this property. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: # 3.8 + class _ReadOnlyForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + ReadOnly = _ReadOnlyForm( + 'ReadOnly', + doc="""A special typing construct to mark a key of a TypedDict as read-only. + + For example: + + class Movie(TypedDict): + title: ReadOnly[str] + year: int + + def mutate_movie(m: Movie) -> None: + m["year"] = 1992 # allowed + m["title"] = "The Matrix" # typechecker error + + There is no runtime checking for this propery. + """) + + +_UNPACK_DOC = """\ +Type unpack operator. + +The type unpack operator takes the child types from some container type, +such as `tuple[int, str]` or a `TypeVarTuple`, and 'pulls them out'. For +example: + + # For some generic class `Foo`: + Foo[Unpack[tuple[int, str]]] # Equivalent to Foo[int, str] + + Ts = TypeVarTuple('Ts') + # Specifies that `Bar` is generic in an arbitrary number of types. + # (Think of `Ts` as a tuple of an arbitrary number of individual + # `TypeVar`s, which the `Unpack` is 'pulling out' directly into the + # `Generic[]`.) + class Bar(Generic[Unpack[Ts]]): ... + Bar[int] # Valid + Bar[int, str] # Also valid + +From Python 3.11, this can also be done using the `*` operator: + + Foo[*tuple[int, str]] + class Bar(Generic[*Ts]): ... + +The operator can also be used along with a `TypedDict` to annotate +`**kwargs` in a function signature. For instance: + + class Movie(TypedDict): + name: str + year: int + + # This function expects two keyword arguments - *name* of type `str` and + # *year* of type `int`. + def foo(**kwargs: Unpack[Movie]): ... + +Note that there is only some runtime checking of this operator. Not +everything the runtime allows may be accepted by static type checkers. + +For more information, see PEP 646 and PEP 692. +""" + + +if sys.version_info >= (3, 12): # PEP 692 changed the repr of Unpack[] + Unpack = typing.Unpack + + def _is_unpack(obj): + return get_origin(obj) is Unpack + +elif sys.version_info[:2] >= (3, 9): # 3.9+ + class _UnpackSpecialForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, getitem): + super().__init__(getitem) + self.__doc__ = _UNPACK_DOC + + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + @property + def __typing_unpacked_tuple_args__(self): + assert self.__origin__ is Unpack + assert len(self.__args__) == 1 + arg, = self.__args__ + if isinstance(arg, (typing._GenericAlias, _types.GenericAlias)): + if arg.__origin__ is not tuple: + raise TypeError("Unpack[...] must be used with a tuple type") + return arg.__args__ + return None + + @_UnpackSpecialForm + def Unpack(self, parameters): + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + +else: # 3.8 + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + class _UnpackForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + Unpack = _UnpackForm('Unpack', doc=_UNPACK_DOC) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + + +if _PEP_696_IMPLEMENTED: + from typing import TypeVarTuple + +elif hasattr(typing, "TypeVarTuple"): # 3.11+ + + def _unpack_args(*args): + newargs = [] + for arg in args: + subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + if subargs is not None and not (subargs and subargs[-1] is ...): + newargs.extend(subargs) + else: + newargs.append(arg) + return newargs + + # Add default parameter - PEP 696 + class TypeVarTuple(metaclass=_TypeVarLikeMeta): + """Type variable tuple.""" + + _backported_typevarlike = typing.TypeVarTuple + + def __new__(cls, name, *, default=NoDefault): + tvt = typing.TypeVarTuple(name) + _set_default(tvt, default) + _set_module(tvt) + + def _typevartuple_prepare_subst(alias, args): + params = alias.__parameters__ + typevartuple_index = params.index(tvt) + for param in params[typevartuple_index + 1:]: + if isinstance(param, TypeVarTuple): + raise TypeError( + f"More than one TypeVarTuple parameter in {alias}" + ) + + alen = len(args) + plen = len(params) + left = typevartuple_index + right = plen - typevartuple_index - 1 + var_tuple_index = None + fillarg = None + for k, arg in enumerate(args): + if not isinstance(arg, type): + subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) + if subargs and len(subargs) == 2 and subargs[-1] is ...: + if var_tuple_index is not None: + raise TypeError( + "More than one unpacked " + "arbitrary-length tuple argument" + ) + var_tuple_index = k + fillarg = subargs[0] + if var_tuple_index is not None: + left = min(left, var_tuple_index) + right = min(right, alen - var_tuple_index - 1) + elif left + right > alen: + raise TypeError(f"Too few arguments for {alias};" + f" actual {alen}, expected at least {plen - 1}") + if left == alen - right and tvt.has_default(): + replacement = _unpack_args(tvt.__default__) + else: + replacement = args[left: alen - right] + + return ( + *args[:left], + *([fillarg] * (typevartuple_index - left)), + replacement, + *([fillarg] * (plen - right - left - typevartuple_index - 1)), + *args[alen - right:], + ) + + tvt.__typing_prepare_subst__ = _typevartuple_prepare_subst + return tvt + + def __init_subclass__(self, *args, **kwds): + raise TypeError("Cannot subclass special typing classes") + +else: # <=3.10 + class TypeVarTuple(_DefaultMixin): + """Type variable tuple. + + Usage:: + + Ts = TypeVarTuple('Ts') + + In the same way that a normal type variable is a stand-in for a single + type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* + type such as ``Tuple[int, str]``. + + Type variable tuples can be used in ``Generic`` declarations. + Consider the following example:: + + class Array(Generic[*Ts]): ... + + The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, + where ``T1`` and ``T2`` are type variables. To use these type variables + as type parameters of ``Array``, we must *unpack* the type variable tuple using + the star operator: ``*Ts``. The signature of ``Array`` then behaves + as if we had simply written ``class Array(Generic[T1, T2]): ...``. + In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows + us to parameterise the class with an *arbitrary* number of type parameters. + + Type variable tuples can be used anywhere a normal ``TypeVar`` can. + This includes class definitions, as shown above, as well as function + signatures and variable annotations:: + + class Array(Generic[*Ts]): + + def __init__(self, shape: Tuple[*Ts]): + self._shape: Tuple[*Ts] = shape + + def get_shape(self) -> Tuple[*Ts]: + return self._shape + + shape = (Height(480), Width(640)) + x: Array[Height, Width] = Array(shape) + y = abs(x) # Inferred type is Array[Height, Width] + z = x + x # ... is Array[Height, Width] + x.get_shape() # ... is tuple[Height, Width] + + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + def __iter__(self): + yield self.__unpacked__ + + def __init__(self, name, *, default=NoDefault): + self.__name__ = name + _DefaultMixin.__init__(self, default) + + # for pickling: + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + self.__unpacked__ = Unpack[self] + + def __repr__(self): + return self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(self, *args, **kwds): + if '_root' not in kwds: + raise TypeError("Cannot subclass special typing classes") + + +if hasattr(typing, "reveal_type"): # 3.11+ + reveal_type = typing.reveal_type +else: # <=3.10 + def reveal_type(obj: T, /) -> T: + """Reveal the inferred type of a variable. + + When a static type checker encounters a call to ``reveal_type()``, + it will emit the inferred type of the argument:: + + x: int = 1 + reveal_type(x) + + Running a static type checker (e.g., ``mypy``) on this example + will produce output similar to 'Revealed type is "builtins.int"'. + + At runtime, the function prints the runtime type of the + argument and returns it unchanged. + + """ + print(f"Runtime type is {type(obj).__name__!r}", file=sys.stderr) + return obj + + +if hasattr(typing, "_ASSERT_NEVER_REPR_MAX_LENGTH"): # 3.11+ + _ASSERT_NEVER_REPR_MAX_LENGTH = typing._ASSERT_NEVER_REPR_MAX_LENGTH +else: # <=3.10 + _ASSERT_NEVER_REPR_MAX_LENGTH = 100 + + +if hasattr(typing, "assert_never"): # 3.11+ + assert_never = typing.assert_never +else: # <=3.10 + def assert_never(arg: Never, /) -> Never: + """Assert to the type checker that a line of code is unreachable. + + Example:: + + def int_or_str(arg: int | str) -> None: + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + assert_never(arg) + + If a type checker finds that a call to assert_never() is + reachable, it will emit an error. + + At runtime, this throws an exception when called. + + """ + value = repr(arg) + if len(value) > _ASSERT_NEVER_REPR_MAX_LENGTH: + value = value[:_ASSERT_NEVER_REPR_MAX_LENGTH] + '...' + raise AssertionError(f"Expected code to be unreachable, but got: {value}") + + +if sys.version_info >= (3, 12): # 3.12+ + # dataclass_transform exists in 3.11 but lacks the frozen_default parameter + dataclass_transform = typing.dataclass_transform +else: # <=3.11 + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: typing.Tuple[ + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], + ... + ] = (), + **kwargs: typing.Any, + ) -> typing.Callable[[T], T]: + """Decorator that marks a function, class, or metaclass as providing + dataclass-like behavior. + + Example: + + from pip._vendor.typing_extensions import dataclass_transform + + _T = TypeVar("_T") + + # Used on a decorator function + @dataclass_transform() + def create_model(cls: type[_T]) -> type[_T]: + ... + return cls + + @create_model + class CustomerModel: + id: int + name: str + + # Used on a base class + @dataclass_transform() + class ModelBase: ... + + class CustomerModel(ModelBase): + id: int + name: str + + # Used on a metaclass + @dataclass_transform() + class ModelMeta(type): ... + + class ModelBase(metaclass=ModelMeta): ... + + class CustomerModel(ModelBase): + id: int + name: str + + Each of the ``CustomerModel`` classes defined in this example will now + behave similarly to a dataclass created with the ``@dataclasses.dataclass`` + decorator. For example, the type checker will synthesize an ``__init__`` + method. + + The arguments to this decorator can be used to customize this behavior: + - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be + True or False if it is omitted by the caller. + - ``order_default`` indicates whether the ``order`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``kw_only_default`` indicates whether the ``kw_only`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``frozen_default`` indicates whether the ``frozen`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``field_specifiers`` specifies a static list of supported classes + or functions that describe fields, similar to ``dataclasses.field()``. + + At runtime, this decorator records its arguments in the + ``__dataclass_transform__`` attribute on the decorated object. + + See PEP 681 for details. + + """ + def decorator(cls_or_fn): + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + return decorator + + +if hasattr(typing, "override"): # 3.12+ + override = typing.override +else: # <=3.11 + _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + def override(arg: _F, /) -> _F: + """Indicate that a method is intended to override a method in a base class. + + Usage: + + class Base: + def method(self) -> None: + pass + + class Child(Base): + @override + def method(self) -> None: + super().method() + + When this decorator is applied to a method, the type checker will + validate that it overrides a method with the same name on a base class. + This helps prevent bugs that may occur when a base class is changed + without an equivalent change to a child class. + + There is no runtime checking of these properties. The decorator + sets the ``__override__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + + See PEP 698 for details. + + """ + try: + arg.__override__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return arg + + +if hasattr(warnings, "deprecated"): + deprecated = warnings.deprecated +else: + _T = typing.TypeVar("_T") + + class deprecated: + """Indicate that a class, function or overload is deprecated. + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + The warning specified by *category* will be emitted at runtime + on use of deprecated objects. For functions, that happens on calls; + for classes, on instantiation and on creation of subclasses. + If the *category* is ``None``, no warning is emitted at runtime. + The *stacklevel* determines where the + warning is emitted. If it is ``1`` (the default), the warning + is emitted at the direct caller of the deprecated object; if it + is higher, it is emitted further up the stack. + Static type checker behavior is not affected by the *category* + and *stacklevel* arguments. + + The deprecation message passed to the decorator is saved in the + ``__deprecated__`` attribute on the decorated object. + If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def __init__( + self, + message: str, + /, + *, + category: typing.Optional[typing.Type[Warning]] = DeprecationWarning, + stacklevel: int = 1, + ) -> None: + if not isinstance(message, str): + raise TypeError( + "Expected an object of type str for 'message', not " + f"{type(message).__name__!r}" + ) + self.message = message + self.category = category + self.stacklevel = stacklevel + + def __call__(self, arg: _T, /) -> _T: + # Make sure the inner functions created below don't + # retain a reference to self. + msg = self.message + category = self.category + stacklevel = self.stacklevel + if category is None: + arg.__deprecated__ = msg + return arg + elif isinstance(arg, type): + import functools + from types import MethodType + + original_new = arg.__new__ + + @functools.wraps(original_new) + def __new__(cls, *args, **kwargs): + if cls is arg: + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + if original_new is not object.__new__: + return original_new(cls, *args, **kwargs) + # Mirrors a similar check in object.__new__. + elif cls.__init__ is object.__init__ and (args or kwargs): + raise TypeError(f"{cls.__name__}() takes no arguments") + else: + return original_new(cls) + + arg.__new__ = staticmethod(__new__) + + original_init_subclass = arg.__init_subclass__ + # We need slightly different behavior if __init_subclass__ + # is a bound method (likely if it was implemented in Python) + if isinstance(original_init_subclass, MethodType): + original_init_subclass = original_init_subclass.__func__ + + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = classmethod(__init_subclass__) + # Or otherwise, which likely means it's a builtin such as + # object's implementation of __init_subclass__. + else: + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = __init_subclass__ + + arg.__deprecated__ = __new__.__deprecated__ = msg + __init_subclass__.__deprecated__ = msg + return arg + elif callable(arg): + import functools + + @functools.wraps(arg) + def wrapper(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return arg(*args, **kwargs) + + arg.__deprecated__ = wrapper.__deprecated__ = msg + return wrapper + else: + raise TypeError( + "@deprecated decorator with non-None category must be applied to " + f"a class or callable, not {arg!r}" + ) + + +# We have to do some monkey patching to deal with the dual nature of +# Unpack/TypeVarTuple: +# - We want Unpack to be a kind of TypeVar so it gets accepted in +# Generic[Unpack[Ts]] +# - We want it to *not* be treated as a TypeVar for the purposes of +# counting generic parameters, so that when we subscript a generic, +# the runtime doesn't try to substitute the Unpack with the subscripted type. +if not hasattr(typing, "TypeVarTuple"): + def _check_generic(cls, parameters, elen=_marker): + """Check correct count for parameters of a generic cls (internal helper). + + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + if elen is _marker: + if not hasattr(cls, "__parameters__") or not cls.__parameters__: + raise TypeError(f"{cls} is not a generic class") + elen = len(cls.__parameters__) + alen = len(parameters) + if alen != elen: + expect_val = elen + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) + if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): + return + + # deal with TypeVarLike defaults + # required TypeVarLikes cannot appear after a defaulted one. + if alen < elen: + # since we validate TypeVarLike default in _collect_type_vars + # or _collect_parameters we can safely check parameters[alen] + if ( + getattr(parameters[alen], '__default__', NoDefault) + is not NoDefault + ): + return + + num_default_tv = sum(getattr(p, '__default__', NoDefault) + is not NoDefault for p in parameters) + + elen -= num_default_tv + + expect_val = f"at least {elen}" + + things = "arguments" if sys.version_info >= (3, 10) else "parameters" + raise TypeError(f"Too {'many' if alen > elen else 'few'} {things}" + f" for {cls}; actual {alen}, expected {expect_val}") +else: + # Python 3.11+ + + def _check_generic(cls, parameters, elen): + """Check correct count for parameters of a generic cls (internal helper). + + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + alen = len(parameters) + if alen != elen: + expect_val = elen + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + + # deal with TypeVarLike defaults + # required TypeVarLikes cannot appear after a defaulted one. + if alen < elen: + # since we validate TypeVarLike default in _collect_type_vars + # or _collect_parameters we can safely check parameters[alen] + if ( + getattr(parameters[alen], '__default__', NoDefault) + is not NoDefault + ): + return + + num_default_tv = sum(getattr(p, '__default__', NoDefault) + is not NoDefault for p in parameters) + + elen -= num_default_tv + + expect_val = f"at least {elen}" + + raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments" + f" for {cls}; actual {alen}, expected {expect_val}") + +if not _PEP_696_IMPLEMENTED: + typing._check_generic = _check_generic + + +def _has_generic_or_protocol_as_origin() -> bool: + try: + frame = sys._getframe(2) + # - Catch AttributeError: not all Python implementations have sys._getframe() + # - Catch ValueError: maybe we're called from an unexpected module + # and the call stack isn't deep enough + except (AttributeError, ValueError): + return False # err on the side of leniency + else: + # If we somehow get invoked from outside typing.py, + # also err on the side of leniency + if frame.f_globals.get("__name__") != "typing": + return False + origin = frame.f_locals.get("origin") + # Cannot use "in" because origin may be an object with a buggy __eq__ that + # throws an error. + return origin is typing.Generic or origin is Protocol or origin is typing.Protocol + + +_TYPEVARTUPLE_TYPES = {TypeVarTuple, getattr(typing, "TypeVarTuple", None)} + + +def _is_unpacked_typevartuple(x) -> bool: + if get_origin(x) is not Unpack: + return False + args = get_args(x) + return ( + bool(args) + and len(args) == 1 + and type(args[0]) in _TYPEVARTUPLE_TYPES + ) + + +# Python 3.11+ _collect_type_vars was renamed to _collect_parameters +if hasattr(typing, '_collect_type_vars'): + def _collect_type_vars(types, typevar_types=None): + """Collect all type variable contained in types in order of + first appearance (lexicographic order). For example:: + + _collect_type_vars((T, List[S, T])) == (T, S) + """ + if typevar_types is None: + typevar_types = typing.TypeVar + tvars = [] + + # A required TypeVarLike cannot appear after a TypeVarLike with a default + # if it was a direct call to `Generic[]` or `Protocol[]` + enforce_default_ordering = _has_generic_or_protocol_as_origin() + default_encountered = False + + # Also, a TypeVarLike with a default cannot appear after a TypeVarTuple + type_var_tuple_encountered = False + + for t in types: + if _is_unpacked_typevartuple(t): + type_var_tuple_encountered = True + elif isinstance(t, typevar_types) and t not in tvars: + if enforce_default_ordering: + has_default = getattr(t, '__default__', NoDefault) is not NoDefault + if has_default: + if type_var_tuple_encountered: + raise TypeError('Type parameter with a default' + ' follows TypeVarTuple') + default_encountered = True + elif default_encountered: + raise TypeError(f'Type parameter {t!r} without a default' + ' follows type parameter with a default') + + tvars.append(t) + if _should_collect_from_parameters(t): + tvars.extend([t for t in t.__parameters__ if t not in tvars]) + return tuple(tvars) + + typing._collect_type_vars = _collect_type_vars +else: + def _collect_parameters(args): + """Collect all type variables and parameter specifications in args + in order of first appearance (lexicographic order). + + For example:: + + assert _collect_parameters((T, Callable[P, T])) == (T, P) + """ + parameters = [] + + # A required TypeVarLike cannot appear after a TypeVarLike with default + # if it was a direct call to `Generic[]` or `Protocol[]` + enforce_default_ordering = _has_generic_or_protocol_as_origin() + default_encountered = False + + # Also, a TypeVarLike with a default cannot appear after a TypeVarTuple + type_var_tuple_encountered = False + + for t in args: + if isinstance(t, type): + # We don't want __parameters__ descriptor of a bare Python class. + pass + elif isinstance(t, tuple): + # `t` might be a tuple, when `ParamSpec` is substituted with + # `[T, int]`, or `[int, *Ts]`, etc. + for x in t: + for collected in _collect_parameters([x]): + if collected not in parameters: + parameters.append(collected) + elif hasattr(t, '__typing_subst__'): + if t not in parameters: + if enforce_default_ordering: + has_default = ( + getattr(t, '__default__', NoDefault) is not NoDefault + ) + + if type_var_tuple_encountered and has_default: + raise TypeError('Type parameter with a default' + ' follows TypeVarTuple') + + if has_default: + default_encountered = True + elif default_encountered: + raise TypeError(f'Type parameter {t!r} without a default' + ' follows type parameter with a default') + + parameters.append(t) + else: + if _is_unpacked_typevartuple(t): + type_var_tuple_encountered = True + for x in getattr(t, '__parameters__', ()): + if x not in parameters: + parameters.append(x) + + return tuple(parameters) + + if not _PEP_696_IMPLEMENTED: + typing._collect_parameters = _collect_parameters + +# Backport typing.NamedTuple as it exists in Python 3.13. +# In 3.11, the ability to define generic `NamedTuple`s was supported. +# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. +# On 3.12, we added __orig_bases__ to call-based NamedTuples +# On 3.13, we deprecated kwargs-based NamedTuples +if sys.version_info >= (3, 13): + NamedTuple = typing.NamedTuple +else: + def _make_nmtuple(name, types, module, defaults=()): + fields = [n for n, t in types] + annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types} + nm_tpl = collections.namedtuple(name, fields, + defaults=defaults, module=module) + nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations + # The `_field_types` attribute was removed in 3.9; + # in earlier versions, it is the same as the `__annotations__` attribute + if sys.version_info < (3, 9): + nm_tpl._field_types = annotations + return nm_tpl + + _prohibited_namedtuple_fields = typing._prohibited + _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + + class _NamedTupleMeta(type): + def __new__(cls, typename, bases, ns): + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not typing.Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) + if "__annotations__" in ns: + types = ns["__annotations__"] + elif "__annotate__" in ns: + # TODO: Use inspect.VALUE here, and make the annotations lazily evaluated + types = ns["__annotate__"](1) + else: + types = {} + default_names = [] + for field_name in types: + if field_name in ns: + default_names.append(field_name) + elif default_names: + raise TypeError(f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}") + nm_tpl = _make_nmtuple( + typename, types.items(), + defaults=[ns[n] for n in default_names], + module=ns['__module__'] + ) + nm_tpl.__bases__ = bases + if typing.Generic in bases: + if hasattr(typing, '_generic_class_getitem'): # 3.12+ + nm_tpl.__class_getitem__ = classmethod(typing._generic_class_getitem) + else: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) + # update from user namespace without overriding special namedtuple attributes + for key, val in ns.items(): + if key in _prohibited_namedtuple_fields: + raise AttributeError("Cannot overwrite NamedTuple attribute " + key) + elif key not in _special_namedtuple_fields: + if key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + try: + set_name = type(val).__set_name__ + except AttributeError: + pass + else: + try: + set_name(val, nm_tpl, key) + except BaseException as e: + msg = ( + f"Error calling __set_name__ on {type(val).__name__!r} " + f"instance {key!r} in {typename!r}" + ) + # BaseException.add_note() existed on py311, + # but the __set_name__ machinery didn't start + # using add_note() until py312. + # Making sure exceptions are raised in the same way + # as in "normal" classes seems most important here. + if sys.version_info >= (3, 12): + e.add_note(msg) + raise + else: + raise RuntimeError(msg) from e + + if typing.Generic in bases: + nm_tpl.__init_subclass__() + return nm_tpl + + _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + + def _namedtuple_mro_entries(bases): + assert NamedTuple in bases + return (_NamedTuple,) + + @_ensure_subclassable(_namedtuple_mro_entries) + def NamedTuple(typename, fields=_marker, /, **kwargs): + """Typed version of namedtuple. + + Usage:: + + class Employee(NamedTuple): + name: str + id: int + + This is equivalent to:: + + Employee = collections.namedtuple('Employee', ['name', 'id']) + + The resulting class has an extra __annotations__ attribute, giving a + dict that maps field names to types. (The field names are also in + the _fields attribute, which is part of the namedtuple API.) + An alternative equivalent functional syntax is also accepted:: + + Employee = NamedTuple('Employee', [('name', str), ('id', int)]) + """ + if fields is _marker: + if kwargs: + deprecated_thing = "Creating NamedTuple classes using keyword arguments" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "Use the class-based or functional syntax instead." + ) + else: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif fields is None: + if kwargs: + raise TypeError( + "Cannot pass `None` as the 'fields' parameter " + "and also specify fields using keyword arguments" + ) + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + if fields is _marker or fields is None: + warnings.warn( + deprecation_msg.format(name=deprecated_thing, remove="3.15"), + DeprecationWarning, + stacklevel=2, + ) + fields = kwargs.items() + nt = _make_nmtuple(typename, fields, module=_caller()) + nt.__orig_bases__ = (NamedTuple,) + return nt + + +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): # noqa: B024 + """Base class for classes that implement the buffer protocol. + + The buffer protocol allows Python objects to expose a low-level + memory buffer interface. Before Python 3.12, it is not possible + to implement the buffer protocol in pure Python code, or even + to check whether a class implements the buffer protocol. In + Python 3.12 and higher, the ``__buffer__`` method allows access + to the buffer protocol from Python code, and the + ``collections.abc.Buffer`` ABC allows checking whether a class + implements the buffer protocol. + + To indicate support for the buffer protocol in earlier versions, + inherit from this ABC, either in a stub file or at runtime, + or use ABC registration. This ABC provides no methods, because + there is no Python-accessible methods shared by pre-3.12 buffer + classes. It is useful primarily for static checks. + + """ + + # As a courtesy, register the most common stdlib buffer classes. + Buffer.register(memoryview) + Buffer.register(bytearray) + Buffer.register(bytes) + + +# Backport of types.get_original_bases, available on 3.12+ in CPython +if hasattr(_types, "get_original_bases"): + get_original_bases = _types.get_original_bases +else: + def get_original_bases(cls, /): + """Return the class's "original" bases prior to modification by `__mro_entries__`. + + Examples:: + + from typing import TypeVar, Generic + from pip._vendor.typing_extensions import NamedTuple, TypedDict + + T = TypeVar("T") + class Foo(Generic[T]): ... + class Bar(Foo[int], float): ... + class Baz(list[str]): ... + Eggs = NamedTuple("Eggs", [("a", int), ("b", str)]) + Spam = TypedDict("Spam", {"a": int, "b": str}) + + assert get_original_bases(Bar) == (Foo[int], float) + assert get_original_bases(Baz) == (list[str],) + assert get_original_bases(Eggs) == (NamedTuple,) + assert get_original_bases(Spam) == (TypedDict,) + assert get_original_bases(int) == (object,) + """ + try: + return cls.__dict__.get("__orig_bases__", cls.__bases__) + except AttributeError: + raise TypeError( + f'Expected an instance of type, not {type(cls).__name__!r}' + ) from None + + +# NewType is a class on Python 3.10+, making it pickleable +# The error message for subclassing instances of NewType was improved on 3.11+ +if sys.version_info >= (3, 11): + NewType = typing.NewType +else: + class NewType: + """NewType creates simple unique types with almost zero + runtime overhead. NewType(name, tp) is considered a subtype of tp + by static type checkers. At runtime, NewType(name, tp) returns + a dummy callable that simply returns its argument. Usage:: + UserId = NewType('UserId', int) + def name_by_id(user_id: UserId) -> str: + ... + UserId('user') # Fails type check + name_by_id(42) # Fails type check + name_by_id(UserId(42)) # OK + num = UserId(5) + 1 # type: int + """ + + def __call__(self, obj, /): + return obj + + def __init__(self, name, tp): + self.__qualname__ = name + if '.' in name: + name = name.rpartition('.')[-1] + self.__name__ = name + self.__supertype__ = tp + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __mro_entries__(self, bases): + # We defined __mro_entries__ to get a better error message + # if a user attempts to subclass a NewType instance. bpo-46170 + supercls_name = self.__name__ + + class Dummy: + def __init_subclass__(cls): + subcls_name = cls.__name__ + raise TypeError( + f"Cannot subclass an instance of NewType. " + f"Perhaps you were looking for: " + f"`{subcls_name} = NewType({subcls_name!r}, {supercls_name})`" + ) + + return (Dummy,) + + def __repr__(self): + return f'{self.__module__}.{self.__qualname__}' + + def __reduce__(self): + return self.__qualname__ + + if sys.version_info >= (3, 10): + # PEP 604 methods + # It doesn't make sense to have these methods on Python <3.10 + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + +if hasattr(typing, "TypeAliasType"): + TypeAliasType = typing.TypeAliasType +else: + def _is_unionable(obj): + """Corresponds to is_unionable() in unionobject.c in CPython.""" + return obj is None or isinstance(obj, ( + type, + _types.GenericAlias, + _types.UnionType, + TypeAliasType, + )) + + class TypeAliasType: + """Create named, parameterized type aliases. + + This provides a backport of the new `type` statement in Python 3.12: + + type ListOrSet[T] = list[T] | set[T] + + is equivalent to: + + T = TypeVar("T") + ListOrSet = TypeAliasType("ListOrSet", list[T] | set[T], type_params=(T,)) + + The name ListOrSet can then be used as an alias for the type it refers to. + + The type_params argument should contain all the type parameters used + in the value of the type alias. If the alias is not generic, this + argument is omitted. + + Static type checkers should only support type aliases declared using + TypeAliasType that follow these rules: + + - The first argument (the name) must be a string literal. + - The TypeAliasType instance must be immediately assigned to a variable + of the same name. (For example, 'X = TypeAliasType("Y", int)' is invalid, + as is 'X, Y = TypeAliasType("X", int), TypeAliasType("Y", int)'). + + """ + + def __init__(self, name: str, value, *, type_params=()): + if not isinstance(name, str): + raise TypeError("TypeAliasType name must be a string") + self.__value__ = value + self.__type_params__ = type_params + + parameters = [] + for type_param in type_params: + if isinstance(type_param, TypeVarTuple): + parameters.extend(type_param) + else: + parameters.append(type_param) + self.__parameters__ = tuple(parameters) + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + # Setting this attribute closes the TypeAliasType from further modification + self.__name__ = name + + def __setattr__(self, name: str, value: object, /) -> None: + if hasattr(self, "__name__"): + self._raise_attribute_error(name) + super().__setattr__(name, value) + + def __delattr__(self, name: str, /) -> Never: + self._raise_attribute_error(name) + + def _raise_attribute_error(self, name: str) -> Never: + # Match the Python 3.12 error messages exactly + if name == "__name__": + raise AttributeError("readonly attribute") + elif name in {"__value__", "__type_params__", "__parameters__", "__module__"}: + raise AttributeError( + f"attribute '{name}' of 'typing.TypeAliasType' objects " + "is not writable" + ) + else: + raise AttributeError( + f"'typing.TypeAliasType' object has no attribute '{name}'" + ) + + def __repr__(self) -> str: + return self.__name__ + + def __getitem__(self, parameters): + if not isinstance(parameters, tuple): + parameters = (parameters,) + parameters = [ + typing._type_check( + item, f'Subscripting {self.__name__} requires a type.' + ) + for item in parameters + ] + return typing._GenericAlias(self, tuple(parameters)) + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + "type 'typing_extensions.TypeAliasType' is not an acceptable base type" + ) + + # The presence of this method convinces typing._type_check + # that TypeAliasTypes are types. + def __call__(self): + raise TypeError("Type alias is not callable") + + if sys.version_info >= (3, 10): + def __or__(self, right): + # For forward compatibility with 3.12, reject Unions + # that are not accepted by the built-in Union. + if not _is_unionable(right): + return NotImplemented + return typing.Union[self, right] + + def __ror__(self, left): + if not _is_unionable(left): + return NotImplemented + return typing.Union[left, self] + + +if hasattr(typing, "is_protocol"): + is_protocol = typing.is_protocol + get_protocol_members = typing.get_protocol_members +else: + def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp is not Protocol + and tp is not typing.Protocol + ) + + def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + if hasattr(tp, '__protocol_attrs__'): + return frozenset(tp.__protocol_attrs__) + return frozenset(_get_protocol_attrs(tp)) + + +if hasattr(typing, "Doc"): + Doc = typing.Doc +else: + class Doc: + """Define the documentation of a type annotation using ``Annotated``, to be + used in class attributes, function and method parameters, return values, + and variables. + + The value should be a positional-only string literal to allow static tools + like editors and documentation generators to use it. + + This complements docstrings. + + The string value passed is available in the attribute ``documentation``. + + Example:: + + >>> from typing_extensions import Annotated, Doc + >>> def hi(to: Annotated[str, Doc("Who to say hi to")]) -> None: ... + """ + def __init__(self, documentation: str, /) -> None: + self.documentation = documentation + + def __repr__(self) -> str: + return f"Doc({self.documentation!r})" + + def __hash__(self) -> int: + return hash(self.documentation) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Doc): + return NotImplemented + return self.documentation == other.documentation + + +_CapsuleType = getattr(_types, "CapsuleType", None) + +if _CapsuleType is None: + try: + import _socket + except ImportError: + pass + else: + _CAPI = getattr(_socket, "CAPI", None) + if _CAPI is not None: + _CapsuleType = type(_CAPI) + +if _CapsuleType is not None: + CapsuleType = _CapsuleType + __all__.append("CapsuleType") + + +# Aliases for items that have always been in typing. +# Explicitly assign these (rather than using `from typing import *` at the top), +# so that we get a CI error if one of these is deleted from typing.py +# in a future version of Python +AbstractSet = typing.AbstractSet +AnyStr = typing.AnyStr +BinaryIO = typing.BinaryIO +Callable = typing.Callable +Collection = typing.Collection +Container = typing.Container +Dict = typing.Dict +ForwardRef = typing.ForwardRef +FrozenSet = typing.FrozenSet +Generic = typing.Generic +Hashable = typing.Hashable +IO = typing.IO +ItemsView = typing.ItemsView +Iterable = typing.Iterable +Iterator = typing.Iterator +KeysView = typing.KeysView +List = typing.List +Mapping = typing.Mapping +MappingView = typing.MappingView +Match = typing.Match +MutableMapping = typing.MutableMapping +MutableSequence = typing.MutableSequence +MutableSet = typing.MutableSet +Optional = typing.Optional +Pattern = typing.Pattern +Reversible = typing.Reversible +Sequence = typing.Sequence +Set = typing.Set +Sized = typing.Sized +TextIO = typing.TextIO +Tuple = typing.Tuple +Union = typing.Union +ValuesView = typing.ValuesView +cast = typing.cast +no_type_check = typing.no_type_check +no_type_check_decorator = typing.no_type_check_decorator diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/vendor.txt b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/vendor.txt new file mode 100644 index 0000000000000000000000000000000000000000..f04a9c1e73c36bfbd2b6ef5d87bf65bc050804f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/pip/_vendor/vendor.txt @@ -0,0 +1,18 @@ +CacheControl==0.14.1 +distlib==0.3.9 +distro==1.9.0 +msgpack==1.1.0 +packaging==24.2 +platformdirs==4.3.6 +pyproject-hooks==1.2.0 +requests==2.32.3 + certifi==2024.8.30 + idna==3.10 + urllib3==1.26.20 +rich==13.9.4 + pygments==2.18.0 + typing_extensions==4.12.2 +resolvelib==1.0.1 +setuptools==70.3.0 +tomli==2.2.1 +truststore==0.10.0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..1062f3bf743c0a37bcde05b1afc1ba5e98a95822 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/AUTHORS @@ -0,0 +1,1379 @@ +All people who contributed to SymPy by sending at least a patch or +more (in the order of the date of their first contribution), except +those who explicitly didn't want to be mentioned. People with a * next +to their names are not found in the metadata of the git history. This +file is generated automatically by running `./bin/authors_update.py`. + +There are a total of 1371 authors. + +Ondřej Čertík +Fabian Pedregosa +Jurjen N.E. Bos +Mateusz Paprocki +*Marc-Etienne M.Leveille +Brian Jorgensen +Jason Gedge +Robert Schwarz +Pearu Peterson +Fredrik Johansson +Chris Wu +*Ulrich Hecht +Goutham Lakshminarayan +David Lawrence +Jaroslaw Tworek +David Marek +Bernhard R. Link +Andrej Tokarčík +Or Dvory +Saroj Adhikari +Pauli Virtanen +Robert Kern +James Aspnes +Nimish Telang +Abderrahim Kitouni +Pan Peng +Friedrich Hagedorn +Elrond der Elbenfuerst +Rizgar Mella +Felix Kaiser +Roberto Nobrega +David Roberts +Sebastian Krämer +Vinzent Steinberg +Riccardo Gori +Case Van Horsen +Stepan Roucka +Ali Raza Syed +Stefano Maggiolo +Robert Cimrman +Bastian Weber +Sebastian Krause +Sebastian Kreft +*Dan +Alan Bromborsky +Boris Timokhin +Robert +Andy R. Terrel +Hubert Tsang +Konrad Meyer +Henrik Johansson +Priit Laes +Freddie Witherden +Brian E. Granger +Andrew Straw +Kaifeng Zhu +Ted Horst +Andrew Docherty +Akshay Srinivasan +Aaron Meurer +Barry Wardell +Tomasz Buchert +Vinay Kumar +Johann Cohen-Tanugi +Jochen Voss +Luke Peterson +Chris Smith +Thomas Sidoti +Florian Mickler +Nicolas Pourcelot +Ben Goodrich +Toon Verstraelen +Ronan Lamy +James Abbatiello +Ryan Krauss +Bill Flynn +Kevin Goodsell +Jorn Baayen +Eh Tan +Renato Coutinho +Oscar Benjamin +Øyvind Jensen +Julio Idichekop Filho +Łukasz Pankowski +*Chu-Ching Huang +Fernando Perez +Raffaele De Feo +Christian Muise +Matt Curry +Kazuo Thow +Christian Schubert +Jezreel Ng +James Pearson +Matthew Brett +Addison Cugini +Nicholas J.S. Kinar +Harold Erbin +Thomas Dixon +Cristóvão Sousa +Andre de Fortier Smit +Mark Dewing +Alexey U. Gudchenko +Gary Kerr +Sherjil Ozair +Oleksandr Gituliar +Sean Vig +Prafullkumar P. Tale +Vladimir Perić +Tom Bachmann +Yuri Karadzhov +Vladimir Lagunov +Matthew Rocklin +Saptarshi Mandal +Gilbert Gede +Anatolii Koval +Tomo Lazovich +Pavel Fedotov +Jack McCaffery +Jeremias Yehdegho +Kibeom Kim +Gregory Ksionda +Tomáš Bambas +Raymond Wong +Luca Weihs +Shai 'Deshe' Wyborski +Thomas Wiecki +Óscar Nájera +Mario Pernici +Benjamin McDonald +Sam Magura +Stefan Krastanov +Bradley Froehle +Min Ragan-Kelley +Emma Hogan +Nikhil Sarda +Julien Rioux +Roberto Colistete, Jr. +Raoul Bourquin +Gert-Ludwig Ingold +Srinivas Vasudevan +Jason Moore +Miha Marolt +Tim Lahey +Luis Garcia +Matt Rajca +David Li +Alexandr Gudulin +Bilal Akhtar +Grzegorz Świrski +Matt Habel +David Ju +Nichita Utiu +Nikolay Lazarov +Steve Anton +Imran Ahmed Manzoor +Ljubiša Moćić <3rdslasher@gmail.com> +Piotr Korgul +Jim Zhang +Sam Sleight +tborisova +Chancellor Arkantos +Stepan Simsa +Tobias Lenz +Siddhanathan Shanmugam +Tiffany Zhu +Tristan Hume +Alexey Subach +Joan Creus +Geoffry Song +Puneeth Chaganti +Marcin Kostrzewa <> +Natalia Nawara +vishal +Shruti Mangipudi +Davy Mao +Swapnil Agarwal +Dhia Kennouche +jerryma1121 +Joachim Durchholz +Martin Povišer +Siddhant Jain +Kevin Hunter +Michael Mayorov +Nathan Alison +Christian Bühler +Carsten Knoll +Bharath M R +Matthias Toews +Sergiu Ivanov +Jorge E. Cardona +Sanket Agarwal +Manoj Babu K. +Sai Nikhil +Aleksandar Makelov +Sachin Irukula +Raphael Michel +Ashwini Oruganti +Andreas Klöckner +Prateek Papriwal +Arpit Goyal +Angadh Nanjangud +Comer Duncan +Jens H. Nielsen +Joseph Dougherty +Elliot Marshall +Guru Devanla +George Waksman +Alexandr Popov +Tarun Gaba +Takafumi Arakaki +Saurabh Jha +Rom le Clair +Angus Griffith <16sn6uv@gmail.com> +Timothy Reluga +Brian Stephanik +Alexander Eberspächer +Sachin Joglekar +Tyler Pirtle +Vasily Povalyaev +Colleen Lee +Matthew Hoff +Niklas Thörne +Huijun Mai +Marek Šuppa +Ramana Venkata +Prasoon Shukla +Stefen Yin +Thomas Hisch +Madeleine Ball +Mary Clark +Rishabh Dixit +Manoj Kumar +Akshit Agarwal +CJ Carey +Patrick Lacasse +Ananya H +Tarang Patel +Christopher Dembia +Benjamin Fishbein +Sean Ge +Amit Jamadagni +Ankit Agrawal +Björn Dahlgren +Christophe Saint-Jean +Demian Wassermann +Khagesh Patel +Stephen Loo +hm +Patrick Poitras +Katja Sophie Hotz +Varun Joshi +Chetna Gupta +Thilina Rathnayake +Max Hutchinson +Shravas K Rao +Matthew Tadd +Alexander Hirzel +Randy Heydon +Oliver Lee +Seshagiri Prabhu +Pradyumna +Erik Welch +Eric Nelson +Roland Puntaier +Chris Conley +Tim Swast +Dmitry Batkovich +Francesco Bonazzi +Yuriy Demidov +Rick Muller +Manish Gill +Markus Müller +Amit Saha +Jeremy +QuaBoo +Stefan van der Walt +David Joyner +Lars Buitinck +Alkiviadis G. Akritas +Vinit Ravishankar +Michael Boyle +Heiner Kirchhoffer +Pablo Puente +James Fiedler +Harsh Gupta +Tuomas Airaksinen +Paul Strickland +James Goppert +rathmann +Avichal Dayal +Paul Scott +Shipra Banga +Pramod Ch +Akshay +Buck Shlegeris +Jonathan Miller +Edward Schembor +Rajath Shashidhara +Zamrath Nizam +Aditya Shah +Rajat Aggarwal +Sambuddha Basu +Zeel Shah +Abhinav Chanda +Jim Crist +Sudhanshu Mishra +Anurag Sharma +Soumya Dipta Biswas +Sushant Hiray +Ben Lucato +Kunal Arora +Henry Gebhardt +Dammina Sahabandu +Manish Shukla +Ralph Bean +richierichrawr +John Connor +Juan Luis Cano Rodríguez +Sahil Shekhawat +Kundan Kumar +Stas Kelvich +sevaader +Dhruvesh Vijay Parikh +Venkatesh Halli +Lennart Fricke +Vlad Seghete +Shashank Agarwal +carstimon +Pierre Haessig +Maciej Baranski +Benjamin Gudehus +Faisal Anees +Mark Shoulson +Robert Johansson +Kalevi Suominen +Kaushik Varanasi +Fawaz Alazemi +Ambar Mehrotra +David P. Sanders +Peter Brady +John V. Siratt +Sarwar Chahal +Nathan Woods +Colin B. Macdonald +Marcus Näslund +Clemens Novak +Mridul Seth +Craig A. Stoudt +Raj +Mihai A. Ionescu +immerrr +Chai Wah Wu +Leonid Blouvshtein +Peleg Michaeli +ck Lux +zsc347 +Hamish Dickson +Michael Gallaspy +Roman Inflianskas +Duane Nykamp +Ted Dokos +Sunny Aggarwal +Victor Brebenar +Akshat Jain +Shivam Vats +Longqi Wang +Juan Felipe Osorio +Ray Cathcart +Lukas Zorich +Eric Miller +Cody Herbst +Nishith Shah +Amit Kumar +Yury G. Kudryashov +Guillaume Gay +Mihir Wadwekar +Tuan Manh Lai +Asish Panda +Darshan Chaudhary +Alec Kalinin +Ralf Stephan +Aaditya Nair +Jayesh Lahori +Harshil Goel +Luv Agarwal +Jason Ly +Lokesh Sharma +Sartaj Singh +Chris Swierczewski +Konstantin Togoi +Param Singh +Sumith Kulal +Juha Remes +Philippe Bouafia +Peter Schmidt +Jiaxing Liang +Lucas Jones +Gregory Ashton +Jennifer White +Renato Orsino +Alistair Lynn +Govind Sahai +Adam Bloomston +Kyle McDaniel +Nguyen Truong Duy +Alex Lindsay +Mathew Chong +Jason Siefken +Gaurav Dhingra +Gao, Xiang +Kevin Ventullo +mao8 +Isuru Fernando +Shivam Tyagi +Richard Otis +Rich LaSota +dustyrockpyle +Anton Akhmerov +Michael Zingale +Chak-Pong Chung +David T +Phil Ruffwind +Sebastian Koslowski +Kumar Krishna Agrawal +Dustin Gadal +João Moura +Yu Kobayashi +Shashank Kumar +Timothy Cyrus +Devyani Kota +Keval Shah +Dzhelil Rufat +Pastafarianist +Sourav Singh +Jacob Garber +Vinay Singh +GolimarOurHero +Prashant Tyagi +Matthew Davis +Tschijnmo TSCHAU +Alexander Bentkamp +Jack Kemp +Kshitij Saraogi +Thomas Baruchel +Nicolás Guarín-Zapata +Jens Jørgen Mortensen +Sampad Kumar Saha +Eva Charlotte Mayer +Laura Domine +Justin Blythe +Meghana Madhyastha +Tanu Hari Dixit +Shekhar Prasad Rajak +Aqnouch Mohammed +Arafat Dad Khan +Boris Atamanovskiy +Sam Tygier +Jai Luthra +Guo Xingjian +Sandeep Veethu +Archit Verma +Shubham Tibra +Ashutosh Saboo +Michael S. Hansen +Anish Shah +Guillaume Jacquenot +Bhautik Mavani +Michał Radwański +Jerry Li +Pablo Zubieta +Shivam Agarwal +Chaitanya Sai Alaparthi +Arihant Parsoya +Ruslan Pisarev +Akash Trehan +Nishant Nikhil +Vladimir Poluhsin +Akshay Nagar +James Brandon Milam +Abhinav Agarwal +Rishabh Daal +Sanya Khurana +Aman Deep +Aravind Reddy +Abhishek Verma +Matthew Parnell +Thomas Hickman +Akshay Siramdas +YiDing Jiang +Jatin Yadav +Matthew Thomas +Rehas Sachdeva +Michael Mueller +Srajan Garg +Prabhjot Singh +Haruki Moriguchi +Tom Gijselinck +Nitin Chaudhary +Alex Argunov +Nathan Musoke +Abhishek Garg +Dana Jacobsen +Vasiliy Dommes +Phillip Berndt +Haimo Zhang +Anthony Scopatz +bluebrook +Leonid Kovalev +Josh Burkart +Dimitra Konomi +Christina Zografou +Fiach Antaw +Langston Barrett +Krit Karan +G. D. McBain +Prempal Singh +Gabriel Orisaka +Matthias Bussonnier +rahuldan +Colin Marquardt +Andrew Taber +Yash Reddy +Peter Stangl +elvis-sik +Nikos Karagiannakis +Jainul Vaghasia +Dennis Meckel +Harshil Meena +Micky +Nick Curtis +Michele Zaffalon +Martha Giannoudovardi +Devang Kulshreshtha +Steph Papanik +Mohammad Sadeq Dousti +Arif Ahmed +Abdullah Javed Nesar +Lakshya Agrawal +shruti +Rohit Rango +Hong Xu +Ivan Petuhov +Alsheh +Marcel Stimberg +Alexey Pakhocmhik +Tommy Olofsson +Zulfikar +Blair Azzopardi +Danny Hermes +Sergey Pestov +Mohit Chandra +Karthik Chintapalli +Marcin Briański +andreo +Flamy Owl +Yicong Guo +Varun Garg +Rishabh Madan +Aditya Kapoor +Karan Sharma +Vedant Rathore +Johan Blåbäck +Pranjal Tale +Jason Tokayer +Raghav Jajodia +Rajat Thakur +Dhruv Bhanushali +Anjul Kumar Tyagi +Barun Parruck +Bao Chau +Tanay Agrawal +Ranjith Kumar +Shikhar Makhija +Yathartha Joshi +Valeriia Gladkova +Sagar Bharadwaj +Daniel Mahler +Ka Yi +Rishat Iskhakov +Szymon Mieszczak +Sachin Agarwal +Priyank Patel +Satya Prakash Dwibedi +tools4origins +Nico Schlömer +Fermi Paradox +Ekansh Purohit +Vedarth Sharma +Peeyush Kushwaha +Jayjayyy +Christopher J. Wright +Jakub Wilk +Mauro Garavello +Chris Tefer +Shikhar Jaiswal +Chiu-Hsiang Hsu +Carlos Cordoba +Fabian Ball +Yerniyaz +Christiano Anderson +Robin Neatherway +Thomas Hunt +Theodore Han +Duc-Minh Phan +Lejla Metohajrova +Samyak Jain +Aditya Rohan +Vincent Delecroix +Michael Sparapany +Harsh Jain +Nathan Goldbaum +latot +Kenneth Lyons +Stan Schymanski +David Daly +Ayush Shridhar +Javed Nissar +Jiri Kuncar +vedantc98 +Rupesh Harode +Rob Zinkov +James Harrop +James Taylor +Ishan Joshi +Marco Mancini +Boris Ettinger +Micah Fitch +Daniel Wennberg +ylemkimon +Akash Vaish +Peter Enenkel +Waldir Pimenta +Jithin D. George +Lev Chelyadinov +Lucas Wiman +Rhea Parekh +James Cotton +Robert Pollak +anca-mc +Sourav Ghosh +Jonathan Allan +Nikhil Pappu +Ethan Ward +Cezary Marczak +dps7ud +Nilabja Bhattacharya +Itay4 <31018228+Itay4@users.noreply.github.com> +Poom Chiarawongse +Yang Yang +Cavendish McKay +Bradley Gannon +B McG +Rob Drynkin +Seth Ebner +Akash Kundu +Mark Jeromin +Roberto Díaz Pérez +Gleb Siroki +Segev Finer +Alex Lubbock +Ayodeji Ige +Matthew Wardrop +Hugo van Kemenade +Austin Palmer +der-blaue-elefant +Filip Gokstorp +Yuki Matsuda +Aaron Miller +Salil Vishnu Kapur +Atharva Khare +Shubham Maheshwari +Pavel Tkachenko +Ashish Kumar Gaurav +Rajeev Singh +Keno Goertz +Lucas Gallindo +Himanshu +David Menéndez Hurtado +Amit Manchanda +Rohit Jain +Jonathan A. Gross +Unknown +Sayan Goswami +Subhash Saurabh +Rastislav Rabatin +Vishal +Jeremey Gluck +Akshat Maheshwari +symbolique +Saloni Jain +Arighna Chakrabarty +Abhigyan Khaund +Jashanpreet Singh +Saurabh Agarwal +luzpaz +P. Sai Prasanth +Nirmal Sarswat +Cristian Di Pietrantonio +Ravi charan +Nityananda Gohain +Cédric Travelletti +Nicholas Bollweg +Himanshu Ladia +Adwait Baokar +Mihail Tarigradschi +Saketh +rushyam +sfoo +Rahil Hastu +Zach Raines +Sidhant Nagpal +Gagandeep Singh +Rishav Chakraborty +Malkhan Singh +Joaquim Monserrat +Mayank Singh +Rémy Léone +Maxence Mayrand <35958639+maxencemayrand@users.noreply.github.com> +Nikoleta Glynatsi +helo9 +Ken Wakita +Carl Sandrock +Fredrik Eriksson +Ian Swire +Bulat +Ehren Metcalfe +Dmitry Savransky +Kiyohito Yamazaki +Caley Finn +Zhi-Qiang Zhou +Alexander Pozdneev +Wes Turner <50891+westurner@users.noreply.github.com> +JMSS-Unknown <31131631+JMSS-Unknown@users.noreply.github.com> +Arshdeep Singh +cym1 <16437732+cym1@users.noreply.github.com> +Stewart Wadsworth +Jared Lumpe +Avi Shrivastava +ramvenkat98 +Bilal Ahmed +Dimas Abreu Archanjo Dutra +Yatna Verma +S.Y. Lee +Miro Hrončok +Sudarshan Kamath +Ayushman Koul +Robert Dougherty-Bliss +Andrey Grozin +Bavish Kulur +Arun Singh +sirnicolaf <43586954+sirnicolaf@users.noreply.github.com> +Zachariah Etienne +Prayush Dawda <35144226+iamprayush@users.noreply.github.com> +2torus +Faisal Riyaz +Martin Roelfs +SirJohnFranklin +Anthony Sottile +ViacheslavP +Safiya03 +Alexander Dunlap +Rohit Sharma <31184621+rohitx007@users.noreply.github.com> +Jonathan Warner +Mohit Balwani +Marduk Bolaños +amsuhane +Matthias Geier +klaasvanaarsen <44929042+klaasvanaarsen@users.noreply.github.com> +Shubham Kumar Jha +rationa-kunal +Animesh Sinha +Gaurang Tandon <1gaurangtandon@gmail.com> +Matthew Craven +Daniel Ingram +Jogi Miglani +Takumasa Nakamura +Ritu Raj Singh +Rajiv Ranjan Singh +Vera Lozhkina +adhoc-king <46354827+adhoc-king@users.noreply.github.com> +Mikel Rouco +Oscar Gustafsson +damianos +Supreet Agrawal +shiksha11 +Martin Ueding +sharma-kunal +Divyanshu Thakur +Susumu Ishizuka +Samnan Rahee +Fredrik Andersson +Bhavya Srivastava +Alpesh Jamgade +Shubham Abhang +Vishesh Mangla +Nicko van Someren +dandiez <47832466+dandiez@users.noreply.github.com> +Frédéric Chapoton +jhanwar +Noumbissi valere Gille Geovan +Salmista-94 +Shivani Kohli +Parker Berry +Pragyan Mehrotra +Nabanita Dash +Gaetano Guerriero +Ankit Raj Pandey +Ritesh Kumar +kangzhiq <709563092@qq.com> +Jun Lin +Petr Kungurtsev +Anway De +znxftw +Denis Ivanenko +Orestis Vaggelis +Nikhil Maan +Abhinav Anand +Qingsha Shi +Juan Barbosa +Prionti Nasir +Bharat Raghunathan +arooshiverma +Christoph Gohle +Charalampos Tsiagkalis +Daniel Sears +Megan Ly +Sean P. Cornelius +Erik R. Gomez +Riccardo Magliocchetti +Henry Metlov +pekochun +Bendik Samseth +Vighnesh Shenoy +Versus Void +Denys Rybalka +Mark Dickinson +Rimi +rimibis <33387803+rimibis@users.noreply.github.com> +Steven Lee +Gilles Schintgen +Abhi58 +Tomasz Pytel +Aadit Kamat +Samesh +Velibor Zeli +Gabriel Bernardino +Joseph Redfern +Evelyn King +Miguel Marco +David Hagen +Hannah Kari +Soniya Nayak +Harsh Agarwal +Enric Florit +Yogesh Mishra +Denis Rykov +Ivan Tkachenko +Kenneth Emeka Odoh +Stephan Seitz +Yeshwanth N +Oscar Gerardo Lazo Arjona +Srinivasa Arun Yeragudipati +Kirtan Mali +TitanSnow +Pengning Chao <8857165+PengningChao@users.noreply.github.com> +Louis Abraham +Morten Olsen Lysgaard +Akash Nagaraj (akasnaga) +Akash Nagaraj +Lauren Glattly +Hou-Rui +George Korepanov +dranknight09 +aditisingh2362 +Gina +gregmedlock +Georgios Giapitzakis Tzintanos +Eric Wieser +Bradley Dowling <34559056+btdow@users.noreply.github.com> +Maria Marginean <33810762+mmargin@users.noreply.github.com> +Akash Agrawall +jgulian +Sourav Goyal +Zlatan Vasović +Alex Meiburg +Smit Lunagariya +Naman Gera +Julien Palard +Dhruv Mendiratta +erdOne <36414270+erdOne@users.noreply.github.com> +risubaba +abhinav28071999 <41710346+abhinav28071999@users.noreply.github.com> +Jisoo Song +Jaime R <38530589+Jaime02@users.noreply.github.com> +Vikrant Malik +Hardik Saini <43683678+Guardianofgotham@users.noreply.github.com> +Abhishek +Johannes Hartung +Milan Jolly +faizan2700 +mohit <39158356+mohitacecode@users.noreply.github.com> +Mohit Gupta +Psycho-Pirate +Chanakya-Ekbote +Rashmi Shehana +Jonty16117 +Anubhav Gupta +Michal Grňo +vezeli <37907135+vezeli@users.noreply.github.com> +Tim Gates +Sandeep Murthy +Neil +V1krant <46847915+V1krant@users.noreply.github.com> +alejandro +Riyan Dhiman +sbt4104 +Seth Troisi +Bhaskar Gupta +Smit Gajjar +rbl +Ilya Pchelintsev +Omar Wagih +prshnt19 +Johan Guzman +Vasileios Kalos +BasileiosKal <61801875+BasileiosKal@users.noreply.github.com> +Shubham Thorat <37049710+sbt4104@users.noreply.github.com> +Arpan Chattopadhyay +Ashutosh Hathidara +Moses Paul R +Saanidhya vats +tnzl +Vatsal Srivastava +Jean-Luc Herren +Dhruv Kothari +seadavis <45022599+seadavis@users.noreply.github.com> +kamimura +slacker404 +Jaime Resano +Ebrahim Byagowi +wuyudi +Akira Kyle +Calvin Jay Ross +Martin Thoma +Thomas A Caswell +Lagaras Stelios +Jerry James +Jan Kruse +Nathan Taylor +Vaishnav Damani +Mohit Shah +Mathias Louboutin +Marijan Smetko +Dave Witte Morris +soumi7 +Zhongshi +Wes Galbraith +KaustubhDamania +w495 +Akhil Rajput +Markus Mohrhard +Benjamin Wolba +彭于斌 <1931127624@qq.com> +Rudr Tiwari +Aaryan Dewan +Benedikt Placke +Sneha Goddu +goddus <39923708+goddus@users.noreply.github.com> +Shivang Dubey +Michael Greminger +Peter Cock +Willem Melching +Elias Basler +Brandon David +Abhay_Dhiman +Tasha Kim +Ayush Malik +Devesh Sawant +Wolfgang Stöcher +Sudeep Sidhu +foice +Ben Payne +Muskan Kumar <31043527+muskanvk@users.noreply.github.com> +noam simcha finkelstein +Garrett Folbe +Islam Mansour +Sayandip Halder +Shubham Agrawal +numbermaniac <5206120+numbermaniac@users.noreply.github.com> +Sakirul Alam +Mohammed Bilal +Chris du Plessis +Coder-RG +Ansh Mishra +Alex Malins +Lorenzo Contento +Naveen Sai +Shital Mule +Amanda Dsouza +Nijso Beishuizen +Harry Zheng +Felix Yan +Constantin Mateescu +Eva Tiwari +Aditya Kumar Sinha +Soumi Bardhan <51290447+Soumi7@users.noreply.github.com> +Kaustubh Chaudhari +Kristian Brünn +Neel Gorasiya +Akshat Sood <68052998+akshatsood2249@users.noreply.github.com> +Jose M. Gomez +Stefan Petrea +Praveen Sahu +Mark Bell +AlexCQY +Fabian Froehlich +Nikhil Gopalam +Kartik Sethi +Muhammed Abdul Quadir Owais +Harshit Yadav +Sidharth Mundhra +Suryam Arnav Kalra +Prince Gupta +Kunal Singh +Mayank Raj +Achal Jain <2achaljain@gmail.com> +Mario Maio +Aaron Stiff <69512633+AaronStiff@users.noreply.github.com> +Wyatt Peak +Bhaskar Joshi +Aditya Jindal +Vaibhav Bhat +Priyansh Rathi +Saket Kumar Singh +Yukai Chou +Qijia Liu +Paul Mandel +Nisarg Chaudhari <54911392+Nisarg-Chaudhari@users.noreply.github.com> +Dominik Stańczak +Rodrigo Luger +Marco Antônio Habitzreuter +Ayush Bisht +Akshansh Bhatt +Brandon T. Willard +Thomas Aarholt +Hiren Chalodiya +Roland Dixon +dimasvq +Sagar231 +Michael Chu +Abby Ng +Angad Sandhu <55819847+angadsinghsandhu@users.noreply.github.com> +Alexander Cockburn +Yaser AlOsh +Davide Sandonà +Jonathan Gutow +Nihir Agarwal +Lee Johnston +Zach Carmichael <20629897+craymichael@users.noreply.github.com> +Vijairam Ganesh Moorthy +Hanspeter Schmid +Ben Oostendorp +Nikita +Aman +Shashank KS +Aman Sharma +Anup Parikh +Lucy Mountain +Miguel Torres Costa +Rikard Nordgren +Arun sanganal <74652697+ArunSanganal@users.noreply.github.com> +Kamlesh Joshi <72374645+kamleshjoshi8102@users.noreply.github.com> +Joseph Rance <56409230+Joseph-Rance@users.noreply.github.com> +Huangduirong +Nils Schulte <47043622+Schnilz@users.noreply.github.com> +Matt Bogosian +Elisha Hollander +Aditya Ravuri +Mamidi Ratna Praneeth +Jeffrey Ryan +Jonathan Daniel <36337649+jond01@users.noreply.github.com> +Robin Richard +Gautam Menghani +Remco de Boer <29308176+redeboer@users.noreply.github.com> +Sebastian East +Evani Balasubramanyam +Rahil Parikh +Jason Ross +Joannah Nanjekye +Ayush Kumar +Kshitij +Daniel Hyams +alijosephine +Matthias Köppe +mohajain +Anibal M. Medina-Mardones +Travis Ens +Evgenia Karunus +Risiraj Dey +lastcodestanding +Andrey Lekar +Abbas Mohammed <42001049+iam-abbas@users.noreply.github.com> +Anutosh Bhat +Steve Kieffer +Paul Spiering +Pieter Gijsbers +Wang Ran (汪然) +naelsondouglas +Aman Thakur +S. Hanko +Dennis Sweeney +Gurpartap Singh +Hampus Malmberg +scimax +Nikhil Date +Kuldeep Borkar Jr +AkuBrain <76952313+Franck2111@users.noreply.github.com> +Leo Battle +Advait Pote +Anurag Bhat +Jeremy Monat +Diane Tchuindjo +Tom Fryers <61272761+TomFryers@users.noreply.github.com> +Zouhair +zzj <29055749+zjzh@users.noreply.github.com> +shubhayu09 +Siddhant Jain +Tirthankar Mazumder <63574588+wermos@users.noreply.github.com> +Sumit Kumar +Shivam Sagar +Gaurav Jain +Andrii Oriekhov +Luis Talavera +Arie Bovenberg +Carson McManus +Jack Schmidt <1107865+jackschmidt@users.noreply.github.com> +Riley Britten +Georges Khaznadar +Donald Wilson +Timo Stienstra +dispasha +Saksham Alok +Varenyam Bhardwaj +oittaa <8972248+oittaa@users.noreply.github.com> +Omkaar <79257339+Pysics@users.noreply.github.com> +Islem BOUZENIA +extraymond +Alexander Behrens +user202729 <25191436+user202729@users.noreply.github.com> +Pieter Eendebak +Zaz Brown +ritikBhandari +viocha <66580331+viocha@users.noreply.github.com> +Arthur Ryman +Xiang Wu +tttc3 +Seth Poulsen +cocolato +Anton Golovanov +Gareth Ma +Clément M.T. Robert +Glenn Horton-Smith +Karan +Stefan Behnle <84378403+behnle@users.noreply.github.com> +Shreyash Mishra <72146041+Shreyash-cyber@users.noreply.github.com> +Arthur Milchior +NotWearingPants <26556598+NotWearingPants@users.noreply.github.com> +Ishan Pandhare +Carlos García Montoro +Parcly Taxel +Saicharan +Kunal Sheth +Biswadeep Purkayastha <98874428+metabiswadeep@users.noreply.github.com> +Jyn Spring 琴春 +Phil LeMaitre +Chris Kerr +José Senart +Uwe L. Korn +ForeverHaibara <69423537+ForeverHaibara@users.noreply.github.com> +Yves Tumushimire +wookie184 +Costor +Klaus Rettinghaus +Sam Brockie +Abhishek Patidar <1e9abhi1e10@gmail.com> +Eric Demer +Pontus von Brömssen +Victor Immanuel +Evandro Bernardes +Michele Ceccacci +Ayush Aryan +Kishore Gopalakrishnan +Jan-Philipp Hoffmann +Daiki Takahashi +Sayan Mitra +Aman Kumar Shukla +Zoufiné Lauer-Baré +Charles Harris +Tejaswini Sanapathi +Devansh +Aaron Gokaslan +Daan Koning (he/him) +Steven Burns +Jay Patankar +Vivek Soni +Le Cong Minh Hieu +Sam Ritchie +Maciej Skórski +Tilo Reneau-Cardoso +Laurence Warne +Lukas Molleman +Konstantinos Riganas +Grace Su +Pedro Rosa +Abhinav Cillanki +Baiyuan Qiu <1061688677@qq.com> +Liwei Cai +Daniel Weindl +Isidora Araya +Seb Tiburzio +Victory Omole +Abhishek Chaudhary +Alexander Zhura +Shuai Zhou +Martin Manns +John Möller +zzc <1378113190@qq.com> +Pablo Galindo Salgado +Johannes Kasimir +Theodore Dias +Kaustubh <90597818+kaustubh-765@users.noreply.github.com> +Idan Pazi +Bobby Palmer +Saikat Das +Suman mondal +Taylan Sahin +Fabio Luporini +Oriel Malihi +Geetika Vadali +Matthias Rettl +Mikhail Remnev +philwillnyc <56197213+philwillnyc@users.noreply.github.com> +Raphael Lehner +Harry Mountain +Bhavik Sachdev +袁野 (Yuan Ye) +fazledyn-or +mohammedouahman +K. Kraus +Zac Hatfield-Dodds +platypus +codecruisader +James Whitehead +atharvParlikar +Ivan Petukhov +Augusto Borges +Han Wei Ang +Congxu Yang +Saicharan <62512681+saicharan2804@users.noreply.github.com> +Arnab Nandi +Harrison Oates <48871176+HarrisonOates@users.noreply.github.com> +Corey Cerovsek +Harsh Kasat +omahs <73983677+omahs@users.noreply.github.com> +Pascal Gitz +Ravindu-Hirimuthugoda +Sophia Pustova +George Pittock +Warren Jacinto +Sachin Singh +Zedmat <104870914+harshkasat@users.noreply.github.com> +Soumendra Ganguly +Samith Karunathilake <55777141+samithkavishke@users.noreply.github.com> +Viraj Vekaria +Shishir Kushwaha +Ankit Kumar Singh +Abhishek Kumar +Mohak Malviya +Matthias Liesenfeld <116307294+maliesen@users.noreply.github.com> +dodo +Mohamed Rezk +Tommaso Vaccari <05-gesto-follemente@icloud.com> +Alexis Schotte +Lauren Yim <31467609+cherryblossom000@users.noreply.github.com> +Prey Patel +Riccardo Di Girolamo +Abhishek kumar +Sam Lubelsky +Henrique Soares +Vladimir Sereda +Hwayeon Kang +Raj Sapale +Gerald Teschl +Richard Samuel <98638849+samuelard7@users.noreply.github.com> +HeeJae Chang +Nick Harder +Ethan DeGuire +Lorenz Winkler +Richard Rodenbusch +Zhenxu Zhu +Mark van Gelder +Mark van Gelder +Ishan Pandhare <91841626+Ishanned@users.noreply.github.com> +James A. Preiss +Emile Fourcini +Alberto Jiménez Ruiz +João Bravo +Dean Price +Edward Z. Yang +James Titus +Zhuoyuan Li +Hugo Kerstens +Jan Jancar +Andrew Mosson +Marek Madejski +Gonzalo Tornaría +Peter Stahlecker +Jean-François B <2589111+jfbu@users.noreply.github.com> +Zexuan Zhou (Bruce) +George Frolov +Corbet Elkins +Håkon Kvernmoen +Muhammad Maaz +Shishir Kushwaha <138311586+shishir-11@users.noreply.github.com> +Matt Wang +bharatAmeria <21001019007@jcboseust.ac.in> +Amir Ebrahimi +Steven Esquea +Rishabh Kamboj <111004091+VectorNd@users.noreply.github.com> +Aasim Ali +Ivan A. Melnikov +Borek Saheli +Guido Roncarolo +Quek Zi Yao +Roelof Rietbroek +MostafaGalal1 +Au Huishan +Kris Katterjohn +Shiyao Guo +Rushabh Mehta +Temiloluwa Yusuf ytemiloluwa@gmail.com ytemiloluwa +Davi Laerte +Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> +Harshit Gupta +Praveen Perumal +Kevin McWhirter +Prayag V +Lucas Kletzander +Pratyksh Gupta +Leonardo Mangani +Karan Anand +Gagan Mishra +Krishnav Bajoria +Matt Ord +Jatin Bhardwaj +Prashant Tandon +Paramjit Singh +João Rodrigues +Alejandro García Prada <114813960+AlexGarciaPrada@users.noreply.github.com> +Matthew Treinish +Clayton Rabideau +Victoria Koval +Voaides Negustor Robert <134785947+voaidesr@users.noreply.github.com> +Ovsk Mendov +David Brooks +Nicholas Laustrup <124007393+nicklaustrup@users.noreply.github.com> +Harikrishna Srinivasan +Mathis Cros +Arnav Mummineni <45217840+RCoder01@users.noreply.github.com> +Thangaraju Sibiraj <85477603+t-sibiraj@users.noreply.github.com> +KJaybhaye diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0744f229d697ca3ed1b1b257bfdb70e3eecf0b9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy-1.14.0.dist-info/licenses/LICENSE @@ -0,0 +1,153 @@ +Copyright (c) 2006-2023 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +Patches that were taken from the Diofant project (https://github.com/diofant/diofant) +are licensed as: + +Copyright (c) 2006-2018 SymPy Development Team, + 2013-2023 Sergey B Kirpichev + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of Diofant or SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +Submodules taken from the multipledispatch project (https://github.com/mrocklin/multipledispatch) +are licensed as: + +Copyright (c) 2014 Matthew Rocklin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of multipledispatch nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +-------------------------------------------------------------------------------- + +The files under the directory sympy/parsing/autolev/tests/pydy-example-repo +are directly copied from PyDy project and are licensed as: + +Copyright (c) 2009-2023, PyDy Authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of this project nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL PYDY AUTHORS BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +The files under the directory sympy/parsing/latex +are directly copied from latex2sympy project and are licensed as: + +Copyright 2016, latex2sympy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d7d9b1b72796078ae54cbedf9ae763ba3fc35e0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/abc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/abc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46b93cb89f95cc3515486ce12dd118ec606f3e0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/abc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/conftest.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/conftest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bed317f3380f622795976e5a819040f7b1c43af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/conftest.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/galgebra.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/galgebra.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d741a4746bc0d0a52e6614ce286e988d4e97f22 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/galgebra.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/release.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/release.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69484bdae335a484a4f052484637540072fc1512 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/release.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/this.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/this.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62e603cc19b4756cc9bb914c4dc553c022e40eb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/__pycache__/this.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af862653f3ce0eeb67f7764e16c32f3466e87024 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__init__.py @@ -0,0 +1,18 @@ +""" +A module to implement logical predicates and assumption system. +""" + +from .assume import ( + AppliedPredicate, Predicate, AssumptionsContext, assuming, + global_assumptions +) +from .ask import Q, ask, register_handler, remove_handler +from .refine import refine +from .relation import BinaryRelation, AppliedBinaryRelation + +__all__ = [ + 'AppliedPredicate', 'Predicate', 'AssumptionsContext', 'assuming', + 'global_assumptions', 'Q', 'ask', 'register_handler', 'remove_handler', + 'refine', + 'BinaryRelation', 'AppliedBinaryRelation' +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a4f055ce759bebac5c1563108cdc1ae0ea6e1ec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4440ec50a7b5992bc88ed20b1b4af6e80d0bff5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask_generated.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask_generated.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e545ab467acb14d3d5ce727a75776ed3d5ebd5c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/ask_generated.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/assume.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/assume.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7e8d7a487fc2e6c8502d6aa66828903f3abc4d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/assume.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/cnf.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/cnf.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..005bcb20d6d6b645a85a02dbc1d391d8c1a25a50 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/cnf.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/facts.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/facts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b48d701e8dd077232085103e1436a3d1c6c34f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/facts.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/lra_satask.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/lra_satask.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11901dda7b3d969f4824e11547fa809e8ac700ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/lra_satask.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/refine.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/refine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca7333bb5ce18edc0e0bde26a55d87c5edd89eae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/refine.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/satask.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/satask.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b45565e69441f74baeca51bec3ca8b6572fb124 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/satask.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/sathandlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/sathandlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97797ebc8f9feba2e76da9bed35191a5848c52b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/sathandlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d365494156dd1b0be4e23f623b81787cdce6e758 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/__pycache__/wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81ec8ecce245c2a798cf9e71af1e9373292bc1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask.py @@ -0,0 +1,651 @@ +"""Module for querying SymPy objects about assumptions.""" + +from sympy.assumptions.assume import (global_assumptions, Predicate, + AppliedPredicate) +from sympy.assumptions.cnf import CNF, EncodedCNF, Literal +from sympy.core import sympify +from sympy.core.kind import BooleanKind +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.logic.inference import satisfiable +from sympy.utilities.decorator import memoize_property +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) + + +# Memoization is necessary for the properties of AssumptionKeys to +# ensure that only one object of Predicate objects are created. +# This is because assumption handlers are registered on those objects. + + +class AssumptionKeys: + """ + This class contains all the supported keys by ``ask``. + It should be accessed via the instance ``sympy.Q``. + + """ + + # DO NOT add methods or properties other than predicate keys. + # SAT solver checks the properties of Q and use them to compute the + # fact system. Non-predicate attributes will break this. + + @memoize_property + def hermitian(self): + from .handlers.sets import HermitianPredicate + return HermitianPredicate() + + @memoize_property + def antihermitian(self): + from .handlers.sets import AntihermitianPredicate + return AntihermitianPredicate() + + @memoize_property + def real(self): + from .handlers.sets import RealPredicate + return RealPredicate() + + @memoize_property + def extended_real(self): + from .handlers.sets import ExtendedRealPredicate + return ExtendedRealPredicate() + + @memoize_property + def imaginary(self): + from .handlers.sets import ImaginaryPredicate + return ImaginaryPredicate() + + @memoize_property + def complex(self): + from .handlers.sets import ComplexPredicate + return ComplexPredicate() + + @memoize_property + def algebraic(self): + from .handlers.sets import AlgebraicPredicate + return AlgebraicPredicate() + + @memoize_property + def transcendental(self): + from .predicates.sets import TranscendentalPredicate + return TranscendentalPredicate() + + @memoize_property + def integer(self): + from .handlers.sets import IntegerPredicate + return IntegerPredicate() + + @memoize_property + def noninteger(self): + from .predicates.sets import NonIntegerPredicate + return NonIntegerPredicate() + + @memoize_property + def rational(self): + from .handlers.sets import RationalPredicate + return RationalPredicate() + + @memoize_property + def irrational(self): + from .handlers.sets import IrrationalPredicate + return IrrationalPredicate() + + @memoize_property + def finite(self): + from .handlers.calculus import FinitePredicate + return FinitePredicate() + + @memoize_property + def infinite(self): + from .handlers.calculus import InfinitePredicate + return InfinitePredicate() + + @memoize_property + def positive_infinite(self): + from .handlers.calculus import PositiveInfinitePredicate + return PositiveInfinitePredicate() + + @memoize_property + def negative_infinite(self): + from .handlers.calculus import NegativeInfinitePredicate + return NegativeInfinitePredicate() + + @memoize_property + def positive(self): + from .handlers.order import PositivePredicate + return PositivePredicate() + + @memoize_property + def negative(self): + from .handlers.order import NegativePredicate + return NegativePredicate() + + @memoize_property + def zero(self): + from .handlers.order import ZeroPredicate + return ZeroPredicate() + + @memoize_property + def extended_positive(self): + from .handlers.order import ExtendedPositivePredicate + return ExtendedPositivePredicate() + + @memoize_property + def extended_negative(self): + from .handlers.order import ExtendedNegativePredicate + return ExtendedNegativePredicate() + + @memoize_property + def nonzero(self): + from .handlers.order import NonZeroPredicate + return NonZeroPredicate() + + @memoize_property + def nonpositive(self): + from .handlers.order import NonPositivePredicate + return NonPositivePredicate() + + @memoize_property + def nonnegative(self): + from .handlers.order import NonNegativePredicate + return NonNegativePredicate() + + @memoize_property + def extended_nonzero(self): + from .handlers.order import ExtendedNonZeroPredicate + return ExtendedNonZeroPredicate() + + @memoize_property + def extended_nonpositive(self): + from .handlers.order import ExtendedNonPositivePredicate + return ExtendedNonPositivePredicate() + + @memoize_property + def extended_nonnegative(self): + from .handlers.order import ExtendedNonNegativePredicate + return ExtendedNonNegativePredicate() + + @memoize_property + def even(self): + from .handlers.ntheory import EvenPredicate + return EvenPredicate() + + @memoize_property + def odd(self): + from .handlers.ntheory import OddPredicate + return OddPredicate() + + @memoize_property + def prime(self): + from .handlers.ntheory import PrimePredicate + return PrimePredicate() + + @memoize_property + def composite(self): + from .handlers.ntheory import CompositePredicate + return CompositePredicate() + + @memoize_property + def commutative(self): + from .handlers.common import CommutativePredicate + return CommutativePredicate() + + @memoize_property + def is_true(self): + from .handlers.common import IsTruePredicate + return IsTruePredicate() + + @memoize_property + def symmetric(self): + from .handlers.matrices import SymmetricPredicate + return SymmetricPredicate() + + @memoize_property + def invertible(self): + from .handlers.matrices import InvertiblePredicate + return InvertiblePredicate() + + @memoize_property + def orthogonal(self): + from .handlers.matrices import OrthogonalPredicate + return OrthogonalPredicate() + + @memoize_property + def unitary(self): + from .handlers.matrices import UnitaryPredicate + return UnitaryPredicate() + + @memoize_property + def positive_definite(self): + from .handlers.matrices import PositiveDefinitePredicate + return PositiveDefinitePredicate() + + @memoize_property + def upper_triangular(self): + from .handlers.matrices import UpperTriangularPredicate + return UpperTriangularPredicate() + + @memoize_property + def lower_triangular(self): + from .handlers.matrices import LowerTriangularPredicate + return LowerTriangularPredicate() + + @memoize_property + def diagonal(self): + from .handlers.matrices import DiagonalPredicate + return DiagonalPredicate() + + @memoize_property + def fullrank(self): + from .handlers.matrices import FullRankPredicate + return FullRankPredicate() + + @memoize_property + def square(self): + from .handlers.matrices import SquarePredicate + return SquarePredicate() + + @memoize_property + def integer_elements(self): + from .handlers.matrices import IntegerElementsPredicate + return IntegerElementsPredicate() + + @memoize_property + def real_elements(self): + from .handlers.matrices import RealElementsPredicate + return RealElementsPredicate() + + @memoize_property + def complex_elements(self): + from .handlers.matrices import ComplexElementsPredicate + return ComplexElementsPredicate() + + @memoize_property + def singular(self): + from .predicates.matrices import SingularPredicate + return SingularPredicate() + + @memoize_property + def normal(self): + from .predicates.matrices import NormalPredicate + return NormalPredicate() + + @memoize_property + def triangular(self): + from .predicates.matrices import TriangularPredicate + return TriangularPredicate() + + @memoize_property + def unit_triangular(self): + from .predicates.matrices import UnitTriangularPredicate + return UnitTriangularPredicate() + + @memoize_property + def eq(self): + from .relation.equality import EqualityPredicate + return EqualityPredicate() + + @memoize_property + def ne(self): + from .relation.equality import UnequalityPredicate + return UnequalityPredicate() + + @memoize_property + def gt(self): + from .relation.equality import StrictGreaterThanPredicate + return StrictGreaterThanPredicate() + + @memoize_property + def ge(self): + from .relation.equality import GreaterThanPredicate + return GreaterThanPredicate() + + @memoize_property + def lt(self): + from .relation.equality import StrictLessThanPredicate + return StrictLessThanPredicate() + + @memoize_property + def le(self): + from .relation.equality import LessThanPredicate + return LessThanPredicate() + + +Q = AssumptionKeys() + +def _extract_all_facts(assump, exprs): + """ + Extract all relevant assumptions from *assump* with respect to given *exprs*. + + Parameters + ========== + + assump : sympy.assumptions.cnf.CNF + + exprs : tuple of expressions + + Returns + ======= + + sympy.assumptions.cnf.CNF + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.ask import _extract_all_facts + >>> from sympy.abc import x, y + >>> assump = CNF.from_prop(Q.positive(x) & Q.integer(y)) + >>> exprs = (x,) + >>> cnf = _extract_all_facts(assump, exprs) + >>> cnf.clauses + {frozenset({Literal(Q.positive, False)})} + + """ + facts = set() + + for clause in assump.clauses: + args = [] + for literal in clause: + if isinstance(literal.lit, AppliedPredicate) and len(literal.lit.arguments) == 1: + if literal.lit.arg in exprs: + # Add literal if it has matching in it + args.append(Literal(literal.lit.function, literal.is_Not)) + else: + # If any of the literals doesn't have matching expr don't add the whole clause. + break + else: + # If any of the literals aren't unary predicate don't add the whole clause. + break + + else: + if args: + facts.add(frozenset(args)) + return CNF(facts) + + +def ask(proposition, assumptions=True, context=global_assumptions): + """ + Function to evaluate the proposition with assumptions. + + Explanation + =========== + + This function evaluates the proposition to ``True`` or ``False`` if + the truth value can be determined. If not, it returns ``None``. + + It should be discerned from :func:`~.refine` which, when applied to a + proposition, simplifies the argument to symbolic ``Boolean`` instead of + Python built-in ``True``, ``False`` or ``None``. + + **Syntax** + + * ask(proposition) + Evaluate the *proposition* in global assumption context. + + * ask(proposition, assumptions) + Evaluate the *proposition* with respect to *assumptions* in + global assumption context. + + Parameters + ========== + + proposition : Boolean + Proposition which will be evaluated to boolean value. If this is + not ``AppliedPredicate``, it will be wrapped by ``Q.is_true``. + + assumptions : Boolean, optional + Local assumptions to evaluate the *proposition*. + + context : AssumptionsContext, optional + Default assumptions to evaluate the *proposition*. By default, + this is ``sympy.assumptions.global_assumptions`` variable. + + Returns + ======= + + ``True``, ``False``, or ``None`` + + Raises + ====== + + TypeError : *proposition* or *assumptions* is not valid logical expression. + + ValueError : assumptions are inconsistent. + + Examples + ======== + + >>> from sympy import ask, Q, pi + >>> from sympy.abc import x, y + >>> ask(Q.rational(pi)) + False + >>> ask(Q.even(x*y), Q.even(x) & Q.integer(y)) + True + >>> ask(Q.prime(4*x), Q.integer(x)) + False + + If the truth value cannot be determined, ``None`` will be returned. + + >>> print(ask(Q.odd(3*x))) # cannot determine unless we know x + None + + ``ValueError`` is raised if assumptions are inconsistent. + + >>> ask(Q.integer(x), Q.even(x) & Q.odd(x)) + Traceback (most recent call last): + ... + ValueError: inconsistent assumptions Q.even(x) & Q.odd(x) + + Notes + ===== + + Relations in assumptions are not implemented (yet), so the following + will not give a meaningful result. + + >>> ask(Q.positive(x), x > 0) + + It is however a work in progress. + + See Also + ======== + + sympy.assumptions.refine.refine : Simplification using assumptions. + Proposition is not reduced to ``None`` if the truth value cannot + be determined. + """ + from sympy.assumptions.satask import satask + from sympy.assumptions.lra_satask import lra_satask + from sympy.logic.algorithms.lra_theory import UnhandledInput + + proposition = sympify(proposition) + assumptions = sympify(assumptions) + + if isinstance(proposition, Predicate) or proposition.kind is not BooleanKind: + raise TypeError("proposition must be a valid logical expression") + + if isinstance(assumptions, Predicate) or assumptions.kind is not BooleanKind: + raise TypeError("assumptions must be a valid logical expression") + + binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + if isinstance(proposition, AppliedPredicate): + key, args = proposition.function, proposition.arguments + elif proposition.func in binrelpreds: + key, args = binrelpreds[type(proposition)], proposition.args + else: + key, args = Q.is_true, (proposition,) + + # convert local and global assumptions to CNF + assump_cnf = CNF.from_prop(assumptions) + assump_cnf.extend(context) + + # extract the relevant facts from assumptions with respect to args + local_facts = _extract_all_facts(assump_cnf, args) + + # convert default facts and assumed facts to encoded CNF + known_facts_cnf = get_all_known_facts() + enc_cnf = EncodedCNF() + enc_cnf.from_cnf(CNF(known_facts_cnf)) + enc_cnf.add_from_cnf(local_facts) + + # check the satisfiability of given assumptions + if local_facts.clauses and satisfiable(enc_cnf) is False: + raise ValueError("inconsistent assumptions %s" % assumptions) + + # quick computation for single fact + res = _ask_single_fact(key, local_facts) + if res is not None: + return res + + # direct resolution method, no logic + res = key(*args)._eval_ask(assumptions) + if res is not None: + return bool(res) + + # using satask (still costly) + res = satask(proposition, assumptions=assumptions, context=context) + if res is not None: + return res + + try: + res = lra_satask(proposition, assumptions=assumptions, context=context) + except UnhandledInput: + return None + + return res + + +def _ask_single_fact(key, local_facts): + """ + Compute the truth value of single predicate using assumptions. + + Parameters + ========== + + key : sympy.assumptions.assume.Predicate + Proposition predicate. + + local_facts : sympy.assumptions.cnf.CNF + Local assumption in CNF form. + + Returns + ======= + + ``True``, ``False`` or ``None`` + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.ask import _ask_single_fact + + If prerequisite of proposition is rejected by the assumption, + return ``False``. + + >>> key, assump = Q.zero, ~Q.zero + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + >>> key, assump = Q.zero, ~Q.even + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + + If assumption implies the proposition, return ``True``. + + >>> key, assump = Q.even, Q.zero + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + True + + If proposition rejects the assumption, return ``False``. + + >>> key, assump = Q.even, Q.odd + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + """ + if local_facts.clauses: + + known_facts_dict = get_known_facts_dict() + + if len(local_facts.clauses) == 1: + cl, = local_facts.clauses + if len(cl) == 1: + f, = cl + prop_facts = known_facts_dict.get(key, None) + prop_req = prop_facts[0] if prop_facts is not None else set() + if f.is_Not and f.arg in prop_req: + # the prerequisite of proposition is rejected + return False + + for clause in local_facts.clauses: + if len(clause) == 1: + f, = clause + prop_facts = known_facts_dict.get(f.arg, None) if not f.is_Not else None + if prop_facts is None: + continue + + prop_req, prop_rej = prop_facts + if key in prop_req: + # assumption implies the proposition + return True + elif key in prop_rej: + # proposition rejects the assumption + return False + + return None + + +def register_handler(key, handler): + """ + Register a handler in the ask system. key must be a string and handler a + class inheriting from AskHandler. + + .. deprecated:: 1.8. + Use multipledispatch handler instead. See :obj:`~.Predicate`. + + """ + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. The register_handler() function + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + if isinstance(key, Predicate): + key = key.name.name + Qkey = getattr(Q, key, None) + if Qkey is not None: + Qkey.add_handler(handler) + else: + setattr(Q, key, Predicate(key, handlers=[handler])) + + +def remove_handler(key, handler): + """ + Removes a handler from the ask system. + + .. deprecated:: 1.8. + Use multipledispatch handler instead. See :obj:`~.Predicate`. + + """ + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. The remove_handler() function + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + if isinstance(key, Predicate): + key = key.name.name + # Don't show the same warning again recursively + with ignore_warnings(SymPyDeprecationWarning): + getattr(Q, key).remove_handler(handler) + + +from sympy.assumptions.ask_generated import (get_all_known_facts, + get_known_facts_dict) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask_generated.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask_generated.py new file mode 100644 index 0000000000000000000000000000000000000000..d90cdffc1e127d78e18f70cda13d8d5e0530d41b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/ask_generated.py @@ -0,0 +1,352 @@ +""" +Do NOT manually edit this file. +Instead, run ./bin/ask_update.py. +""" + +from sympy.assumptions.ask import Q +from sympy.assumptions.cnf import Literal +from sympy.core.cache import cacheit + +@cacheit +def get_all_known_facts(): + """ + Known facts between unary predicates as CNF clauses. + """ + return { + frozenset((Literal(Q.algebraic, False), Literal(Q.imaginary, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.negative, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.positive, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.rational, True))), + frozenset((Literal(Q.algebraic, False), Literal(Q.transcendental, False), Literal(Q.zero, True))), + frozenset((Literal(Q.algebraic, True), Literal(Q.finite, False))), + frozenset((Literal(Q.algebraic, True), Literal(Q.transcendental, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.imaginary, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.finite, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.infinite, True))), + frozenset((Literal(Q.complex_elements, False), Literal(Q.real_elements, True))), + frozenset((Literal(Q.composite, False), Literal(Q.even, True), Literal(Q.positive, True), Literal(Q.prime, False))), + frozenset((Literal(Q.composite, True), Literal(Q.even, False), Literal(Q.odd, False))), + frozenset((Literal(Q.composite, True), Literal(Q.positive, False))), + frozenset((Literal(Q.composite, True), Literal(Q.prime, True))), + frozenset((Literal(Q.diagonal, False), Literal(Q.lower_triangular, True), Literal(Q.upper_triangular, True))), + frozenset((Literal(Q.diagonal, True), Literal(Q.lower_triangular, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.normal, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.symmetric, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.even, False), Literal(Q.odd, False), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.zero, True))), + frozenset((Literal(Q.even, True), Literal(Q.odd, True))), + frozenset((Literal(Q.even, True), Literal(Q.rational, False))), + frozenset((Literal(Q.finite, False), Literal(Q.transcendental, True))), + frozenset((Literal(Q.finite, True), Literal(Q.infinite, True))), + frozenset((Literal(Q.fullrank, False), Literal(Q.invertible, True))), + frozenset((Literal(Q.fullrank, True), Literal(Q.invertible, False), Literal(Q.square, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.negative, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.positive, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.negative, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.positive, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.zero, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.integer_elements, True), Literal(Q.real_elements, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.positive_definite, True))), + frozenset((Literal(Q.invertible, False), Literal(Q.singular, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.singular, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.square, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.negative, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.positive, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.rational, False), Literal(Q.zero, True))), + frozenset((Literal(Q.irrational, True), Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.zero, False))), + frozenset((Literal(Q.irrational, True), Literal(Q.rational, True))), + frozenset((Literal(Q.lower_triangular, False), Literal(Q.triangular, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.lower_triangular, True), Literal(Q.triangular, False))), + frozenset((Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.rational, True), Literal(Q.zero, False))), + frozenset((Literal(Q.negative, True), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.zero, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.normal, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.normal, True), Literal(Q.square, False))), + frozenset((Literal(Q.odd, True), Literal(Q.rational, False))), + frozenset((Literal(Q.orthogonal, False), Literal(Q.real_elements, True), Literal(Q.unitary, True))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.positive_definite, False))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.unitary, False))), + frozenset((Literal(Q.positive, False), Literal(Q.prime, True))), + frozenset((Literal(Q.positive, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.positive, True), Literal(Q.zero, True))), + frozenset((Literal(Q.positive_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.square, False), Literal(Q.symmetric, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.unit_triangular, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.upper_triangular, True))) + } + +@cacheit +def get_all_known_matrix_facts(): + """ + Known facts between unary predicates for matrices as CNF clauses. + """ + return { + frozenset((Literal(Q.complex_elements, False), Literal(Q.real_elements, True))), + frozenset((Literal(Q.diagonal, False), Literal(Q.lower_triangular, True), Literal(Q.upper_triangular, True))), + frozenset((Literal(Q.diagonal, True), Literal(Q.lower_triangular, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.normal, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.symmetric, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.fullrank, False), Literal(Q.invertible, True))), + frozenset((Literal(Q.fullrank, True), Literal(Q.invertible, False), Literal(Q.square, True))), + frozenset((Literal(Q.integer_elements, True), Literal(Q.real_elements, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.positive_definite, True))), + frozenset((Literal(Q.invertible, False), Literal(Q.singular, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.singular, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.square, False))), + frozenset((Literal(Q.lower_triangular, False), Literal(Q.triangular, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.lower_triangular, True), Literal(Q.triangular, False))), + frozenset((Literal(Q.normal, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.normal, True), Literal(Q.square, False))), + frozenset((Literal(Q.orthogonal, False), Literal(Q.real_elements, True), Literal(Q.unitary, True))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.positive_definite, False))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.unitary, False))), + frozenset((Literal(Q.square, False), Literal(Q.symmetric, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.unit_triangular, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.upper_triangular, True))) + } + +@cacheit +def get_all_known_number_facts(): + """ + Known facts between unary predicates for numbers as CNF clauses. + """ + return { + frozenset((Literal(Q.algebraic, False), Literal(Q.imaginary, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.negative, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.positive, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.rational, True))), + frozenset((Literal(Q.algebraic, False), Literal(Q.transcendental, False), Literal(Q.zero, True))), + frozenset((Literal(Q.algebraic, True), Literal(Q.finite, False))), + frozenset((Literal(Q.algebraic, True), Literal(Q.transcendental, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.imaginary, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.finite, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.infinite, True))), + frozenset((Literal(Q.composite, False), Literal(Q.even, True), Literal(Q.positive, True), Literal(Q.prime, False))), + frozenset((Literal(Q.composite, True), Literal(Q.even, False), Literal(Q.odd, False))), + frozenset((Literal(Q.composite, True), Literal(Q.positive, False))), + frozenset((Literal(Q.composite, True), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.odd, False), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.zero, True))), + frozenset((Literal(Q.even, True), Literal(Q.odd, True))), + frozenset((Literal(Q.even, True), Literal(Q.rational, False))), + frozenset((Literal(Q.finite, False), Literal(Q.transcendental, True))), + frozenset((Literal(Q.finite, True), Literal(Q.infinite, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.negative, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.positive, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.negative, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.positive, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.zero, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.irrational, False), Literal(Q.negative, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.positive, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.rational, False), Literal(Q.zero, True))), + frozenset((Literal(Q.irrational, True), Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.zero, False))), + frozenset((Literal(Q.irrational, True), Literal(Q.rational, True))), + frozenset((Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.rational, True), Literal(Q.zero, False))), + frozenset((Literal(Q.negative, True), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.zero, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.odd, True), Literal(Q.rational, False))), + frozenset((Literal(Q.positive, False), Literal(Q.prime, True))), + frozenset((Literal(Q.positive, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.positive, True), Literal(Q.zero, True))), + frozenset((Literal(Q.positive_infinite, True), Literal(Q.zero, True))) + } + +@cacheit +def get_known_facts_dict(): + """ + Logical relations between unary predicates as dictionary. + + Each key is a predicate, and item is two groups of predicates. + First group contains the predicates which are implied by the key, and + second group contains the predicates which are rejected by the key. + + """ + return { + Q.algebraic: (set([Q.algebraic, Q.commutative, Q.complex, Q.finite]), + set([Q.infinite, Q.negative_infinite, Q.positive_infinite, + Q.transcendental])), + Q.antihermitian: (set([Q.antihermitian]), set([])), + Q.commutative: (set([Q.commutative]), set([])), + Q.complex: (set([Q.commutative, Q.complex, Q.finite]), + set([Q.infinite, Q.negative_infinite, Q.positive_infinite])), + Q.complex_elements: (set([Q.complex_elements]), set([])), + Q.composite: (set([Q.algebraic, Q.commutative, Q.complex, Q.composite, + Q.extended_nonnegative, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonnegative, Q.nonzero, Q.positive, Q.rational, + Q.real]), set([Q.extended_negative, Q.extended_nonpositive, + Q.imaginary, Q.infinite, Q.irrational, Q.negative, + Q.negative_infinite, Q.nonpositive, Q.positive_infinite, + Q.prime, Q.transcendental, Q.zero])), + Q.diagonal: (set([Q.diagonal, Q.lower_triangular, Q.normal, Q.square, + Q.symmetric, Q.triangular, Q.upper_triangular]), set([])), + Q.even: (set([Q.algebraic, Q.commutative, Q.complex, Q.even, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, Q.rational, + Q.real]), set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.odd, Q.positive_infinite, + Q.transcendental])), + Q.extended_negative: (set([Q.commutative, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real]), + set([Q.composite, Q.extended_nonnegative, Q.extended_positive, + Q.imaginary, Q.nonnegative, Q.positive, Q.positive_infinite, + Q.prime, Q.zero])), + Q.extended_nonnegative: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_real]), set([Q.extended_negative, Q.imaginary, + Q.negative, Q.negative_infinite])), + Q.extended_nonpositive: (set([Q.commutative, Q.extended_nonpositive, + Q.extended_real]), set([Q.composite, Q.extended_positive, + Q.imaginary, Q.positive, Q.positive_infinite, Q.prime])), + Q.extended_nonzero: (set([Q.commutative, Q.extended_nonzero, + Q.extended_real]), set([Q.imaginary, Q.zero])), + Q.extended_positive: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real]), + set([Q.extended_negative, Q.extended_nonpositive, Q.imaginary, + Q.negative, Q.negative_infinite, Q.nonpositive, Q.zero])), + Q.extended_real: (set([Q.commutative, Q.extended_real]), + set([Q.imaginary])), + Q.finite: (set([Q.commutative, Q.finite]), set([Q.infinite, + Q.negative_infinite, Q.positive_infinite])), + Q.fullrank: (set([Q.fullrank]), set([])), + Q.hermitian: (set([Q.hermitian]), set([])), + Q.imaginary: (set([Q.antihermitian, Q.commutative, Q.complex, + Q.finite, Q.imaginary]), set([Q.composite, Q.even, + Q.extended_negative, Q.extended_nonnegative, + Q.extended_nonpositive, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.infinite, Q.integer, + Q.irrational, Q.negative, Q.negative_infinite, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, + Q.positive_infinite, Q.prime, Q.rational, Q.real, Q.zero])), + Q.infinite: (set([Q.commutative, Q.infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.finite, Q.imaginary, + Q.integer, Q.irrational, Q.negative, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.integer: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, Q.rational, + Q.real]), set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental])), + Q.integer_elements: (set([Q.complex_elements, Q.integer_elements, + Q.real_elements]), set([])), + Q.invertible: (set([Q.fullrank, Q.invertible, Q.square]), + set([Q.singular])), + Q.irrational: (set([Q.commutative, Q.complex, Q.extended_nonzero, + Q.extended_real, Q.finite, Q.hermitian, Q.irrational, + Q.nonzero, Q.real]), set([Q.composite, Q.even, Q.imaginary, + Q.infinite, Q.integer, Q.negative_infinite, Q.odd, + Q.positive_infinite, Q.prime, Q.rational, Q.zero])), + Q.is_true: (set([Q.is_true]), set([])), + Q.lower_triangular: (set([Q.lower_triangular, Q.triangular]), set([])), + Q.negative: (set([Q.commutative, Q.complex, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real, + Q.finite, Q.hermitian, Q.negative, Q.nonpositive, Q.nonzero, + Q.real]), set([Q.composite, Q.extended_nonnegative, + Q.extended_positive, Q.imaginary, Q.infinite, + Q.negative_infinite, Q.nonnegative, Q.positive, + Q.positive_infinite, Q.prime, Q.zero])), + Q.negative_infinite: (set([Q.commutative, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real, + Q.infinite, Q.negative_infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.extended_nonnegative, + Q.extended_positive, Q.finite, Q.imaginary, Q.integer, + Q.irrational, Q.negative, Q.nonnegative, Q.nonpositive, + Q.nonzero, Q.odd, Q.positive, Q.positive_infinite, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.noninteger: (set([Q.noninteger]), set([])), + Q.nonnegative: (set([Q.commutative, Q.complex, Q.extended_nonnegative, + Q.extended_real, Q.finite, Q.hermitian, Q.nonnegative, + Q.real]), set([Q.extended_negative, Q.imaginary, Q.infinite, + Q.negative, Q.negative_infinite, Q.positive_infinite])), + Q.nonpositive: (set([Q.commutative, Q.complex, Q.extended_nonpositive, + Q.extended_real, Q.finite, Q.hermitian, Q.nonpositive, + Q.real]), set([Q.composite, Q.extended_positive, Q.imaginary, + Q.infinite, Q.negative_infinite, Q.positive, + Q.positive_infinite, Q.prime])), + Q.nonzero: (set([Q.commutative, Q.complex, Q.extended_nonzero, + Q.extended_real, Q.finite, Q.hermitian, Q.nonzero, Q.real]), + set([Q.imaginary, Q.infinite, Q.negative_infinite, + Q.positive_infinite, Q.zero])), + Q.normal: (set([Q.normal, Q.square]), set([])), + Q.odd: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_nonzero, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonzero, Q.odd, Q.rational, Q.real]), + set([Q.even, Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental, + Q.zero])), + Q.orthogonal: (set([Q.fullrank, Q.invertible, Q.normal, Q.orthogonal, + Q.positive_definite, Q.square, Q.unitary]), set([Q.singular])), + Q.positive: (set([Q.commutative, Q.complex, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real, + Q.finite, Q.hermitian, Q.nonnegative, Q.nonzero, Q.positive, + Q.real]), set([Q.extended_negative, Q.extended_nonpositive, + Q.imaginary, Q.infinite, Q.negative, Q.negative_infinite, + Q.nonpositive, Q.positive_infinite, Q.zero])), + Q.positive_definite: (set([Q.fullrank, Q.invertible, + Q.positive_definite, Q.square]), set([Q.singular])), + Q.positive_infinite: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real, + Q.infinite, Q.positive_infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.extended_negative, + Q.extended_nonpositive, Q.finite, Q.imaginary, Q.integer, + Q.irrational, Q.negative, Q.negative_infinite, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.prime: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_nonnegative, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonnegative, Q.nonzero, Q.positive, Q.prime, + Q.rational, Q.real]), set([Q.composite, Q.extended_negative, + Q.extended_nonpositive, Q.imaginary, Q.infinite, Q.irrational, + Q.negative, Q.negative_infinite, Q.nonpositive, + Q.positive_infinite, Q.transcendental, Q.zero])), + Q.rational: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_real, Q.finite, Q.hermitian, Q.rational, Q.real]), + set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental])), + Q.real: (set([Q.commutative, Q.complex, Q.extended_real, Q.finite, + Q.hermitian, Q.real]), set([Q.imaginary, Q.infinite, + Q.negative_infinite, Q.positive_infinite])), + Q.real_elements: (set([Q.complex_elements, Q.real_elements]), set([])), + Q.singular: (set([Q.singular]), set([Q.invertible, Q.orthogonal, + Q.positive_definite, Q.unitary])), + Q.square: (set([Q.square]), set([])), + Q.symmetric: (set([Q.square, Q.symmetric]), set([])), + Q.transcendental: (set([Q.commutative, Q.complex, Q.finite, + Q.transcendental]), set([Q.algebraic, Q.composite, Q.even, + Q.infinite, Q.integer, Q.negative_infinite, Q.odd, + Q.positive_infinite, Q.prime, Q.rational, Q.zero])), + Q.triangular: (set([Q.triangular]), set([])), + Q.unit_triangular: (set([Q.triangular, Q.unit_triangular]), set([])), + Q.unitary: (set([Q.fullrank, Q.invertible, Q.normal, Q.square, + Q.unitary]), set([Q.singular])), + Q.upper_triangular: (set([Q.triangular, Q.upper_triangular]), set([])), + Q.zero: (set([Q.algebraic, Q.commutative, Q.complex, Q.even, + Q.extended_nonnegative, Q.extended_nonpositive, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, + Q.nonnegative, Q.nonpositive, Q.rational, Q.real, Q.zero]), + set([Q.composite, Q.extended_negative, Q.extended_nonzero, + Q.extended_positive, Q.imaginary, Q.infinite, Q.irrational, + Q.negative, Q.negative_infinite, Q.nonzero, Q.odd, Q.positive, + Q.positive_infinite, Q.prime, Q.transcendental])), + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/assume.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/assume.py new file mode 100644 index 0000000000000000000000000000000000000000..743195a865a1d39389d471b95728ca79834ed019 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/assume.py @@ -0,0 +1,485 @@ +"""A module which implements predicates and assumption context.""" + +from contextlib import contextmanager +import inspect +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.logic.boolalg import Boolean, false, true +from sympy.multipledispatch.dispatcher import Dispatcher, str_signature +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.source import get_class + + +class AssumptionsContext(set): + """ + Set containing default assumptions which are applied to the ``ask()`` + function. + + Explanation + =========== + + This is used to represent global assumptions, but you can also use this + class to create your own local assumptions contexts. It is basically a thin + wrapper to Python's set, so see its documentation for advanced usage. + + Examples + ======== + + The default assumption context is ``global_assumptions``, which is initially empty: + + >>> from sympy import ask, Q + >>> from sympy.assumptions import global_assumptions + >>> global_assumptions + AssumptionsContext() + + You can add default assumptions: + + >>> from sympy.abc import x + >>> global_assumptions.add(Q.real(x)) + >>> global_assumptions + AssumptionsContext({Q.real(x)}) + >>> ask(Q.real(x)) + True + + And remove them: + + >>> global_assumptions.remove(Q.real(x)) + >>> print(ask(Q.real(x))) + None + + The ``clear()`` method removes every assumption: + + >>> global_assumptions.add(Q.positive(x)) + >>> global_assumptions + AssumptionsContext({Q.positive(x)}) + >>> global_assumptions.clear() + >>> global_assumptions + AssumptionsContext() + + See Also + ======== + + assuming + + """ + + def add(self, *assumptions): + """Add assumptions.""" + for a in assumptions: + super().add(a) + + def _sympystr(self, printer): + if not self: + return "%s()" % self.__class__.__name__ + return "{}({})".format(self.__class__.__name__, printer._print_set(self)) + +global_assumptions = AssumptionsContext() + + +class AppliedPredicate(Boolean): + """ + The class of expressions resulting from applying ``Predicate`` to + the arguments. ``AppliedPredicate`` merely wraps its argument and + remain unevaluated. To evaluate it, use the ``ask()`` function. + + Examples + ======== + + >>> from sympy import Q, ask + >>> Q.integer(1) + Q.integer(1) + + The ``function`` attribute returns the predicate, and the ``arguments`` + attribute returns the tuple of arguments. + + >>> type(Q.integer(1)) + + >>> Q.integer(1).function + Q.integer + >>> Q.integer(1).arguments + (1,) + + Applied predicates can be evaluated to a boolean value with ``ask``: + + >>> ask(Q.integer(1)) + True + + """ + __slots__ = () + + def __new__(cls, predicate, *args): + if not isinstance(predicate, Predicate): + raise TypeError("%s is not a Predicate." % predicate) + args = map(_sympify, args) + return super().__new__(cls, predicate, *args) + + @property + def arg(self): + """ + Return the expression used by this assumption. + + Examples + ======== + + >>> from sympy import Q, Symbol + >>> x = Symbol('x') + >>> a = Q.integer(x + 1) + >>> a.arg + x + 1 + + """ + # Will be deprecated + args = self._args + if len(args) == 2: + # backwards compatibility + return args[1] + raise TypeError("'arg' property is allowed only for unary predicates.") + + @property + def function(self): + """ + Return the predicate. + """ + # Will be changed to self.args[0] after args overriding is removed + return self._args[0] + + @property + def arguments(self): + """ + Return the arguments which are applied to the predicate. + """ + # Will be changed to self.args[1:] after args overriding is removed + return self._args[1:] + + def _eval_ask(self, assumptions): + return self.function.eval(self.arguments, assumptions) + + @property + def binary_symbols(self): + from .ask import Q + if self.function == Q.is_true: + i = self.arguments[0] + if i.is_Boolean or i.is_Symbol: + return i.binary_symbols + if self.function in (Q.eq, Q.ne): + if true in self.arguments or false in self.arguments: + if self.arguments[0].is_Symbol: + return {self.arguments[0]} + elif self.arguments[1].is_Symbol: + return {self.arguments[1]} + return set() + + +class PredicateMeta(type): + def __new__(cls, clsname, bases, dct): + # If handler is not defined, assign empty dispatcher. + if "handler" not in dct: + name = f"Ask{clsname.capitalize()}Handler" + handler = Dispatcher(name, doc="Handler for key %s" % name) + dct["handler"] = handler + + dct["_orig_doc"] = dct.get("__doc__", "") + + return super().__new__(cls, clsname, bases, dct) + + @property + def __doc__(cls): + handler = cls.handler + doc = cls._orig_doc + if cls is not Predicate and handler is not None: + doc += "Handler\n" + doc += " =======\n\n" + + # Append the handler's doc without breaking sphinx documentation. + docs = [" Multiply dispatched method: %s" % handler.name] + if handler.doc: + for line in handler.doc.splitlines(): + if not line: + continue + docs.append(" %s" % line) + other = [] + for sig in handler.ordering[::-1]: + func = handler.funcs[sig] + if func.__doc__: + s = ' Inputs: <%s>' % str_signature(sig) + lines = [] + for line in func.__doc__.splitlines(): + lines.append(" %s" % line) + s += "\n".join(lines) + docs.append(s) + else: + other.append(str_signature(sig)) + if other: + othersig = " Other signatures:" + for line in other: + othersig += "\n * %s" % line + docs.append(othersig) + + doc += '\n\n'.join(docs) + + return doc + + +class Predicate(Boolean, metaclass=PredicateMeta): + """ + Base class for mathematical predicates. It also serves as a + constructor for undefined predicate objects. + + Explanation + =========== + + Predicate is a function that returns a boolean value [1]. + + Predicate function is object, and it is instance of predicate class. + When a predicate is applied to arguments, ``AppliedPredicate`` + instance is returned. This merely wraps the argument and remain + unevaluated. To obtain the truth value of applied predicate, use the + function ``ask``. + + Evaluation of predicate is done by multiple dispatching. You can + register new handler to the predicate to support new types. + + Every predicate in SymPy can be accessed via the property of ``Q``. + For example, ``Q.even`` returns the predicate which checks if the + argument is even number. + + To define a predicate which can be evaluated, you must subclass this + class, make an instance of it, and register it to ``Q``. After then, + dispatch the handler by argument types. + + If you directly construct predicate using this class, you will get + ``UndefinedPredicate`` which cannot be dispatched. This is useful + when you are building boolean expressions which do not need to be + evaluated. + + Examples + ======== + + Applying and evaluating to boolean value: + + >>> from sympy import Q, ask + >>> ask(Q.prime(7)) + True + + You can define a new predicate by subclassing and dispatching. Here, + we define a predicate for sexy primes [2] as an example. + + >>> from sympy import Predicate, Integer + >>> class SexyPrimePredicate(Predicate): + ... name = "sexyprime" + >>> Q.sexyprime = SexyPrimePredicate() + >>> @Q.sexyprime.register(Integer, Integer) + ... def _(int1, int2, assumptions): + ... args = sorted([int1, int2]) + ... if not all(ask(Q.prime(a), assumptions) for a in args): + ... return False + ... return args[1] - args[0] == 6 + >>> ask(Q.sexyprime(5, 11)) + True + + Direct constructing returns ``UndefinedPredicate``, which can be + applied but cannot be dispatched. + + >>> from sympy import Predicate, Integer + >>> Q.P = Predicate("P") + >>> type(Q.P) + + >>> Q.P(1) + Q.P(1) + >>> Q.P.register(Integer)(lambda expr, assump: True) + Traceback (most recent call last): + ... + TypeError: cannot be dispatched. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Predicate_%28mathematical_logic%29 + .. [2] https://en.wikipedia.org/wiki/Sexy_prime + + """ + + is_Atom = True + + def __new__(cls, *args, **kwargs): + if cls is Predicate: + return UndefinedPredicate(*args, **kwargs) + obj = super().__new__(cls, *args) + return obj + + @property + def name(self): + # May be overridden + return type(self).__name__ + + @classmethod + def register(cls, *types, **kwargs): + """ + Register the signature to the handler. + """ + if cls.handler is None: + raise TypeError("%s cannot be dispatched." % type(cls)) + return cls.handler.register(*types, **kwargs) + + @classmethod + def register_many(cls, *types, **kwargs): + """ + Register multiple signatures to same handler. + """ + def _(func): + for t in types: + if not is_sequence(t): + t = (t,) # for convenience, allow passing `type` to mean `(type,)` + cls.register(*t, **kwargs)(func) + return _ + + def __call__(self, *args): + return AppliedPredicate(self, *args) + + def eval(self, args, assumptions=True): + """ + Evaluate ``self(*args)`` under the given assumptions. + + This uses only direct resolution methods, not logical inference. + """ + result = None + try: + result = self.handler(*args, assumptions=assumptions) + except NotImplementedError: + pass + return result + + def _eval_refine(self, assumptions): + # When Predicate is no longer Boolean, delete this method + return self + + +class UndefinedPredicate(Predicate): + """ + Predicate without handler. + + Explanation + =========== + + This predicate is generated by using ``Predicate`` directly for + construction. It does not have a handler, and evaluating this with + arguments is done by SAT solver. + + Examples + ======== + + >>> from sympy import Predicate, Q + >>> Q.P = Predicate('P') + >>> Q.P.func + + >>> Q.P.name + Str('P') + + """ + + handler = None + + def __new__(cls, name, handlers=None): + # "handlers" parameter supports old design + if not isinstance(name, Str): + name = Str(name) + obj = super(Boolean, cls).__new__(cls, name) + obj.handlers = handlers or [] + return obj + + @property + def name(self): + return self.args[0] + + def _hashable_content(self): + return (self.name,) + + def __getnewargs__(self): + return (self.name,) + + def __call__(self, expr): + return AppliedPredicate(self, expr) + + def add_handler(self, handler): + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Predicate.add_handler() + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + self.handlers.append(handler) + + def remove_handler(self, handler): + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Predicate.remove_handler() + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + self.handlers.remove(handler) + + def eval(self, args, assumptions=True): + # Support for deprecated design + # When old design is removed, this will always return None + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Evaluating UndefinedPredicate + objects should be replaced with the multipledispatch handler of + Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + stacklevel=5, + ) + expr, = args + res, _res = None, None + mro = inspect.getmro(type(expr)) + for handler in self.handlers: + cls = get_class(handler) + for subclass in mro: + eval_ = getattr(cls, subclass.__name__, None) + if eval_ is None: + continue + res = eval_(expr, assumptions) + # Do not stop if value returned is None + # Try to check for higher classes + if res is None: + continue + if _res is None: + _res = res + else: + # only check consistency if both resolutors have concluded + if _res != res: + raise ValueError('incompatible resolutors') + break + return res + + +@contextmanager +def assuming(*assumptions): + """ + Context manager for assumptions. + + Examples + ======== + + >>> from sympy import assuming, Q, ask + >>> from sympy.abc import x, y + >>> print(ask(Q.integer(x + y))) + None + >>> with assuming(Q.integer(x), Q.integer(y)): + ... print(ask(Q.integer(x + y))) + True + """ + old_global_assumptions = global_assumptions.copy() + global_assumptions.update(assumptions) + try: + yield + finally: + global_assumptions.clear() + global_assumptions.update(old_global_assumptions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/cnf.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/cnf.py new file mode 100644 index 0000000000000000000000000000000000000000..a95d27bed6eeb64c42f4edd9d49bd8e5753069e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/cnf.py @@ -0,0 +1,445 @@ +""" +The classes used here are for the internal use of assumptions system +only and should not be used anywhere else as these do not possess the +signatures common to SymPy objects. For general use of logic constructs +please refer to sympy.logic classes And, Or, Not, etc. +""" +from itertools import combinations, product, zip_longest +from sympy.assumptions.assume import AppliedPredicate, Predicate +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.core.singleton import S +from sympy.logic.boolalg import Or, And, Not, Xnor +from sympy.logic.boolalg import (Equivalent, ITE, Implies, Nand, Nor, Xor) + + +class Literal: + """ + The smallest element of a CNF object. + + Parameters + ========== + + lit : Boolean expression + + is_Not : bool + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import Literal + >>> from sympy.abc import x + >>> Literal(Q.even(x)) + Literal(Q.even(x), False) + >>> Literal(~Q.even(x)) + Literal(Q.even(x), True) + """ + + def __new__(cls, lit, is_Not=False): + if isinstance(lit, Not): + lit = lit.args[0] + is_Not = True + elif isinstance(lit, (AND, OR, Literal)): + return ~lit if is_Not else lit + obj = super().__new__(cls) + obj.lit = lit + obj.is_Not = is_Not + return obj + + @property + def arg(self): + return self.lit + + def rcall(self, expr): + if callable(self.lit): + lit = self.lit(expr) + else: + lit = self.lit.apply(expr) + return type(self)(lit, self.is_Not) + + def __invert__(self): + is_Not = not self.is_Not + return Literal(self.lit, is_Not) + + def __str__(self): + return '{}({}, {})'.format(type(self).__name__, self.lit, self.is_Not) + + __repr__ = __str__ + + def __eq__(self, other): + return self.arg == other.arg and self.is_Not == other.is_Not + + def __hash__(self): + h = hash((type(self).__name__, self.arg, self.is_Not)) + return h + + +class OR: + """ + A low-level implementation for Or + """ + def __init__(self, *args): + self._args = args + + @property + def args(self): + return sorted(self._args, key=str) + + def rcall(self, expr): + return type(self)(*[arg.rcall(expr) + for arg in self._args + ]) + + def __invert__(self): + return AND(*[~arg for arg in self._args]) + + def __hash__(self): + return hash((type(self).__name__,) + tuple(self.args)) + + def __eq__(self, other): + return self.args == other.args + + def __str__(self): + s = '(' + ' | '.join([str(arg) for arg in self.args]) + ')' + return s + + __repr__ = __str__ + + +class AND: + """ + A low-level implementation for And + """ + def __init__(self, *args): + self._args = args + + def __invert__(self): + return OR(*[~arg for arg in self._args]) + + @property + def args(self): + return sorted(self._args, key=str) + + def rcall(self, expr): + return type(self)(*[arg.rcall(expr) + for arg in self._args + ]) + + def __hash__(self): + return hash((type(self).__name__,) + tuple(self.args)) + + def __eq__(self, other): + return self.args == other.args + + def __str__(self): + s = '('+' & '.join([str(arg) for arg in self.args])+')' + return s + + __repr__ = __str__ + + +def to_NNF(expr, composite_map=None): + """ + Generates the Negation Normal Form of any boolean expression in terms + of AND, OR, and Literal objects. + + Examples + ======== + + >>> from sympy import Q, Eq + >>> from sympy.assumptions.cnf import to_NNF + >>> from sympy.abc import x, y + >>> expr = Q.even(x) & ~Q.positive(x) + >>> to_NNF(expr) + (Literal(Q.even(x), False) & Literal(Q.positive(x), True)) + + Supported boolean objects are converted to corresponding predicates. + + >>> to_NNF(Eq(x, y)) + Literal(Q.eq(x, y), False) + + If ``composite_map`` argument is given, ``to_NNF`` decomposes the + specified predicate into a combination of primitive predicates. + + >>> cmap = {Q.nonpositive: Q.negative | Q.zero} + >>> to_NNF(Q.nonpositive, cmap) + (Literal(Q.negative, False) | Literal(Q.zero, False)) + >>> to_NNF(Q.nonpositive(x), cmap) + (Literal(Q.negative(x), False) | Literal(Q.zero(x), False)) + """ + from sympy.assumptions.ask import Q + + if composite_map is None: + composite_map = {} + + + binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + if type(expr) in binrelpreds: + pred = binrelpreds[type(expr)] + expr = pred(*expr.args) + + if isinstance(expr, Not): + arg = expr.args[0] + tmp = to_NNF(arg, composite_map) # Strategy: negate the NNF of expr + return ~tmp + + if isinstance(expr, Or): + return OR(*[to_NNF(x, composite_map) for x in Or.make_args(expr)]) + + if isinstance(expr, And): + return AND(*[to_NNF(x, composite_map) for x in And.make_args(expr)]) + + if isinstance(expr, Nand): + tmp = AND(*[to_NNF(x, composite_map) for x in expr.args]) + return ~tmp + + if isinstance(expr, Nor): + tmp = OR(*[to_NNF(x, composite_map) for x in expr.args]) + return ~tmp + + if isinstance(expr, Xor): + cnfs = [] + for i in range(0, len(expr.args) + 1, 2): + for neg in combinations(expr.args, i): + clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map) + for s in expr.args] + cnfs.append(OR(*clause)) + return AND(*cnfs) + + if isinstance(expr, Xnor): + cnfs = [] + for i in range(0, len(expr.args) + 1, 2): + for neg in combinations(expr.args, i): + clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map) + for s in expr.args] + cnfs.append(OR(*clause)) + return ~AND(*cnfs) + + if isinstance(expr, Implies): + L, R = to_NNF(expr.args[0], composite_map), to_NNF(expr.args[1], composite_map) + return OR(~L, R) + + if isinstance(expr, Equivalent): + cnfs = [] + for a, b in zip_longest(expr.args, expr.args[1:], fillvalue=expr.args[0]): + a = to_NNF(a, composite_map) + b = to_NNF(b, composite_map) + cnfs.append(OR(~a, b)) + return AND(*cnfs) + + if isinstance(expr, ITE): + L = to_NNF(expr.args[0], composite_map) + M = to_NNF(expr.args[1], composite_map) + R = to_NNF(expr.args[2], composite_map) + return AND(OR(~L, M), OR(L, R)) + + if isinstance(expr, AppliedPredicate): + pred, args = expr.function, expr.arguments + newpred = composite_map.get(pred, None) + if newpred is not None: + return to_NNF(newpred.rcall(*args), composite_map) + + if isinstance(expr, Predicate): + newpred = composite_map.get(expr, None) + if newpred is not None: + return to_NNF(newpred, composite_map) + + return Literal(expr) + + +def distribute_AND_over_OR(expr): + """ + Distributes AND over OR in the NNF expression. + Returns the result( Conjunctive Normal Form of expression) + as a CNF object. + """ + if not isinstance(expr, (AND, OR)): + tmp = set() + tmp.add(frozenset((expr,))) + return CNF(tmp) + + if isinstance(expr, OR): + return CNF.all_or(*[distribute_AND_over_OR(arg) + for arg in expr._args]) + + if isinstance(expr, AND): + return CNF.all_and(*[distribute_AND_over_OR(arg) + for arg in expr._args]) + + +class CNF: + """ + Class to represent CNF of a Boolean expression. + Consists of set of clauses, which themselves are stored as + frozenset of Literal objects. + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.abc import x + >>> cnf = CNF.from_prop(Q.real(x) & ~Q.zero(x)) + >>> cnf.clauses + {frozenset({Literal(Q.zero(x), True)}), + frozenset({Literal(Q.negative(x), False), + Literal(Q.positive(x), False), Literal(Q.zero(x), False)})} + """ + def __init__(self, clauses=None): + if not clauses: + clauses = set() + self.clauses = clauses + + def add(self, prop): + clauses = CNF.to_CNF(prop).clauses + self.add_clauses(clauses) + + def __str__(self): + s = ' & '.join( + ['(' + ' | '.join([str(lit) for lit in clause]) +')' + for clause in self.clauses] + ) + return s + + def extend(self, props): + for p in props: + self.add(p) + return self + + def copy(self): + return CNF(set(self.clauses)) + + def add_clauses(self, clauses): + self.clauses |= clauses + + @classmethod + def from_prop(cls, prop): + res = cls() + res.add(prop) + return res + + def __iand__(self, other): + self.add_clauses(other.clauses) + return self + + def all_predicates(self): + predicates = set() + for c in self.clauses: + predicates |= {arg.lit for arg in c} + return predicates + + def _or(self, cnf): + clauses = set() + for a, b in product(self.clauses, cnf.clauses): + tmp = set(a) + tmp.update(b) + clauses.add(frozenset(tmp)) + return CNF(clauses) + + def _and(self, cnf): + clauses = self.clauses.union(cnf.clauses) + return CNF(clauses) + + def _not(self): + clss = list(self.clauses) + ll = {frozenset((~x,)) for x in clss[-1]} + ll = CNF(ll) + + for rest in clss[:-1]: + p = {frozenset((~x,)) for x in rest} + ll = ll._or(CNF(p)) + return ll + + def rcall(self, expr): + clause_list = [] + for clause in self.clauses: + lits = [arg.rcall(expr) for arg in clause] + clause_list.append(OR(*lits)) + expr = AND(*clause_list) + return distribute_AND_over_OR(expr) + + @classmethod + def all_or(cls, *cnfs): + b = cnfs[0].copy() + for rest in cnfs[1:]: + b = b._or(rest) + return b + + @classmethod + def all_and(cls, *cnfs): + b = cnfs[0].copy() + for rest in cnfs[1:]: + b = b._and(rest) + return b + + @classmethod + def to_CNF(cls, expr): + from sympy.assumptions.facts import get_composite_predicates + expr = to_NNF(expr, get_composite_predicates()) + expr = distribute_AND_over_OR(expr) + return expr + + @classmethod + def CNF_to_cnf(cls, cnf): + """ + Converts CNF object to SymPy's boolean expression + retaining the form of expression. + """ + def remove_literal(arg): + return Not(arg.lit) if arg.is_Not else arg.lit + + return And(*(Or(*(remove_literal(arg) for arg in clause)) for clause in cnf.clauses)) + + +class EncodedCNF: + """ + Class for encoding the CNF expression. + """ + def __init__(self, data=None, encoding=None): + if not data and not encoding: + data = [] + encoding = {} + self.data = data + self.encoding = encoding + self._symbols = list(encoding.keys()) + + def from_cnf(self, cnf): + self._symbols = list(cnf.all_predicates()) + n = len(self._symbols) + self.encoding = dict(zip(self._symbols, range(1, n + 1))) + self.data = [self.encode(clause) for clause in cnf.clauses] + + @property + def symbols(self): + return self._symbols + + @property + def variables(self): + return range(1, len(self._symbols) + 1) + + def copy(self): + new_data = [set(clause) for clause in self.data] + return EncodedCNF(new_data, dict(self.encoding)) + + def add_prop(self, prop): + cnf = CNF.from_prop(prop) + self.add_from_cnf(cnf) + + def add_from_cnf(self, cnf): + clauses = [self.encode(clause) for clause in cnf.clauses] + self.data += clauses + + def encode_arg(self, arg): + literal = arg.lit + value = self.encoding.get(literal, None) + if value is None: + n = len(self._symbols) + self._symbols.append(literal) + value = self.encoding[literal] = n + 1 + if arg.is_Not: + return -value + else: + return value + + def encode(self, clause): + return {self.encode_arg(arg) if not arg.lit == S.false else 0 for arg in clause} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/facts.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/facts.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff268677cf74e252ac6c3bc3eecbea08b9414d0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/facts.py @@ -0,0 +1,270 @@ +""" +Known facts in assumptions module. + +This module defines the facts between unary predicates in ``get_known_facts()``, +and supports functions to generate the contents in +``sympy.assumptions.ask_generated`` file. +""" + +from sympy.assumptions.ask import Q +from sympy.assumptions.assume import AppliedPredicate +from sympy.core.cache import cacheit +from sympy.core.symbol import Symbol +from sympy.logic.boolalg import (to_cnf, And, Not, Implies, Equivalent, + Exclusive,) +from sympy.logic.inference import satisfiable + + +@cacheit +def get_composite_predicates(): + # To reduce the complexity of sat solver, these predicates are + # transformed into the combination of primitive predicates. + return { + Q.real : Q.negative | Q.zero | Q.positive, + Q.integer : Q.even | Q.odd, + Q.nonpositive : Q.negative | Q.zero, + Q.nonzero : Q.negative | Q.positive, + Q.nonnegative : Q.zero | Q.positive, + Q.extended_real : Q.negative_infinite | Q.negative | Q.zero | Q.positive | Q.positive_infinite, + Q.extended_positive: Q.positive | Q.positive_infinite, + Q.extended_negative: Q.negative | Q.negative_infinite, + Q.extended_nonzero: Q.negative_infinite | Q.negative | Q.positive | Q.positive_infinite, + Q.extended_nonpositive: Q.negative_infinite | Q.negative | Q.zero, + Q.extended_nonnegative: Q.zero | Q.positive | Q.positive_infinite, + Q.complex : Q.algebraic | Q.transcendental + } + + +@cacheit +def get_known_facts(x=None): + """ + Facts between unary predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + get_number_facts(x), + get_matrix_facts(x) + ) + return fact + + +@cacheit +def get_number_facts(x = None): + """ + Facts between unary number predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + # primitive predicates for extended real exclude each other. + Exclusive(Q.negative_infinite(x), Q.negative(x), Q.zero(x), + Q.positive(x), Q.positive_infinite(x)), + + # build complex plane + Exclusive(Q.real(x), Q.imaginary(x)), + Implies(Q.real(x) | Q.imaginary(x), Q.complex(x)), + + # other subsets of complex + Exclusive(Q.transcendental(x), Q.algebraic(x)), + Equivalent(Q.real(x), Q.rational(x) | Q.irrational(x)), + Exclusive(Q.irrational(x), Q.rational(x)), + Implies(Q.rational(x), Q.algebraic(x)), + + # integers + Exclusive(Q.even(x), Q.odd(x)), + Implies(Q.integer(x), Q.rational(x)), + Implies(Q.zero(x), Q.even(x)), + Exclusive(Q.composite(x), Q.prime(x)), + Implies(Q.composite(x) | Q.prime(x), Q.integer(x) & Q.positive(x)), + Implies(Q.even(x) & Q.positive(x) & ~Q.prime(x), Q.composite(x)), + + # hermitian and antihermitian + Implies(Q.real(x), Q.hermitian(x)), + Implies(Q.imaginary(x), Q.antihermitian(x)), + Implies(Q.zero(x), Q.hermitian(x) | Q.antihermitian(x)), + + # define finity and infinity, and build extended real line + Exclusive(Q.infinite(x), Q.finite(x)), + Implies(Q.complex(x), Q.finite(x)), + Implies(Q.negative_infinite(x) | Q.positive_infinite(x), Q.infinite(x)), + + # commutativity + Implies(Q.finite(x) | Q.infinite(x), Q.commutative(x)), + ) + return fact + + +@cacheit +def get_matrix_facts(x = None): + """ + Facts between unary matrix predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + # matrices + Implies(Q.orthogonal(x), Q.positive_definite(x)), + Implies(Q.orthogonal(x), Q.unitary(x)), + Implies(Q.unitary(x) & Q.real_elements(x), Q.orthogonal(x)), + Implies(Q.unitary(x), Q.normal(x)), + Implies(Q.unitary(x), Q.invertible(x)), + Implies(Q.normal(x), Q.square(x)), + Implies(Q.diagonal(x), Q.normal(x)), + Implies(Q.positive_definite(x), Q.invertible(x)), + Implies(Q.diagonal(x), Q.upper_triangular(x)), + Implies(Q.diagonal(x), Q.lower_triangular(x)), + Implies(Q.lower_triangular(x), Q.triangular(x)), + Implies(Q.upper_triangular(x), Q.triangular(x)), + Implies(Q.triangular(x), Q.upper_triangular(x) | Q.lower_triangular(x)), + Implies(Q.upper_triangular(x) & Q.lower_triangular(x), Q.diagonal(x)), + Implies(Q.diagonal(x), Q.symmetric(x)), + Implies(Q.unit_triangular(x), Q.triangular(x)), + Implies(Q.invertible(x), Q.fullrank(x)), + Implies(Q.invertible(x), Q.square(x)), + Implies(Q.symmetric(x), Q.square(x)), + Implies(Q.fullrank(x) & Q.square(x), Q.invertible(x)), + Equivalent(Q.invertible(x), ~Q.singular(x)), + Implies(Q.integer_elements(x), Q.real_elements(x)), + Implies(Q.real_elements(x), Q.complex_elements(x)), + ) + return fact + + + +def generate_known_facts_dict(keys, fact): + """ + Computes and returns a dictionary which contains the relations between + unary predicates. + + Each key is a predicate, and item is two groups of predicates. + First group contains the predicates which are implied by the key, and + second group contains the predicates which are rejected by the key. + + All predicates in *keys* and *fact* must be unary and have same placeholder + symbol. + + Parameters + ========== + + keys : list of AppliedPredicate instances. + + fact : Fact between predicates in conjugated normal form. + + Examples + ======== + + >>> from sympy import Q, And, Implies + >>> from sympy.assumptions.facts import generate_known_facts_dict + >>> from sympy.abc import x + >>> keys = [Q.even(x), Q.odd(x), Q.zero(x)] + >>> fact = And(Implies(Q.even(x), ~Q.odd(x)), + ... Implies(Q.zero(x), Q.even(x))) + >>> generate_known_facts_dict(keys, fact) + {Q.even: ({Q.even}, {Q.odd}), + Q.odd: ({Q.odd}, {Q.even, Q.zero}), + Q.zero: ({Q.even, Q.zero}, {Q.odd})} + """ + fact_cnf = to_cnf(fact) + mapping = single_fact_lookup(keys, fact_cnf) + + ret = {} + for key, value in mapping.items(): + implied = set() + rejected = set() + for expr in value: + if isinstance(expr, AppliedPredicate): + implied.add(expr.function) + elif isinstance(expr, Not): + pred = expr.args[0] + rejected.add(pred.function) + ret[key.function] = (implied, rejected) + return ret + + +@cacheit +def get_known_facts_keys(): + """ + Return every unary predicates registered to ``Q``. + + This function is used to generate the keys for + ``generate_known_facts_dict``. + + """ + # exclude polyadic predicates + exclude = {Q.eq, Q.ne, Q.gt, Q.lt, Q.ge, Q.le} + + result = [] + for attr in Q.__class__.__dict__: + if attr.startswith('__'): + continue + pred = getattr(Q, attr) + if pred in exclude: + continue + result.append(pred) + return result + + +def single_fact_lookup(known_facts_keys, known_facts_cnf): + # Return the dictionary for quick lookup of single fact + mapping = {} + for key in known_facts_keys: + mapping[key] = {key} + for other_key in known_facts_keys: + if other_key != key: + if ask_full_inference(other_key, key, known_facts_cnf): + mapping[key].add(other_key) + if ask_full_inference(~other_key, key, known_facts_cnf): + mapping[key].add(~other_key) + return mapping + + +def ask_full_inference(proposition, assumptions, known_facts_cnf): + """ + Method for inferring properties about objects. + + """ + if not satisfiable(And(known_facts_cnf, assumptions, proposition)): + return False + if not satisfiable(And(known_facts_cnf, assumptions, Not(proposition))): + return True + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbe618eb8b43e252ac8fb0baf1eeee22bf347cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__init__.py @@ -0,0 +1,13 @@ +""" +Multipledispatch handlers for ``Predicate`` are implemented here. +Handlers in this module are not directly imported to other modules in +order to avoid circular import problem. +""" + +from .common import (AskHandler, CommonHandler, + test_closed_group) + +__all__ = [ + 'AskHandler', 'CommonHandler', + 'test_closed_group' +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55ef1fa46f9fe20361f6e7e39fb1d04ae663b4d7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/calculus.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/calculus.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b27e033fcb5da04deb059db529ccc8cf727f55c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/calculus.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..355665805d7825730e536d2ad981e958365bf1ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/matrices.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/matrices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2d25bffc3598765aec21402fa9abeca9fda69ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/matrices.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/ntheory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/ntheory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55035d53cf1e3d972620bb3ee5069153abb448a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/ntheory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/order.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/order.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471184faf80770adaecb3905b89fe68916047fed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/order.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/sets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/sets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941285f0fb3176b052735616b8e67e9577ff8d67 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/__pycache__/sets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/calculus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/calculus.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b9c43ccea216988a25aa671ea23bc81d2209ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/calculus.py @@ -0,0 +1,273 @@ +""" +This module contains query handlers responsible for calculus queries: +infinitesimal, finite, etc. +""" + +from sympy.assumptions import Q, ask +from sympy.core import Expr, Add, Mul, Pow, Symbol +from sympy.core.numbers import (NegativeInfinity, GoldenRatio, + Infinity, Exp1, ComplexInfinity, ImaginaryUnit, NaN, Number, Pi, E, + TribonacciConstant) +from sympy.functions import cos, exp, log, sign, sin +from sympy.logic.boolalg import conjuncts + +from ..predicates.calculus import (FinitePredicate, InfinitePredicate, + PositiveInfinitePredicate, NegativeInfinitePredicate) + + +# FinitePredicate + + +@FinitePredicate.register(Symbol) +def _(expr, assumptions): + """ + Handles Symbol. + """ + if expr.is_finite is not None: + return expr.is_finite + if Q.finite(expr) in conjuncts(assumptions): + return True + return None + +@FinitePredicate.register(Add) +def _(expr, assumptions): + """ + Return True if expr is bounded, False if not and None if unknown. + + Truth Table: + + +-------+-----+-----------+-----------+ + | | | | | + | | B | U | ? | + | | | | | + +-------+-----+---+---+---+---+---+---+ + | | | | | | | | | + | | |'+'|'-'|'x'|'+'|'-'|'x'| + | | | | | | | | | + +-------+-----+---+---+---+---+---+---+ + | | | | | + | B | B | U | ? | + | | | | | + +---+---+-----+---+---+---+---+---+---+ + | | | | | | | | | | + | |'+'| | U | ? | ? | U | ? | ? | + | | | | | | | | | | + | +---+-----+---+---+---+---+---+---+ + | | | | | | | | | | + | U |'-'| | ? | U | ? | ? | U | ? | + | | | | | | | | | | + | +---+-----+---+---+---+---+---+---+ + | | | | | | + | |'x'| | ? | ? | + | | | | | | + +---+---+-----+---+---+---+---+---+---+ + | | | | | + | ? | | | ? | + | | | | | + +-------+-----+-----------+---+---+---+ + + * 'B' = Bounded + + * 'U' = Unbounded + + * '?' = unknown boundedness + + * '+' = positive sign + + * '-' = negative sign + + * 'x' = sign unknown + + * All Bounded -> True + + * 1 Unbounded and the rest Bounded -> False + + * >1 Unbounded, all with same known sign -> False + + * Any Unknown and unknown sign -> None + + * Else -> None + + When the signs are not the same you can have an undefined + result as in oo - oo, hence 'bounded' is also undefined. + """ + sign = -1 # sign of unknown or infinite + result = True + for arg in expr.args: + _bounded = ask(Q.finite(arg), assumptions) + if _bounded: + continue + s = ask(Q.extended_positive(arg), assumptions) + # if there has been more than one sign or if the sign of this arg + # is None and Bounded is None or there was already + # an unknown sign, return None + if sign != -1 and s != sign or \ + s is None and None in (_bounded, sign): + return None + else: + sign = s + # once False, do not change + if result is not False: + result = _bounded + return result + +@FinitePredicate.register(Mul) +def _(expr, assumptions): + """ + Return True if expr is bounded, False if not and None if unknown. + + Truth Table: + + +---+---+---+--------+ + | | | | | + | | B | U | ? | + | | | | | + +---+---+---+---+----+ + | | | | | | + | | | | s | /s | + | | | | | | + +---+---+---+---+----+ + | | | | | + | B | B | U | ? | + | | | | | + +---+---+---+---+----+ + | | | | | | + | U | | U | U | ? | + | | | | | | + +---+---+---+---+----+ + | | | | | + | ? | | | ? | + | | | | | + +---+---+---+---+----+ + + * B = Bounded + + * U = Unbounded + + * ? = unknown boundedness + + * s = signed (hence nonzero) + + * /s = not signed + """ + result = True + possible_zero = False + for arg in expr.args: + _bounded = ask(Q.finite(arg), assumptions) + if _bounded: + if ask(Q.zero(arg), assumptions) is not False: + if result is False: + return None + possible_zero = True + elif _bounded is None: + if result is None: + return None + if ask(Q.extended_nonzero(arg), assumptions) is None: + return None + if result is not False: + result = None + else: + if possible_zero: + return None + result = False + return result + +@FinitePredicate.register(Pow) +def _(expr, assumptions): + """ + * Unbounded ** NonZero -> Unbounded + + * Bounded ** Bounded -> Bounded + + * Abs()<=1 ** Positive -> Bounded + + * Abs()>=1 ** Negative -> Bounded + + * Otherwise unknown + """ + if expr.base == E: + return ask(Q.finite(expr.exp), assumptions) + + base_bounded = ask(Q.finite(expr.base), assumptions) + exp_bounded = ask(Q.finite(expr.exp), assumptions) + if base_bounded is None and exp_bounded is None: # Common Case + return None + if base_bounded is False and ask(Q.extended_nonzero(expr.exp), assumptions): + return False + if base_bounded and exp_bounded: + is_base_zero = ask(Q.zero(expr.base),assumptions) + is_exp_negative = ask(Q.negative(expr.exp),assumptions) + if is_base_zero is True and is_exp_negative is True: + return False + if is_base_zero is not False and is_exp_negative is not False: + return None + return True + if (abs(expr.base) <= 1) == True and ask(Q.extended_positive(expr.exp), assumptions): + return True + if (abs(expr.base) >= 1) == True and ask(Q.extended_negative(expr.exp), assumptions): + return True + if (abs(expr.base) >= 1) == True and exp_bounded is False: + return False + return None + +@FinitePredicate.register(exp) +def _(expr, assumptions): + return ask(Q.finite(expr.exp), assumptions) + +@FinitePredicate.register(log) +def _(expr, assumptions): + # After complex -> finite fact is registered to new assumption system, + # querying Q.infinite may be removed. + if ask(Q.infinite(expr.args[0]), assumptions): + return False + return ask(~Q.zero(expr.args[0]), assumptions) + +@FinitePredicate.register_many(cos, sin, Number, Pi, Exp1, GoldenRatio, + TribonacciConstant, ImaginaryUnit, sign) +def _(expr, assumptions): + return True + +@FinitePredicate.register_many(ComplexInfinity, Infinity, NegativeInfinity) +def _(expr, assumptions): + return False + +@FinitePredicate.register(NaN) +def _(expr, assumptions): + return None + + +# InfinitePredicate + + +@InfinitePredicate.register(Expr) +def _(expr, assumptions): + is_finite = Q.finite(expr)._eval_ask(assumptions) + if is_finite is None: + return None + return not is_finite + + +# PositiveInfinitePredicate + + +@PositiveInfinitePredicate.register(Infinity) +def _(expr, assumptions): + return True + + +@PositiveInfinitePredicate.register_many(NegativeInfinity, ComplexInfinity) +def _(expr, assumptions): + return False + + +# NegativeInfinitePredicate + + +@NegativeInfinitePredicate.register(NegativeInfinity) +def _(expr, assumptions): + return True + + +@NegativeInfinitePredicate.register_many(Infinity, ComplexInfinity) +def _(expr, assumptions): + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e9f6f321be461c09b16c03b9cec5708404d21a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/common.py @@ -0,0 +1,164 @@ +""" +This module defines base class for handlers and some core handlers: +``Q.commutative`` and ``Q.is_true``. +""" + +from sympy.assumptions import Q, ask, AppliedPredicate +from sympy.core import Basic, Symbol +from sympy.core.logic import _fuzzy_group, fuzzy_and, fuzzy_or +from sympy.core.numbers import NaN, Number +from sympy.logic.boolalg import (And, BooleanTrue, BooleanFalse, conjuncts, + Equivalent, Implies, Not, Or) +from sympy.utilities.exceptions import sympy_deprecation_warning + +from ..predicates.common import CommutativePredicate, IsTruePredicate + + +class AskHandler: + """Base class that all Ask Handlers must inherit.""" + def __new__(cls, *args, **kwargs): + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. The AskHandler class should + be replaced with the multipledispatch handler of Predicate + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + return super().__new__(cls, *args, **kwargs) + + +class CommonHandler(AskHandler): + # Deprecated + """Defines some useful methods common to most Handlers. """ + + @staticmethod + def AlwaysTrue(expr, assumptions): + return True + + @staticmethod + def AlwaysFalse(expr, assumptions): + return False + + @staticmethod + def AlwaysNone(expr, assumptions): + return None + + NaN = AlwaysFalse + + +# CommutativePredicate + +@CommutativePredicate.register(Symbol) +def _(expr, assumptions): + """Objects are expected to be commutative unless otherwise stated""" + assumps = conjuncts(assumptions) + if expr.is_commutative is not None: + return expr.is_commutative and not ~Q.commutative(expr) in assumps + if Q.commutative(expr) in assumps: + return True + elif ~Q.commutative(expr) in assumps: + return False + return True + +@CommutativePredicate.register(Basic) +def _(expr, assumptions): + for arg in expr.args: + if not ask(Q.commutative(arg), assumptions): + return False + return True + +@CommutativePredicate.register(Number) +def _(expr, assumptions): + return True + +@CommutativePredicate.register(NaN) +def _(expr, assumptions): + return True + + +# IsTruePredicate + +@IsTruePredicate.register(bool) +def _(expr, assumptions): + return expr + +@IsTruePredicate.register(BooleanTrue) +def _(expr, assumptions): + return True + +@IsTruePredicate.register(BooleanFalse) +def _(expr, assumptions): + return False + +@IsTruePredicate.register(AppliedPredicate) +def _(expr, assumptions): + return ask(expr, assumptions) + +@IsTruePredicate.register(Not) +def _(expr, assumptions): + arg = expr.args[0] + if arg.is_Symbol: + # symbol used as abstract boolean object + return None + value = ask(arg, assumptions=assumptions) + if value in (True, False): + return not value + else: + return None + +@IsTruePredicate.register(Or) +def _(expr, assumptions): + result = False + for arg in expr.args: + p = ask(arg, assumptions=assumptions) + if p is True: + return True + if p is None: + result = None + return result + +@IsTruePredicate.register(And) +def _(expr, assumptions): + result = True + for arg in expr.args: + p = ask(arg, assumptions=assumptions) + if p is False: + return False + if p is None: + result = None + return result + +@IsTruePredicate.register(Implies) +def _(expr, assumptions): + p, q = expr.args + return ask(~p | q, assumptions=assumptions) + +@IsTruePredicate.register(Equivalent) +def _(expr, assumptions): + p, q = expr.args + pt = ask(p, assumptions=assumptions) + if pt is None: + return None + qt = ask(q, assumptions=assumptions) + if qt is None: + return None + return pt == qt + + +#### Helper methods +def test_closed_group(expr, assumptions, key): + """ + Test for membership in a group with respect + to the current operation. + """ + return _fuzzy_group( + (ask(key(a), assumptions) for a in expr.args), quick_exit=True) + +def ask_all(*queries, assumptions): + return fuzzy_and( + (ask(query, assumptions) for query in queries)) + +def ask_any(*queries, assumptions): + return fuzzy_or( + (ask(query, assumptions) for query in queries)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/matrices.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..3b20385360136629ea037eb7238c45b70ba57fd2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/matrices.py @@ -0,0 +1,716 @@ +""" +This module contains query handlers responsible for Matrices queries: +Square, Symmetric, Invertible etc. +""" + +from sympy.logic.boolalg import conjuncts +from sympy.assumptions import Q, ask +from sympy.assumptions.handlers import test_closed_group +from sympy.matrices import MatrixBase +from sympy.matrices.expressions import (BlockMatrix, BlockDiagMatrix, Determinant, + DiagMatrix, DiagonalMatrix, HadamardProduct, Identity, Inverse, MatAdd, MatMul, + MatPow, MatrixExpr, MatrixSlice, MatrixSymbol, OneMatrix, Trace, Transpose, + ZeroMatrix) +from sympy.matrices.expressions.blockmatrix import reblock_2x2 +from sympy.matrices.expressions.factorizations import Factorization +from sympy.matrices.expressions.fourier import DFT +from sympy.core.logic import fuzzy_and +from sympy.utilities.iterables import sift +from sympy.core import Basic + +from ..predicates.matrices import (SquarePredicate, SymmetricPredicate, + InvertiblePredicate, OrthogonalPredicate, UnitaryPredicate, + FullRankPredicate, PositiveDefinitePredicate, UpperTriangularPredicate, + LowerTriangularPredicate, DiagonalPredicate, IntegerElementsPredicate, + RealElementsPredicate, ComplexElementsPredicate) + + +def _Factorization(predicate, expr, assumptions): + if predicate in expr.predicates: + return True + + +# SquarePredicate + +@SquarePredicate.register(MatrixExpr) +def _(expr, assumptions): + return expr.shape[0] == expr.shape[1] + + +# SymmetricPredicate + +@SymmetricPredicate.register(MatMul) +def _(expr, assumptions): + factor, mmul = expr.as_coeff_mmul() + if all(ask(Q.symmetric(arg), assumptions) for arg in mmul.args): + return True + # TODO: implement sathandlers system for the matrices. + # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric). + if ask(Q.diagonal(expr), assumptions): + return True + if len(mmul.args) >= 2 and mmul.args[0] == mmul.args[-1].T: + if len(mmul.args) == 2: + return True + return ask(Q.symmetric(MatMul(*mmul.args[1:-1])), assumptions) + +@SymmetricPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.symmetric(base), assumptions) + return None + +@SymmetricPredicate.register(MatAdd) +def _(expr, assumptions): + return all(ask(Q.symmetric(arg), assumptions) for arg in expr.args) + +@SymmetricPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if not expr.is_square: + return False + # TODO: implement sathandlers system for the matrices. + # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric). + if ask(Q.diagonal(expr), assumptions): + return True + if Q.symmetric(expr) in conjuncts(assumptions): + return True + +@SymmetricPredicate.register_many(OneMatrix, ZeroMatrix) +def _(expr, assumptions): + return ask(Q.square(expr), assumptions) + +@SymmetricPredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.symmetric(expr.arg), assumptions) + +@SymmetricPredicate.register(MatrixSlice) +def _(expr, assumptions): + # TODO: implement sathandlers system for the matrices. + # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric). + if ask(Q.diagonal(expr), assumptions): + return True + if not expr.on_diag: + return None + else: + return ask(Q.symmetric(expr.parent), assumptions) + +@SymmetricPredicate.register(Identity) +def _(expr, assumptions): + return True + + +# InvertiblePredicate + +@InvertiblePredicate.register(MatMul) +def _(expr, assumptions): + factor, mmul = expr.as_coeff_mmul() + if all(ask(Q.invertible(arg), assumptions) for arg in mmul.args): + return True + if any(ask(Q.invertible(arg), assumptions) is False + for arg in mmul.args): + return False + +@InvertiblePredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + if exp.is_negative == False: + return ask(Q.invertible(base), assumptions) + return None + +@InvertiblePredicate.register(MatAdd) +def _(expr, assumptions): + return None + +@InvertiblePredicate.register(MatrixSymbol) +def _(expr, assumptions): + if not expr.is_square: + return False + if Q.invertible(expr) in conjuncts(assumptions): + return True + +@InvertiblePredicate.register_many(Identity, Inverse) +def _(expr, assumptions): + return True + +@InvertiblePredicate.register(ZeroMatrix) +def _(expr, assumptions): + return False + +@InvertiblePredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@InvertiblePredicate.register(Transpose) +def _(expr, assumptions): + return ask(Q.invertible(expr.arg), assumptions) + +@InvertiblePredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.invertible(expr.parent), assumptions) + +@InvertiblePredicate.register(MatrixBase) +def _(expr, assumptions): + if not expr.is_square: + return False + return expr.rank() == expr.rows + +@InvertiblePredicate.register(MatrixExpr) +def _(expr, assumptions): + if not expr.is_square: + return False + return None + +@InvertiblePredicate.register(BlockMatrix) +def _(expr, assumptions): + if not expr.is_square: + return False + if expr.blockshape == (1, 1): + return ask(Q.invertible(expr.blocks[0, 0]), assumptions) + expr = reblock_2x2(expr) + if expr.blockshape == (2, 2): + [[A, B], [C, D]] = expr.blocks.tolist() + if ask(Q.invertible(A), assumptions) == True: + invertible = ask(Q.invertible(D - C * A.I * B), assumptions) + if invertible is not None: + return invertible + if ask(Q.invertible(B), assumptions) == True: + invertible = ask(Q.invertible(C - D * B.I * A), assumptions) + if invertible is not None: + return invertible + if ask(Q.invertible(C), assumptions) == True: + invertible = ask(Q.invertible(B - A * C.I * D), assumptions) + if invertible is not None: + return invertible + if ask(Q.invertible(D), assumptions) == True: + invertible = ask(Q.invertible(A - B * D.I * C), assumptions) + if invertible is not None: + return invertible + return None + +@InvertiblePredicate.register(BlockDiagMatrix) +def _(expr, assumptions): + if expr.rowblocksizes != expr.colblocksizes: + return None + return fuzzy_and([ask(Q.invertible(a), assumptions) for a in expr.diag]) + + +# OrthogonalPredicate + +@OrthogonalPredicate.register(MatMul) +def _(expr, assumptions): + factor, mmul = expr.as_coeff_mmul() + if (all(ask(Q.orthogonal(arg), assumptions) for arg in mmul.args) and + factor == 1): + return True + if any(ask(Q.invertible(arg), assumptions) is False + for arg in mmul.args): + return False + +@OrthogonalPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if int_exp: + return ask(Q.orthogonal(base), assumptions) + return None + +@OrthogonalPredicate.register(MatAdd) +def _(expr, assumptions): + if (len(expr.args) == 1 and + ask(Q.orthogonal(expr.args[0]), assumptions)): + return True + +@OrthogonalPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if (not expr.is_square or + ask(Q.invertible(expr), assumptions) is False): + return False + if Q.orthogonal(expr) in conjuncts(assumptions): + return True + +@OrthogonalPredicate.register(Identity) +def _(expr, assumptions): + return True + +@OrthogonalPredicate.register(ZeroMatrix) +def _(expr, assumptions): + return False + +@OrthogonalPredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.orthogonal(expr.arg), assumptions) + +@OrthogonalPredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.orthogonal(expr.parent), assumptions) + +@OrthogonalPredicate.register(Factorization) +def _(expr, assumptions): + return _Factorization(Q.orthogonal, expr, assumptions) + + +# UnitaryPredicate + +@UnitaryPredicate.register(MatMul) +def _(expr, assumptions): + factor, mmul = expr.as_coeff_mmul() + if (all(ask(Q.unitary(arg), assumptions) for arg in mmul.args) and + abs(factor) == 1): + return True + if any(ask(Q.invertible(arg), assumptions) is False + for arg in mmul.args): + return False + +@UnitaryPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if int_exp: + return ask(Q.unitary(base), assumptions) + return None + +@UnitaryPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if (not expr.is_square or + ask(Q.invertible(expr), assumptions) is False): + return False + if Q.unitary(expr) in conjuncts(assumptions): + return True + +@UnitaryPredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.unitary(expr.arg), assumptions) + +@UnitaryPredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.unitary(expr.parent), assumptions) + +@UnitaryPredicate.register_many(DFT, Identity) +def _(expr, assumptions): + return True + +@UnitaryPredicate.register(ZeroMatrix) +def _(expr, assumptions): + return False + +@UnitaryPredicate.register(Factorization) +def _(expr, assumptions): + return _Factorization(Q.unitary, expr, assumptions) + + +# FullRankPredicate + +@FullRankPredicate.register(MatMul) +def _(expr, assumptions): + if all(ask(Q.fullrank(arg), assumptions) for arg in expr.args): + return True + +@FullRankPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if int_exp and ask(~Q.negative(exp), assumptions): + return ask(Q.fullrank(base), assumptions) + return None + +@FullRankPredicate.register(Identity) +def _(expr, assumptions): + return True + +@FullRankPredicate.register(ZeroMatrix) +def _(expr, assumptions): + return False + +@FullRankPredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@FullRankPredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.fullrank(expr.arg), assumptions) + +@FullRankPredicate.register(MatrixSlice) +def _(expr, assumptions): + if ask(Q.orthogonal(expr.parent), assumptions): + return True + + +# PositiveDefinitePredicate + +@PositiveDefinitePredicate.register(MatMul) +def _(expr, assumptions): + factor, mmul = expr.as_coeff_mmul() + if (all(ask(Q.positive_definite(arg), assumptions) + for arg in mmul.args) and factor > 0): + return True + if (len(mmul.args) >= 2 + and mmul.args[0] == mmul.args[-1].T + and ask(Q.fullrank(mmul.args[0]), assumptions)): + return ask(Q.positive_definite( + MatMul(*mmul.args[1:-1])), assumptions) + +@PositiveDefinitePredicate.register(MatPow) +def _(expr, assumptions): + # a power of a positive definite matrix is positive definite + if ask(Q.positive_definite(expr.args[0]), assumptions): + return True + +@PositiveDefinitePredicate.register(MatAdd) +def _(expr, assumptions): + if all(ask(Q.positive_definite(arg), assumptions) + for arg in expr.args): + return True + +@PositiveDefinitePredicate.register(MatrixSymbol) +def _(expr, assumptions): + if not expr.is_square: + return False + if Q.positive_definite(expr) in conjuncts(assumptions): + return True + +@PositiveDefinitePredicate.register(Identity) +def _(expr, assumptions): + return True + +@PositiveDefinitePredicate.register(ZeroMatrix) +def _(expr, assumptions): + return False + +@PositiveDefinitePredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@PositiveDefinitePredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.positive_definite(expr.arg), assumptions) + +@PositiveDefinitePredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.positive_definite(expr.parent), assumptions) + + +# UpperTriangularPredicate + +@UpperTriangularPredicate.register(MatMul) +def _(expr, assumptions): + factor, matrices = expr.as_coeff_matrices() + if all(ask(Q.upper_triangular(m), assumptions) for m in matrices): + return True + +@UpperTriangularPredicate.register(MatAdd) +def _(expr, assumptions): + if all(ask(Q.upper_triangular(arg), assumptions) for arg in expr.args): + return True + +@UpperTriangularPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.upper_triangular(base), assumptions) + return None + +@UpperTriangularPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if Q.upper_triangular(expr) in conjuncts(assumptions): + return True + +@UpperTriangularPredicate.register_many(Identity, ZeroMatrix) +def _(expr, assumptions): + return True + +@UpperTriangularPredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@UpperTriangularPredicate.register(Transpose) +def _(expr, assumptions): + return ask(Q.lower_triangular(expr.arg), assumptions) + +@UpperTriangularPredicate.register(Inverse) +def _(expr, assumptions): + return ask(Q.upper_triangular(expr.arg), assumptions) + +@UpperTriangularPredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.upper_triangular(expr.parent), assumptions) + +@UpperTriangularPredicate.register(Factorization) +def _(expr, assumptions): + return _Factorization(Q.upper_triangular, expr, assumptions) + +# LowerTriangularPredicate + +@LowerTriangularPredicate.register(MatMul) +def _(expr, assumptions): + factor, matrices = expr.as_coeff_matrices() + if all(ask(Q.lower_triangular(m), assumptions) for m in matrices): + return True + +@LowerTriangularPredicate.register(MatAdd) +def _(expr, assumptions): + if all(ask(Q.lower_triangular(arg), assumptions) for arg in expr.args): + return True + +@LowerTriangularPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.lower_triangular(base), assumptions) + return None + +@LowerTriangularPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if Q.lower_triangular(expr) in conjuncts(assumptions): + return True + +@LowerTriangularPredicate.register_many(Identity, ZeroMatrix) +def _(expr, assumptions): + return True + +@LowerTriangularPredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@LowerTriangularPredicate.register(Transpose) +def _(expr, assumptions): + return ask(Q.upper_triangular(expr.arg), assumptions) + +@LowerTriangularPredicate.register(Inverse) +def _(expr, assumptions): + return ask(Q.lower_triangular(expr.arg), assumptions) + +@LowerTriangularPredicate.register(MatrixSlice) +def _(expr, assumptions): + if not expr.on_diag: + return None + else: + return ask(Q.lower_triangular(expr.parent), assumptions) + +@LowerTriangularPredicate.register(Factorization) +def _(expr, assumptions): + return _Factorization(Q.lower_triangular, expr, assumptions) + + +# DiagonalPredicate + +def _is_empty_or_1x1(expr): + return expr.shape in ((0, 0), (1, 1)) + +@DiagonalPredicate.register(MatMul) +def _(expr, assumptions): + if _is_empty_or_1x1(expr): + return True + factor, matrices = expr.as_coeff_matrices() + if all(ask(Q.diagonal(m), assumptions) for m in matrices): + return True + +@DiagonalPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.diagonal(base), assumptions) + return None + +@DiagonalPredicate.register(MatAdd) +def _(expr, assumptions): + if all(ask(Q.diagonal(arg), assumptions) for arg in expr.args): + return True + +@DiagonalPredicate.register(MatrixSymbol) +def _(expr, assumptions): + if _is_empty_or_1x1(expr): + return True + if Q.diagonal(expr) in conjuncts(assumptions): + return True + +@DiagonalPredicate.register(OneMatrix) +def _(expr, assumptions): + return expr.shape[0] == 1 and expr.shape[1] == 1 + +@DiagonalPredicate.register_many(Inverse, Transpose) +def _(expr, assumptions): + return ask(Q.diagonal(expr.arg), assumptions) + +@DiagonalPredicate.register(MatrixSlice) +def _(expr, assumptions): + if _is_empty_or_1x1(expr): + return True + if not expr.on_diag: + return None + else: + return ask(Q.diagonal(expr.parent), assumptions) + +@DiagonalPredicate.register_many(DiagonalMatrix, DiagMatrix, Identity, ZeroMatrix) +def _(expr, assumptions): + return True + +@DiagonalPredicate.register(Factorization) +def _(expr, assumptions): + return _Factorization(Q.diagonal, expr, assumptions) + + +# IntegerElementsPredicate + +def BM_elements(predicate, expr, assumptions): + """ Block Matrix elements. """ + return all(ask(predicate(b), assumptions) for b in expr.blocks) + +def MS_elements(predicate, expr, assumptions): + """ Matrix Slice elements. """ + return ask(predicate(expr.parent), assumptions) + +def MatMul_elements(matrix_predicate, scalar_predicate, expr, assumptions): + d = sift(expr.args, lambda x: isinstance(x, MatrixExpr)) + factors, matrices = d[False], d[True] + return fuzzy_and([ + test_closed_group(Basic(*factors), assumptions, scalar_predicate), + test_closed_group(Basic(*matrices), assumptions, matrix_predicate)]) + + +@IntegerElementsPredicate.register_many(Determinant, HadamardProduct, MatAdd, + Trace, Transpose) +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.integer_elements) + +@IntegerElementsPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + if exp.is_negative == False: + return ask(Q.integer_elements(base), assumptions) + return None + +@IntegerElementsPredicate.register_many(Identity, OneMatrix, ZeroMatrix) +def _(expr, assumptions): + return True + +@IntegerElementsPredicate.register(MatMul) +def _(expr, assumptions): + return MatMul_elements(Q.integer_elements, Q.integer, expr, assumptions) + +@IntegerElementsPredicate.register(MatrixSlice) +def _(expr, assumptions): + return MS_elements(Q.integer_elements, expr, assumptions) + +@IntegerElementsPredicate.register(BlockMatrix) +def _(expr, assumptions): + return BM_elements(Q.integer_elements, expr, assumptions) + + +# RealElementsPredicate + +@RealElementsPredicate.register_many(Determinant, Factorization, HadamardProduct, + MatAdd, Trace, Transpose) +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.real_elements) + +@RealElementsPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.real_elements(base), assumptions) + return None + +@RealElementsPredicate.register(MatMul) +def _(expr, assumptions): + return MatMul_elements(Q.real_elements, Q.real, expr, assumptions) + +@RealElementsPredicate.register(MatrixSlice) +def _(expr, assumptions): + return MS_elements(Q.real_elements, expr, assumptions) + +@RealElementsPredicate.register(BlockMatrix) +def _(expr, assumptions): + return BM_elements(Q.real_elements, expr, assumptions) + + +# ComplexElementsPredicate + +@ComplexElementsPredicate.register_many(Determinant, Factorization, HadamardProduct, + Inverse, MatAdd, Trace, Transpose) +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.complex_elements) + +@ComplexElementsPredicate.register(MatPow) +def _(expr, assumptions): + # only for integer powers + base, exp = expr.args + int_exp = ask(Q.integer(exp), assumptions) + if not int_exp: + return None + non_negative = ask(~Q.negative(exp), assumptions) + if (non_negative or non_negative == False + and ask(Q.invertible(base), assumptions)): + return ask(Q.complex_elements(base), assumptions) + return None + +@ComplexElementsPredicate.register(MatMul) +def _(expr, assumptions): + return MatMul_elements(Q.complex_elements, Q.complex, expr, assumptions) + +@ComplexElementsPredicate.register(MatrixSlice) +def _(expr, assumptions): + return MS_elements(Q.complex_elements, expr, assumptions) + +@ComplexElementsPredicate.register(BlockMatrix) +def _(expr, assumptions): + return BM_elements(Q.complex_elements, expr, assumptions) + +@ComplexElementsPredicate.register(DFT) +def _(expr, assumptions): + return True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/ntheory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe63ba6467ea6863c6112c5e35bb3a78191a23e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/ntheory.py @@ -0,0 +1,279 @@ +""" +Handlers for keys related to number theory: prime, even, odd, etc. +""" + +from sympy.assumptions import Q, ask +from sympy.core import Add, Basic, Expr, Float, Mul, Pow, S +from sympy.core.numbers import (ImaginaryUnit, Infinity, Integer, NaN, + NegativeInfinity, NumberSymbol, Rational, int_valued) +from sympy.functions import Abs, im, re +from sympy.ntheory import isprime + +from sympy.multipledispatch import MDNotImplementedError + +from ..predicates.ntheory import (PrimePredicate, CompositePredicate, + EvenPredicate, OddPredicate) + + +# PrimePredicate + +def _PrimePredicate_number(expr, assumptions): + # helper method + exact = not expr.atoms(Float) + try: + i = int(expr.round()) + if (expr - i).equals(0) is False: + raise TypeError + except TypeError: + return False + if exact: + return isprime(i) + # when not exact, we won't give a True or False + # since the number represents an approximate value + +@PrimePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_prime + if ret is None: + raise MDNotImplementedError + return ret + +@PrimePredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + return _PrimePredicate_number(expr, assumptions) + +@PrimePredicate.register(Mul) +def _(expr, assumptions): + if expr.is_number: + return _PrimePredicate_number(expr, assumptions) + for arg in expr.args: + if not ask(Q.integer(arg), assumptions): + return None + for arg in expr.args: + if arg.is_number and arg.is_composite: + return False + +@PrimePredicate.register(Pow) +def _(expr, assumptions): + """ + Integer**Integer -> !Prime + """ + if expr.is_number: + return _PrimePredicate_number(expr, assumptions) + if ask(Q.integer(expr.exp), assumptions) and \ + ask(Q.integer(expr.base), assumptions): + prime_base = ask(Q.prime(expr.base), assumptions) + if prime_base is False: + return False + is_exp_one = ask(Q.eq(expr.exp, 1), assumptions) + if is_exp_one is False: + return False + if prime_base is True and is_exp_one is True: + return True + +@PrimePredicate.register(Integer) +def _(expr, assumptions): + return isprime(expr) + +@PrimePredicate.register_many(Rational, Infinity, NegativeInfinity, ImaginaryUnit) +def _(expr, assumptions): + return False + +@PrimePredicate.register(Float) +def _(expr, assumptions): + return _PrimePredicate_number(expr, assumptions) + +@PrimePredicate.register(NumberSymbol) +def _(expr, assumptions): + return _PrimePredicate_number(expr, assumptions) + +@PrimePredicate.register(NaN) +def _(expr, assumptions): + return None + + +# CompositePredicate + +@CompositePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_composite + if ret is None: + raise MDNotImplementedError + return ret + +@CompositePredicate.register(Basic) +def _(expr, assumptions): + _positive = ask(Q.positive(expr), assumptions) + if _positive: + _integer = ask(Q.integer(expr), assumptions) + if _integer: + _prime = ask(Q.prime(expr), assumptions) + if _prime is None: + return + # Positive integer which is not prime is not + # necessarily composite + _is_one = ask(Q.eq(expr, 1), assumptions) + if _is_one: + return False + if _is_one is None: + return None + return not _prime + else: + return _integer + else: + return _positive + + +# EvenPredicate + +def _EvenPredicate_number(expr, assumptions): + # helper method + if isinstance(expr, (float, Float)): + if int_valued(expr): + return None + return False + try: + i = int(expr.round()) + except TypeError: + return False + if not (expr - i).equals(0): + return False + return i % 2 == 0 + +@EvenPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_even + if ret is None: + raise MDNotImplementedError + return ret + +@EvenPredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + return _EvenPredicate_number(expr, assumptions) + +@EvenPredicate.register(Mul) +def _(expr, assumptions): + """ + Even * Integer -> Even + Even * Odd -> Even + Integer * Odd -> ? + Odd * Odd -> Odd + Even * Even -> Even + Integer * Integer -> Even if Integer + Integer = Odd + otherwise -> ? + """ + if expr.is_number: + return _EvenPredicate_number(expr, assumptions) + even, odd, irrational, acc = False, 0, False, 1 + for arg in expr.args: + # check for all integers and at least one even + if ask(Q.integer(arg), assumptions): + if ask(Q.even(arg), assumptions): + even = True + elif ask(Q.odd(arg), assumptions): + odd += 1 + elif not even and acc != 1: + if ask(Q.odd(acc + arg), assumptions): + even = True + elif ask(Q.irrational(arg), assumptions): + # one irrational makes the result False + # two makes it undefined + if irrational: + break + irrational = True + else: + break + acc = arg + else: + if irrational: + return False + if even: + return True + if odd == len(expr.args): + return False + +@EvenPredicate.register(Add) +def _(expr, assumptions): + """ + Even + Odd -> Odd + Even + Even -> Even + Odd + Odd -> Even + + """ + if expr.is_number: + return _EvenPredicate_number(expr, assumptions) + _result = True + for arg in expr.args: + if ask(Q.even(arg), assumptions): + pass + elif ask(Q.odd(arg), assumptions): + _result = not _result + else: + break + else: + return _result + +@EvenPredicate.register(Pow) +def _(expr, assumptions): + if expr.is_number: + return _EvenPredicate_number(expr, assumptions) + if ask(Q.integer(expr.exp), assumptions): + if ask(Q.positive(expr.exp), assumptions): + return ask(Q.even(expr.base), assumptions) + elif ask(~Q.negative(expr.exp) & Q.odd(expr.base), assumptions): + return False + elif expr.base is S.NegativeOne: + return False + +@EvenPredicate.register(Integer) +def _(expr, assumptions): + return not bool(expr.p & 1) + +@EvenPredicate.register_many(Rational, Infinity, NegativeInfinity, ImaginaryUnit) +def _(expr, assumptions): + return False + +@EvenPredicate.register(NumberSymbol) +def _(expr, assumptions): + return _EvenPredicate_number(expr, assumptions) + +@EvenPredicate.register(Abs) +def _(expr, assumptions): + if ask(Q.real(expr.args[0]), assumptions): + return ask(Q.even(expr.args[0]), assumptions) + +@EvenPredicate.register(re) +def _(expr, assumptions): + if ask(Q.real(expr.args[0]), assumptions): + return ask(Q.even(expr.args[0]), assumptions) + +@EvenPredicate.register(im) +def _(expr, assumptions): + if ask(Q.real(expr.args[0]), assumptions): + return True + +@EvenPredicate.register(NaN) +def _(expr, assumptions): + return None + + +# OddPredicate + +@OddPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_odd + if ret is None: + raise MDNotImplementedError + return ret + +@OddPredicate.register(Basic) +def _(expr, assumptions): + _integer = ask(Q.integer(expr), assumptions) + if _integer: + _even = ask(Q.even(expr), assumptions) + if _even is None: + return None + return not _even + return _integer diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/order.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/order.py new file mode 100644 index 0000000000000000000000000000000000000000..24a8bae7f30777f62a1bec0579d58b9875143679 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/order.py @@ -0,0 +1,440 @@ +""" +Handlers related to order relations: positive, negative, etc. +""" + +from sympy.assumptions import Q, ask +from sympy.core import Add, Basic, Expr, Mul, Pow, S +from sympy.core.logic import fuzzy_not, fuzzy_and, fuzzy_or +from sympy.core.numbers import E, ImaginaryUnit, NaN, I, pi +from sympy.functions import Abs, acos, acot, asin, atan, exp, factorial, log +from sympy.matrices import Determinant, Trace +from sympy.matrices.expressions.matexpr import MatrixElement + +from sympy.multipledispatch import MDNotImplementedError + +from ..predicates.order import (NegativePredicate, NonNegativePredicate, + NonZeroPredicate, ZeroPredicate, NonPositivePredicate, PositivePredicate, + ExtendedNegativePredicate, ExtendedNonNegativePredicate, + ExtendedNonPositivePredicate, ExtendedNonZeroPredicate, + ExtendedPositivePredicate,) + + +# NegativePredicate + +def _NegativePredicate_number(expr, assumptions): + r, i = expr.as_real_imag() + + if r == S.NaN or i == S.NaN: + return None + + # If the imaginary part can symbolically be shown to be zero then + # we just evaluate the real part; otherwise we evaluate the imaginary + # part to see if it actually evaluates to zero and if it does then + # we make the comparison between the real part and zero. + if not i: + r = r.evalf(2) + if r._prec != 1: + return r < 0 + else: + i = i.evalf(2) + if i._prec != 1: + if i != 0: + return False + r = r.evalf(2) + if r._prec != 1: + return r < 0 + +@NegativePredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + return _NegativePredicate_number(expr, assumptions) + +@NegativePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_negative + if ret is None: + raise MDNotImplementedError + return ret + +@NegativePredicate.register(Add) +def _(expr, assumptions): + """ + Positive + Positive -> Positive, + Negative + Negative -> Negative + """ + if expr.is_number: + return _NegativePredicate_number(expr, assumptions) + + r = ask(Q.real(expr), assumptions) + if r is not True: + return r + + nonpos = 0 + for arg in expr.args: + if ask(Q.negative(arg), assumptions) is not True: + if ask(Q.positive(arg), assumptions) is False: + nonpos += 1 + else: + break + else: + if nonpos < len(expr.args): + return True + +@NegativePredicate.register(Mul) +def _(expr, assumptions): + if expr.is_number: + return _NegativePredicate_number(expr, assumptions) + result = None + for arg in expr.args: + if result is None: + result = False + if ask(Q.negative(arg), assumptions): + result = not result + elif ask(Q.positive(arg), assumptions): + pass + else: + return + return result + +@NegativePredicate.register(Pow) +def _(expr, assumptions): + """ + Real ** Even -> NonNegative + Real ** Odd -> same_as_base + NonNegative ** Positive -> NonNegative + """ + if expr.base == E: + # Exponential is always positive: + if ask(Q.real(expr.exp), assumptions): + return False + return + + if expr.is_number: + return _NegativePredicate_number(expr, assumptions) + if ask(Q.real(expr.base), assumptions): + if ask(Q.positive(expr.base), assumptions): + if ask(Q.real(expr.exp), assumptions): + return False + if ask(Q.even(expr.exp), assumptions): + return False + if ask(Q.odd(expr.exp), assumptions): + return ask(Q.negative(expr.base), assumptions) + +@NegativePredicate.register_many(Abs, ImaginaryUnit) +def _(expr, assumptions): + return False + +@NegativePredicate.register(exp) +def _(expr, assumptions): + if ask(Q.real(expr.exp), assumptions): + return False + raise MDNotImplementedError + + +# NonNegativePredicate + +@NonNegativePredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + notnegative = fuzzy_not(_NegativePredicate_number(expr, assumptions)) + if notnegative: + return ask(Q.real(expr), assumptions) + else: + return notnegative + +@NonNegativePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_nonnegative + if ret is None: + raise MDNotImplementedError + return ret + + +# NonZeroPredicate + +@NonZeroPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_nonzero + if ret is None: + raise MDNotImplementedError + return ret + +@NonZeroPredicate.register(Basic) +def _(expr, assumptions): + if ask(Q.real(expr)) is False: + return False + if expr.is_number: + # if there are no symbols just evalf + i = expr.evalf(2) + def nonz(i): + if i._prec != 1: + return i != 0 + return fuzzy_or(nonz(i) for i in i.as_real_imag()) + +@NonZeroPredicate.register(Add) +def _(expr, assumptions): + if all(ask(Q.positive(x), assumptions) for x in expr.args) \ + or all(ask(Q.negative(x), assumptions) for x in expr.args): + return True + +@NonZeroPredicate.register(Mul) +def _(expr, assumptions): + for arg in expr.args: + result = ask(Q.nonzero(arg), assumptions) + if result: + continue + return result + return True + +@NonZeroPredicate.register(Pow) +def _(expr, assumptions): + return ask(Q.nonzero(expr.base), assumptions) + +@NonZeroPredicate.register(Abs) +def _(expr, assumptions): + return ask(Q.nonzero(expr.args[0]), assumptions) + +@NonZeroPredicate.register(NaN) +def _(expr, assumptions): + return None + + +# ZeroPredicate + +@ZeroPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_zero + if ret is None: + raise MDNotImplementedError + return ret + +@ZeroPredicate.register(Basic) +def _(expr, assumptions): + return fuzzy_and([fuzzy_not(ask(Q.nonzero(expr), assumptions)), + ask(Q.real(expr), assumptions)]) + +@ZeroPredicate.register(Mul) +def _(expr, assumptions): + # TODO: This should be deducible from the nonzero handler + return fuzzy_or(ask(Q.zero(arg), assumptions) for arg in expr.args) + + +# NonPositivePredicate + +@NonPositivePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_nonpositive + if ret is None: + raise MDNotImplementedError + return ret + +@NonPositivePredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + notpositive = fuzzy_not(_PositivePredicate_number(expr, assumptions)) + if notpositive: + return ask(Q.real(expr), assumptions) + else: + return notpositive + + +# PositivePredicate + +def _PositivePredicate_number(expr, assumptions): + r, i = expr.as_real_imag() + # If the imaginary part can symbolically be shown to be zero then + # we just evaluate the real part; otherwise we evaluate the imaginary + # part to see if it actually evaluates to zero and if it does then + # we make the comparison between the real part and zero. + if not i: + r = r.evalf(2) + if r._prec != 1: + return r > 0 + else: + i = i.evalf(2) + if i._prec != 1: + if i != 0: + return False + r = r.evalf(2) + if r._prec != 1: + return r > 0 + +@PositivePredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_positive + if ret is None: + raise MDNotImplementedError + return ret + +@PositivePredicate.register(Basic) +def _(expr, assumptions): + if expr.is_number: + return _PositivePredicate_number(expr, assumptions) + +@PositivePredicate.register(Mul) +def _(expr, assumptions): + if expr.is_number: + return _PositivePredicate_number(expr, assumptions) + result = True + for arg in expr.args: + if ask(Q.positive(arg), assumptions): + continue + elif ask(Q.negative(arg), assumptions): + result = result ^ True + else: + return + return result + +@PositivePredicate.register(Add) +def _(expr, assumptions): + if expr.is_number: + return _PositivePredicate_number(expr, assumptions) + + r = ask(Q.real(expr), assumptions) + if r is not True: + return r + + nonneg = 0 + for arg in expr.args: + if ask(Q.positive(arg), assumptions) is not True: + if ask(Q.negative(arg), assumptions) is False: + nonneg += 1 + else: + break + else: + if nonneg < len(expr.args): + return True + +@PositivePredicate.register(Pow) +def _(expr, assumptions): + if expr.base == E: + if ask(Q.real(expr.exp), assumptions): + return True + if ask(Q.imaginary(expr.exp), assumptions): + return ask(Q.even(expr.exp/(I*pi)), assumptions) + return + + if expr.is_number: + return _PositivePredicate_number(expr, assumptions) + if ask(Q.positive(expr.base), assumptions): + if ask(Q.real(expr.exp), assumptions): + return True + if ask(Q.negative(expr.base), assumptions): + if ask(Q.even(expr.exp), assumptions): + return True + if ask(Q.odd(expr.exp), assumptions): + return False + +@PositivePredicate.register(exp) +def _(expr, assumptions): + if ask(Q.real(expr.exp), assumptions): + return True + if ask(Q.imaginary(expr.exp), assumptions): + return ask(Q.even(expr.exp/(I*pi)), assumptions) + +@PositivePredicate.register(log) +def _(expr, assumptions): + r = ask(Q.real(expr.args[0]), assumptions) + if r is not True: + return r + if ask(Q.positive(expr.args[0] - 1), assumptions): + return True + if ask(Q.negative(expr.args[0] - 1), assumptions): + return False + +@PositivePredicate.register(factorial) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.integer(x) & Q.positive(x), assumptions): + return True + +@PositivePredicate.register(ImaginaryUnit) +def _(expr, assumptions): + return False + +@PositivePredicate.register(Abs) +def _(expr, assumptions): + return ask(Q.nonzero(expr), assumptions) + +@PositivePredicate.register(Trace) +def _(expr, assumptions): + if ask(Q.positive_definite(expr.arg), assumptions): + return True + +@PositivePredicate.register(Determinant) +def _(expr, assumptions): + if ask(Q.positive_definite(expr.arg), assumptions): + return True + +@PositivePredicate.register(MatrixElement) +def _(expr, assumptions): + if (expr.i == expr.j + and ask(Q.positive_definite(expr.parent), assumptions)): + return True + +@PositivePredicate.register(atan) +def _(expr, assumptions): + return ask(Q.positive(expr.args[0]), assumptions) + +@PositivePredicate.register(asin) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.positive(x) & Q.nonpositive(x - 1), assumptions): + return True + if ask(Q.negative(x) & Q.nonnegative(x + 1), assumptions): + return False + +@PositivePredicate.register(acos) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.nonpositive(x - 1) & Q.nonnegative(x + 1), assumptions): + return True + +@PositivePredicate.register(acot) +def _(expr, assumptions): + return ask(Q.real(expr.args[0]), assumptions) + +@PositivePredicate.register(NaN) +def _(expr, assumptions): + return None + + +# ExtendedNegativePredicate + +@ExtendedNegativePredicate.register(object) +def _(expr, assumptions): + return ask(Q.negative(expr) | Q.negative_infinite(expr), assumptions) + + +# ExtendedPositivePredicate + +@ExtendedPositivePredicate.register(object) +def _(expr, assumptions): + return ask(Q.positive(expr) | Q.positive_infinite(expr), assumptions) + + +# ExtendedNonZeroPredicate + +@ExtendedNonZeroPredicate.register(object) +def _(expr, assumptions): + return ask( + Q.negative_infinite(expr) | Q.negative(expr) | Q.positive(expr) | Q.positive_infinite(expr), + assumptions) + + +# ExtendedNonPositivePredicate + +@ExtendedNonPositivePredicate.register(object) +def _(expr, assumptions): + return ask( + Q.negative_infinite(expr) | Q.negative(expr) | Q.zero(expr), + assumptions) + + +# ExtendedNonNegativePredicate + +@ExtendedNonNegativePredicate.register(object) +def _(expr, assumptions): + return ask( + Q.zero(expr) | Q.positive(expr) | Q.positive_infinite(expr), + assumptions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/sets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..7a13ed9bf99c5b0ffc4f32fd55cb60c2c15ab836 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/handlers/sets.py @@ -0,0 +1,816 @@ +""" +Handlers for predicates related to set membership: integer, rational, etc. +""" + +from sympy.assumptions import Q, ask +from sympy.core import Add, Basic, Expr, Mul, Pow, S +from sympy.core.numbers import (AlgebraicNumber, ComplexInfinity, Exp1, Float, + GoldenRatio, ImaginaryUnit, Infinity, Integer, NaN, NegativeInfinity, + Number, NumberSymbol, Pi, pi, Rational, TribonacciConstant, E) +from sympy.core.logic import fuzzy_bool +from sympy.functions import (Abs, acos, acot, asin, atan, cos, cot, exp, im, + log, re, sin, tan) +from sympy.core.numbers import I +from sympy.core.relational import Eq +from sympy.functions.elementary.complexes import conjugate +from sympy.matrices import Determinant, MatrixBase, Trace +from sympy.matrices.expressions.matexpr import MatrixElement + +from sympy.multipledispatch import MDNotImplementedError + +from .common import test_closed_group, ask_all, ask_any +from ..predicates.sets import (IntegerPredicate, RationalPredicate, + IrrationalPredicate, RealPredicate, ExtendedRealPredicate, + HermitianPredicate, ComplexPredicate, ImaginaryPredicate, + AntihermitianPredicate, AlgebraicPredicate) + + +# IntegerPredicate + +def _IntegerPredicate_number(expr, assumptions): + # helper function + try: + i = int(expr.round()) + if not (expr - i).equals(0): + raise TypeError + return True + except TypeError: + return False + +@IntegerPredicate.register_many(int, Integer) # type:ignore +def _(expr, assumptions): + return True + +@IntegerPredicate.register_many(Exp1, GoldenRatio, ImaginaryUnit, Infinity, + NegativeInfinity, Pi, Rational, TribonacciConstant) +def _(expr, assumptions): + return False + +@IntegerPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_integer + if ret is None: + raise MDNotImplementedError + return ret + +@IntegerPredicate.register(Add) +def _(expr, assumptions): + """ + * Integer + Integer -> Integer + * Integer + !Integer -> !Integer + * !Integer + !Integer -> ? + """ + if expr.is_number: + return _IntegerPredicate_number(expr, assumptions) + return test_closed_group(expr, assumptions, Q.integer) + +@IntegerPredicate.register(Pow) +def _(expr,assumptions): + if expr.is_number: + return _IntegerPredicate_number(expr, assumptions) + if ask_all(~Q.zero(expr.base), Q.finite(expr.base), Q.zero(expr.exp), assumptions=assumptions): + return True + if ask_all(Q.integer(expr.base), Q.integer(expr.exp), assumptions=assumptions): + if ask_any(Q.positive(expr.exp), Q.nonnegative(expr.exp) & ~Q.zero(expr.base), Q.zero(expr.base-1), Q.zero(expr.base+1), assumptions=assumptions): + return True + +@IntegerPredicate.register(Mul) +def _(expr, assumptions): + """ + * Integer*Integer -> Integer + * Integer*Irrational -> !Integer + * Odd/Even -> !Integer + * Integer*Rational -> ? + """ + if expr.is_number: + return _IntegerPredicate_number(expr, assumptions) + _output = True + for arg in expr.args: + if not ask(Q.integer(arg), assumptions): + if arg.is_Rational: + if arg.q == 2: + return ask(Q.even(2*expr), assumptions) + if ~(arg.q & 1): + return None + elif ask(Q.irrational(arg), assumptions): + if _output: + _output = False + else: + return + else: + return + + return _output + +@IntegerPredicate.register(Abs) +def _(expr, assumptions): + if ask(Q.integer(expr.args[0]), assumptions): + return True + +@IntegerPredicate.register_many(Determinant, MatrixElement, Trace) +def _(expr, assumptions): + return ask(Q.integer_elements(expr.args[0]), assumptions) + + +# RationalPredicate + +@RationalPredicate.register(Rational) +def _(expr, assumptions): + return True + +@RationalPredicate.register(Float) +def _(expr, assumptions): + return None + +@RationalPredicate.register_many(Exp1, GoldenRatio, ImaginaryUnit, Infinity, + NegativeInfinity, Pi, TribonacciConstant) +def _(expr, assumptions): + return False + +@RationalPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_rational + if ret is None: + raise MDNotImplementedError + return ret + +@RationalPredicate.register_many(Add, Mul) +def _(expr, assumptions): + """ + * Rational + Rational -> Rational + * Rational + !Rational -> !Rational + * !Rational + !Rational -> ? + """ + if expr.is_number: + if expr.as_real_imag()[1]: + return False + return test_closed_group(expr, assumptions, Q.rational) + +@RationalPredicate.register(Pow) +def _(expr, assumptions): + """ + * Rational ** Integer -> Rational + * Irrational ** Rational -> Irrational + * Rational ** Irrational -> ? + """ + if expr.base == E: + x = expr.exp + if ask(Q.rational(x), assumptions): + return ask(Q.zero(x), assumptions) + return + + is_exp_integer = ask(Q.integer(expr.exp), assumptions) + if is_exp_integer: + is_base_rational = ask(Q.rational(expr.base),assumptions) + if is_base_rational: + is_base_zero = ask(Q.zero(expr.base),assumptions) + if is_base_zero is False: + return True + if is_base_zero and ask(Q.positive(expr.exp)): + return True + if ask(Q.algebraic(expr.base),assumptions) is False: + return ask(Q.zero(expr.exp), assumptions) + if ask(Q.irrational(expr.base),assumptions) and ask(Q.eq(expr.exp,-1)): + return False + return + elif ask(Q.rational(expr.exp), assumptions): + if ask(Q.prime(expr.base), assumptions) and is_exp_integer is False: + return False + if ask(Q.zero(expr.base)) and ask(Q.positive(expr.exp)): + return True + if ask(Q.eq(expr.base,1)): + return True + +@RationalPredicate.register_many(asin, atan, cos, sin, tan) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.rational(x), assumptions): + return ask(~Q.nonzero(x), assumptions) + +@RationalPredicate.register(exp) +def _(expr, assumptions): + x = expr.exp + if ask(Q.rational(x), assumptions): + return ask(~Q.nonzero(x), assumptions) + +@RationalPredicate.register_many(acot, cot) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.rational(x), assumptions): + return False + +@RationalPredicate.register_many(acos, log) +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.rational(x), assumptions): + return ask(~Q.nonzero(x - 1), assumptions) + + +# IrrationalPredicate + +@IrrationalPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_irrational + if ret is None: + raise MDNotImplementedError + return ret + +@IrrationalPredicate.register(Basic) +def _(expr, assumptions): + _real = ask(Q.real(expr), assumptions) + if _real: + _rational = ask(Q.rational(expr), assumptions) + if _rational is None: + return None + return not _rational + else: + return _real + + +# RealPredicate + +def _RealPredicate_number(expr, assumptions): + # let as_real_imag() work first since the expression may + # be simpler to evaluate + i = expr.as_real_imag()[1].evalf(2) + if i._prec != 1: + return not i + # allow None to be returned if we couldn't show for sure + # that i was 0 + +@RealPredicate.register_many(Abs, Exp1, Float, GoldenRatio, im, Pi, Rational, + re, TribonacciConstant) +def _(expr, assumptions): + return True + +@RealPredicate.register_many(ImaginaryUnit, Infinity, NegativeInfinity) +def _(expr, assumptions): + return False + +@RealPredicate.register(Expr) +def _(expr, assumptions): + ret = expr.is_real + if ret is None: + raise MDNotImplementedError + return ret + +@RealPredicate.register(Add) +def _(expr, assumptions): + """ + * Real + Real -> Real + * Real + (Complex & !Real) -> !Real + """ + if expr.is_number: + return _RealPredicate_number(expr, assumptions) + return test_closed_group(expr, assumptions, Q.real) + +@RealPredicate.register(Mul) +def _(expr, assumptions): + """ + * Real*Real -> Real + * Real*Imaginary -> !Real + * Imaginary*Imaginary -> Real + """ + if expr.is_number: + return _RealPredicate_number(expr, assumptions) + result = True + for arg in expr.args: + if ask(Q.real(arg), assumptions): + pass + elif ask(Q.imaginary(arg), assumptions): + result = result ^ True + else: + break + else: + return result + +@RealPredicate.register(Pow) +def _(expr, assumptions): + """ + * Real**Integer -> Real + * Positive**Real -> Real + * Negative**Real -> ? + * Real**(Integer/Even) -> Real if base is nonnegative + * Real**(Integer/Odd) -> Real + * Imaginary**(Integer/Even) -> Real + * Imaginary**(Integer/Odd) -> not Real + * Imaginary**Real -> ? since Real could be 0 (giving real) + or 1 (giving imaginary) + * b**Imaginary -> Real if log(b) is imaginary and b != 0 + and exponent != integer multiple of + I*pi/log(b) + * Real**Real -> ? e.g. sqrt(-1) is imaginary and + sqrt(2) is not + """ + if expr.is_number: + return _RealPredicate_number(expr, assumptions) + + if expr.base == E: + return ask( + Q.integer(expr.exp/I/pi) | Q.real(expr.exp), assumptions + ) + + if expr.base.func == exp or (expr.base.is_Pow and expr.base.base == E): + if ask(Q.imaginary(expr.base.exp), assumptions): + if ask(Q.imaginary(expr.exp), assumptions): + return True + # If the i = (exp's arg)/(I*pi) is an integer or half-integer + # multiple of I*pi then 2*i will be an integer. In addition, + # exp(i*I*pi) = (-1)**i so the overall realness of the expr + # can be determined by replacing exp(i*I*pi) with (-1)**i. + i = expr.base.exp/I/pi + if ask(Q.integer(2*i), assumptions): + return ask(Q.real((S.NegativeOne**i)**expr.exp), assumptions) + return + + if ask(Q.imaginary(expr.base), assumptions): + if ask(Q.integer(expr.exp), assumptions): + odd = ask(Q.odd(expr.exp), assumptions) + if odd is not None: + return not odd + return + + if ask(Q.imaginary(expr.exp), assumptions): + imlog = ask(Q.imaginary(log(expr.base)), assumptions) + if imlog is not None: + # I**i -> real, log(I) is imag; + # (2*I)**i -> complex, log(2*I) is not imag + return imlog + + if ask(Q.real(expr.base), assumptions): + if ask(Q.real(expr.exp), assumptions): + if ask(Q.zero(expr.base), assumptions) is not False: + if ask(Q.positive(expr.exp), assumptions): + return True + return + if expr.exp.is_Rational and \ + ask(Q.even(expr.exp.q), assumptions): + return ask(Q.positive(expr.base), assumptions) + elif ask(Q.integer(expr.exp), assumptions): + return True + elif ask(Q.positive(expr.base), assumptions): + return True + +@RealPredicate.register_many(cos, sin) +def _(expr, assumptions): + if ask(Q.real(expr.args[0]), assumptions): + return True + +@RealPredicate.register(exp) +def _(expr, assumptions): + return ask( + Q.integer(expr.exp/I/pi) | Q.real(expr.exp), assumptions + ) + +@RealPredicate.register(log) +def _(expr, assumptions): + return ask(Q.positive(expr.args[0]), assumptions) + +@RealPredicate.register_many(Determinant, MatrixElement, Trace) +def _(expr, assumptions): + return ask(Q.real_elements(expr.args[0]), assumptions) + + +# ExtendedRealPredicate + +@ExtendedRealPredicate.register(object) +def _(expr, assumptions): + return ask(Q.negative_infinite(expr) + | Q.negative(expr) + | Q.zero(expr) + | Q.positive(expr) + | Q.positive_infinite(expr), + assumptions) + +@ExtendedRealPredicate.register_many(Infinity, NegativeInfinity) +def _(expr, assumptions): + return True + +@ExtendedRealPredicate.register_many(Add, Mul, Pow) # type:ignore +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.extended_real) + + +# HermitianPredicate + +@HermitianPredicate.register(object) # type:ignore +def _(expr, assumptions): + if isinstance(expr, MatrixBase): + return None + return ask(Q.real(expr), assumptions) + +@HermitianPredicate.register(Add) # type:ignore +def _(expr, assumptions): + """ + * Hermitian + Hermitian -> Hermitian + * Hermitian + !Hermitian -> !Hermitian + """ + if expr.is_number: + raise MDNotImplementedError + return test_closed_group(expr, assumptions, Q.hermitian) + +@HermitianPredicate.register(Mul) # type:ignore +def _(expr, assumptions): + """ + As long as there is at most only one noncommutative term: + + * Hermitian*Hermitian -> Hermitian + * Hermitian*Antihermitian -> !Hermitian + * Antihermitian*Antihermitian -> Hermitian + """ + if expr.is_number: + raise MDNotImplementedError + nccount = 0 + result = True + for arg in expr.args: + if ask(Q.antihermitian(arg), assumptions): + result = result ^ True + elif not ask(Q.hermitian(arg), assumptions): + break + if ask(~Q.commutative(arg), assumptions): + nccount += 1 + if nccount > 1: + break + else: + return result + +@HermitianPredicate.register(Pow) # type:ignore +def _(expr, assumptions): + """ + * Hermitian**Integer -> Hermitian + """ + if expr.is_number: + raise MDNotImplementedError + if expr.base == E: + if ask(Q.hermitian(expr.exp), assumptions): + return True + raise MDNotImplementedError + if ask(Q.hermitian(expr.base), assumptions): + if ask(Q.integer(expr.exp), assumptions): + return True + raise MDNotImplementedError + +@HermitianPredicate.register_many(cos, sin) # type:ignore +def _(expr, assumptions): + if ask(Q.hermitian(expr.args[0]), assumptions): + return True + raise MDNotImplementedError + +@HermitianPredicate.register(exp) # type:ignore +def _(expr, assumptions): + if ask(Q.hermitian(expr.exp), assumptions): + return True + raise MDNotImplementedError + +@HermitianPredicate.register(MatrixBase) # type:ignore +def _(mat, assumptions): + rows, cols = mat.shape + ret_val = True + for i in range(rows): + for j in range(i, cols): + cond = fuzzy_bool(Eq(mat[i, j], conjugate(mat[j, i]))) + if cond is None: + ret_val = None + if cond == False: + return False + if ret_val is None: + raise MDNotImplementedError + return ret_val + + +# ComplexPredicate + +@ComplexPredicate.register_many(Abs, cos, exp, im, ImaginaryUnit, log, Number, # type:ignore + NumberSymbol, re, sin) +def _(expr, assumptions): + return True + +@ComplexPredicate.register_many(Infinity, NegativeInfinity) # type:ignore +def _(expr, assumptions): + return False + +@ComplexPredicate.register(Expr) # type:ignore +def _(expr, assumptions): + ret = expr.is_complex + if ret is None: + raise MDNotImplementedError + return ret + +@ComplexPredicate.register_many(Add, Mul) # type:ignore +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.complex) + +@ComplexPredicate.register(Pow) # type:ignore +def _(expr, assumptions): + if expr.base == E: + return True + return test_closed_group(expr, assumptions, Q.complex) + +@ComplexPredicate.register_many(Determinant, MatrixElement, Trace) # type:ignore +def _(expr, assumptions): + return ask(Q.complex_elements(expr.args[0]), assumptions) + +@ComplexPredicate.register(NaN) # type:ignore +def _(expr, assumptions): + return None + + +# ImaginaryPredicate + +def _Imaginary_number(expr, assumptions): + # let as_real_imag() work first since the expression may + # be simpler to evaluate + r = expr.as_real_imag()[0].evalf(2) + if r._prec != 1: + return not r + # allow None to be returned if we couldn't show for sure + # that r was 0 + +@ImaginaryPredicate.register(ImaginaryUnit) # type:ignore +def _(expr, assumptions): + return True + +@ImaginaryPredicate.register(Expr) # type:ignore +def _(expr, assumptions): + ret = expr.is_imaginary + if ret is None: + raise MDNotImplementedError + return ret + +@ImaginaryPredicate.register(Add) # type:ignore +def _(expr, assumptions): + """ + * Imaginary + Imaginary -> Imaginary + * Imaginary + Complex -> ? + * Imaginary + Real -> !Imaginary + """ + if expr.is_number: + return _Imaginary_number(expr, assumptions) + + reals = 0 + for arg in expr.args: + if ask(Q.imaginary(arg), assumptions): + pass + elif ask(Q.real(arg), assumptions): + reals += 1 + else: + break + else: + if reals == 0: + return True + if reals in (1, len(expr.args)): + # two reals could sum 0 thus giving an imaginary + return False + +@ImaginaryPredicate.register(Mul) # type:ignore +def _(expr, assumptions): + """ + * Real*Imaginary -> Imaginary + * Imaginary*Imaginary -> Real + """ + if expr.is_number: + return _Imaginary_number(expr, assumptions) + result = False + reals = 0 + for arg in expr.args: + if ask(Q.imaginary(arg), assumptions): + result = result ^ True + elif not ask(Q.real(arg), assumptions): + break + else: + if reals == len(expr.args): + return False + return result + +@ImaginaryPredicate.register(Pow) # type:ignore +def _(expr, assumptions): + """ + * Imaginary**Odd -> Imaginary + * Imaginary**Even -> Real + * b**Imaginary -> !Imaginary if exponent is an integer + multiple of I*pi/log(b) + * Imaginary**Real -> ? + * Positive**Real -> Real + * Negative**Integer -> Real + * Negative**(Integer/2) -> Imaginary + * Negative**Real -> not Imaginary if exponent is not Rational + """ + if expr.is_number: + return _Imaginary_number(expr, assumptions) + + if expr.base == E: + a = expr.exp/I/pi + return ask(Q.integer(2*a) & ~Q.integer(a), assumptions) + + if expr.base.func == exp or (expr.base.is_Pow and expr.base.base == E): + if ask(Q.imaginary(expr.base.exp), assumptions): + if ask(Q.imaginary(expr.exp), assumptions): + return False + i = expr.base.exp/I/pi + if ask(Q.integer(2*i), assumptions): + return ask(Q.imaginary((S.NegativeOne**i)**expr.exp), assumptions) + + if ask(Q.imaginary(expr.base), assumptions): + if ask(Q.integer(expr.exp), assumptions): + odd = ask(Q.odd(expr.exp), assumptions) + if odd is not None: + return odd + return + + if ask(Q.imaginary(expr.exp), assumptions): + imlog = ask(Q.imaginary(log(expr.base)), assumptions) + if imlog is not None: + # I**i -> real; (2*I)**i -> complex ==> not imaginary + return False + + if ask(Q.real(expr.base) & Q.real(expr.exp), assumptions): + if ask(Q.positive(expr.base), assumptions): + return False + else: + rat = ask(Q.rational(expr.exp), assumptions) + if not rat: + return rat + if ask(Q.integer(expr.exp), assumptions): + return False + else: + half = ask(Q.integer(2*expr.exp), assumptions) + if half: + return ask(Q.negative(expr.base), assumptions) + return half + +@ImaginaryPredicate.register(log) # type:ignore +def _(expr, assumptions): + if ask(Q.real(expr.args[0]), assumptions): + if ask(Q.positive(expr.args[0]), assumptions): + return False + return + # XXX it should be enough to do + # return ask(Q.nonpositive(expr.args[0]), assumptions) + # but ask(Q.nonpositive(exp(x)), Q.imaginary(x)) -> None; + # it should return True since exp(x) will be either 0 or complex + if expr.args[0].func == exp or (expr.args[0].is_Pow and expr.args[0].base == E): + if expr.args[0].exp in [I, -I]: + return True + im = ask(Q.imaginary(expr.args[0]), assumptions) + if im is False: + return False + +@ImaginaryPredicate.register(exp) # type:ignore +def _(expr, assumptions): + a = expr.exp/I/pi + return ask(Q.integer(2*a) & ~Q.integer(a), assumptions) + +@ImaginaryPredicate.register_many(Number, NumberSymbol) # type:ignore +def _(expr, assumptions): + return not (expr.as_real_imag()[1] == 0) + +@ImaginaryPredicate.register(NaN) # type:ignore +def _(expr, assumptions): + return None + + +# AntihermitianPredicate + +@AntihermitianPredicate.register(object) # type:ignore +def _(expr, assumptions): + if isinstance(expr, MatrixBase): + return None + if ask(Q.zero(expr), assumptions): + return True + return ask(Q.imaginary(expr), assumptions) + +@AntihermitianPredicate.register(Add) # type:ignore +def _(expr, assumptions): + """ + * Antihermitian + Antihermitian -> Antihermitian + * Antihermitian + !Antihermitian -> !Antihermitian + """ + if expr.is_number: + raise MDNotImplementedError + return test_closed_group(expr, assumptions, Q.antihermitian) + +@AntihermitianPredicate.register(Mul) # type:ignore +def _(expr, assumptions): + """ + As long as there is at most only one noncommutative term: + + * Hermitian*Hermitian -> !Antihermitian + * Hermitian*Antihermitian -> Antihermitian + * Antihermitian*Antihermitian -> !Antihermitian + """ + if expr.is_number: + raise MDNotImplementedError + nccount = 0 + result = False + for arg in expr.args: + if ask(Q.antihermitian(arg), assumptions): + result = result ^ True + elif not ask(Q.hermitian(arg), assumptions): + break + if ask(~Q.commutative(arg), assumptions): + nccount += 1 + if nccount > 1: + break + else: + return result + +@AntihermitianPredicate.register(Pow) # type:ignore +def _(expr, assumptions): + """ + * Hermitian**Integer -> !Antihermitian + * Antihermitian**Even -> !Antihermitian + * Antihermitian**Odd -> Antihermitian + """ + if expr.is_number: + raise MDNotImplementedError + if ask(Q.hermitian(expr.base), assumptions): + if ask(Q.integer(expr.exp), assumptions): + return False + elif ask(Q.antihermitian(expr.base), assumptions): + if ask(Q.even(expr.exp), assumptions): + return False + elif ask(Q.odd(expr.exp), assumptions): + return True + raise MDNotImplementedError + +@AntihermitianPredicate.register(MatrixBase) # type:ignore +def _(mat, assumptions): + rows, cols = mat.shape + ret_val = True + for i in range(rows): + for j in range(i, cols): + cond = fuzzy_bool(Eq(mat[i, j], -conjugate(mat[j, i]))) + if cond is None: + ret_val = None + if cond == False: + return False + if ret_val is None: + raise MDNotImplementedError + return ret_val + + +# AlgebraicPredicate + +@AlgebraicPredicate.register_many(AlgebraicNumber, Float, GoldenRatio, # type:ignore + ImaginaryUnit, TribonacciConstant) +def _(expr, assumptions): + return True + +@AlgebraicPredicate.register_many(ComplexInfinity, Exp1, Infinity, # type:ignore + NegativeInfinity, Pi) +def _(expr, assumptions): + return False + +@AlgebraicPredicate.register_many(Add, Mul) # type:ignore +def _(expr, assumptions): + return test_closed_group(expr, assumptions, Q.algebraic) + +@AlgebraicPredicate.register(Pow) # type:ignore +def _(expr, assumptions): + if expr.base == E: + if ask(Q.algebraic(expr.exp), assumptions): + return ask(~Q.nonzero(expr.exp), assumptions) + return + if expr.base == pi: + if ask(Q.integer(expr.exp), assumptions) and ask(Q.positive(expr.exp), assumptions): + return False + return + exp_rational = ask(Q.rational(expr.exp), assumptions) + base_algebraic = ask(Q.algebraic(expr.base), assumptions) + exp_algebraic = ask(Q.algebraic(expr.exp),assumptions) + if base_algebraic and exp_algebraic: + if exp_rational: + return True + # Check based on the Gelfond-Schneider theorem: + # If the base is algebraic and not equal to 0 or 1, and the exponent + # is irrational,then the result is transcendental. + if ask(Q.ne(expr.base,0) & Q.ne(expr.base,1)) and exp_rational is False: + return False + +@AlgebraicPredicate.register(Rational) # type:ignore +def _(expr, assumptions): + return expr.q != 0 + +@AlgebraicPredicate.register_many(asin, atan, cos, sin, tan) # type:ignore +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.algebraic(x), assumptions): + return ask(~Q.nonzero(x), assumptions) + +@AlgebraicPredicate.register(exp) # type:ignore +def _(expr, assumptions): + x = expr.exp + if ask(Q.algebraic(x), assumptions): + return ask(~Q.nonzero(x), assumptions) + +@AlgebraicPredicate.register_many(acot, cot) # type:ignore +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.algebraic(x), assumptions): + return False + +@AlgebraicPredicate.register_many(acos, log) # type:ignore +def _(expr, assumptions): + x = expr.args[0] + if ask(Q.algebraic(x), assumptions): + return ask(~Q.nonzero(x - 1), assumptions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/lra_satask.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/lra_satask.py new file mode 100644 index 0000000000000000000000000000000000000000..53afe3e5abe99109ec01a47f19f1a8a4c99c5628 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/lra_satask.py @@ -0,0 +1,286 @@ +from sympy.assumptions.assume import global_assumptions +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.assumptions.ask import Q +from sympy.logic.inference import satisfiable +from sympy.logic.algorithms.lra_theory import UnhandledInput, ALLOWED_PRED +from sympy.matrices.kind import MatrixKind +from sympy.core.kind import NumberKind +from sympy.assumptions.assume import AppliedPredicate +from sympy.core.mul import Mul +from sympy.core.singleton import S + + +def lra_satask(proposition, assumptions=True, context=global_assumptions): + """ + Function to evaluate the proposition with assumptions using SAT algorithm + in conjunction with an Linear Real Arithmetic theory solver. + + Used to handle inequalities. Should eventually be depreciated and combined + into satask, but infinity handling and other things need to be implemented + before that can happen. + """ + props = CNF.from_prop(proposition) + _props = CNF.from_prop(~proposition) + + cnf = CNF.from_prop(assumptions) + assumptions = EncodedCNF() + assumptions.from_cnf(cnf) + + context_cnf = CNF() + if context: + context_cnf = context_cnf.extend(context) + + assumptions.add_from_cnf(context_cnf) + + return check_satisfiability(props, _props, assumptions) + +# Some predicates such as Q.prime can't be handled by lra_satask. +# For example, (x > 0) & (x < 1) & Q.prime(x) is unsat but lra_satask would think it was sat. +# WHITE_LIST is a list of predicates that can always be handled. +WHITE_LIST = ALLOWED_PRED | {Q.positive, Q.negative, Q.zero, Q.nonzero, Q.nonpositive, Q.nonnegative, + Q.extended_positive, Q.extended_negative, Q.extended_nonpositive, + Q.extended_negative, Q.extended_nonzero, Q.negative_infinite, + Q.positive_infinite} + + +def check_satisfiability(prop, _prop, factbase): + sat_true = factbase.copy() + sat_false = factbase.copy() + sat_true.add_from_cnf(prop) + sat_false.add_from_cnf(_prop) + + all_pred, all_exprs = get_all_pred_and_expr_from_enc_cnf(sat_true) + + for pred in all_pred: + if pred.function not in WHITE_LIST and pred.function != Q.ne: + raise UnhandledInput(f"LRASolver: {pred} is an unhandled predicate") + for expr in all_exprs: + if expr.kind == MatrixKind(NumberKind): + raise UnhandledInput(f"LRASolver: {expr} is of MatrixKind") + if expr == S.NaN: + raise UnhandledInput("LRASolver: nan") + + # convert old assumptions into predicates and add them to sat_true and sat_false + # also check for unhandled predicates + for assm in extract_pred_from_old_assum(all_exprs): + n = len(sat_true.encoding) + if assm not in sat_true.encoding: + sat_true.encoding[assm] = n+1 + sat_true.data.append([sat_true.encoding[assm]]) + + n = len(sat_false.encoding) + if assm not in sat_false.encoding: + sat_false.encoding[assm] = n+1 + sat_false.data.append([sat_false.encoding[assm]]) + + + sat_true = _preprocess(sat_true) + sat_false = _preprocess(sat_false) + + can_be_true = satisfiable(sat_true, use_lra_theory=True) is not False + can_be_false = satisfiable(sat_false, use_lra_theory=True) is not False + + if can_be_true and can_be_false: + return None + + if can_be_true and not can_be_false: + return True + + if not can_be_true and can_be_false: + return False + + if not can_be_true and not can_be_false: + raise ValueError("Inconsistent assumptions") + + +def _preprocess(enc_cnf): + """ + Returns an encoded cnf with only Q.eq, Q.gt, Q.lt, + Q.ge, and Q.le predicate. + + Converts every unequality into a disjunction of strict + inequalities. For example, x != 3 would become + x < 3 OR x > 3. + + Also converts all negated Q.ne predicates into + equalities. + """ + + # loops through each literal in each clause + # to construct a new, preprocessed encodedCNF + + enc_cnf = enc_cnf.copy() + cur_enc = 1 + rev_encoding = {value: key for key, value in enc_cnf.encoding.items()} + + new_encoding = {} + new_data = [] + for clause in enc_cnf.data: + new_clause = [] + for lit in clause: + if lit == 0: + new_clause.append(lit) + new_encoding[lit] = False + continue + prop = rev_encoding[abs(lit)] + negated = lit < 0 + sign = (lit > 0) - (lit < 0) + + prop = _pred_to_binrel(prop) + + if not isinstance(prop, AppliedPredicate): + if prop not in new_encoding: + new_encoding[prop] = cur_enc + cur_enc += 1 + lit = new_encoding[prop] + new_clause.append(sign*lit) + continue + + + if negated and prop.function == Q.eq: + negated = False + prop = Q.ne(*prop.arguments) + + if prop.function == Q.ne: + arg1, arg2 = prop.arguments + if negated: + new_prop = Q.eq(arg1, arg2) + if new_prop not in new_encoding: + new_encoding[new_prop] = cur_enc + cur_enc += 1 + + new_enc = new_encoding[new_prop] + new_clause.append(new_enc) + continue + else: + new_props = (Q.gt(arg1, arg2), Q.lt(arg1, arg2)) + for new_prop in new_props: + if new_prop not in new_encoding: + new_encoding[new_prop] = cur_enc + cur_enc += 1 + + new_enc = new_encoding[new_prop] + new_clause.append(new_enc) + continue + + if prop.function == Q.eq and negated: + assert False + + if prop not in new_encoding: + new_encoding[prop] = cur_enc + cur_enc += 1 + new_clause.append(new_encoding[prop]*sign) + new_data.append(new_clause) + + assert len(new_encoding) >= cur_enc - 1 + + enc_cnf = EncodedCNF(new_data, new_encoding) + return enc_cnf + + +def _pred_to_binrel(pred): + if not isinstance(pred, AppliedPredicate): + return pred + + if pred.function in pred_to_pos_neg_zero: + f = pred_to_pos_neg_zero[pred.function] + if f is False: + return False + pred = f(pred.arguments[0]) + + if pred.function == Q.positive: + pred = Q.gt(pred.arguments[0], 0) + elif pred.function == Q.negative: + pred = Q.lt(pred.arguments[0], 0) + elif pred.function == Q.zero: + pred = Q.eq(pred.arguments[0], 0) + elif pred.function == Q.nonpositive: + pred = Q.le(pred.arguments[0], 0) + elif pred.function == Q.nonnegative: + pred = Q.ge(pred.arguments[0], 0) + elif pred.function == Q.nonzero: + pred = Q.ne(pred.arguments[0], 0) + + return pred + +pred_to_pos_neg_zero = { + Q.extended_positive: Q.positive, + Q.extended_negative: Q.negative, + Q.extended_nonpositive: Q.nonpositive, + Q.extended_negative: Q.negative, + Q.extended_nonzero: Q.nonzero, + Q.negative_infinite: False, + Q.positive_infinite: False +} + +def get_all_pred_and_expr_from_enc_cnf(enc_cnf): + all_exprs = set() + all_pred = set() + for pred in enc_cnf.encoding.keys(): + if isinstance(pred, AppliedPredicate): + all_pred.add(pred) + all_exprs.update(pred.arguments) + + return all_pred, all_exprs + +def extract_pred_from_old_assum(all_exprs): + """ + Returns a list of relevant new assumption predicate + based on any old assumptions. + + Raises an UnhandledInput exception if any of the assumptions are + unhandled. + + Ignored predicate: + - commutative + - complex + - algebraic + - transcendental + - extended_real + - real + - all matrix predicate + - rational + - irrational + + Example + ======= + >>> from sympy.assumptions.lra_satask import extract_pred_from_old_assum + >>> from sympy import symbols + >>> x, y = symbols("x y", positive=True) + >>> extract_pred_from_old_assum([x, y, 2]) + [Q.positive(x), Q.positive(y)] + """ + ret = [] + for expr in all_exprs: + if not hasattr(expr, "free_symbols"): + continue + if len(expr.free_symbols) == 0: + continue + + if expr.is_real is not True: + raise UnhandledInput(f"LRASolver: {expr} must be real") + # test for I times imaginary variable; such expressions are considered real + if isinstance(expr, Mul) and any(arg.is_real is not True for arg in expr.args): + raise UnhandledInput(f"LRASolver: {expr} must be real") + + if expr.is_integer == True and expr.is_zero != True: + raise UnhandledInput(f"LRASolver: {expr} is an integer") + if expr.is_integer == False: + raise UnhandledInput(f"LRASolver: {expr} can't be an integer") + if expr.is_rational == False: + raise UnhandledInput(f"LRASolver: {expr} is irational") + + if expr.is_zero: + ret.append(Q.zero(expr)) + elif expr.is_positive: + ret.append(Q.positive(expr)) + elif expr.is_negative: + ret.append(Q.negative(expr)) + elif expr.is_nonzero: + ret.append(Q.nonzero(expr)) + elif expr.is_nonpositive: + ret.append(Q.nonpositive(expr)) + elif expr.is_nonnegative: + ret.append(Q.nonnegative(expr)) + + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e294544bfdce13633ecff762ff42861aa12719f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__init__.py @@ -0,0 +1,5 @@ +""" +Module to implement predicate classes. + +Class of every predicate registered to ``Q`` is defined here. +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d2d713d71dce2850e4b77f4ce61f1574828f6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/calculus.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/calculus.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b1fc35bcb3e4e5099a6ea238ade5e49beb18f38 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/calculus.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ecbc4d2e65ec2ca326ad086e4d066d92f20a5fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/matrices.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/matrices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a4f2ed0e78dcee1e6b6c399c920acfe4f20fb8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/matrices.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/ntheory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/ntheory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b594c60f59225fcbe6773fc8076399e915f6d7d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/ntheory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/order.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/order.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b98a3c238dbf1375bbd034b376577f36e47cffa7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/order.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/sets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/sets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10d82f2276795d4c8a7a917d76d387b0e11a1d41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/__pycache__/sets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/calculus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/calculus.py new file mode 100644 index 0000000000000000000000000000000000000000..f300703788683c07649ee3a0afd6e9d4eabd4567 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/calculus.py @@ -0,0 +1,82 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + +class FinitePredicate(Predicate): + """ + Finite number predicate. + + Explanation + =========== + + ``Q.finite(x)`` is true if ``x`` is a number but neither an infinity + nor a ``NaN``. In other words, ``ask(Q.finite(x))`` is true for all + numerical ``x`` having a bounded absolute value. + + Examples + ======== + + >>> from sympy import Q, ask, S, oo, I, zoo + >>> from sympy.abc import x + >>> ask(Q.finite(oo)) + False + >>> ask(Q.finite(-oo)) + False + >>> ask(Q.finite(zoo)) + False + >>> ask(Q.finite(1)) + True + >>> ask(Q.finite(2 + 3*I)) + True + >>> ask(Q.finite(x), Q.positive(x)) + True + >>> print(ask(Q.finite(S.NaN))) + None + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Finite + + """ + name = 'finite' + handler = Dispatcher( + "FiniteHandler", + doc=("Handler for Q.finite. Test that an expression is bounded respect" + " to all its variables.") + ) + + +class InfinitePredicate(Predicate): + """ + Infinite number predicate. + + ``Q.infinite(x)`` is true iff the absolute value of ``x`` is + infinity. + + """ + # TODO: Add examples + name = 'infinite' + handler = Dispatcher( + "InfiniteHandler", + doc="""Handler for Q.infinite key.""" + ) + + +class PositiveInfinitePredicate(Predicate): + """ + Positive infinity predicate. + + ``Q.positive_infinite(x)`` is true iff ``x`` is positive infinity ``oo``. + """ + name = 'positive_infinite' + handler = Dispatcher("PositiveInfiniteHandler") + + +class NegativeInfinitePredicate(Predicate): + """ + Negative infinity predicate. + + ``Q.negative_infinite(x)`` is true iff ``x`` is negative infinity ``-oo``. + """ + name = 'negative_infinite' + handler = Dispatcher("NegativeInfiniteHandler") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/common.py new file mode 100644 index 0000000000000000000000000000000000000000..a53892747131b03636abeb8f563c4f76cf3e281e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/common.py @@ -0,0 +1,81 @@ +from sympy.assumptions import Predicate, AppliedPredicate, Q +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.multipledispatch import Dispatcher + + +class CommutativePredicate(Predicate): + """ + Commutative predicate. + + Explanation + =========== + + ``ask(Q.commutative(x))`` is true iff ``x`` commutes with any other + object with respect to multiplication operation. + + """ + # TODO: Add examples + name = 'commutative' + handler = Dispatcher("CommutativeHandler", doc="Handler for key 'commutative'.") + + +binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + +class IsTruePredicate(Predicate): + """ + Generic predicate. + + Explanation + =========== + + ``ask(Q.is_true(x))`` is true iff ``x`` is true. This only makes + sense if ``x`` is a boolean object. + + Examples + ======== + + >>> from sympy import ask, Q + >>> from sympy.abc import x, y + >>> ask(Q.is_true(True)) + True + + Wrapping another applied predicate just returns the applied predicate. + + >>> Q.is_true(Q.even(x)) + Q.even(x) + + Wrapping binary relation classes in SymPy core returns applied binary + relational predicates. + + >>> from sympy import Eq, Gt + >>> Q.is_true(Eq(x, y)) + Q.eq(x, y) + >>> Q.is_true(Gt(x, y)) + Q.gt(x, y) + + Notes + ===== + + This class is designed to wrap the boolean objects so that they can + behave as if they are applied predicates. Consequently, wrapping another + applied predicate is unnecessary and thus it just returns the argument. + Also, binary relation classes in SymPy core have binary predicates to + represent themselves and thus wrapping them with ``Q.is_true`` converts them + to these applied predicates. + + """ + name = 'is_true' + handler = Dispatcher( + "IsTrueHandler", + doc="Wrapper allowing to query the truth value of a boolean expression." + ) + + def __call__(self, arg): + # No need to wrap another predicate + if isinstance(arg, AppliedPredicate): + return arg + # Convert relational predicates instead of wrapping them + if getattr(arg, "is_Relational", False): + pred = binrelpreds[type(arg)] + return pred(*arg.args) + return super().__call__(arg) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/matrices.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..151e78c4ff345800e1d2f17973fb0591b8d379d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/matrices.py @@ -0,0 +1,511 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + +class SquarePredicate(Predicate): + """ + Square matrix predicate. + + Explanation + =========== + + ``Q.square(x)`` is true iff ``x`` is a square matrix. A square matrix + is a matrix with the same number of rows and columns. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('X', 2, 3) + >>> ask(Q.square(X)) + True + >>> ask(Q.square(Y)) + False + >>> ask(Q.square(ZeroMatrix(3, 3))) + True + >>> ask(Q.square(Identity(3))) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square_matrix + + """ + name = 'square' + handler = Dispatcher("SquareHandler", doc="Handler for Q.square.") + + +class SymmetricPredicate(Predicate): + """ + Symmetric matrix predicate. + + Explanation + =========== + + ``Q.symmetric(x)`` is true iff ``x`` is a square matrix and is equal to + its transpose. Every square diagonal matrix is a symmetric matrix. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.symmetric(X*Z), Q.symmetric(X) & Q.symmetric(Z)) + True + >>> ask(Q.symmetric(X + Z), Q.symmetric(X) & Q.symmetric(Z)) + True + >>> ask(Q.symmetric(Y)) + False + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_matrix + + """ + # TODO: Add handlers to make these keys work with + # actual matrices and add more examples in the docstring. + name = 'symmetric' + handler = Dispatcher("SymmetricHandler", doc="Handler for Q.symmetric.") + + +class InvertiblePredicate(Predicate): + """ + Invertible matrix predicate. + + Explanation + =========== + + ``Q.invertible(x)`` is true iff ``x`` is an invertible matrix. + A square matrix is called invertible only if its determinant is 0. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.invertible(X*Y), Q.invertible(X)) + False + >>> ask(Q.invertible(X*Z), Q.invertible(X) & Q.invertible(Z)) + True + >>> ask(Q.invertible(X), Q.fullrank(X) & Q.square(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Invertible_matrix + + """ + name = 'invertible' + handler = Dispatcher("InvertibleHandler", doc="Handler for Q.invertible.") + + +class OrthogonalPredicate(Predicate): + """ + Orthogonal matrix predicate. + + Explanation + =========== + + ``Q.orthogonal(x)`` is true iff ``x`` is an orthogonal matrix. + A square matrix ``M`` is an orthogonal matrix if it satisfies + ``M^TM = MM^T = I`` where ``M^T`` is the transpose matrix of + ``M`` and ``I`` is an identity matrix. Note that an orthogonal + matrix is necessarily invertible. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.orthogonal(Y)) + False + >>> ask(Q.orthogonal(X*Z*X), Q.orthogonal(X) & Q.orthogonal(Z)) + True + >>> ask(Q.orthogonal(Identity(3))) + True + >>> ask(Q.invertible(X), Q.orthogonal(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Orthogonal_matrix + + """ + name = 'orthogonal' + handler = Dispatcher("OrthogonalHandler", doc="Handler for key 'orthogonal'.") + + +class UnitaryPredicate(Predicate): + """ + Unitary matrix predicate. + + Explanation + =========== + + ``Q.unitary(x)`` is true iff ``x`` is a unitary matrix. + Unitary matrix is an analogue to orthogonal matrix. A square + matrix ``M`` with complex elements is unitary if :math:``M^TM = MM^T= I`` + where :math:``M^T`` is the conjugate transpose matrix of ``M``. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.unitary(Y)) + False + >>> ask(Q.unitary(X*Z*X), Q.unitary(X) & Q.unitary(Z)) + True + >>> ask(Q.unitary(Identity(3))) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Unitary_matrix + + """ + name = 'unitary' + handler = Dispatcher("UnitaryHandler", doc="Handler for key 'unitary'.") + + +class FullRankPredicate(Predicate): + """ + Fullrank matrix predicate. + + Explanation + =========== + + ``Q.fullrank(x)`` is true iff ``x`` is a full rank matrix. + A matrix is full rank if all rows and columns of the matrix + are linearly independent. A square matrix is full rank iff + its determinant is nonzero. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> ask(Q.fullrank(X.T), Q.fullrank(X)) + True + >>> ask(Q.fullrank(ZeroMatrix(3, 3))) + False + >>> ask(Q.fullrank(Identity(3))) + True + + """ + name = 'fullrank' + handler = Dispatcher("FullRankHandler", doc="Handler for key 'fullrank'.") + + +class PositiveDefinitePredicate(Predicate): + r""" + Positive definite matrix predicate. + + Explanation + =========== + + If $M$ is a :math:`n \times n` symmetric real matrix, it is said + to be positive definite if :math:`Z^TMZ` is positive for + every non-zero column vector $Z$ of $n$ real numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.positive_definite(Y)) + False + >>> ask(Q.positive_definite(Identity(3))) + True + >>> ask(Q.positive_definite(X + Z), Q.positive_definite(X) & + ... Q.positive_definite(Z)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Positive-definite_matrix + + """ + name = "positive_definite" + handler = Dispatcher("PositiveDefiniteHandler", doc="Handler for key 'positive_definite'.") + + +class UpperTriangularPredicate(Predicate): + """ + Upper triangular matrix predicate. + + Explanation + =========== + + A matrix $M$ is called upper triangular matrix if :math:`M_{ij}=0` + for :math:`i>> from sympy import Q, ask, ZeroMatrix, Identity + >>> ask(Q.upper_triangular(Identity(3))) + True + >>> ask(Q.upper_triangular(ZeroMatrix(3, 3))) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/UpperTriangularMatrix.html + + """ + name = "upper_triangular" + handler = Dispatcher("UpperTriangularHandler", doc="Handler for key 'upper_triangular'.") + + +class LowerTriangularPredicate(Predicate): + """ + Lower triangular matrix predicate. + + Explanation + =========== + + A matrix $M$ is called lower triangular matrix if :math:`M_{ij}=0` + for :math:`i>j`. + + Examples + ======== + + >>> from sympy import Q, ask, ZeroMatrix, Identity + >>> ask(Q.lower_triangular(Identity(3))) + True + >>> ask(Q.lower_triangular(ZeroMatrix(3, 3))) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/LowerTriangularMatrix.html + + """ + name = "lower_triangular" + handler = Dispatcher("LowerTriangularHandler", doc="Handler for key 'lower_triangular'.") + + +class DiagonalPredicate(Predicate): + """ + Diagonal matrix predicate. + + Explanation + =========== + + ``Q.diagonal(x)`` is true iff ``x`` is a diagonal matrix. A diagonal + matrix is a matrix in which the entries outside the main diagonal + are all zero. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix + >>> X = MatrixSymbol('X', 2, 2) + >>> ask(Q.diagonal(ZeroMatrix(3, 3))) + True + >>> ask(Q.diagonal(X), Q.lower_triangular(X) & + ... Q.upper_triangular(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Diagonal_matrix + + """ + name = "diagonal" + handler = Dispatcher("DiagonalHandler", doc="Handler for key 'diagonal'.") + + +class IntegerElementsPredicate(Predicate): + """ + Integer elements matrix predicate. + + Explanation + =========== + + ``Q.integer_elements(x)`` is true iff all the elements of ``x`` + are integers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.integer(X[1, 2]), Q.integer_elements(X)) + True + + """ + name = "integer_elements" + handler = Dispatcher("IntegerElementsHandler", doc="Handler for key 'integer_elements'.") + + +class RealElementsPredicate(Predicate): + """ + Real elements matrix predicate. + + Explanation + =========== + + ``Q.real_elements(x)`` is true iff all the elements of ``x`` + are real numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.real(X[1, 2]), Q.real_elements(X)) + True + + """ + name = "real_elements" + handler = Dispatcher("RealElementsHandler", doc="Handler for key 'real_elements'.") + + +class ComplexElementsPredicate(Predicate): + """ + Complex elements matrix predicate. + + Explanation + =========== + + ``Q.complex_elements(x)`` is true iff all the elements of ``x`` + are complex numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.complex(X[1, 2]), Q.complex_elements(X)) + True + >>> ask(Q.complex_elements(X), Q.integer_elements(X)) + True + + """ + name = "complex_elements" + handler = Dispatcher("ComplexElementsHandler", doc="Handler for key 'complex_elements'.") + + +class SingularPredicate(Predicate): + """ + Singular matrix predicate. + + A matrix is singular iff the value of its determinant is 0. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.singular(X), Q.invertible(X)) + False + >>> ask(Q.singular(X), ~Q.invertible(X)) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/SingularMatrix.html + + """ + name = "singular" + handler = Dispatcher("SingularHandler", doc="Predicate fore key 'singular'.") + + +class NormalPredicate(Predicate): + """ + Normal matrix predicate. + + A matrix is normal if it commutes with its conjugate transpose. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.normal(X), Q.unitary(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal_matrix + + """ + name = "normal" + handler = Dispatcher("NormalHandler", doc="Predicate fore key 'normal'.") + + +class TriangularPredicate(Predicate): + """ + Triangular matrix predicate. + + Explanation + =========== + + ``Q.triangular(X)`` is true if ``X`` is one that is either lower + triangular or upper triangular. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.triangular(X), Q.upper_triangular(X)) + True + >>> ask(Q.triangular(X), Q.lower_triangular(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Triangular_matrix + + """ + name = "triangular" + handler = Dispatcher("TriangularHandler", doc="Predicate fore key 'triangular'.") + + +class UnitTriangularPredicate(Predicate): + """ + Unit triangular matrix predicate. + + Explanation + =========== + + A unit triangular matrix is a triangular matrix with 1s + on the diagonal. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.triangular(X), Q.unit_triangular(X)) + True + + """ + name = "unit_triangular" + handler = Dispatcher("UnitTriangularHandler", doc="Predicate fore key 'unit_triangular'.") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/ntheory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..6c598e0ed1bd4a1170aa28044f9ae6de2fa1a1e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/ntheory.py @@ -0,0 +1,126 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + + +class PrimePredicate(Predicate): + """ + Prime number predicate. + + Explanation + =========== + + ``ask(Q.prime(x))`` is true iff ``x`` is a natural number greater + than 1 that has no positive divisors other than ``1`` and the + number itself. + + Examples + ======== + + >>> from sympy import Q, ask + >>> ask(Q.prime(0)) + False + >>> ask(Q.prime(1)) + False + >>> ask(Q.prime(2)) + True + >>> ask(Q.prime(20)) + False + >>> ask(Q.prime(-3)) + False + + """ + name = 'prime' + handler = Dispatcher( + "PrimeHandler", + doc=("Handler for key 'prime'. Test that an expression represents a prime" + " number. When the expression is an exact number, the result (when True)" + " is subject to the limitations of isprime() which is used to return the " + "result.") + ) + + +class CompositePredicate(Predicate): + """ + Composite number predicate. + + Explanation + =========== + + ``ask(Q.composite(x))`` is true iff ``x`` is a positive integer and has + at least one positive divisor other than ``1`` and the number itself. + + Examples + ======== + + >>> from sympy import Q, ask + >>> ask(Q.composite(0)) + False + >>> ask(Q.composite(1)) + False + >>> ask(Q.composite(2)) + False + >>> ask(Q.composite(20)) + True + + """ + name = 'composite' + handler = Dispatcher("CompositeHandler", doc="Handler for key 'composite'.") + + +class EvenPredicate(Predicate): + """ + Even number predicate. + + Explanation + =========== + + ``ask(Q.even(x))`` is true iff ``x`` belongs to the set of even + integers. + + Examples + ======== + + >>> from sympy import Q, ask, pi + >>> ask(Q.even(0)) + True + >>> ask(Q.even(2)) + True + >>> ask(Q.even(3)) + False + >>> ask(Q.even(pi)) + False + + """ + name = 'even' + handler = Dispatcher("EvenHandler", doc="Handler for key 'even'.") + + +class OddPredicate(Predicate): + """ + Odd number predicate. + + Explanation + =========== + + ``ask(Q.odd(x))`` is true iff ``x`` belongs to the set of odd numbers. + + Examples + ======== + + >>> from sympy import Q, ask, pi + >>> ask(Q.odd(0)) + False + >>> ask(Q.odd(2)) + False + >>> ask(Q.odd(3)) + True + >>> ask(Q.odd(pi)) + False + + """ + name = 'odd' + handler = Dispatcher( + "OddHandler", + doc=("Handler for key 'odd'. Test that an expression represents an odd" + " number.") + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/order.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/order.py new file mode 100644 index 0000000000000000000000000000000000000000..86bfb2ae49789efd5b0df99e2cfc63984e956dd0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/order.py @@ -0,0 +1,390 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + + +class NegativePredicate(Predicate): + r""" + Negative number predicate. + + Explanation + =========== + + ``Q.negative(x)`` is true iff ``x`` is a real number and :math:`x < 0`, that is, + it is in the interval :math:`(-\infty, 0)`. Note in particular that negative + infinity is not negative. + + A few important facts about negative numbers: + + - Note that ``Q.nonnegative`` and ``~Q.negative`` are *not* the same + thing. ``~Q.negative(x)`` simply means that ``x`` is not negative, + whereas ``Q.nonnegative(x)`` means that ``x`` is real and not + negative, i.e., ``Q.nonnegative(x)`` is logically equivalent to + ``Q.zero(x) | Q.positive(x)``. So for example, ``~Q.negative(I)`` is + true, whereas ``Q.nonnegative(I)`` is false. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I + >>> x = symbols('x') + >>> ask(Q.negative(x), Q.real(x) & ~Q.positive(x) & ~Q.zero(x)) + True + >>> ask(Q.negative(-1)) + True + >>> ask(Q.nonnegative(I)) + False + >>> ask(~Q.negative(I)) + True + + """ + name = 'negative' + handler = Dispatcher( + "NegativeHandler", + doc=("Handler for Q.negative. Test that an expression is strictly less" + " than zero.") + ) + + +class NonNegativePredicate(Predicate): + """ + Nonnegative real number predicate. + + Explanation + =========== + + ``ask(Q.nonnegative(x))`` is true iff ``x`` belongs to the set of + positive numbers including zero. + + - Note that ``Q.nonnegative`` and ``~Q.negative`` are *not* the same + thing. ``~Q.negative(x)`` simply means that ``x`` is not negative, + whereas ``Q.nonnegative(x)`` means that ``x`` is real and not + negative, i.e., ``Q.nonnegative(x)`` is logically equivalent to + ``Q.zero(x) | Q.positive(x)``. So for example, ``~Q.negative(I)`` is + true, whereas ``Q.nonnegative(I)`` is false. + + Examples + ======== + + >>> from sympy import Q, ask, I + >>> ask(Q.nonnegative(1)) + True + >>> ask(Q.nonnegative(0)) + True + >>> ask(Q.nonnegative(-1)) + False + >>> ask(Q.nonnegative(I)) + False + >>> ask(Q.nonnegative(-I)) + False + + """ + name = 'nonnegative' + handler = Dispatcher( + "NonNegativeHandler", + doc=("Handler for Q.nonnegative.") + ) + + +class NonZeroPredicate(Predicate): + """ + Nonzero real number predicate. + + Explanation + =========== + + ``ask(Q.nonzero(x))`` is true iff ``x`` is real and ``x`` is not zero. Note in + particular that ``Q.nonzero(x)`` is false if ``x`` is not real. Use + ``~Q.zero(x)`` if you want the negation of being zero without any real + assumptions. + + A few important facts about nonzero numbers: + + - ``Q.nonzero`` is logically equivalent to ``Q.positive | Q.negative``. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I, oo + >>> x = symbols('x') + >>> print(ask(Q.nonzero(x), ~Q.zero(x))) + None + >>> ask(Q.nonzero(x), Q.positive(x)) + True + >>> ask(Q.nonzero(x), Q.zero(x)) + False + >>> ask(Q.nonzero(0)) + False + >>> ask(Q.nonzero(I)) + False + >>> ask(~Q.zero(I)) + True + >>> ask(Q.nonzero(oo)) + False + + """ + name = 'nonzero' + handler = Dispatcher( + "NonZeroHandler", + doc=("Handler for key 'nonzero'. Test that an expression is not identically" + " zero.") + ) + + +class ZeroPredicate(Predicate): + """ + Zero number predicate. + + Explanation + =========== + + ``ask(Q.zero(x))`` is true iff the value of ``x`` is zero. + + Examples + ======== + + >>> from sympy import ask, Q, oo, symbols + >>> x, y = symbols('x, y') + >>> ask(Q.zero(0)) + True + >>> ask(Q.zero(1/oo)) + True + >>> print(ask(Q.zero(0*oo))) + None + >>> ask(Q.zero(1)) + False + >>> ask(Q.zero(x*y), Q.zero(x) | Q.zero(y)) + True + + """ + name = 'zero' + handler = Dispatcher( + "ZeroHandler", + doc="Handler for key 'zero'." + ) + + +class NonPositivePredicate(Predicate): + """ + Nonpositive real number predicate. + + Explanation + =========== + + ``ask(Q.nonpositive(x))`` is true iff ``x`` belongs to the set of + negative numbers including zero. + + - Note that ``Q.nonpositive`` and ``~Q.positive`` are *not* the same + thing. ``~Q.positive(x)`` simply means that ``x`` is not positive, + whereas ``Q.nonpositive(x)`` means that ``x`` is real and not + positive, i.e., ``Q.nonpositive(x)`` is logically equivalent to + `Q.negative(x) | Q.zero(x)``. So for example, ``~Q.positive(I)`` is + true, whereas ``Q.nonpositive(I)`` is false. + + Examples + ======== + + >>> from sympy import Q, ask, I + + >>> ask(Q.nonpositive(-1)) + True + >>> ask(Q.nonpositive(0)) + True + >>> ask(Q.nonpositive(1)) + False + >>> ask(Q.nonpositive(I)) + False + >>> ask(Q.nonpositive(-I)) + False + + """ + name = 'nonpositive' + handler = Dispatcher( + "NonPositiveHandler", + doc="Handler for key 'nonpositive'." + ) + + +class PositivePredicate(Predicate): + r""" + Positive real number predicate. + + Explanation + =========== + + ``Q.positive(x)`` is true iff ``x`` is real and `x > 0`, that is if ``x`` + is in the interval `(0, \infty)`. In particular, infinity is not + positive. + + A few important facts about positive numbers: + + - Note that ``Q.nonpositive`` and ``~Q.positive`` are *not* the same + thing. ``~Q.positive(x)`` simply means that ``x`` is not positive, + whereas ``Q.nonpositive(x)`` means that ``x`` is real and not + positive, i.e., ``Q.nonpositive(x)`` is logically equivalent to + `Q.negative(x) | Q.zero(x)``. So for example, ``~Q.positive(I)`` is + true, whereas ``Q.nonpositive(I)`` is false. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I + >>> x = symbols('x') + >>> ask(Q.positive(x), Q.real(x) & ~Q.negative(x) & ~Q.zero(x)) + True + >>> ask(Q.positive(1)) + True + >>> ask(Q.nonpositive(I)) + False + >>> ask(~Q.positive(I)) + True + + """ + name = 'positive' + handler = Dispatcher( + "PositiveHandler", + doc=("Handler for key 'positive'. Test that an expression is strictly" + " greater than zero.") + ) + + +class ExtendedPositivePredicate(Predicate): + r""" + Positive extended real number predicate. + + Explanation + =========== + + ``Q.extended_positive(x)`` is true iff ``x`` is extended real and + `x > 0`, that is if ``x`` is in the interval `(0, \infty]`. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_positive(1)) + True + >>> ask(Q.extended_positive(oo)) + True + >>> ask(Q.extended_positive(I)) + False + + """ + name = 'extended_positive' + handler = Dispatcher("ExtendedPositiveHandler") + + +class ExtendedNegativePredicate(Predicate): + r""" + Negative extended real number predicate. + + Explanation + =========== + + ``Q.extended_negative(x)`` is true iff ``x`` is extended real and + `x < 0`, that is if ``x`` is in the interval `[-\infty, 0)`. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_negative(-1)) + True + >>> ask(Q.extended_negative(-oo)) + True + >>> ask(Q.extended_negative(-I)) + False + + """ + name = 'extended_negative' + handler = Dispatcher("ExtendedNegativeHandler") + + +class ExtendedNonZeroPredicate(Predicate): + """ + Nonzero extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonzero(x))`` is true iff ``x`` is extended real and + ``x`` is not zero. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonzero(-1)) + True + >>> ask(Q.extended_nonzero(oo)) + True + >>> ask(Q.extended_nonzero(I)) + False + + """ + name = 'extended_nonzero' + handler = Dispatcher("ExtendedNonZeroHandler") + + +class ExtendedNonPositivePredicate(Predicate): + """ + Nonpositive extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonpositive(x))`` is true iff ``x`` is extended real and + ``x`` is not positive. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonpositive(-1)) + True + >>> ask(Q.extended_nonpositive(oo)) + False + >>> ask(Q.extended_nonpositive(0)) + True + >>> ask(Q.extended_nonpositive(I)) + False + + """ + name = 'extended_nonpositive' + handler = Dispatcher("ExtendedNonPositiveHandler") + + +class ExtendedNonNegativePredicate(Predicate): + """ + Nonnegative extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonnegative(x))`` is true iff ``x`` is extended real and + ``x`` is not negative. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonnegative(-1)) + False + >>> ask(Q.extended_nonnegative(oo)) + True + >>> ask(Q.extended_nonnegative(0)) + True + >>> ask(Q.extended_nonnegative(I)) + False + + """ + name = 'extended_nonnegative' + handler = Dispatcher("ExtendedNonNegativeHandler") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/sets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..18261cee2d9de65df14a31a56b2cd22328328ed0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/predicates/sets.py @@ -0,0 +1,399 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + + +class IntegerPredicate(Predicate): + """ + Integer predicate. + + Explanation + =========== + + ``Q.integer(x)`` is true iff ``x`` belongs to the set of integer + numbers. + + Examples + ======== + + >>> from sympy import Q, ask, S + >>> ask(Q.integer(5)) + True + >>> ask(Q.integer(S(1)/2)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Integer + + """ + name = 'integer' + handler = Dispatcher( + "IntegerHandler", + doc=("Handler for Q.integer.\n\n" + "Test that an expression belongs to the field of integer numbers.") + ) + + +class NonIntegerPredicate(Predicate): + """ + Non-integer extended real predicate. + """ + name = 'noninteger' + handler = Dispatcher( + "NonIntegerHandler", + doc=("Handler for Q.noninteger.\n\n" + "Test that an expression is a non-integer extended real number.") + ) + + +class RationalPredicate(Predicate): + """ + Rational number predicate. + + Explanation + =========== + + ``Q.rational(x)`` is true iff ``x`` belongs to the set of + rational numbers. + + Examples + ======== + + >>> from sympy import ask, Q, pi, S + >>> ask(Q.rational(0)) + True + >>> ask(Q.rational(S(1)/2)) + True + >>> ask(Q.rational(pi)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rational_number + + """ + name = 'rational' + handler = Dispatcher( + "RationalHandler", + doc=("Handler for Q.rational.\n\n" + "Test that an expression belongs to the field of rational numbers.") + ) + + +class IrrationalPredicate(Predicate): + """ + Irrational number predicate. + + Explanation + =========== + + ``Q.irrational(x)`` is true iff ``x`` is any real number that + cannot be expressed as a ratio of integers. + + Examples + ======== + + >>> from sympy import ask, Q, pi, S, I + >>> ask(Q.irrational(0)) + False + >>> ask(Q.irrational(S(1)/2)) + False + >>> ask(Q.irrational(pi)) + True + >>> ask(Q.irrational(I)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Irrational_number + + """ + name = 'irrational' + handler = Dispatcher( + "IrrationalHandler", + doc=("Handler for Q.irrational.\n\n" + "Test that an expression is irrational numbers.") + ) + + +class RealPredicate(Predicate): + r""" + Real number predicate. + + Explanation + =========== + + ``Q.real(x)`` is true iff ``x`` is a real number, i.e., it is in the + interval `(-\infty, \infty)`. Note that, in particular the + infinities are not real. Use ``Q.extended_real`` if you want to + consider those as well. + + A few important facts about reals: + + - Every real number is positive, negative, or zero. Furthermore, + because these sets are pairwise disjoint, each real number is + exactly one of those three. + + - Every real number is also complex. + + - Every real number is finite. + + - Every real number is either rational or irrational. + + - Every real number is either algebraic or transcendental. + + - The facts ``Q.negative``, ``Q.zero``, ``Q.positive``, + ``Q.nonnegative``, ``Q.nonpositive``, ``Q.nonzero``, + ``Q.integer``, ``Q.rational``, and ``Q.irrational`` all imply + ``Q.real``, as do all facts that imply those facts. + + - The facts ``Q.algebraic``, and ``Q.transcendental`` do not imply + ``Q.real``; they imply ``Q.complex``. An algebraic or + transcendental number may or may not be real. + + - The "non" facts (i.e., ``Q.nonnegative``, ``Q.nonzero``, + ``Q.nonpositive`` and ``Q.noninteger``) are not equivalent to + not the fact, but rather, not the fact *and* ``Q.real``. + For example, ``Q.nonnegative`` means ``~Q.negative & Q.real``. + So for example, ``I`` is not nonnegative, nonzero, or + nonpositive. + + Examples + ======== + + >>> from sympy import Q, ask, symbols + >>> x = symbols('x') + >>> ask(Q.real(x), Q.positive(x)) + True + >>> ask(Q.real(0)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Real_number + + """ + name = 'real' + handler = Dispatcher( + "RealHandler", + doc=("Handler for Q.real.\n\n" + "Test that an expression belongs to the field of real numbers.") + ) + + +class ExtendedRealPredicate(Predicate): + r""" + Extended real predicate. + + Explanation + =========== + + ``Q.extended_real(x)`` is true iff ``x`` is a real number or + `\{-\infty, \infty\}`. + + See documentation of ``Q.real`` for more information about related + facts. + + Examples + ======== + + >>> from sympy import ask, Q, oo, I + >>> ask(Q.extended_real(1)) + True + >>> ask(Q.extended_real(I)) + False + >>> ask(Q.extended_real(oo)) + True + + """ + name = 'extended_real' + handler = Dispatcher( + "ExtendedRealHandler", + doc=("Handler for Q.extended_real.\n\n" + "Test that an expression belongs to the field of extended real\n" + "numbers, that is real numbers union {Infinity, -Infinity}.") + ) + + +class HermitianPredicate(Predicate): + """ + Hermitian predicate. + + Explanation + =========== + + ``ask(Q.hermitian(x))`` is true iff ``x`` belongs to the set of + Hermitian operators. + + References + ========== + + .. [1] https://mathworld.wolfram.com/HermitianOperator.html + + """ + # TODO: Add examples + name = 'hermitian' + handler = Dispatcher( + "HermitianHandler", + doc=("Handler for Q.hermitian.\n\n" + "Test that an expression belongs to the field of Hermitian operators.") + ) + + +class ComplexPredicate(Predicate): + """ + Complex number predicate. + + Explanation + =========== + + ``Q.complex(x)`` is true iff ``x`` belongs to the set of complex + numbers. Note that every complex number is finite. + + Examples + ======== + + >>> from sympy import Q, Symbol, ask, I, oo + >>> x = Symbol('x') + >>> ask(Q.complex(0)) + True + >>> ask(Q.complex(2 + 3*I)) + True + >>> ask(Q.complex(oo)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Complex_number + + """ + name = 'complex' + handler = Dispatcher( + "ComplexHandler", + doc=("Handler for Q.complex.\n\n" + "Test that an expression belongs to the field of complex numbers.") + ) + + +class ImaginaryPredicate(Predicate): + """ + Imaginary number predicate. + + Explanation + =========== + + ``Q.imaginary(x)`` is true iff ``x`` can be written as a real + number multiplied by the imaginary unit ``I``. Please note that ``0`` + is not considered to be an imaginary number. + + Examples + ======== + + >>> from sympy import Q, ask, I + >>> ask(Q.imaginary(3*I)) + True + >>> ask(Q.imaginary(2 + 3*I)) + False + >>> ask(Q.imaginary(0)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Imaginary_number + + """ + name = 'imaginary' + handler = Dispatcher( + "ImaginaryHandler", + doc=("Handler for Q.imaginary.\n\n" + "Test that an expression belongs to the field of imaginary numbers,\n" + "that is, numbers in the form x*I, where x is real.") + ) + + +class AntihermitianPredicate(Predicate): + """ + Antihermitian predicate. + + Explanation + =========== + + ``Q.antihermitian(x)`` is true iff ``x`` belongs to the field of + antihermitian operators, i.e., operators in the form ``x*I``, where + ``x`` is Hermitian. + + References + ========== + + .. [1] https://mathworld.wolfram.com/HermitianOperator.html + + """ + # TODO: Add examples + name = 'antihermitian' + handler = Dispatcher( + "AntiHermitianHandler", + doc=("Handler for Q.antihermitian.\n\n" + "Test that an expression belongs to the field of anti-Hermitian\n" + "operators, that is, operators in the form x*I, where x is Hermitian.") + ) + + +class AlgebraicPredicate(Predicate): + r""" + Algebraic number predicate. + + Explanation + =========== + + ``Q.algebraic(x)`` is true iff ``x`` belongs to the set of + algebraic numbers. ``x`` is algebraic if there is some polynomial + in ``p(x)\in \mathbb\{Q\}[x]`` such that ``p(x) = 0``. + + Examples + ======== + + >>> from sympy import ask, Q, sqrt, I, pi + >>> ask(Q.algebraic(sqrt(2))) + True + >>> ask(Q.algebraic(I)) + True + >>> ask(Q.algebraic(pi)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Algebraic_number + + """ + name = 'algebraic' + AlgebraicHandler = Dispatcher( + "AlgebraicHandler", + doc="""Handler for Q.algebraic key.""" + ) + + +class TranscendentalPredicate(Predicate): + """ + Transcedental number predicate. + + Explanation + =========== + + ``Q.transcendental(x)`` is true iff ``x`` belongs to the set of + transcendental numbers. A transcendental number is a real + or complex number that is not algebraic. + + """ + # TODO: Add examples + name = 'transcendental' + handler = Dispatcher( + "Transcendental", + doc="""Handler for Q.transcendental key.""" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/refine.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..c36a4e1cdb40f1b59a96f60a3b36182b587920fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/refine.py @@ -0,0 +1,405 @@ +from __future__ import annotations +from typing import Callable + +from sympy.core import S, Add, Expr, Basic, Mul, Pow, Rational +from sympy.core.logic import fuzzy_not +from sympy.logic.boolalg import Boolean + +from sympy.assumptions import ask, Q # type: ignore + + +def refine(expr, assumptions=True): + """ + Simplify an expression using assumptions. + + Explanation + =========== + + Unlike :func:`~.simplify` which performs structural simplification + without any assumption, this function transforms the expression into + the form which is only valid under certain assumptions. Note that + ``simplify()`` is generally not done in refining process. + + Refining boolean expression involves reducing it to ``S.true`` or + ``S.false``. Unlike :func:`~.ask`, the expression will not be reduced + if the truth value cannot be determined. + + Examples + ======== + + >>> from sympy import refine, sqrt, Q + >>> from sympy.abc import x + >>> refine(sqrt(x**2), Q.real(x)) + Abs(x) + >>> refine(sqrt(x**2), Q.positive(x)) + x + + >>> refine(Q.real(x), Q.positive(x)) + True + >>> refine(Q.positive(x), Q.real(x)) + Q.positive(x) + + See Also + ======== + + sympy.simplify.simplify.simplify : Structural simplification without assumptions. + sympy.assumptions.ask.ask : Query for boolean expressions using assumptions. + """ + if not isinstance(expr, Basic): + return expr + + if not expr.is_Atom: + args = [refine(arg, assumptions) for arg in expr.args] + # TODO: this will probably not work with Integral or Polynomial + expr = expr.func(*args) + if hasattr(expr, '_eval_refine'): + ref_expr = expr._eval_refine(assumptions) + if ref_expr is not None: + return ref_expr + name = expr.__class__.__name__ + handler = handlers_dict.get(name, None) + if handler is None: + return expr + new_expr = handler(expr, assumptions) + if (new_expr is None) or (expr == new_expr): + return expr + if not isinstance(new_expr, Expr): + return new_expr + return refine(new_expr, assumptions) + + +def refine_abs(expr, assumptions): + """ + Handler for the absolute value. + + Examples + ======== + + >>> from sympy import Q, Abs + >>> from sympy.assumptions.refine import refine_abs + >>> from sympy.abc import x + >>> refine_abs(Abs(x), Q.real(x)) + >>> refine_abs(Abs(x), Q.positive(x)) + x + >>> refine_abs(Abs(x), Q.negative(x)) + -x + + """ + from sympy.functions.elementary.complexes import Abs + arg = expr.args[0] + if ask(Q.real(arg), assumptions) and \ + fuzzy_not(ask(Q.negative(arg), assumptions)): + # if it's nonnegative + return arg + if ask(Q.negative(arg), assumptions): + return -arg + # arg is Mul + if isinstance(arg, Mul): + r = [refine(abs(a), assumptions) for a in arg.args] + non_abs = [] + in_abs = [] + for i in r: + if isinstance(i, Abs): + in_abs.append(i.args[0]) + else: + non_abs.append(i) + return Mul(*non_abs) * Abs(Mul(*in_abs)) + + +def refine_Pow(expr, assumptions): + """ + Handler for instances of Pow. + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.refine import refine_Pow + >>> from sympy.abc import x,y,z + >>> refine_Pow((-1)**x, Q.real(x)) + >>> refine_Pow((-1)**x, Q.even(x)) + 1 + >>> refine_Pow((-1)**x, Q.odd(x)) + -1 + + For powers of -1, even parts of the exponent can be simplified: + + >>> refine_Pow((-1)**(x+y), Q.even(x)) + (-1)**y + >>> refine_Pow((-1)**(x+y+z), Q.odd(x) & Q.odd(z)) + (-1)**y + >>> refine_Pow((-1)**(x+y+2), Q.odd(x)) + (-1)**(y + 1) + >>> refine_Pow((-1)**(x+3), True) + (-1)**(x + 1) + + """ + from sympy.functions.elementary.complexes import Abs + from sympy.functions import sign + if isinstance(expr.base, Abs): + if ask(Q.real(expr.base.args[0]), assumptions) and \ + ask(Q.even(expr.exp), assumptions): + return expr.base.args[0] ** expr.exp + if ask(Q.real(expr.base), assumptions): + if expr.base.is_number: + if ask(Q.even(expr.exp), assumptions): + return abs(expr.base) ** expr.exp + if ask(Q.odd(expr.exp), assumptions): + return sign(expr.base) * abs(expr.base) ** expr.exp + if isinstance(expr.exp, Rational): + if isinstance(expr.base, Pow): + return abs(expr.base.base) ** (expr.base.exp * expr.exp) + + if expr.base is S.NegativeOne: + if expr.exp.is_Add: + + old = expr + + # For powers of (-1) we can remove + # - even terms + # - pairs of odd terms + # - a single odd term + 1 + # - A numerical constant N can be replaced with mod(N,2) + + coeff, terms = expr.exp.as_coeff_add() + terms = set(terms) + even_terms = set() + odd_terms = set() + initial_number_of_terms = len(terms) + + for t in terms: + if ask(Q.even(t), assumptions): + even_terms.add(t) + elif ask(Q.odd(t), assumptions): + odd_terms.add(t) + + terms -= even_terms + if len(odd_terms) % 2: + terms -= odd_terms + new_coeff = (coeff + S.One) % 2 + else: + terms -= odd_terms + new_coeff = coeff % 2 + + if new_coeff != coeff or len(terms) < initial_number_of_terms: + terms.add(new_coeff) + expr = expr.base**(Add(*terms)) + + # Handle (-1)**((-1)**n/2 + m/2) + e2 = 2*expr.exp + if ask(Q.even(e2), assumptions): + if e2.could_extract_minus_sign(): + e2 *= expr.base + if e2.is_Add: + i, p = e2.as_two_terms() + if p.is_Pow and p.base is S.NegativeOne: + if ask(Q.integer(p.exp), assumptions): + i = (i + 1)/2 + if ask(Q.even(i), assumptions): + return expr.base**p.exp + elif ask(Q.odd(i), assumptions): + return expr.base**(p.exp + 1) + else: + return expr.base**(p.exp + i) + + if old != expr: + return expr + + +def refine_atan2(expr, assumptions): + """ + Handler for the atan2 function. + + Examples + ======== + + >>> from sympy import Q, atan2 + >>> from sympy.assumptions.refine import refine_atan2 + >>> from sympy.abc import x, y + >>> refine_atan2(atan2(y,x), Q.real(y) & Q.positive(x)) + atan(y/x) + >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.negative(x)) + atan(y/x) - pi + >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.negative(x)) + atan(y/x) + pi + >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.negative(x)) + pi + >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.zero(x)) + pi/2 + >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.zero(x)) + -pi/2 + >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.zero(x)) + nan + """ + from sympy.functions.elementary.trigonometric import atan + y, x = expr.args + if ask(Q.real(y) & Q.positive(x), assumptions): + return atan(y / x) + elif ask(Q.negative(y) & Q.negative(x), assumptions): + return atan(y / x) - S.Pi + elif ask(Q.positive(y) & Q.negative(x), assumptions): + return atan(y / x) + S.Pi + elif ask(Q.zero(y) & Q.negative(x), assumptions): + return S.Pi + elif ask(Q.positive(y) & Q.zero(x), assumptions): + return S.Pi/2 + elif ask(Q.negative(y) & Q.zero(x), assumptions): + return -S.Pi/2 + elif ask(Q.zero(y) & Q.zero(x), assumptions): + return S.NaN + else: + return expr + + +def refine_re(expr, assumptions): + """ + Handler for real part. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_re + >>> from sympy import Q, re + >>> from sympy.abc import x + >>> refine_re(re(x), Q.real(x)) + x + >>> refine_re(re(x), Q.imaginary(x)) + 0 + """ + arg = expr.args[0] + if ask(Q.real(arg), assumptions): + return arg + if ask(Q.imaginary(arg), assumptions): + return S.Zero + return _refine_reim(expr, assumptions) + + +def refine_im(expr, assumptions): + """ + Handler for imaginary part. + + Explanation + =========== + + >>> from sympy.assumptions.refine import refine_im + >>> from sympy import Q, im + >>> from sympy.abc import x + >>> refine_im(im(x), Q.real(x)) + 0 + >>> refine_im(im(x), Q.imaginary(x)) + -I*x + """ + arg = expr.args[0] + if ask(Q.real(arg), assumptions): + return S.Zero + if ask(Q.imaginary(arg), assumptions): + return - S.ImaginaryUnit * arg + return _refine_reim(expr, assumptions) + +def refine_arg(expr, assumptions): + """ + Handler for complex argument + + Explanation + =========== + + >>> from sympy.assumptions.refine import refine_arg + >>> from sympy import Q, arg + >>> from sympy.abc import x + >>> refine_arg(arg(x), Q.positive(x)) + 0 + >>> refine_arg(arg(x), Q.negative(x)) + pi + """ + rg = expr.args[0] + if ask(Q.positive(rg), assumptions): + return S.Zero + if ask(Q.negative(rg), assumptions): + return S.Pi + return None + + +def _refine_reim(expr, assumptions): + # Helper function for refine_re & refine_im + expanded = expr.expand(complex = True) + if expanded != expr: + refined = refine(expanded, assumptions) + if refined != expanded: + return refined + # Best to leave the expression as is + return None + + +def refine_sign(expr, assumptions): + """ + Handler for sign. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_sign + >>> from sympy import Symbol, Q, sign, im + >>> x = Symbol('x', real = True) + >>> expr = sign(x) + >>> refine_sign(expr, Q.positive(x) & Q.nonzero(x)) + 1 + >>> refine_sign(expr, Q.negative(x) & Q.nonzero(x)) + -1 + >>> refine_sign(expr, Q.zero(x)) + 0 + >>> y = Symbol('y', imaginary = True) + >>> expr = sign(y) + >>> refine_sign(expr, Q.positive(im(y))) + I + >>> refine_sign(expr, Q.negative(im(y))) + -I + """ + arg = expr.args[0] + if ask(Q.zero(arg), assumptions): + return S.Zero + if ask(Q.real(arg)): + if ask(Q.positive(arg), assumptions): + return S.One + if ask(Q.negative(arg), assumptions): + return S.NegativeOne + if ask(Q.imaginary(arg)): + arg_re, arg_im = arg.as_real_imag() + if ask(Q.positive(arg_im), assumptions): + return S.ImaginaryUnit + if ask(Q.negative(arg_im), assumptions): + return -S.ImaginaryUnit + return expr + + +def refine_matrixelement(expr, assumptions): + """ + Handler for symmetric part. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_matrixelement + >>> from sympy import MatrixSymbol, Q + >>> X = MatrixSymbol('X', 3, 3) + >>> refine_matrixelement(X[0, 1], Q.symmetric(X)) + X[0, 1] + >>> refine_matrixelement(X[1, 0], Q.symmetric(X)) + X[0, 1] + """ + from sympy.matrices.expressions.matexpr import MatrixElement + matrix, i, j = expr.args + if ask(Q.symmetric(matrix), assumptions): + if (i - j).could_extract_minus_sign(): + return expr + return MatrixElement(matrix, j, i) + +handlers_dict: dict[str, Callable[[Expr, Boolean], Expr]] = { + 'Abs': refine_abs, + 'Pow': refine_Pow, + 'atan2': refine_atan2, + 're': refine_re, + 'im': refine_im, + 'arg': refine_arg, + 'sign': refine_sign, + 'MatrixElement': refine_matrixelement +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04f5ed37893766feec941614691a9177f14e4027 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__init__.py @@ -0,0 +1,13 @@ +""" +A module to implement finitary relations [1] as predicate. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Finitary_relation + +""" + +__all__ = ['BinaryRelation', 'AppliedBinaryRelation'] + +from .binrel import BinaryRelation, AppliedBinaryRelation diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f951521183e23a61a8beed57e56fff82006e6119 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/binrel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/binrel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5f06fab6a5af916e7c100c3bb23a0b5682e4639 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/binrel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/equality.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/equality.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..063393d2e20bde395dac46617dbc6b35cc0385cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/__pycache__/equality.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/binrel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/binrel.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4eba05bcce40f1a05483a30136b6ccd891c42f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/binrel.py @@ -0,0 +1,212 @@ +""" +General binary relations. +""" +from typing import Optional + +from sympy.core.singleton import S +from sympy.assumptions import AppliedPredicate, ask, Predicate, Q # type: ignore +from sympy.core.kind import BooleanKind +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.logic.boolalg import conjuncts, Not + +__all__ = ["BinaryRelation", "AppliedBinaryRelation"] + + +class BinaryRelation(Predicate): + """ + Base class for all binary relational predicates. + + Explanation + =========== + + Binary relation takes two arguments and returns ``AppliedBinaryRelation`` + instance. To evaluate it to boolean value, use :obj:`~.ask()` or + :obj:`~.refine()` function. + + You can add support for new types by registering the handler to dispatcher. + See :obj:`~.Predicate()` for more information about predicate dispatching. + + Examples + ======== + + Applying and evaluating to boolean value: + + >>> from sympy import Q, ask, sin, cos + >>> from sympy.abc import x + >>> Q.eq(sin(x)**2+cos(x)**2, 1) + Q.eq(sin(x)**2 + cos(x)**2, 1) + >>> ask(_) + True + + You can define a new binary relation by subclassing and dispatching. + Here, we define a relation $R$ such that $x R y$ returns true if + $x = y + 1$. + + >>> from sympy import ask, Number, Q + >>> from sympy.assumptions import BinaryRelation + >>> class MyRel(BinaryRelation): + ... name = "R" + ... is_reflexive = False + >>> Q.R = MyRel() + >>> @Q.R.register(Number, Number) + ... def _(n1, n2, assumptions): + ... return ask(Q.zero(n1 - n2 - 1), assumptions) + >>> Q.R(2, 1) + Q.R(2, 1) + + Now, we can use ``ask()`` to evaluate it to boolean value. + + >>> ask(Q.R(2, 1)) + True + >>> ask(Q.R(1, 2)) + False + + ``Q.R`` returns ``False`` with minimum cost if two arguments have same + structure because it is antireflexive relation [1] by + ``is_reflexive = False``. + + >>> ask(Q.R(x, x)) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Reflexive_relation + """ + + is_reflexive: Optional[bool] = None + is_symmetric: Optional[bool] = None + + def __call__(self, *args): + if not len(args) == 2: + raise ValueError("Binary relation takes two arguments, but got %s." % len(args)) + return AppliedBinaryRelation(self, *args) + + @property + def reversed(self): + if self.is_symmetric: + return self + return None + + @property + def negated(self): + return None + + def _compare_reflexive(self, lhs, rhs): + # quick exit for structurally same arguments + # do not check != here because it cannot catch the + # equivalent arguments with different structures. + + # reflexivity does not hold to NaN + if lhs is S.NaN or rhs is S.NaN: + return None + + reflexive = self.is_reflexive + if reflexive is None: + pass + elif reflexive and (lhs == rhs): + return True + elif not reflexive and (lhs == rhs): + return False + return None + + def eval(self, args, assumptions=True): + # quick exit for structurally same arguments + ret = self._compare_reflexive(*args) + if ret is not None: + return ret + + # don't perform simplify on args here. (done by AppliedBinaryRelation._eval_ask) + # evaluate by multipledispatch + lhs, rhs = args + ret = self.handler(lhs, rhs, assumptions=assumptions) + if ret is not None: + return ret + + # check reversed order if the relation is reflexive + if self.is_reflexive: + types = (type(lhs), type(rhs)) + if self.handler.dispatch(*types) is not self.handler.dispatch(*reversed(types)): + ret = self.handler(rhs, lhs, assumptions=assumptions) + + return ret + + +class AppliedBinaryRelation(AppliedPredicate): + """ + The class of expressions resulting from applying ``BinaryRelation`` + to the arguments. + + """ + + @property + def lhs(self): + """The left-hand side of the relation.""" + return self.arguments[0] + + @property + def rhs(self): + """The right-hand side of the relation.""" + return self.arguments[1] + + @property + def reversed(self): + """ + Try to return the relationship with sides reversed. + """ + revfunc = self.function.reversed + if revfunc is None: + return self + return revfunc(self.rhs, self.lhs) + + @property + def reversedsign(self): + """ + Try to return the relationship with signs reversed. + """ + revfunc = self.function.reversed + if revfunc is None: + return self + if not any(side.kind is BooleanKind for side in self.arguments): + return revfunc(-self.lhs, -self.rhs) + return self + + @property + def negated(self): + neg_rel = self.function.negated + if neg_rel is None: + return Not(self, evaluate=False) + return neg_rel(*self.arguments) + + def _eval_ask(self, assumptions): + conj_assumps = set() + binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + for a in conjuncts(assumptions): + if a.func in binrelpreds: + conj_assumps.add(binrelpreds[type(a)](*a.args)) + else: + conj_assumps.add(a) + + # After CNF in assumptions module is modified to take polyadic + # predicate, this will be removed + if any(rel in conj_assumps for rel in (self, self.reversed)): + return True + neg_rels = (self.negated, self.reversed.negated, Not(self, evaluate=False), + Not(self.reversed, evaluate=False)) + if any(rel in conj_assumps for rel in neg_rels): + return False + + # evaluation using multipledispatching + ret = self.function.eval(self.arguments, assumptions) + if ret is not None: + return ret + + # simplify the args and try again + args = tuple(a.simplify() for a in self.arguments) + return self.function.eval(args, assumptions) + + def __bool__(self): + ret = ask(self) + if ret is None: + raise TypeError("Cannot determine truth value of %s" % self) + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/equality.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/equality.py new file mode 100644 index 0000000000000000000000000000000000000000..d467cea2da706de2cbbc9875f93c7f8e324a9088 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/relation/equality.py @@ -0,0 +1,302 @@ +""" +Module for mathematical equality [1] and inequalities [2]. + +The purpose of this module is to provide the instances which represent the +binary predicates in order to combine the relationals into logical inference +system. Objects such as ``Q.eq``, ``Q.lt`` should remain internal to +assumptions module, and user must use the classes such as :obj:`~.Eq()`, +:obj:`~.Lt()` instead to construct the relational expressions. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Equality_(mathematics) +.. [2] https://en.wikipedia.org/wiki/Inequality_(mathematics) +""" +from sympy.assumptions import Q +from sympy.core.relational import is_eq, is_neq, is_gt, is_ge, is_lt, is_le + +from .binrel import BinaryRelation + +__all__ = ['EqualityPredicate', 'UnequalityPredicate', 'StrictGreaterThanPredicate', + 'GreaterThanPredicate', 'StrictLessThanPredicate', 'LessThanPredicate'] + + +class EqualityPredicate(BinaryRelation): + """ + Binary predicate for $=$. + + The purpose of this class is to provide the instance which represent + the equality predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Eq()` instead to construct the equality expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_eq` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.eq(0, 0) + Q.eq(0, 0) + >>> ask(_) + True + + See Also + ======== + + sympy.core.relational.Eq + + """ + is_reflexive = True + is_symmetric = True + + name = 'eq' + handler = None # Do not allow dispatching by this predicate + + @property + def negated(self): + return Q.ne + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_eq is None + assumptions = None + return is_eq(*args, assumptions) + + +class UnequalityPredicate(BinaryRelation): + r""" + Binary predicate for $\neq$. + + The purpose of this class is to provide the instance which represent + the inequation predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Ne()` instead to construct the inequation expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_neq` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.ne(0, 0) + Q.ne(0, 0) + >>> ask(_) + False + + See Also + ======== + + sympy.core.relational.Ne + + """ + is_reflexive = False + is_symmetric = True + + name = 'ne' + handler = None + + @property + def negated(self): + return Q.eq + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_neq is None + assumptions = None + return is_neq(*args, assumptions) + + +class StrictGreaterThanPredicate(BinaryRelation): + """ + Binary predicate for $>$. + + The purpose of this class is to provide the instance which represent + the ">" predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Gt()` instead to construct the equality expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_gt` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.gt(0, 0) + Q.gt(0, 0) + >>> ask(_) + False + + See Also + ======== + + sympy.core.relational.Gt + + """ + is_reflexive = False + is_symmetric = False + + name = 'gt' + handler = None + + @property + def reversed(self): + return Q.lt + + @property + def negated(self): + return Q.le + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_gt is None + assumptions = None + return is_gt(*args, assumptions) + + +class GreaterThanPredicate(BinaryRelation): + """ + Binary predicate for $>=$. + + The purpose of this class is to provide the instance which represent + the ">=" predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Ge()` instead to construct the equality expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_ge` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.ge(0, 0) + Q.ge(0, 0) + >>> ask(_) + True + + See Also + ======== + + sympy.core.relational.Ge + + """ + is_reflexive = True + is_symmetric = False + + name = 'ge' + handler = None + + @property + def reversed(self): + return Q.le + + @property + def negated(self): + return Q.lt + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_ge is None + assumptions = None + return is_ge(*args, assumptions) + + +class StrictLessThanPredicate(BinaryRelation): + """ + Binary predicate for $<$. + + The purpose of this class is to provide the instance which represent + the "<" predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Lt()` instead to construct the equality expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_lt` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.lt(0, 0) + Q.lt(0, 0) + >>> ask(_) + False + + See Also + ======== + + sympy.core.relational.Lt + + """ + is_reflexive = False + is_symmetric = False + + name = 'lt' + handler = None + + @property + def reversed(self): + return Q.gt + + @property + def negated(self): + return Q.ge + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_lt is None + assumptions = None + return is_lt(*args, assumptions) + + +class LessThanPredicate(BinaryRelation): + """ + Binary predicate for $<=$. + + The purpose of this class is to provide the instance which represent + the "<=" predicate in order to allow the logical inference. + This class must remain internal to assumptions module and user must + use :obj:`~.Le()` instead to construct the equality expression. + + Evaluating this predicate to ``True`` or ``False`` is done by + :func:`~.core.relational.is_le` + + Examples + ======== + + >>> from sympy import ask, Q + >>> Q.le(0, 0) + Q.le(0, 0) + >>> ask(_) + True + + See Also + ======== + + sympy.core.relational.Le + + """ + is_reflexive = True + is_symmetric = False + + name = 'le' + handler = None + + @property + def reversed(self): + return Q.ge + + @property + def negated(self): + return Q.gt + + def eval(self, args, assumptions=True): + if assumptions == True: + # default assumptions for is_le is None + assumptions = None + return is_le(*args, assumptions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/satask.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/satask.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc13f6d3bc3fb7f573c8d5d0564b780440c1a8c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/satask.py @@ -0,0 +1,369 @@ +""" +Module to evaluate the proposition with assumptions using SAT algorithm. +""" + +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.assumptions.ask_generated import get_all_known_matrix_facts, get_all_known_number_facts +from sympy.assumptions.assume import global_assumptions, AppliedPredicate +from sympy.assumptions.sathandlers import class_fact_registry +from sympy.core import oo +from sympy.logic.inference import satisfiable +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.matrices.kind import MatrixKind + + +def satask(proposition, assumptions=True, context=global_assumptions, + use_known_facts=True, iterations=oo): + """ + Function to evaluate the proposition with assumptions using SAT algorithm. + + This function extracts every fact relevant to the expressions composing + proposition and assumptions. For example, if a predicate containing + ``Abs(x)`` is proposed, then ``Q.zero(Abs(x)) | Q.positive(Abs(x))`` + will be found and passed to SAT solver because ``Q.nonnegative`` is + registered as a fact for ``Abs``. + + Proposition is evaluated to ``True`` or ``False`` if the truth value can be + determined. If not, ``None`` is returned. + + Parameters + ========== + + proposition : Any boolean expression. + Proposition which will be evaluated to boolean value. + + assumptions : Any boolean expression, optional. + Local assumptions to evaluate the *proposition*. + + context : AssumptionsContext, optional. + Default assumptions to evaluate the *proposition*. By default, + this is ``sympy.assumptions.global_assumptions`` variable. + + use_known_facts : bool, optional. + If ``True``, facts from ``sympy.assumptions.ask_generated`` + module are passed to SAT solver as well. + + iterations : int, optional. + Number of times that relevant facts are recursively extracted. + Default is infinite times until no new fact is found. + + Returns + ======= + + ``True``, ``False``, or ``None`` + + Examples + ======== + + >>> from sympy import Abs, Q + >>> from sympy.assumptions.satask import satask + >>> from sympy.abc import x + >>> satask(Q.zero(Abs(x)), Q.zero(x)) + True + + """ + props = CNF.from_prop(proposition) + _props = CNF.from_prop(~proposition) + + assumptions = CNF.from_prop(assumptions) + + context_cnf = CNF() + if context: + context_cnf = context_cnf.extend(context) + + sat = get_all_relevant_facts(props, assumptions, context_cnf, + use_known_facts=use_known_facts, iterations=iterations) + sat.add_from_cnf(assumptions) + if context: + sat.add_from_cnf(context_cnf) + + return check_satisfiability(props, _props, sat) + + +def check_satisfiability(prop, _prop, factbase): + sat_true = factbase.copy() + sat_false = factbase.copy() + sat_true.add_from_cnf(prop) + sat_false.add_from_cnf(_prop) + can_be_true = satisfiable(sat_true) + can_be_false = satisfiable(sat_false) + + if can_be_true and can_be_false: + return None + + if can_be_true and not can_be_false: + return True + + if not can_be_true and can_be_false: + return False + + if not can_be_true and not can_be_false: + # TODO: Run additional checks to see which combination of the + # assumptions, global_assumptions, and relevant_facts are + # inconsistent. + raise ValueError("Inconsistent assumptions") + + +def extract_predargs(proposition, assumptions=None, context=None): + """ + Extract every expression in the argument of predicates from *proposition*, + *assumptions* and *context*. + + Parameters + ========== + + proposition : sympy.assumptions.cnf.CNF + + assumptions : sympy.assumptions.cnf.CNF, optional. + + context : sympy.assumptions.cnf.CNF, optional. + CNF generated from assumptions context. + + Examples + ======== + + >>> from sympy import Q, Abs + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.satask import extract_predargs + >>> from sympy.abc import x, y + >>> props = CNF.from_prop(Q.zero(Abs(x*y))) + >>> assump = CNF.from_prop(Q.zero(x) & Q.zero(y)) + >>> extract_predargs(props, assump) + {x, y, Abs(x*y)} + + """ + req_keys = find_symbols(proposition) + keys = proposition.all_predicates() + # XXX: We need this since True/False are not Basic + lkeys = set() + if assumptions: + lkeys |= assumptions.all_predicates() + if context: + lkeys |= context.all_predicates() + + lkeys = lkeys - {S.true, S.false} + tmp_keys = None + while tmp_keys != set(): + tmp = set() + for l in lkeys: + syms = find_symbols(l) + if (syms & req_keys) != set(): + tmp |= syms + tmp_keys = tmp - req_keys + req_keys |= tmp_keys + keys |= {l for l in lkeys if find_symbols(l) & req_keys != set()} + + exprs = set() + for key in keys: + if isinstance(key, AppliedPredicate): + exprs |= set(key.arguments) + else: + exprs.add(key) + return exprs + +def find_symbols(pred): + """ + Find every :obj:`~.Symbol` in *pred*. + + Parameters + ========== + + pred : sympy.assumptions.cnf.CNF, or any Expr. + + """ + if isinstance(pred, CNF): + symbols = set() + for a in pred.all_predicates(): + symbols |= find_symbols(a) + return symbols + return pred.atoms(Symbol) + + +def get_relevant_clsfacts(exprs, relevant_facts=None): + """ + Extract relevant facts from the items in *exprs*. Facts are defined in + ``assumptions.sathandlers`` module. + + This function is recursively called by ``get_all_relevant_facts()``. + + Parameters + ========== + + exprs : set + Expressions whose relevant facts are searched. + + relevant_facts : sympy.assumptions.cnf.CNF, optional. + Pre-discovered relevant facts. + + Returns + ======= + + exprs : set + Candidates for next relevant fact searching. + + relevant_facts : sympy.assumptions.cnf.CNF + Updated relevant facts. + + Examples + ======== + + Here, we will see how facts relevant to ``Abs(x*y)`` are recursively + extracted. On the first run, set containing the expression is passed + without pre-discovered relevant facts. The result is a set containing + candidates for next run, and ``CNF()`` instance containing facts + which are relevant to ``Abs`` and its argument. + + >>> from sympy import Abs + >>> from sympy.assumptions.satask import get_relevant_clsfacts + >>> from sympy.abc import x, y + >>> exprs = {Abs(x*y)} + >>> exprs, facts = get_relevant_clsfacts(exprs) + >>> exprs + {x*y} + >>> facts.clauses #doctest: +SKIP + {frozenset({Literal(Q.odd(Abs(x*y)), False), Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), False), Literal(Q.zero(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), Literal(Q.even(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), True), Literal(Q.zero(x*y), False)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.odd(Abs(x*y)), False), + Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.even(x*y), True), + Literal(Q.odd(Abs(x*y)), False)}), + frozenset({Literal(Q.positive(Abs(x*y)), False), + Literal(Q.zero(Abs(x*y)), False)})} + + We pass the first run's results to the second run, and get the expressions + for next run and updated facts. + + >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts) + >>> exprs + {x, y} + + On final run, no more candidate is returned thus we know that all + relevant facts are successfully retrieved. + + >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts) + >>> exprs + set() + + """ + if not relevant_facts: + relevant_facts = CNF() + + newexprs = set() + for expr in exprs: + for fact in class_fact_registry(expr): + newfact = CNF.to_CNF(fact) + relevant_facts = relevant_facts._and(newfact) + for key in newfact.all_predicates(): + if isinstance(key, AppliedPredicate): + newexprs |= set(key.arguments) + + return newexprs - exprs, relevant_facts + + +def get_all_relevant_facts(proposition, assumptions, context, + use_known_facts=True, iterations=oo): + """ + Extract all relevant facts from *proposition* and *assumptions*. + + This function extracts the facts by recursively calling + ``get_relevant_clsfacts()``. Extracted facts are converted to + ``EncodedCNF`` and returned. + + Parameters + ========== + + proposition : sympy.assumptions.cnf.CNF + CNF generated from proposition expression. + + assumptions : sympy.assumptions.cnf.CNF + CNF generated from assumption expression. + + context : sympy.assumptions.cnf.CNF + CNF generated from assumptions context. + + use_known_facts : bool, optional. + If ``True``, facts from ``sympy.assumptions.ask_generated`` + module are encoded as well. + + iterations : int, optional. + Number of times that relevant facts are recursively extracted. + Default is infinite times until no new fact is found. + + Returns + ======= + + sympy.assumptions.cnf.EncodedCNF + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.satask import get_all_relevant_facts + >>> from sympy.abc import x, y + >>> props = CNF.from_prop(Q.nonzero(x*y)) + >>> assump = CNF.from_prop(Q.nonzero(x)) + >>> context = CNF.from_prop(Q.nonzero(y)) + >>> get_all_relevant_facts(props, assump, context) #doctest: +SKIP + + + """ + # The relevant facts might introduce new keys, e.g., Q.zero(x*y) will + # introduce the keys Q.zero(x) and Q.zero(y), so we need to run it until + # we stop getting new things. Hopefully this strategy won't lead to an + # infinite loop in the future. + i = 0 + relevant_facts = CNF() + all_exprs = set() + while True: + if i == 0: + exprs = extract_predargs(proposition, assumptions, context) + all_exprs |= exprs + exprs, relevant_facts = get_relevant_clsfacts(exprs, relevant_facts) + i += 1 + if i >= iterations: + break + if not exprs: + break + + if use_known_facts: + known_facts_CNF = CNF() + + if any(expr.kind == MatrixKind(NumberKind) for expr in all_exprs): + known_facts_CNF.add_clauses(get_all_known_matrix_facts()) + # check for undefinedKind since kind system isn't fully implemented + if any(((expr.kind == NumberKind) or (expr.kind == UndefinedKind)) for expr in all_exprs): + known_facts_CNF.add_clauses(get_all_known_number_facts()) + + kf_encoded = EncodedCNF() + kf_encoded.from_cnf(known_facts_CNF) + + def translate_literal(lit, delta): + if lit > 0: + return lit + delta + else: + return lit - delta + + def translate_data(data, delta): + return [{translate_literal(i, delta) for i in clause} for clause in data] + data = [] + symbols = [] + n_lit = len(kf_encoded.symbols) + for i, expr in enumerate(all_exprs): + symbols += [pred(expr) for pred in kf_encoded.symbols] + data += translate_data(kf_encoded.data, i * n_lit) + + encoding = dict(list(zip(symbols, range(1, len(symbols)+1)))) + ctx = EncodedCNF(data, encoding) + else: + ctx = EncodedCNF() + + ctx.add_from_cnf(relevant_facts) + + return ctx diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/sathandlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/sathandlers.py new file mode 100644 index 0000000000000000000000000000000000000000..a11199eb0e547187ab280c18196c0259c178e004 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/sathandlers.py @@ -0,0 +1,322 @@ +from collections import defaultdict + +from sympy.assumptions.ask import Q +from sympy.core import (Add, Mul, Pow, Number, NumberSymbol, Symbol) +from sympy.core.numbers import ImaginaryUnit +from sympy.functions.elementary.complexes import Abs +from sympy.logic.boolalg import (Equivalent, And, Or, Implies) +from sympy.matrices.expressions import MatMul + +# APIs here may be subject to change + + +### Helper functions ### + +def allargs(symbol, fact, expr): + """ + Apply all arguments of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import allargs + >>> from sympy.abc import x, y + >>> allargs(x, Q.negative(x) | Q.positive(x), x*y) + (Q.negative(x) | Q.positive(x)) & (Q.negative(y) | Q.positive(y)) + + """ + return And(*[fact.subs(symbol, arg) for arg in expr.args]) + + +def anyarg(symbol, fact, expr): + """ + Apply any argument of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import anyarg + >>> from sympy.abc import x, y + >>> anyarg(x, Q.negative(x) & Q.positive(x), x*y) + (Q.negative(x) & Q.positive(x)) | (Q.negative(y) & Q.positive(y)) + + """ + return Or(*[fact.subs(symbol, arg) for arg in expr.args]) + + +def exactlyonearg(symbol, fact, expr): + """ + Apply exactly one argument of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import exactlyonearg + >>> from sympy.abc import x, y + >>> exactlyonearg(x, Q.positive(x), x*y) + (Q.positive(x) & ~Q.positive(y)) | (Q.positive(y) & ~Q.positive(x)) + + """ + pred_args = [fact.subs(symbol, arg) for arg in expr.args] + res = Or(*[And(pred_args[i], *[~lit for lit in pred_args[:i] + + pred_args[i+1:]]) for i in range(len(pred_args))]) + return res + + +### Fact registry ### + +class ClassFactRegistry: + """ + Register handlers against classes. + + Explanation + =========== + + ``register`` method registers the handler function for a class. Here, + handler function should return a single fact. ``multiregister`` method + registers the handler function for multiple classes. Here, handler function + should return a container of multiple facts. + + ``registry(expr)`` returns a set of facts for *expr*. + + Examples + ======== + + Here, we register the facts for ``Abs``. + + >>> from sympy import Abs, Equivalent, Q + >>> from sympy.assumptions.sathandlers import ClassFactRegistry + >>> reg = ClassFactRegistry() + >>> @reg.register(Abs) + ... def f1(expr): + ... return Q.nonnegative(expr) + >>> @reg.register(Abs) + ... def f2(expr): + ... arg = expr.args[0] + ... return Equivalent(~Q.zero(arg), ~Q.zero(expr)) + + Calling the registry with expression returns the defined facts for the + expression. + + >>> from sympy.abc import x + >>> reg(Abs(x)) + {Q.nonnegative(Abs(x)), Equivalent(~Q.zero(x), ~Q.zero(Abs(x)))} + + Multiple facts can be registered at once by ``multiregister`` method. + + >>> reg2 = ClassFactRegistry() + >>> @reg2.multiregister(Abs) + ... def _(expr): + ... arg = expr.args[0] + ... return [Q.even(arg) >> Q.even(expr), Q.odd(arg) >> Q.odd(expr)] + >>> reg2(Abs(x)) + {Implies(Q.even(x), Q.even(Abs(x))), Implies(Q.odd(x), Q.odd(Abs(x)))} + + """ + def __init__(self): + self.singlefacts = defaultdict(frozenset) + self.multifacts = defaultdict(frozenset) + + def register(self, cls): + def _(func): + self.singlefacts[cls] |= {func} + return func + return _ + + def multiregister(self, *classes): + def _(func): + for cls in classes: + self.multifacts[cls] |= {func} + return func + return _ + + def __getitem__(self, key): + ret1 = self.singlefacts[key] + for k in self.singlefacts: + if issubclass(key, k): + ret1 |= self.singlefacts[k] + + ret2 = self.multifacts[key] + for k in self.multifacts: + if issubclass(key, k): + ret2 |= self.multifacts[k] + + return ret1, ret2 + + def __call__(self, expr): + ret = set() + + handlers1, handlers2 = self[type(expr)] + + ret.update(h(expr) for h in handlers1) + for h in handlers2: + ret.update(h(expr)) + return ret + +class_fact_registry = ClassFactRegistry() + + + +### Class fact registration ### + +x = Symbol('x') + +## Abs ## + +@class_fact_registry.multiregister(Abs) +def _(expr): + arg = expr.args[0] + return [Q.nonnegative(expr), + Equivalent(~Q.zero(arg), ~Q.zero(expr)), + Q.even(arg) >> Q.even(expr), + Q.odd(arg) >> Q.odd(expr), + Q.integer(arg) >> Q.integer(expr), + ] + + +### Add ## + +@class_fact_registry.multiregister(Add) +def _(expr): + return [allargs(x, Q.positive(x), expr) >> Q.positive(expr), + allargs(x, Q.negative(x), expr) >> Q.negative(expr), + allargs(x, Q.real(x), expr) >> Q.real(expr), + allargs(x, Q.rational(x), expr) >> Q.rational(expr), + allargs(x, Q.integer(x), expr) >> Q.integer(expr), + exactlyonearg(x, ~Q.integer(x), expr) >> ~Q.integer(expr), + ] + +@class_fact_registry.register(Add) +def _(expr): + allargs_real = allargs(x, Q.real(x), expr) + onearg_irrational = exactlyonearg(x, Q.irrational(x), expr) + return Implies(allargs_real, Implies(onearg_irrational, Q.irrational(expr))) + + +### Mul ### + +@class_fact_registry.multiregister(Mul) +def _(expr): + return [Equivalent(Q.zero(expr), anyarg(x, Q.zero(x), expr)), + allargs(x, Q.positive(x), expr) >> Q.positive(expr), + allargs(x, Q.real(x), expr) >> Q.real(expr), + allargs(x, Q.rational(x), expr) >> Q.rational(expr), + allargs(x, Q.integer(x), expr) >> Q.integer(expr), + exactlyonearg(x, ~Q.rational(x), expr) >> ~Q.integer(expr), + allargs(x, Q.commutative(x), expr) >> Q.commutative(expr), + ] + +@class_fact_registry.register(Mul) +def _(expr): + # Implicitly assumes Mul has more than one arg + # Would be allargs(x, Q.prime(x) | Q.composite(x)) except 1 is composite + # More advanced prime assumptions will require inequalities, as 1 provides + # a corner case. + allargs_prime = allargs(x, Q.prime(x), expr) + return Implies(allargs_prime, ~Q.prime(expr)) + +@class_fact_registry.register(Mul) +def _(expr): + # General Case: Odd number of imaginary args implies mul is imaginary(To be implemented) + allargs_imag_or_real = allargs(x, Q.imaginary(x) | Q.real(x), expr) + onearg_imaginary = exactlyonearg(x, Q.imaginary(x), expr) + return Implies(allargs_imag_or_real, Implies(onearg_imaginary, Q.imaginary(expr))) + +@class_fact_registry.register(Mul) +def _(expr): + allargs_real = allargs(x, Q.real(x), expr) + onearg_irrational = exactlyonearg(x, Q.irrational(x), expr) + return Implies(allargs_real, Implies(onearg_irrational, Q.irrational(expr))) + +@class_fact_registry.register(Mul) +def _(expr): + # Including the integer qualification means we don't need to add any facts + # for odd, since the assumptions already know that every integer is + # exactly one of even or odd. + allargs_integer = allargs(x, Q.integer(x), expr) + anyarg_even = anyarg(x, Q.even(x), expr) + return Implies(allargs_integer, Equivalent(anyarg_even, Q.even(expr))) + + +### MatMul ### + +@class_fact_registry.register(MatMul) +def _(expr): + allargs_square = allargs(x, Q.square(x), expr) + allargs_invertible = allargs(x, Q.invertible(x), expr) + return Implies(allargs_square, Equivalent(Q.invertible(expr), allargs_invertible)) + + +### Pow ### + +@class_fact_registry.multiregister(Pow) +def _(expr): + base, exp = expr.base, expr.exp + return [ + (Q.real(base) & Q.even(exp) & Q.nonnegative(exp)) >> Q.nonnegative(expr), + (Q.nonnegative(base) & Q.odd(exp) & Q.nonnegative(exp)) >> Q.nonnegative(expr), + (Q.nonpositive(base) & Q.odd(exp) & Q.nonnegative(exp)) >> Q.nonpositive(expr), + Equivalent(Q.zero(expr), Q.zero(base) & Q.positive(exp)) + ] + + +### Numbers ### + +_old_assump_getters = { + Q.positive: lambda o: o.is_positive, + Q.zero: lambda o: o.is_zero, + Q.negative: lambda o: o.is_negative, + Q.rational: lambda o: o.is_rational, + Q.irrational: lambda o: o.is_irrational, + Q.even: lambda o: o.is_even, + Q.odd: lambda o: o.is_odd, + Q.imaginary: lambda o: o.is_imaginary, + Q.prime: lambda o: o.is_prime, + Q.composite: lambda o: o.is_composite, +} + +@class_fact_registry.multiregister(Number, NumberSymbol, ImaginaryUnit) +def _(expr): + ret = [] + for p, getter in _old_assump_getters.items(): + pred = p(expr) + prop = getter(expr) + if prop is not None: + ret.append(Equivalent(pred, prop)) + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6817c8f20f53734f0119343b40505f38e29f206 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_assumptions_2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_assumptions_2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccd7efe8f9b194aa4ada8d56d02bd2ec7d6d4e18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_assumptions_2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51c25ee40dd59865606165952bc00ddc8e64ac1b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_matrices.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_matrices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29082a5cace1e4499fa076ff6560d7dc7b944d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_matrices.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_refine.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_refine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd509be412f31190e48f44542a6af290a5cca3e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_refine.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_rel_queries.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_rel_queries.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..484fc6c3a37e98e6e599384c430cb105b31485bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_rel_queries.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_satask.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_satask.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6964e8efe01e8262ce6eede001a522c4edf2839c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_satask.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_sathandlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_sathandlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec637a419da3611411b4502c0f80199cb01ea55c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_sathandlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b2aaaa4e003c60d3035aaea3f4afbda940bcde Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/__pycache__/test_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_assumptions_2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_assumptions_2.py new file mode 100644 index 0000000000000000000000000000000000000000..493fe4a7ed70301754ad2cfe181c5acf30433768 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_assumptions_2.py @@ -0,0 +1,35 @@ +""" +rename this to test_assumptions.py when the old assumptions system is deleted +""" +from sympy.abc import x, y +from sympy.assumptions.assume import global_assumptions +from sympy.assumptions.ask import Q +from sympy.printing import pretty + + +def test_equal(): + """Test for equality""" + assert Q.positive(x) == Q.positive(x) + assert Q.positive(x) != ~Q.positive(x) + assert ~Q.positive(x) == ~Q.positive(x) + + +def test_pretty(): + assert pretty(Q.positive(x)) == "Q.positive(x)" + assert pretty( + {Q.positive, Q.integer}) == "{Q.integer, Q.positive}" + + +def test_global(): + """Test for global assumptions""" + global_assumptions.add(x > 0) + assert (x > 0) in global_assumptions + global_assumptions.remove(x > 0) + assert not (x > 0) in global_assumptions + # same with multiple of assumptions + global_assumptions.add(x > 0, y > 0) + assert (x > 0) in global_assumptions + assert (y > 0) in global_assumptions + global_assumptions.clear() + assert not (x > 0) in global_assumptions + assert not (y > 0) in global_assumptions diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_context.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_context.py new file mode 100644 index 0000000000000000000000000000000000000000..be162f1c69492218ff90ea69492925d7779567a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_context.py @@ -0,0 +1,39 @@ +from sympy.assumptions import ask, Q +from sympy.assumptions.assume import assuming, global_assumptions +from sympy.abc import x, y + +def test_assuming(): + with assuming(Q.integer(x)): + assert ask(Q.integer(x)) + assert not ask(Q.integer(x)) + +def test_assuming_nested(): + assert not ask(Q.integer(x)) + assert not ask(Q.integer(y)) + with assuming(Q.integer(x)): + assert ask(Q.integer(x)) + assert not ask(Q.integer(y)) + with assuming(Q.integer(y)): + assert ask(Q.integer(x)) + assert ask(Q.integer(y)) + assert ask(Q.integer(x)) + assert not ask(Q.integer(y)) + assert not ask(Q.integer(x)) + assert not ask(Q.integer(y)) + +def test_finally(): + try: + with assuming(Q.integer(x)): + 1/0 + except ZeroDivisionError: + pass + assert not ask(Q.integer(x)) + +def test_remove_safe(): + global_assumptions.add(Q.integer(x)) + with assuming(): + assert ask(Q.integer(x)) + global_assumptions.remove(Q.integer(x)) + assert not ask(Q.integer(x)) + assert ask(Q.integer(x)) + global_assumptions.clear() # for the benefit of other tests diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_matrices.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfa990f080eebe4d6dd5bfdd733ce1a19adf329 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_matrices.py @@ -0,0 +1,283 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions import (MatrixSymbol, Identity, ZeroMatrix, + OneMatrix, Trace, MatrixSlice, Determinant, BlockMatrix, BlockDiagMatrix) +from sympy.matrices.expressions.factorizations import LofLU +from sympy.testing.pytest import XFAIL + +X = MatrixSymbol('X', 2, 2) +Y = MatrixSymbol('Y', 2, 3) +Z = MatrixSymbol('Z', 2, 2) +A1x1 = MatrixSymbol('A1x1', 1, 1) +B1x1 = MatrixSymbol('B1x1', 1, 1) +C0x0 = MatrixSymbol('C0x0', 0, 0) +V1 = MatrixSymbol('V1', 2, 1) +V2 = MatrixSymbol('V2', 2, 1) + +def test_square(): + assert ask(Q.square(X)) + assert not ask(Q.square(Y)) + assert ask(Q.square(Y*Y.T)) + +def test_invertible(): + assert ask(Q.invertible(X), Q.invertible(X)) + assert ask(Q.invertible(Y)) is False + assert ask(Q.invertible(X*Y), Q.invertible(X)) is False + assert ask(Q.invertible(X*Z), Q.invertible(X)) is None + assert ask(Q.invertible(X*Z), Q.invertible(X) & Q.invertible(Z)) is True + assert ask(Q.invertible(X.T)) is None + assert ask(Q.invertible(X.T), Q.invertible(X)) is True + assert ask(Q.invertible(X.I)) is True + assert ask(Q.invertible(Identity(3))) is True + assert ask(Q.invertible(ZeroMatrix(3, 3))) is False + assert ask(Q.invertible(OneMatrix(1, 1))) is True + assert ask(Q.invertible(OneMatrix(3, 3))) is False + assert ask(Q.invertible(X), Q.fullrank(X) & Q.square(X)) + +def test_singular(): + assert ask(Q.singular(X)) is None + assert ask(Q.singular(X), Q.invertible(X)) is False + assert ask(Q.singular(X), ~Q.invertible(X)) is True + +@XFAIL +def test_invertible_fullrank(): + assert ask(Q.invertible(X), Q.fullrank(X)) is True + + +def test_invertible_BlockMatrix(): + assert ask(Q.invertible(BlockMatrix([Identity(3)]))) == True + assert ask(Q.invertible(BlockMatrix([ZeroMatrix(3, 3)]))) == False + + X = Matrix([[1, 2, 3], [3, 5, 4]]) + Y = Matrix([[4, 2, 7], [2, 3, 5]]) + # non-invertible A block + assert ask(Q.invertible(BlockMatrix([ + [Matrix.ones(3, 3), Y.T], + [X, Matrix.eye(2)], + ]))) == True + # non-invertible B block + assert ask(Q.invertible(BlockMatrix([ + [Y.T, Matrix.ones(3, 3)], + [Matrix.eye(2), X], + ]))) == True + # non-invertible C block + assert ask(Q.invertible(BlockMatrix([ + [X, Matrix.eye(2)], + [Matrix.ones(3, 3), Y.T], + ]))) == True + # non-invertible D block + assert ask(Q.invertible(BlockMatrix([ + [Matrix.eye(2), X], + [Y.T, Matrix.ones(3, 3)], + ]))) == True + + +def test_invertible_BlockDiagMatrix(): + assert ask(Q.invertible(BlockDiagMatrix(Identity(3), Identity(5)))) == True + assert ask(Q.invertible(BlockDiagMatrix(ZeroMatrix(3, 3), Identity(5)))) == False + assert ask(Q.invertible(BlockDiagMatrix(Identity(3), OneMatrix(5, 5)))) == False + + +def test_symmetric(): + assert ask(Q.symmetric(X), Q.symmetric(X)) + assert ask(Q.symmetric(X*Z), Q.symmetric(X)) is None + assert ask(Q.symmetric(X*Z), Q.symmetric(X) & Q.symmetric(Z)) is True + assert ask(Q.symmetric(X + Z), Q.symmetric(X) & Q.symmetric(Z)) is True + assert ask(Q.symmetric(Y)) is False + assert ask(Q.symmetric(Y*Y.T)) is True + assert ask(Q.symmetric(Y.T*X*Y)) is None + assert ask(Q.symmetric(Y.T*X*Y), Q.symmetric(X)) is True + assert ask(Q.symmetric(X**10), Q.symmetric(X)) is True + assert ask(Q.symmetric(A1x1)) is True + assert ask(Q.symmetric(A1x1 + B1x1)) is True + assert ask(Q.symmetric(A1x1 * B1x1)) is True + assert ask(Q.symmetric(V1.T*V1)) is True + assert ask(Q.symmetric(V1.T*(V1 + V2))) is True + assert ask(Q.symmetric(V1.T*(V1 + V2) + A1x1)) is True + assert ask(Q.symmetric(MatrixSlice(Y, (0, 1), (1, 2)))) is True + assert ask(Q.symmetric(Identity(3))) is True + assert ask(Q.symmetric(ZeroMatrix(3, 3))) is True + assert ask(Q.symmetric(OneMatrix(3, 3))) is True + +def _test_orthogonal_unitary(predicate): + assert ask(predicate(X), predicate(X)) + assert ask(predicate(X.T), predicate(X)) is True + assert ask(predicate(X.I), predicate(X)) is True + assert ask(predicate(X**2), predicate(X)) + assert ask(predicate(Y)) is False + assert ask(predicate(X)) is None + assert ask(predicate(X), ~Q.invertible(X)) is False + assert ask(predicate(X*Z*X), predicate(X) & predicate(Z)) is True + assert ask(predicate(Identity(3))) is True + assert ask(predicate(ZeroMatrix(3, 3))) is False + assert ask(Q.invertible(X), predicate(X)) + assert not ask(predicate(X + Z), predicate(X) & predicate(Z)) + +def test_orthogonal(): + _test_orthogonal_unitary(Q.orthogonal) + +def test_unitary(): + _test_orthogonal_unitary(Q.unitary) + assert ask(Q.unitary(X), Q.orthogonal(X)) + +def test_fullrank(): + assert ask(Q.fullrank(X), Q.fullrank(X)) + assert ask(Q.fullrank(X**2), Q.fullrank(X)) + assert ask(Q.fullrank(X.T), Q.fullrank(X)) is True + assert ask(Q.fullrank(X)) is None + assert ask(Q.fullrank(Y)) is None + assert ask(Q.fullrank(X*Z), Q.fullrank(X) & Q.fullrank(Z)) is True + assert ask(Q.fullrank(Identity(3))) is True + assert ask(Q.fullrank(ZeroMatrix(3, 3))) is False + assert ask(Q.fullrank(OneMatrix(1, 1))) is True + assert ask(Q.fullrank(OneMatrix(3, 3))) is False + assert ask(Q.invertible(X), ~Q.fullrank(X)) == False + + +def test_positive_definite(): + assert ask(Q.positive_definite(X), Q.positive_definite(X)) + assert ask(Q.positive_definite(X.T), Q.positive_definite(X)) is True + assert ask(Q.positive_definite(X.I), Q.positive_definite(X)) is True + assert ask(Q.positive_definite(Y)) is False + assert ask(Q.positive_definite(X)) is None + assert ask(Q.positive_definite(X**3), Q.positive_definite(X)) + assert ask(Q.positive_definite(X*Z*X), + Q.positive_definite(X) & Q.positive_definite(Z)) is True + assert ask(Q.positive_definite(X), Q.orthogonal(X)) + assert ask(Q.positive_definite(Y.T*X*Y), + Q.positive_definite(X) & Q.fullrank(Y)) is True + assert not ask(Q.positive_definite(Y.T*X*Y), Q.positive_definite(X)) + assert ask(Q.positive_definite(Identity(3))) is True + assert ask(Q.positive_definite(ZeroMatrix(3, 3))) is False + assert ask(Q.positive_definite(OneMatrix(1, 1))) is True + assert ask(Q.positive_definite(OneMatrix(3, 3))) is False + assert ask(Q.positive_definite(X + Z), Q.positive_definite(X) & + Q.positive_definite(Z)) is True + assert not ask(Q.positive_definite(-X), Q.positive_definite(X)) + assert ask(Q.positive(X[1, 1]), Q.positive_definite(X)) + +def test_triangular(): + assert ask(Q.upper_triangular(X + Z.T + Identity(2)), Q.upper_triangular(X) & + Q.lower_triangular(Z)) is True + assert ask(Q.upper_triangular(X*Z.T), Q.upper_triangular(X) & + Q.lower_triangular(Z)) is True + assert ask(Q.lower_triangular(Identity(3))) is True + assert ask(Q.lower_triangular(ZeroMatrix(3, 3))) is True + assert ask(Q.upper_triangular(ZeroMatrix(3, 3))) is True + assert ask(Q.lower_triangular(OneMatrix(1, 1))) is True + assert ask(Q.upper_triangular(OneMatrix(1, 1))) is True + assert ask(Q.lower_triangular(OneMatrix(3, 3))) is False + assert ask(Q.upper_triangular(OneMatrix(3, 3))) is False + assert ask(Q.triangular(X), Q.unit_triangular(X)) + assert ask(Q.upper_triangular(X**3), Q.upper_triangular(X)) + assert ask(Q.lower_triangular(X**3), Q.lower_triangular(X)) + + +def test_diagonal(): + assert ask(Q.diagonal(X + Z.T + Identity(2)), Q.diagonal(X) & + Q.diagonal(Z)) is True + assert ask(Q.diagonal(ZeroMatrix(3, 3))) + assert ask(Q.diagonal(OneMatrix(1, 1))) is True + assert ask(Q.diagonal(OneMatrix(3, 3))) is False + assert ask(Q.lower_triangular(X) & Q.upper_triangular(X), Q.diagonal(X)) + assert ask(Q.diagonal(X), Q.lower_triangular(X) & Q.upper_triangular(X)) + assert ask(Q.symmetric(X), Q.diagonal(X)) + assert ask(Q.triangular(X), Q.diagonal(X)) + assert ask(Q.diagonal(C0x0)) + assert ask(Q.diagonal(A1x1)) + assert ask(Q.diagonal(A1x1 + B1x1)) + assert ask(Q.diagonal(A1x1*B1x1)) + assert ask(Q.diagonal(V1.T*V2)) + assert ask(Q.diagonal(V1.T*(X + Z)*V1)) + assert ask(Q.diagonal(MatrixSlice(Y, (0, 1), (1, 2)))) is True + assert ask(Q.diagonal(V1.T*(V1 + V2))) is True + assert ask(Q.diagonal(X**3), Q.diagonal(X)) + assert ask(Q.diagonal(Identity(3))) + assert ask(Q.diagonal(DiagMatrix(V1))) + assert ask(Q.diagonal(DiagonalMatrix(X))) + + +def test_non_atoms(): + assert ask(Q.real(Trace(X)), Q.positive(Trace(X))) + +@XFAIL +def test_non_trivial_implies(): + X = MatrixSymbol('X', 3, 3) + Y = MatrixSymbol('Y', 3, 3) + assert ask(Q.lower_triangular(X+Y), Q.lower_triangular(X) & + Q.lower_triangular(Y)) is True + assert ask(Q.triangular(X), Q.lower_triangular(X)) is True + assert ask(Q.triangular(X+Y), Q.lower_triangular(X) & + Q.lower_triangular(Y)) is True + +def test_MatrixSlice(): + X = MatrixSymbol('X', 4, 4) + B = MatrixSlice(X, (1, 3), (1, 3)) + C = MatrixSlice(X, (0, 3), (1, 3)) + assert ask(Q.symmetric(B), Q.symmetric(X)) + assert ask(Q.invertible(B), Q.invertible(X)) + assert ask(Q.diagonal(B), Q.diagonal(X)) + assert ask(Q.orthogonal(B), Q.orthogonal(X)) + assert ask(Q.upper_triangular(B), Q.upper_triangular(X)) + + assert not ask(Q.symmetric(C), Q.symmetric(X)) + assert not ask(Q.invertible(C), Q.invertible(X)) + assert not ask(Q.diagonal(C), Q.diagonal(X)) + assert not ask(Q.orthogonal(C), Q.orthogonal(X)) + assert not ask(Q.upper_triangular(C), Q.upper_triangular(X)) + +def test_det_trace_positive(): + X = MatrixSymbol('X', 4, 4) + assert ask(Q.positive(Trace(X)), Q.positive_definite(X)) + assert ask(Q.positive(Determinant(X)), Q.positive_definite(X)) + +def test_field_assumptions(): + X = MatrixSymbol('X', 4, 4) + Y = MatrixSymbol('Y', 4, 4) + assert ask(Q.real_elements(X), Q.real_elements(X)) + assert not ask(Q.integer_elements(X), Q.real_elements(X)) + assert ask(Q.complex_elements(X), Q.real_elements(X)) + assert ask(Q.complex_elements(X**2), Q.real_elements(X)) + assert ask(Q.real_elements(X**2), Q.integer_elements(X)) + assert ask(Q.real_elements(X+Y), Q.real_elements(X)) is None + assert ask(Q.real_elements(X+Y), Q.real_elements(X) & Q.real_elements(Y)) + from sympy.matrices.expressions.hadamard import HadamardProduct + assert ask(Q.real_elements(HadamardProduct(X, Y)), + Q.real_elements(X) & Q.real_elements(Y)) + assert ask(Q.complex_elements(X+Y), Q.real_elements(X) & Q.complex_elements(Y)) + + assert ask(Q.real_elements(X.T), Q.real_elements(X)) + assert ask(Q.real_elements(X.I), Q.real_elements(X) & Q.invertible(X)) + assert ask(Q.real_elements(Trace(X)), Q.real_elements(X)) + assert ask(Q.integer_elements(Determinant(X)), Q.integer_elements(X)) + assert not ask(Q.integer_elements(X.I), Q.integer_elements(X)) + alpha = Symbol('alpha') + assert ask(Q.real_elements(alpha*X), Q.real_elements(X) & Q.real(alpha)) + assert ask(Q.real_elements(LofLU(X)), Q.real_elements(X)) + e = Symbol('e', integer=True, negative=True) + assert ask(Q.real_elements(X**e), Q.real_elements(X) & Q.invertible(X)) + assert ask(Q.real_elements(X**e), Q.real_elements(X)) is None + +def test_matrix_element_sets(): + X = MatrixSymbol('X', 4, 4) + assert ask(Q.real(X[1, 2]), Q.real_elements(X)) + assert ask(Q.integer(X[1, 2]), Q.integer_elements(X)) + assert ask(Q.complex(X[1, 2]), Q.complex_elements(X)) + assert ask(Q.integer_elements(Identity(3))) + assert ask(Q.integer_elements(ZeroMatrix(3, 3))) + assert ask(Q.integer_elements(OneMatrix(3, 3))) + from sympy.matrices.expressions.fourier import DFT + assert ask(Q.complex_elements(DFT(3))) + + +def test_matrix_element_sets_slices_blocks(): + X = MatrixSymbol('X', 4, 4) + assert ask(Q.integer_elements(X[:, 3]), Q.integer_elements(X)) + assert ask(Q.integer_elements(BlockMatrix([[X], [X]])), + Q.integer_elements(X)) + +def test_matrix_element_sets_determinant_trace(): + assert ask(Q.integer(Determinant(X)), Q.integer_elements(X)) + assert ask(Q.integer(Trace(X)), Q.integer_elements(X)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_query.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_query.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ae1f1e482e5c9d19e3dcce3f37ad46b6821817 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_query.py @@ -0,0 +1,2541 @@ +from sympy.abc import t, w, x, y, z, n, k, m, p, i +from sympy.assumptions import (ask, AssumptionsContext, Q, register_handler, + remove_handler) +from sympy.assumptions.assume import assuming, global_assumptions, Predicate +from sympy.assumptions.cnf import CNF, Literal +from sympy.assumptions.facts import (single_fact_lookup, + get_known_facts, generate_known_facts_dict, get_known_facts_keys) +from sympy.assumptions.handlers import AskHandler +from sympy.assumptions.ask_generated import (get_all_known_facts, + get_known_facts_dict) +from sympy.core.add import Add +from sympy.core.numbers import (I, Integer, Rational, oo, zoo, pi) +from sympy.core.singleton import S +from sympy.core.power import Pow +from sympy.core.symbol import Str, symbols, Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (Abs, im, re, sign) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import ( + acos, acot, asin, atan, cos, cot, sin, tan) +from sympy.logic.boolalg import Equivalent, Implies, Xor, And, to_cnf +from sympy.matrices import Matrix, SparseMatrix +from sympy.testing.pytest import (XFAIL, slow, raises, warns_deprecated_sympy, + _both_exp_pow) +import math + + +def test_int_1(): + z = 1 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is True + assert ask(Q.rational(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is True + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_int_11(): + z = 11 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is True + assert ask(Q.rational(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is True + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is True + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_int_12(): + z = 12 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is True + assert ask(Q.rational(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is True + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is True + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_float_1(): + z = 1.0 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is None + assert ask(Q.rational(z)) is None + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is None + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is None + assert ask(Q.odd(z)) is None + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is None + assert ask(Q.composite(z)) is None + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + z = 7.2123 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is None + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is None + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + # test for issue #12168 + assert ask(Q.rational(math.pi)) is None + + +def test_zero_0(): + z = Integer(0) + assert ask(Q.nonzero(z)) is False + assert ask(Q.zero(z)) is True + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is True + assert ask(Q.rational(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is False + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is True + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is True + + +def test_negativeone(): + z = Integer(-1) + assert ask(Q.nonzero(z)) is True + assert ask(Q.zero(z)) is False + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is True + assert ask(Q.rational(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is False + assert ask(Q.negative(z)) is True + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is True + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_infinity(): + assert ask(Q.commutative(oo)) is True + assert ask(Q.integer(oo)) is False + assert ask(Q.rational(oo)) is False + assert ask(Q.algebraic(oo)) is False + assert ask(Q.real(oo)) is False + assert ask(Q.extended_real(oo)) is True + assert ask(Q.complex(oo)) is False + assert ask(Q.irrational(oo)) is False + assert ask(Q.imaginary(oo)) is False + assert ask(Q.positive(oo)) is False + assert ask(Q.extended_positive(oo)) is True + assert ask(Q.negative(oo)) is False + assert ask(Q.even(oo)) is False + assert ask(Q.odd(oo)) is False + assert ask(Q.finite(oo)) is False + assert ask(Q.infinite(oo)) is True + assert ask(Q.prime(oo)) is False + assert ask(Q.composite(oo)) is False + assert ask(Q.hermitian(oo)) is False + assert ask(Q.antihermitian(oo)) is False + assert ask(Q.positive_infinite(oo)) is True + assert ask(Q.negative_infinite(oo)) is False + + +def test_neg_infinity(): + mm = S.NegativeInfinity + assert ask(Q.commutative(mm)) is True + assert ask(Q.integer(mm)) is False + assert ask(Q.rational(mm)) is False + assert ask(Q.algebraic(mm)) is False + assert ask(Q.real(mm)) is False + assert ask(Q.extended_real(mm)) is True + assert ask(Q.complex(mm)) is False + assert ask(Q.irrational(mm)) is False + assert ask(Q.imaginary(mm)) is False + assert ask(Q.positive(mm)) is False + assert ask(Q.negative(mm)) is False + assert ask(Q.extended_negative(mm)) is True + assert ask(Q.even(mm)) is False + assert ask(Q.odd(mm)) is False + assert ask(Q.finite(mm)) is False + assert ask(Q.infinite(oo)) is True + assert ask(Q.prime(mm)) is False + assert ask(Q.composite(mm)) is False + assert ask(Q.hermitian(mm)) is False + assert ask(Q.antihermitian(mm)) is False + assert ask(Q.positive_infinite(-oo)) is False + assert ask(Q.negative_infinite(-oo)) is True + + +def test_complex_infinity(): + assert ask(Q.commutative(zoo)) is True + assert ask(Q.integer(zoo)) is False + assert ask(Q.rational(zoo)) is False + assert ask(Q.algebraic(zoo)) is False + assert ask(Q.real(zoo)) is False + assert ask(Q.extended_real(zoo)) is False + assert ask(Q.complex(zoo)) is False + assert ask(Q.irrational(zoo)) is False + assert ask(Q.imaginary(zoo)) is False + assert ask(Q.positive(zoo)) is False + assert ask(Q.negative(zoo)) is False + assert ask(Q.zero(zoo)) is False + assert ask(Q.nonzero(zoo)) is False + assert ask(Q.even(zoo)) is False + assert ask(Q.odd(zoo)) is False + assert ask(Q.finite(zoo)) is False + assert ask(Q.infinite(zoo)) is True + assert ask(Q.prime(zoo)) is False + assert ask(Q.composite(zoo)) is False + assert ask(Q.hermitian(zoo)) is False + assert ask(Q.antihermitian(zoo)) is False + assert ask(Q.positive_infinite(zoo)) is False + assert ask(Q.negative_infinite(zoo)) is False + + +def test_nan(): + nan = S.NaN + assert ask(Q.commutative(nan)) is True + assert ask(Q.integer(nan)) is None + assert ask(Q.rational(nan)) is None + assert ask(Q.algebraic(nan)) is None + assert ask(Q.real(nan)) is None + assert ask(Q.extended_real(nan)) is None + assert ask(Q.complex(nan)) is None + assert ask(Q.irrational(nan)) is None + assert ask(Q.imaginary(nan)) is None + assert ask(Q.positive(nan)) is None + assert ask(Q.nonzero(nan)) is None + assert ask(Q.zero(nan)) is None + assert ask(Q.even(nan)) is None + assert ask(Q.odd(nan)) is None + assert ask(Q.finite(nan)) is None + assert ask(Q.infinite(nan)) is None + assert ask(Q.prime(nan)) is None + assert ask(Q.composite(nan)) is None + assert ask(Q.hermitian(nan)) is None + assert ask(Q.antihermitian(nan)) is None + + +def test_Rational_number(): + r = Rational(3, 4) + assert ask(Q.commutative(r)) is True + assert ask(Q.integer(r)) is False + assert ask(Q.rational(r)) is True + assert ask(Q.real(r)) is True + assert ask(Q.complex(r)) is True + assert ask(Q.irrational(r)) is False + assert ask(Q.imaginary(r)) is False + assert ask(Q.positive(r)) is True + assert ask(Q.negative(r)) is False + assert ask(Q.even(r)) is False + assert ask(Q.odd(r)) is False + assert ask(Q.finite(r)) is True + assert ask(Q.prime(r)) is False + assert ask(Q.composite(r)) is False + assert ask(Q.hermitian(r)) is True + assert ask(Q.antihermitian(r)) is False + + r = Rational(1, 4) + assert ask(Q.positive(r)) is True + assert ask(Q.negative(r)) is False + + r = Rational(5, 4) + assert ask(Q.negative(r)) is False + assert ask(Q.positive(r)) is True + + r = Rational(5, 3) + assert ask(Q.positive(r)) is True + assert ask(Q.negative(r)) is False + + r = Rational(-3, 4) + assert ask(Q.positive(r)) is False + assert ask(Q.negative(r)) is True + + r = Rational(-1, 4) + assert ask(Q.positive(r)) is False + assert ask(Q.negative(r)) is True + + r = Rational(-5, 4) + assert ask(Q.negative(r)) is True + assert ask(Q.positive(r)) is False + + r = Rational(-5, 3) + assert ask(Q.positive(r)) is False + assert ask(Q.negative(r)) is True + + +def test_sqrt_2(): + z = sqrt(2) + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_pi(): + z = S.Pi + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + z = S.Pi + 1 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + z = 2*S.Pi + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + z = S.Pi ** 2 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + z = (1 + S.Pi) ** 2 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is None + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_E(): + z = S.Exp1 + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is False + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_GoldenRatio(): + z = S.GoldenRatio + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_TribonacciConstant(): + z = S.TribonacciConstant + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is True + assert ask(Q.real(z)) is True + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is True + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is True + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is True + assert ask(Q.antihermitian(z)) is False + + +def test_I(): + z = I + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is True + assert ask(Q.real(z)) is False + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is True + assert ask(Q.positive(z)) is False + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is False + assert ask(Q.antihermitian(z)) is True + + z = 1 + I + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is True + assert ask(Q.real(z)) is False + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is False + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is False + assert ask(Q.antihermitian(z)) is False + + z = I*(1 + I) + assert ask(Q.commutative(z)) is True + assert ask(Q.integer(z)) is False + assert ask(Q.rational(z)) is False + assert ask(Q.algebraic(z)) is True + assert ask(Q.real(z)) is False + assert ask(Q.complex(z)) is True + assert ask(Q.irrational(z)) is False + assert ask(Q.imaginary(z)) is False + assert ask(Q.positive(z)) is False + assert ask(Q.negative(z)) is False + assert ask(Q.even(z)) is False + assert ask(Q.odd(z)) is False + assert ask(Q.finite(z)) is True + assert ask(Q.prime(z)) is False + assert ask(Q.composite(z)) is False + assert ask(Q.hermitian(z)) is False + assert ask(Q.antihermitian(z)) is False + + z = I**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (-I)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (3*I)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is False + + z = (1)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (-1)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (1+I)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is False + + z = (I)**(I+3) + assert ask(Q.imaginary(z)) is True + assert ask(Q.real(z)) is False + + z = (I)**(I+2) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (I)**(2) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + + z = (I)**(3) + assert ask(Q.imaginary(z)) is True + assert ask(Q.real(z)) is False + + z = (3)**(I) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is False + + z = (I)**(0) + assert ask(Q.imaginary(z)) is False + assert ask(Q.real(z)) is True + +def test_bounded(): + x, y, z = symbols('x,y,z') + a = x + y + x, y = a.args + assert ask(Q.finite(a), Q.positive_infinite(y)) is None + assert ask(Q.finite(x)) is None + assert ask(Q.finite(x), Q.finite(x)) is True + assert ask(Q.finite(x), Q.finite(y)) is None + assert ask(Q.finite(x), Q.complex(x)) is True + assert ask(Q.finite(x), Q.extended_real(x)) is None + + assert ask(Q.finite(x + 1)) is None + assert ask(Q.finite(x + 1), Q.finite(x)) is True + a = x + y + x, y = a.args + # B + B + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y)) is True + assert ask(Q.finite(a), Q.positive(x) & Q.finite(y)) is True + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y)) is True + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y)) is True + assert ask(Q.finite(a), Q.positive(x) & Q.finite(y) + & ~Q.positive(y)) is True + assert ask(Q.finite(a), Q.finite(x) & ~Q.positive(x) + & Q.positive(y)) is True + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) & ~Q.positive(x) + & ~Q.positive(y)) is True + # B + U + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y)) is False + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y)) is False + assert ask(Q.finite(a), Q.finite(x) + & Q.positive_infinite(y)) is False + assert ask(Q.finite(a), Q.positive(x) + & Q.positive_infinite(y)) is False + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y) + & ~Q.positive(y)) is False + assert ask(Q.finite(a), Q.finite(x) & ~Q.positive(x) + & Q.positive_infinite(y)) is False + assert ask(Q.finite(a), Q.finite(x) & ~Q.positive(x) & ~Q.finite(y) + & ~Q.positive(y)) is False + # B + ? + assert ask(Q.finite(a), Q.finite(x)) is None + assert ask(Q.finite(a), Q.positive(x)) is None + assert ask(Q.finite(a), Q.finite(x) + & Q.extended_positive(y)) is None + assert ask(Q.finite(a), Q.positive(x) + & Q.extended_positive(y)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.positive(y)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.positive(x) + & Q.extended_positive(y)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.positive(x) + & ~Q.positive(y)) is None + # U + U + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & ~Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.positive_infinite(y)) is False + assert ask(Q.finite(a), Q.positive_infinite(x) & ~Q.finite(y) + & ~Q.extended_positive(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.extended_positive(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & ~Q.extended_positive(x) & ~Q.extended_positive(y)) is False + # U + ? + assert ask(Q.finite(a), ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.extended_positive(x) + & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.extended_positive(x) + & Q.positive_infinite(y)) is False + assert ask(Q.finite(a), Q.extended_positive(x) + & ~Q.finite(y) & ~Q.extended_positive(y)) is None + assert ask(Q.finite(a), ~Q.extended_positive(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), ~Q.extended_positive(x) & ~Q.finite(y) + & ~Q.extended_positive(y)) is False + # ? + ? + assert ask(Q.finite(a)) is None + assert ask(Q.finite(a), Q.extended_positive(x)) is None + assert ask(Q.finite(a), Q.extended_positive(y)) is None + assert ask(Q.finite(a), Q.extended_positive(x) + & Q.extended_positive(y)) is None + assert ask(Q.finite(a), Q.extended_positive(x) + & ~Q.extended_positive(y)) is None + assert ask(Q.finite(a), ~Q.extended_positive(x) + & Q.extended_positive(y)) is None + assert ask(Q.finite(a), ~Q.extended_positive(x) + & ~Q.extended_positive(y)) is None + + x, y, z = symbols('x,y,z') + a = x + y + z + x, y, z = a.args + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.negative(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.finite(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.finite(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.extended_positive(y) + & Q.finite(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative_infinite(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.negative_infinite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative_infinite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative_infinite(y) + & Q.extended_negative(z)) is False + assert ask(Q.finite(a), Q.negative(x) + & Q.negative_infinite(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.negative_infinite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & ~Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.negative(x) & ~Q.finite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.negative(x) & ~Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.negative(x) & ~Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.positive_infinite(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.positive_infinite(y) + & Q.negative_infinite(z)) is None + assert ask(Q.finite(a), Q.negative(x) & + Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.positive_infinite(y) + & Q.extended_positive(z)) is False + assert ask(Q.finite(a), Q.negative(x) & Q.extended_negative(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.extended_negative(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x)) is None + assert ask(Q.finite(a), Q.negative(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative(x) & Q.extended_positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.finite(z)) is True + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.negative_infinite(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.negative_infinite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.negative_infinite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.negative_infinite(y) + & Q.extended_negative(z)) is False + assert ask(Q.finite(a), Q.finite(x) + & Q.negative_infinite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.negative_infinite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.positive_infinite(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.positive_infinite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.finite(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.positive_infinite(y) + & Q.extended_positive(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.extended_negative(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.finite(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.extended_negative(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x)) is None + assert ask(Q.finite(a), Q.finite(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.extended_positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & Q.positive(z)) is True + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.negative_infinite(y) + & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.negative_infinite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.negative_infinite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.negative_infinite(y) + & Q.extended_negative(z)) is False + assert ask(Q.finite(a), Q.positive(x) + & Q.negative_infinite(y)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.negative_infinite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.positive(x) & ~Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.positive_infinite(y) + & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.positive_infinite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.positive_infinite(y) + & Q.extended_positive(z)) is False + assert ask(Q.finite(a), Q.positive(x) & Q.extended_negative(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.extended_negative(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x)) is None + assert ask(Q.finite(a), Q.positive(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive(x) & Q.extended_positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y) & Q.negative_infinite(z)) is False + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y) & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y)& Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y) & Q.extended_negative(z)) is False + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.negative_infinite(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & ~Q.finite(y) & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & ~Q.finite(y) & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & ~Q.finite(y) & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & ~Q.finite(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.positive_infinite(y) & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.positive_infinite(y) & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.positive_infinite(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.extended_negative(y) & Q.extended_negative(z)) is False + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.extended_negative(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.negative_infinite(x) + & Q.extended_positive(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.positive_infinite(z) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.positive_infinite(y) + & Q.positive_infinite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.positive_infinite(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.positive_infinite(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.extended_negative(y) + & Q.extended_negative(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.extended_negative(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), ~Q.finite(x)) is None + assert ask(Q.finite(a), ~Q.finite(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.extended_positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.positive_infinite(y) & Q.positive_infinite(z)) is False + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.positive_infinite(y) & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.positive_infinite(y)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.positive_infinite(y) & Q.extended_positive(z)) is False + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.extended_negative(y) & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.extended_negative(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive_infinite(x)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.positive_infinite(x) + & Q.extended_positive(y) & Q.extended_positive(z)) is False + assert ask(Q.finite(a), Q.extended_negative(x) + & Q.extended_negative(y) & Q.extended_negative(z)) is None + assert ask(Q.finite(a), Q.extended_negative(x) + & Q.extended_negative(y)) is None + assert ask(Q.finite(a), Q.extended_negative(x) + & Q.extended_negative(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.extended_negative(x)) is None + assert ask(Q.finite(a), Q.extended_negative(x) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.extended_negative(x) + & Q.extended_positive(y) & Q.extended_positive(z)) is None + assert ask(Q.finite(a)) is None + assert ask(Q.finite(a), Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.extended_positive(y) + & Q.extended_positive(z)) is None + assert ask(Q.finite(a), Q.extended_positive(x) + & Q.extended_positive(y) & Q.extended_positive(z)) is None + + assert ask(Q.finite(2*x)) is None + assert ask(Q.finite(2*x), Q.finite(x)) is True + + x, y, z = symbols('x,y,z') + a = x*y + x, y = a.args + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y)) is True + assert ask(Q.finite(a), Q.finite(x) & ~Q.zero(x) & ~Q.finite(y)) is False + assert ask(Q.finite(a), Q.finite(x)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y) &~Q.zero(y)) is False + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y)) is False + assert ask(Q.finite(a), ~Q.finite(x)) is None + assert ask(Q.finite(a), Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(y)) is None + assert ask(Q.finite(a)) is None + a = x*y*z + x, y, z = a.args + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & Q.finite(z)) is True + assert ask(Q.finite(a), Q.finite(x) & ~Q.zero(x) & Q.finite(y) + & ~Q.zero(y) & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.zero(x) & ~Q.finite(y) + & Q.finite(z) & ~Q.zero(z)) is False + assert ask(Q.finite(a), Q.finite(x) & ~Q.zero(x) & ~Q.finite(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.finite(x) & Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(x)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y) & ~Q.zero(y) + & Q.finite(z) & ~Q.zero(z)) is False + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.zero(x) & Q.finite(y) + & ~Q.zero(y) & ~Q.finite(z)) is False + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & Q.finite(z) & ~Q.zero(z)) is False + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & ~Q.finite(z)) is False + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x)) is None + assert ask(Q.finite(a), Q.finite(y) & Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(y) & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(y) & Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(y) & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(y)) is None + assert ask(Q.finite(a), Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(z) & Q.extended_nonzero(x) + & Q.extended_nonzero(y) & Q.extended_nonzero(z)) is None + assert ask(Q.finite(a), Q.extended_nonzero(x) & ~Q.finite(y) + & Q.extended_nonzero(y) & ~Q.finite(z) + & Q.extended_nonzero(z)) is False + + x, y, z = symbols('x,y,z') + assert ask(Q.finite(x**2)) is None + assert ask(Q.finite(2**x)) is None + assert ask(Q.finite(2**x), Q.finite(x)) is True + assert ask(Q.finite(x**x)) is None + assert ask(Q.finite(S.Half ** x)) is None + assert ask(Q.finite(S.Half ** x), Q.extended_positive(x)) is True + assert ask(Q.finite(S.Half ** x), Q.extended_negative(x)) is None + assert ask(Q.finite(2**x), Q.extended_negative(x)) is True + assert ask(Q.finite(sqrt(x))) is None + assert ask(Q.finite(2**x), ~Q.finite(x)) is False + assert ask(Q.finite(x**2), ~Q.finite(x)) is False + + # https://github.com/sympy/sympy/issues/27707 + assert ask(Q.finite(x**y), Q.real(x) & Q.real(y)) is None + assert ask(Q.finite(x**y), Q.real(x) & Q.negative(y)) is None + assert ask(Q.finite(x**y), Q.zero(x) & Q.negative(y)) is False + assert ask(Q.finite(x**y), Q.real(x) & Q.positive(y)) is True + assert ask(Q.finite(x**y), Q.nonzero(x) & Q.real(y)) is True + assert ask(Q.finite(x**y), Q.nonzero(x) & Q.negative(y)) is True + assert ask(Q.finite(x**y), Q.zero(x) & Q.positive(y)) is True + + # sign function + assert ask(Q.finite(sign(x))) is True + assert ask(Q.finite(sign(x)), ~Q.finite(x)) is True + + # exponential functions + assert ask(Q.finite(log(x))) is None + assert ask(Q.finite(log(x)), Q.finite(x)) is None + assert ask(Q.finite(log(x)), ~Q.zero(x)) is True + assert ask(Q.finite(log(x)), Q.infinite(x)) is False + assert ask(Q.finite(log(x)), Q.zero(x)) is False + assert ask(Q.finite(exp(x))) is None + assert ask(Q.finite(exp(x)), Q.finite(x)) is True + assert ask(Q.finite(exp(2))) is True + + # trigonometric functions + assert ask(Q.finite(sin(x))) is True + assert ask(Q.finite(sin(x)), ~Q.finite(x)) is True + assert ask(Q.finite(cos(x))) is True + assert ask(Q.finite(cos(x)), ~Q.finite(x)) is True + assert ask(Q.finite(2*sin(x))) is True + assert ask(Q.finite(sin(x)**2)) is True + assert ask(Q.finite(cos(x)**2)) is True + assert ask(Q.finite(cos(x) + sin(x))) is True + + +def test_unbounded(): + assert ask(Q.infinite(I * oo)) is True + assert ask(Q.infinite(1 + I*oo)) is True + assert ask(Q.infinite(3 * (I * oo))) is True + assert ask(Q.infinite(-I * oo)) is True + assert ask(Q.infinite(1 + zoo)) is True + assert ask(Q.infinite(I * zoo)) is True + assert ask(Q.infinite(x / y), Q.infinite(x) & Q.finite(y) & ~Q.zero(y)) is True + assert ask(Q.infinite(I * oo - I * oo)) is None + assert ask(Q.infinite(x * I * oo)) is None + assert ask(Q.infinite(1 / x), Q.finite(x) & ~Q.zero(x)) is False + assert ask(Q.infinite(1 / (I * oo))) is False + + +def test_issue_27441(): + # https://github.com/sympy/sympy/issues/27441 + assert ask(Q.composite(y), Q.integer(y) & Q.positive(y) & ~Q.prime(y)) is None + + +def test_issue_27447(): + x,y,z = symbols('x y z') + a = x*y + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y)) is None + + a = x*y*z + assert ask(Q.finite(a), Q.finite(x) & Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & Q.finite(z) ) is None + assert ask(Q.finite(a), Q.finite(x) & ~Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y) + & Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & Q.finite(y) + & ~Q.finite(z)) is None + assert ask(Q.finite(a), ~Q.finite(x) & ~Q.finite(y) + & Q.finite(z)) is None + + +@XFAIL +def test_issue_27662_xfail(): + assert ask(Q.finite(x*y), ~Q.finite(x) + & Q.zero(y)) is None + + +@XFAIL +def test_bounded_xfail(): + """We need to support relations in ask for this to work""" + assert ask(Q.finite(sin(x)**x)) is True + assert ask(Q.finite(cos(x)**x)) is True + + +def test_commutative(): + """By default objects are Q.commutative that is why it returns True + for both key=True and key=False""" + assert ask(Q.commutative(x)) is True + assert ask(Q.commutative(x), ~Q.commutative(x)) is False + assert ask(Q.commutative(x), Q.complex(x)) is True + assert ask(Q.commutative(x), Q.imaginary(x)) is True + assert ask(Q.commutative(x), Q.real(x)) is True + assert ask(Q.commutative(x), Q.positive(x)) is True + assert ask(Q.commutative(x), ~Q.commutative(y)) is True + + assert ask(Q.commutative(2*x)) is True + assert ask(Q.commutative(2*x), ~Q.commutative(x)) is False + + assert ask(Q.commutative(x + 1)) is True + assert ask(Q.commutative(x + 1), ~Q.commutative(x)) is False + + assert ask(Q.commutative(x**2)) is True + assert ask(Q.commutative(x**2), ~Q.commutative(x)) is False + + assert ask(Q.commutative(log(x))) is True + + +@_both_exp_pow +def test_complex(): + assert ask(Q.complex(x)) is None + assert ask(Q.complex(x), Q.complex(x)) is True + assert ask(Q.complex(x), Q.complex(y)) is None + assert ask(Q.complex(x), ~Q.complex(x)) is False + assert ask(Q.complex(x), Q.real(x)) is True + assert ask(Q.complex(x), ~Q.real(x)) is None + assert ask(Q.complex(x), Q.rational(x)) is True + assert ask(Q.complex(x), Q.irrational(x)) is True + assert ask(Q.complex(x), Q.positive(x)) is True + assert ask(Q.complex(x), Q.imaginary(x)) is True + assert ask(Q.complex(x), Q.algebraic(x)) is True + + # a+b + assert ask(Q.complex(x + 1), Q.complex(x)) is True + assert ask(Q.complex(x + 1), Q.real(x)) is True + assert ask(Q.complex(x + 1), Q.rational(x)) is True + assert ask(Q.complex(x + 1), Q.irrational(x)) is True + assert ask(Q.complex(x + 1), Q.imaginary(x)) is True + assert ask(Q.complex(x + 1), Q.integer(x)) is True + assert ask(Q.complex(x + 1), Q.even(x)) is True + assert ask(Q.complex(x + 1), Q.odd(x)) is True + assert ask(Q.complex(x + y), Q.complex(x) & Q.complex(y)) is True + assert ask(Q.complex(x + y), Q.real(x) & Q.imaginary(y)) is True + + # a*x +b + assert ask(Q.complex(2*x + 1), Q.complex(x)) is True + assert ask(Q.complex(2*x + 1), Q.real(x)) is True + assert ask(Q.complex(2*x + 1), Q.positive(x)) is True + assert ask(Q.complex(2*x + 1), Q.rational(x)) is True + assert ask(Q.complex(2*x + 1), Q.irrational(x)) is True + assert ask(Q.complex(2*x + 1), Q.imaginary(x)) is True + assert ask(Q.complex(2*x + 1), Q.integer(x)) is True + assert ask(Q.complex(2*x + 1), Q.even(x)) is True + assert ask(Q.complex(2*x + 1), Q.odd(x)) is True + + # x**2 + assert ask(Q.complex(x**2), Q.complex(x)) is True + assert ask(Q.complex(x**2), Q.real(x)) is True + assert ask(Q.complex(x**2), Q.positive(x)) is True + assert ask(Q.complex(x**2), Q.rational(x)) is True + assert ask(Q.complex(x**2), Q.irrational(x)) is True + assert ask(Q.complex(x**2), Q.imaginary(x)) is True + assert ask(Q.complex(x**2), Q.integer(x)) is True + assert ask(Q.complex(x**2), Q.even(x)) is True + assert ask(Q.complex(x**2), Q.odd(x)) is True + + # 2**x + assert ask(Q.complex(2**x), Q.complex(x)) is True + assert ask(Q.complex(2**x), Q.real(x)) is True + assert ask(Q.complex(2**x), Q.positive(x)) is True + assert ask(Q.complex(2**x), Q.rational(x)) is True + assert ask(Q.complex(2**x), Q.irrational(x)) is True + assert ask(Q.complex(2**x), Q.imaginary(x)) is True + assert ask(Q.complex(2**x), Q.integer(x)) is True + assert ask(Q.complex(2**x), Q.even(x)) is True + assert ask(Q.complex(2**x), Q.odd(x)) is True + assert ask(Q.complex(x**y), Q.complex(x) & Q.complex(y)) is True + + # trigonometric expressions + assert ask(Q.complex(sin(x))) is True + assert ask(Q.complex(sin(2*x + 1))) is True + assert ask(Q.complex(cos(x))) is True + assert ask(Q.complex(cos(2*x + 1))) is True + + # exponential + assert ask(Q.complex(exp(x))) is True + assert ask(Q.complex(exp(x))) is True + + # Q.complexes + assert ask(Q.complex(Abs(x))) is True + assert ask(Q.complex(re(x))) is True + assert ask(Q.complex(im(x))) is True + + +def test_even_query(): + assert ask(Q.even(x)) is None + assert ask(Q.even(x), Q.integer(x)) is None + assert ask(Q.even(x), ~Q.integer(x)) is False + assert ask(Q.even(x), Q.rational(x)) is None + assert ask(Q.even(x), Q.positive(x)) is None + + assert ask(Q.even(2*x)) is None + assert ask(Q.even(2*x), Q.integer(x)) is True + assert ask(Q.even(2*x), Q.even(x)) is True + assert ask(Q.even(2*x), Q.irrational(x)) is False + assert ask(Q.even(2*x), Q.odd(x)) is True + assert ask(Q.even(2*x), ~Q.integer(x)) is None + assert ask(Q.even(3*x), Q.integer(x)) is None + assert ask(Q.even(3*x), Q.even(x)) is True + assert ask(Q.even(3*x), Q.odd(x)) is False + + assert ask(Q.even(x + 1), Q.odd(x)) is True + assert ask(Q.even(x + 1), Q.even(x)) is False + assert ask(Q.even(x + 2), Q.odd(x)) is False + assert ask(Q.even(x + 2), Q.even(x)) is True + assert ask(Q.even(7 - x), Q.odd(x)) is True + assert ask(Q.even(7 + x), Q.odd(x)) is True + assert ask(Q.even(x + y), Q.odd(x) & Q.odd(y)) is True + assert ask(Q.even(x + y), Q.odd(x) & Q.even(y)) is False + assert ask(Q.even(x + y), Q.even(x) & Q.even(y)) is True + + assert ask(Q.even(2*x + 1), Q.integer(x)) is False + assert ask(Q.even(2*x*y), Q.rational(x) & Q.rational(x)) is None + assert ask(Q.even(2*x*y), Q.irrational(x) & Q.irrational(x)) is None + + assert ask(Q.even(x + y + z), Q.odd(x) & Q.odd(y) & Q.even(z)) is True + assert ask(Q.even(x + y + z + t), + Q.odd(x) & Q.odd(y) & Q.even(z) & Q.integer(t)) is None + + assert ask(Q.even(Abs(x)), Q.even(x)) is True + assert ask(Q.even(Abs(x)), ~Q.even(x)) is None + assert ask(Q.even(re(x)), Q.even(x)) is True + assert ask(Q.even(re(x)), ~Q.even(x)) is None + assert ask(Q.even(im(x)), Q.even(x)) is True + assert ask(Q.even(im(x)), Q.real(x)) is True + + assert ask(Q.even((-1)**n), Q.integer(n)) is False + + assert ask(Q.even(k**2), Q.even(k)) is True + assert ask(Q.even(n**2), Q.odd(n)) is False + assert ask(Q.even(2**k), Q.even(k)) is None + assert ask(Q.even(x**2)) is None + + assert ask(Q.even(k**m), Q.even(k) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.even(n**m), Q.odd(n) & Q.integer(m) & ~Q.negative(m)) is False + + assert ask(Q.even(k**p), Q.even(k) & Q.integer(p) & Q.positive(p)) is True + assert ask(Q.even(n**p), Q.odd(n) & Q.integer(p) & Q.positive(p)) is False + + assert ask(Q.even(m**k), Q.even(k) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.even(p**k), Q.even(k) & Q.integer(p) & Q.positive(p)) is None + + assert ask(Q.even(m**n), Q.odd(n) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.even(p**n), Q.odd(n) & Q.integer(p) & Q.positive(p)) is None + + assert ask(Q.even(k**x), Q.even(k)) is None + assert ask(Q.even(n**x), Q.odd(n)) is None + + assert ask(Q.even(x*y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.even(x*x), Q.integer(x)) is None + assert ask(Q.even(x*(x + y)), Q.integer(x) & Q.odd(y)) is True + assert ask(Q.even(x*(x + y)), Q.integer(x) & Q.even(y)) is None + + +@XFAIL +def test_evenness_in_ternary_integer_product_with_odd(): + # Tests that oddness inference is independent of term ordering. + # Term ordering at the point of testing depends on SymPy's symbol order, so + # we try to force a different order by modifying symbol names. + assert ask(Q.even(x*y*(y + z)), Q.integer(x) & Q.integer(y) & Q.odd(z)) is True + assert ask(Q.even(y*x*(x + z)), Q.integer(x) & Q.integer(y) & Q.odd(z)) is True + + +def test_evenness_in_ternary_integer_product_with_even(): + assert ask(Q.even(x*y*(y + z)), Q.integer(x) & Q.integer(y) & Q.even(z)) is None + + +def test_extended_real(): + assert ask(Q.extended_real(x), Q.positive_infinite(x)) is True + assert ask(Q.extended_real(x), Q.positive(x)) is True + assert ask(Q.extended_real(x), Q.zero(x)) is True + assert ask(Q.extended_real(x), Q.negative(x)) is True + assert ask(Q.extended_real(x), Q.negative_infinite(x)) is True + + assert ask(Q.extended_real(-x), Q.positive(x)) is True + assert ask(Q.extended_real(-x), Q.negative(x)) is True + + assert ask(Q.extended_real(x + S.Infinity), Q.real(x)) is True + + assert ask(Q.extended_real(x), Q.infinite(x)) is None + + +@_both_exp_pow +def test_rational(): + assert ask(Q.rational(x), Q.integer(x)) is True + assert ask(Q.rational(x), Q.irrational(x)) is False + assert ask(Q.rational(x), Q.real(x)) is None + assert ask(Q.rational(x), Q.positive(x)) is None + assert ask(Q.rational(x), Q.negative(x)) is None + assert ask(Q.rational(x), Q.nonzero(x)) is None + assert ask(Q.rational(x), ~Q.algebraic(x)) is False + + assert ask(Q.rational(2*x), Q.rational(x)) is True + assert ask(Q.rational(2*x), Q.integer(x)) is True + assert ask(Q.rational(2*x), Q.even(x)) is True + assert ask(Q.rational(2*x), Q.odd(x)) is True + assert ask(Q.rational(2*x), Q.irrational(x)) is False + + assert ask(Q.rational(x/2), Q.rational(x)) is True + assert ask(Q.rational(x/2), Q.integer(x)) is True + assert ask(Q.rational(x/2), Q.even(x)) is True + assert ask(Q.rational(x/2), Q.odd(x)) is True + assert ask(Q.rational(x/2), Q.irrational(x)) is False + + assert ask(Q.rational(1/x), Q.rational(x) & Q.nonzero(x)) is True + assert ask(Q.rational(1/x), Q.integer(x) & Q.nonzero(x)) is True + assert ask(Q.rational(1/x), Q.even(x) & Q.nonzero(x)) is True + assert ask(Q.rational(1/x), Q.odd(x)) is True + assert ask(Q.rational(1/x), Q.irrational(x)) is False + + assert ask(Q.rational(2/x), Q.rational(x) & Q.nonzero(x)) is True + assert ask(Q.rational(2/x), Q.integer(x) & Q.nonzero(x)) is True + assert ask(Q.rational(2/x), Q.even(x) & Q.nonzero(x)) is True + assert ask(Q.rational(2/x), Q.odd(x)) is True + assert ask(Q.rational(2/x), Q.irrational(x)) is False + + assert ask(Q.rational(x), ~Q.algebraic(x)) is False + + # with multiple symbols + assert ask(Q.rational(x*y), Q.irrational(x) & Q.irrational(y)) is None + assert ask(Q.rational(y/x), Q.rational(x) & Q.rational(y) & Q.nonzero(x)) is True + assert ask(Q.rational(y/x), Q.integer(x) & Q.rational(y) & Q.nonzero(x)) is True + assert ask(Q.rational(y/x), Q.even(x) & Q.rational(y) & Q.nonzero(x)) is True + assert ask(Q.rational(y/x), Q.odd(x) & Q.rational(y)) is True + assert ask(Q.rational(y/x), Q.irrational(x) & Q.rational(y) & Q.nonzero(y)) is False + + for f in [exp, sin, tan, asin, atan, cos]: + assert ask(Q.rational(f(7))) is False + assert ask(Q.rational(f(7, evaluate=False))) is False + assert ask(Q.rational(f(0, evaluate=False))) is True + assert ask(Q.rational(f(x)), Q.rational(x)) is None + assert ask(Q.rational(f(x)), Q.rational(x) & Q.nonzero(x)) is False + + for g in [log, acos]: + assert ask(Q.rational(g(7))) is False + assert ask(Q.rational(g(7, evaluate=False))) is False + assert ask(Q.rational(g(1, evaluate=False))) is True + assert ask(Q.rational(g(x)), Q.rational(x)) is None + assert ask(Q.rational(g(x)), Q.rational(x) & Q.nonzero(x - 1)) is False + + for h in [cot, acot]: + assert ask(Q.rational(h(7))) is False + assert ask(Q.rational(h(7, evaluate=False))) is False + assert ask(Q.rational(h(x)), Q.rational(x)) is False + + # https://github.com/sympy/sympy/issues/27442 + assert ask(Q.rational(x**y),Q.irrational(x) & Q.rational(y)) is None + assert ask(Q.rational(x**y),Q.integer(x) & Q.prime(x) & Q.rational(y)) is None + assert ask(Q.rational(x**y),Q.integer(x) & Q.integer(y)) is None + assert ask(Q.rational(x**y),Q.integer(x) & Q.eq(x,0) & Q.integer(y)) is None + assert ask(Q.rational(x**y),Q.eq(x,1) & Q.rational(y)) is None + assert ask(Q.rational(x**y),Q.eq(x,-1) & Q.rational(y)) is None + assert ask(Q.rational(x**y), Q.prime(x) & Q.rational(y)) is None + assert ask(Q.rational(x**y), ~Q.rational(x) & Q.integer(y) ) is None + assert ask(Q.rational(Pow(-1, x, evaluate=False), Q.rational(x))) is None + assert ask(Q.rational(x**y), Q.integer(y) & ~Q. algebraic(x)) is None + assert ask(Q.rational(x**y), Q.integer(y) & ~Q. algebraic(x) & ~Q.zero(x)) is None + assert ask(Q.rational(x**y), Q.integer(y) & ~Q.algebraic(x) & Q.complex(x) & ~Q.real(x)) is None + assert ask(Q.rational(x**y), Q.integer(y) & ~Q.algebraic(x) & Q.complex(x)) is None + + +def test_hermitian(): + assert ask(Q.hermitian(x)) is None + assert ask(Q.hermitian(x), Q.antihermitian(x)) is None + assert ask(Q.hermitian(x), Q.imaginary(x)) is False + assert ask(Q.hermitian(x), Q.prime(x)) is True + assert ask(Q.hermitian(x), Q.real(x)) is True + assert ask(Q.hermitian(x), Q.zero(x)) is True + + assert ask(Q.hermitian(x + 1), Q.antihermitian(x)) is None + assert ask(Q.hermitian(x + 1), Q.complex(x)) is None + assert ask(Q.hermitian(x + 1), Q.hermitian(x)) is True + assert ask(Q.hermitian(x + 1), Q.imaginary(x)) is False + assert ask(Q.hermitian(x + 1), Q.real(x)) is True + assert ask(Q.hermitian(x + I), Q.antihermitian(x)) is None + assert ask(Q.hermitian(x + I), Q.complex(x)) is None + assert ask(Q.hermitian(x + I), Q.hermitian(x)) is False + assert ask(Q.hermitian(x + I), Q.imaginary(x)) is None + assert ask(Q.hermitian(x + I), Q.real(x)) is False + assert ask( + Q.hermitian(x + y), Q.antihermitian(x) & Q.antihermitian(y)) is None + assert ask(Q.hermitian(x + y), Q.antihermitian(x) & Q.complex(y)) is None + assert ask( + Q.hermitian(x + y), Q.antihermitian(x) & Q.hermitian(y)) is None + assert ask(Q.hermitian(x + y), Q.antihermitian(x) & Q.imaginary(y)) is None + assert ask(Q.hermitian(x + y), Q.antihermitian(x) & Q.real(y)) is None + assert ask(Q.hermitian(x + y), Q.hermitian(x) & Q.complex(y)) is None + assert ask(Q.hermitian(x + y), Q.hermitian(x) & Q.hermitian(y)) is True + assert ask(Q.hermitian(x + y), Q.hermitian(x) & Q.imaginary(y)) is False + assert ask(Q.hermitian(x + y), Q.hermitian(x) & Q.real(y)) is True + assert ask(Q.hermitian(x + y), Q.imaginary(x) & Q.complex(y)) is None + assert ask(Q.hermitian(x + y), Q.imaginary(x) & Q.imaginary(y)) is None + assert ask(Q.hermitian(x + y), Q.imaginary(x) & Q.real(y)) is False + assert ask(Q.hermitian(x + y), Q.real(x) & Q.complex(y)) is None + assert ask(Q.hermitian(x + y), Q.real(x) & Q.real(y)) is True + + assert ask(Q.hermitian(I*x), Q.antihermitian(x)) is True + assert ask(Q.hermitian(I*x), Q.complex(x)) is None + assert ask(Q.hermitian(I*x), Q.hermitian(x)) is False + assert ask(Q.hermitian(I*x), Q.imaginary(x)) is True + assert ask(Q.hermitian(I*x), Q.real(x)) is False + assert ask(Q.hermitian(x*y), Q.hermitian(x) & Q.real(y)) is True + + assert ask( + Q.hermitian(x + y + z), Q.real(x) & Q.real(y) & Q.real(z)) is True + assert ask(Q.hermitian(x + y + z), + Q.real(x) & Q.real(y) & Q.imaginary(z)) is False + assert ask(Q.hermitian(x + y + z), + Q.real(x) & Q.imaginary(y) & Q.imaginary(z)) is None + assert ask(Q.hermitian(x + y + z), + Q.imaginary(x) & Q.imaginary(y) & Q.imaginary(z)) is None + + assert ask(Q.antihermitian(x)) is None + assert ask(Q.antihermitian(x), Q.real(x)) is False + assert ask(Q.antihermitian(x), Q.prime(x)) is False + + assert ask(Q.antihermitian(x + 1), Q.antihermitian(x)) is False + assert ask(Q.antihermitian(x + 1), Q.complex(x)) is None + assert ask(Q.antihermitian(x + 1), Q.hermitian(x)) is None + assert ask(Q.antihermitian(x + 1), Q.imaginary(x)) is False + assert ask(Q.antihermitian(x + 1), Q.real(x)) is None + assert ask(Q.antihermitian(x + I), Q.antihermitian(x)) is True + assert ask(Q.antihermitian(x + I), Q.complex(x)) is None + assert ask(Q.antihermitian(x + I), Q.hermitian(x)) is None + assert ask(Q.antihermitian(x + I), Q.imaginary(x)) is True + assert ask(Q.antihermitian(x + I), Q.real(x)) is False + assert ask(Q.antihermitian(x), Q.zero(x)) is True + + assert ask( + Q.antihermitian(x + y), Q.antihermitian(x) & Q.antihermitian(y) + ) is True + assert ask( + Q.antihermitian(x + y), Q.antihermitian(x) & Q.complex(y)) is None + assert ask( + Q.antihermitian(x + y), Q.antihermitian(x) & Q.hermitian(y)) is None + assert ask( + Q.antihermitian(x + y), Q.antihermitian(x) & Q.imaginary(y)) is True + assert ask(Q.antihermitian(x + y), Q.antihermitian(x) & Q.real(y) + ) is False + assert ask(Q.antihermitian(x + y), Q.hermitian(x) & Q.complex(y)) is None + assert ask(Q.antihermitian(x + y), Q.hermitian(x) & Q.hermitian(y) + ) is None + assert ask( + Q.antihermitian(x + y), Q.hermitian(x) & Q.imaginary(y)) is None + assert ask(Q.antihermitian(x + y), Q.hermitian(x) & Q.real(y)) is None + assert ask(Q.antihermitian(x + y), Q.imaginary(x) & Q.complex(y)) is None + assert ask(Q.antihermitian(x + y), Q.imaginary(x) & Q.imaginary(y)) is True + assert ask(Q.antihermitian(x + y), Q.imaginary(x) & Q.real(y)) is False + assert ask(Q.antihermitian(x + y), Q.real(x) & Q.complex(y)) is None + assert ask(Q.antihermitian(x + y), Q.real(x) & Q.real(y)) is None + + assert ask(Q.antihermitian(I*x), Q.real(x)) is True + assert ask(Q.antihermitian(I*x), Q.antihermitian(x)) is False + assert ask(Q.antihermitian(I*x), Q.complex(x)) is None + assert ask(Q.antihermitian(x*y), Q.antihermitian(x) & Q.real(y)) is True + + assert ask(Q.antihermitian(x + y + z), + Q.real(x) & Q.real(y) & Q.real(z)) is None + assert ask(Q.antihermitian(x + y + z), + Q.real(x) & Q.real(y) & Q.imaginary(z)) is None + assert ask(Q.antihermitian(x + y + z), + Q.real(x) & Q.imaginary(y) & Q.imaginary(z)) is False + assert ask(Q.antihermitian(x + y + z), + Q.imaginary(x) & Q.imaginary(y) & Q.imaginary(z)) is True + + +@_both_exp_pow +def test_imaginary(): + assert ask(Q.imaginary(x)) is None + assert ask(Q.imaginary(x), Q.real(x)) is False + assert ask(Q.imaginary(x), Q.prime(x)) is False + + assert ask(Q.imaginary(x + 1), Q.real(x)) is False + assert ask(Q.imaginary(x + 1), Q.imaginary(x)) is False + assert ask(Q.imaginary(x + I), Q.real(x)) is False + assert ask(Q.imaginary(x + I), Q.imaginary(x)) is True + assert ask(Q.imaginary(x + y), Q.imaginary(x) & Q.imaginary(y)) is True + assert ask(Q.imaginary(x + y), Q.real(x) & Q.real(y)) is False + assert ask(Q.imaginary(x + y), Q.imaginary(x) & Q.real(y)) is False + assert ask(Q.imaginary(x + y), Q.complex(x) & Q.real(y)) is None + assert ask( + Q.imaginary(x + y + z), Q.real(x) & Q.real(y) & Q.real(z)) is False + assert ask(Q.imaginary(x + y + z), + Q.real(x) & Q.real(y) & Q.imaginary(z)) is None + assert ask(Q.imaginary(x + y + z), + Q.real(x) & Q.imaginary(y) & Q.imaginary(z)) is False + + assert ask(Q.imaginary(I*x), Q.real(x)) is True + assert ask(Q.imaginary(I*x), Q.imaginary(x)) is False + assert ask(Q.imaginary(I*x), Q.complex(x)) is None + assert ask(Q.imaginary(x*y), Q.imaginary(x) & Q.real(y)) is True + assert ask(Q.imaginary(x*y), Q.real(x) & Q.real(y)) is False + + assert ask(Q.imaginary(I**x), Q.negative(x)) is None + assert ask(Q.imaginary(I**x), Q.positive(x)) is None + assert ask(Q.imaginary(I**x), Q.even(x)) is False + assert ask(Q.imaginary(I**x), Q.odd(x)) is True + assert ask(Q.imaginary(I**x), Q.imaginary(x)) is False + assert ask(Q.imaginary((2*I)**x), Q.imaginary(x)) is False + assert ask(Q.imaginary(x**0), Q.imaginary(x)) is False + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.imaginary(y)) is None + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.real(y)) is None + assert ask(Q.imaginary(x**y), Q.real(x) & Q.imaginary(y)) is None + assert ask(Q.imaginary(x**y), Q.real(x) & Q.real(y)) is None + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.integer(y)) is None + assert ask(Q.imaginary(x**y), Q.imaginary(y) & Q.integer(x)) is None + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.odd(y)) is True + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.rational(y)) is None + assert ask(Q.imaginary(x**y), Q.imaginary(x) & Q.even(y)) is False + + assert ask(Q.imaginary(x**y), Q.real(x) & Q.integer(y)) is False + assert ask(Q.imaginary(x**y), Q.positive(x) & Q.real(y)) is False + assert ask(Q.imaginary(x**y), Q.negative(x) & Q.real(y)) is None + assert ask(Q.imaginary(x**y), Q.negative(x) & Q.real(y) & ~Q.rational(y)) is False + assert ask(Q.imaginary(x**y), Q.integer(x) & Q.imaginary(y)) is None + assert ask(Q.imaginary(x**y), Q.negative(x) & Q.rational(y) & Q.integer(2*y)) is True + assert ask(Q.imaginary(x**y), Q.negative(x) & Q.rational(y) & ~Q.integer(2*y)) is False + assert ask(Q.imaginary(x**y), Q.negative(x) & Q.rational(y)) is None + assert ask(Q.imaginary(x**y), Q.real(x) & Q.rational(y) & ~Q.integer(2*y)) is False + assert ask(Q.imaginary(x**y), Q.real(x) & Q.rational(y) & Q.integer(2*y)) is None + + # logarithm + assert ask(Q.imaginary(log(I))) is True + assert ask(Q.imaginary(log(2*I))) is False + assert ask(Q.imaginary(log(I + 1))) is False + assert ask(Q.imaginary(log(x)), Q.complex(x)) is None + assert ask(Q.imaginary(log(x)), Q.imaginary(x)) is None + assert ask(Q.imaginary(log(x)), Q.positive(x)) is False + assert ask(Q.imaginary(log(exp(x))), Q.complex(x)) is None + assert ask(Q.imaginary(log(exp(x))), Q.imaginary(x)) is None # zoo/I/a+I*b + assert ask(Q.imaginary(log(exp(I)))) is True + + # exponential + assert ask(Q.imaginary(exp(x)**x), Q.imaginary(x)) is False + eq = Pow(exp(pi*I*x, evaluate=False), x, evaluate=False) + assert ask(Q.imaginary(eq), Q.even(x)) is False + eq = Pow(exp(pi*I*x/2, evaluate=False), x, evaluate=False) + assert ask(Q.imaginary(eq), Q.odd(x)) is True + assert ask(Q.imaginary(exp(3*I*pi*x)**x), Q.integer(x)) is False + assert ask(Q.imaginary(exp(2*pi*I, evaluate=False))) is False + assert ask(Q.imaginary(exp(pi*I/2, evaluate=False))) is True + + # issue 7886 + assert ask(Q.imaginary(Pow(x, Rational(1, 4))), Q.real(x) & Q.negative(x)) is False + + +def test_integer(): + assert ask(Q.integer(x)) is None + assert ask(Q.integer(x), Q.integer(x)) is True + assert ask(Q.integer(x), ~Q.integer(x)) is False + assert ask(Q.integer(x), ~Q.real(x)) is False + assert ask(Q.integer(x), ~Q.positive(x)) is None + assert ask(Q.integer(x), Q.even(x) | Q.odd(x)) is True + + assert ask(Q.integer(2*x), Q.integer(x)) is True + assert ask(Q.integer(2*x), Q.even(x)) is True + assert ask(Q.integer(2*x), Q.prime(x)) is True + assert ask(Q.integer(2*x), Q.rational(x)) is None + assert ask(Q.integer(2*x), Q.real(x)) is None + assert ask(Q.integer(sqrt(2)*x), Q.integer(x)) is False + assert ask(Q.integer(sqrt(2)*x), Q.irrational(x)) is None + + assert ask(Q.integer(x/2), Q.odd(x)) is False + assert ask(Q.integer(x/2), Q.even(x)) is True + assert ask(Q.integer(x/3), Q.odd(x)) is None + assert ask(Q.integer(x/3), Q.even(x)) is None + + # https://github.com/sympy/sympy/issues/7286 + assert ask(Q.integer(Abs(x)),Q.integer(x)) is True + assert ask(Q.integer(Abs(-x)),Q.integer(x)) is True + assert ask(Q.integer(Abs(x)), ~Q.integer(x)) is None + assert ask(Q.integer(Abs(x)),Q.complex(x)) is None + assert ask(Q.integer(Abs(x+I*y)),Q.real(x) & Q.real(y)) is None + + # https://github.com/sympy/sympy/issues/27739 + assert ask(Q.integer(x/y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.integer(1/x), Q.integer(x)) is None + assert ask(Q.integer(x**y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.integer(sqrt(5))) is False + assert ask(Q.integer(x**y), Q.nonzero(x) & Q.zero(y)) is True + assert ask(Q.integer(x**y), Q.integer(x) & Q.integer(y) & Q.positive(y)) is True + assert ask(Q.integer(-1**x), Q.integer(x)) is True + assert ask(Q.integer(x**y), Q.integer(x) & Q.integer(y) & Q.positive(y)) is True + assert ask(Q.integer(x**y), Q.zero(x) & Q.integer(y) & Q.positive(y)) is True + assert ask(Q.integer(pi**x), Q.zero(x)) is True + assert ask(Q.integer(x**y), Q.imaginary(x) & Q.zero(y)) is True + + +def test_negative(): + assert ask(Q.negative(x), Q.negative(x)) is True + assert ask(Q.negative(x), Q.positive(x)) is False + assert ask(Q.negative(x), ~Q.real(x)) is False + assert ask(Q.negative(x), Q.prime(x)) is False + assert ask(Q.negative(x), ~Q.prime(x)) is None + + assert ask(Q.negative(-x), Q.positive(x)) is True + assert ask(Q.negative(-x), ~Q.positive(x)) is None + assert ask(Q.negative(-x), Q.negative(x)) is False + assert ask(Q.negative(-x), Q.positive(x)) is True + + assert ask(Q.negative(x - 1), Q.negative(x)) is True + assert ask(Q.negative(x + y)) is None + assert ask(Q.negative(x + y), Q.negative(x)) is None + assert ask(Q.negative(x + y), Q.negative(x) & Q.negative(y)) is True + assert ask(Q.negative(x + y), Q.negative(x) & Q.nonpositive(y)) is True + assert ask(Q.negative(2 + I)) is False + # although this could be False, it is representative of expressions + # that don't evaluate to a zero with precision + assert ask(Q.negative(cos(I)**2 + sin(I)**2 - 1)) is None + assert ask(Q.negative(-I + I*(cos(2)**2 + sin(2)**2))) is None + + assert ask(Q.negative(x**2)) is None + assert ask(Q.negative(x**2), Q.real(x)) is False + assert ask(Q.negative(x**1.4), Q.real(x)) is None + + assert ask(Q.negative(x**I), Q.positive(x)) is None + + assert ask(Q.negative(x*y)) is None + assert ask(Q.negative(x*y), Q.positive(x) & Q.positive(y)) is False + assert ask(Q.negative(x*y), Q.positive(x) & Q.negative(y)) is True + assert ask(Q.negative(x*y), Q.complex(x) & Q.complex(y)) is None + + assert ask(Q.negative(x**y)) is None + assert ask(Q.negative(x**y), Q.negative(x) & Q.even(y)) is False + assert ask(Q.negative(x**y), Q.negative(x) & Q.odd(y)) is True + assert ask(Q.negative(x**y), Q.positive(x) & Q.integer(y)) is False + + assert ask(Q.negative(Abs(x))) is False + + +def test_nonzero(): + assert ask(Q.nonzero(x)) is None + assert ask(Q.nonzero(x), Q.real(x)) is None + assert ask(Q.nonzero(x), Q.positive(x)) is True + assert ask(Q.nonzero(x), Q.negative(x)) is True + assert ask(Q.nonzero(x), Q.negative(x) | Q.positive(x)) is True + + assert ask(Q.nonzero(x + y)) is None + assert ask(Q.nonzero(x + y), Q.positive(x) & Q.positive(y)) is True + assert ask(Q.nonzero(x + y), Q.positive(x) & Q.negative(y)) is None + assert ask(Q.nonzero(x + y), Q.negative(x) & Q.negative(y)) is True + + assert ask(Q.nonzero(2*x)) is None + assert ask(Q.nonzero(2*x), Q.positive(x)) is True + assert ask(Q.nonzero(2*x), Q.negative(x)) is True + assert ask(Q.nonzero(x*y), Q.nonzero(x)) is None + assert ask(Q.nonzero(x*y), Q.nonzero(x) & Q.nonzero(y)) is True + + assert ask(Q.nonzero(x**y), Q.nonzero(x)) is True + + assert ask(Q.nonzero(Abs(x))) is None + assert ask(Q.nonzero(Abs(x)), Q.nonzero(x)) is True + + assert ask(Q.nonzero(log(exp(2*I)))) is False + # although this could be False, it is representative of expressions + # that don't evaluate to a zero with precision + assert ask(Q.nonzero(cos(1)**2 + sin(1)**2 - 1)) is None + + +def test_zero(): + assert ask(Q.zero(x)) is None + assert ask(Q.zero(x), Q.real(x)) is None + assert ask(Q.zero(x), Q.positive(x)) is False + assert ask(Q.zero(x), Q.negative(x)) is False + assert ask(Q.zero(x), Q.negative(x) | Q.positive(x)) is False + + assert ask(Q.zero(x), Q.nonnegative(x) & Q.nonpositive(x)) is True + + assert ask(Q.zero(x + y)) is None + assert ask(Q.zero(x + y), Q.positive(x) & Q.positive(y)) is False + assert ask(Q.zero(x + y), Q.positive(x) & Q.negative(y)) is None + assert ask(Q.zero(x + y), Q.negative(x) & Q.negative(y)) is False + + assert ask(Q.zero(2*x)) is None + assert ask(Q.zero(2*x), Q.positive(x)) is False + assert ask(Q.zero(2*x), Q.negative(x)) is False + assert ask(Q.zero(x*y), Q.nonzero(x)) is None + + assert ask(Q.zero(Abs(x))) is None + assert ask(Q.zero(Abs(x)), Q.zero(x)) is True + + assert ask(Q.integer(x), Q.zero(x)) is True + assert ask(Q.even(x), Q.zero(x)) is True + assert ask(Q.odd(x), Q.zero(x)) is False + assert ask(Q.zero(x), Q.even(x)) is None + assert ask(Q.zero(x), Q.odd(x)) is False + assert ask(Q.zero(x) | Q.zero(y), Q.zero(x*y)) is True + + +def test_odd_query(): + assert ask(Q.odd(x)) is None + assert ask(Q.odd(x), Q.odd(x)) is True + assert ask(Q.odd(x), Q.integer(x)) is None + assert ask(Q.odd(x), ~Q.integer(x)) is False + assert ask(Q.odd(x), Q.rational(x)) is None + assert ask(Q.odd(x), Q.positive(x)) is None + + assert ask(Q.odd(-x), Q.odd(x)) is True + + assert ask(Q.odd(2*x)) is None + assert ask(Q.odd(2*x), Q.integer(x)) is False + assert ask(Q.odd(2*x), Q.odd(x)) is False + assert ask(Q.odd(2*x), Q.irrational(x)) is False + assert ask(Q.odd(2*x), ~Q.integer(x)) is None + assert ask(Q.odd(3*x), Q.integer(x)) is None + + assert ask(Q.odd(x/3), Q.odd(x)) is None + assert ask(Q.odd(x/3), Q.even(x)) is None + + assert ask(Q.odd(x + 1), Q.even(x)) is True + assert ask(Q.odd(x + 2), Q.even(x)) is False + assert ask(Q.odd(x + 2), Q.odd(x)) is True + assert ask(Q.odd(3 - x), Q.odd(x)) is False + assert ask(Q.odd(3 - x), Q.even(x)) is True + assert ask(Q.odd(3 + x), Q.odd(x)) is False + assert ask(Q.odd(3 + x), Q.even(x)) is True + assert ask(Q.odd(x + y), Q.odd(x) & Q.odd(y)) is False + assert ask(Q.odd(x + y), Q.odd(x) & Q.even(y)) is True + assert ask(Q.odd(x - y), Q.even(x) & Q.odd(y)) is True + assert ask(Q.odd(x - y), Q.odd(x) & Q.odd(y)) is False + + assert ask(Q.odd(x + y + z), Q.odd(x) & Q.odd(y) & Q.even(z)) is False + assert ask(Q.odd(x + y + z + t), + Q.odd(x) & Q.odd(y) & Q.even(z) & Q.integer(t)) is None + + assert ask(Q.odd(2*x + 1), Q.integer(x)) is True + assert ask(Q.odd(2*x + y), Q.integer(x) & Q.odd(y)) is True + assert ask(Q.odd(2*x + y), Q.integer(x) & Q.even(y)) is False + assert ask(Q.odd(2*x + y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.odd(x*y), Q.odd(x) & Q.even(y)) is False + assert ask(Q.odd(x*y), Q.odd(x) & Q.odd(y)) is True + assert ask(Q.odd(2*x*y), Q.rational(x) & Q.rational(x)) is None + assert ask(Q.odd(2*x*y), Q.irrational(x) & Q.irrational(x)) is None + + assert ask(Q.odd(Abs(x)), Q.odd(x)) is True + + assert ask(Q.odd((-1)**n), Q.integer(n)) is True + + assert ask(Q.odd(k**2), Q.even(k)) is False + assert ask(Q.odd(n**2), Q.odd(n)) is True + assert ask(Q.odd(3**k), Q.even(k)) is None + + assert ask(Q.odd(k**m), Q.even(k) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.odd(n**m), Q.odd(n) & Q.integer(m) & ~Q.negative(m)) is True + + assert ask(Q.odd(k**p), Q.even(k) & Q.integer(p) & Q.positive(p)) is False + assert ask(Q.odd(n**p), Q.odd(n) & Q.integer(p) & Q.positive(p)) is True + + assert ask(Q.odd(m**k), Q.even(k) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.odd(p**k), Q.even(k) & Q.integer(p) & Q.positive(p)) is None + + assert ask(Q.odd(m**n), Q.odd(n) & Q.integer(m) & ~Q.negative(m)) is None + assert ask(Q.odd(p**n), Q.odd(n) & Q.integer(p) & Q.positive(p)) is None + + assert ask(Q.odd(k**x), Q.even(k)) is None + assert ask(Q.odd(n**x), Q.odd(n)) is None + + assert ask(Q.odd(x*y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.odd(x*x), Q.integer(x)) is None + assert ask(Q.odd(x*(x + y)), Q.integer(x) & Q.odd(y)) is False + assert ask(Q.odd(x*(x + y)), Q.integer(x) & Q.even(y)) is None + + +@XFAIL +def test_oddness_in_ternary_integer_product_with_odd(): + # Tests that oddness inference is independent of term ordering. + # Term ordering at the point of testing depends on SymPy's symbol order, so + # we try to force a different order by modifying symbol names. + assert ask(Q.odd(x*y*(y + z)), Q.integer(x) & Q.integer(y) & Q.odd(z)) is False + assert ask(Q.odd(y*x*(x + z)), Q.integer(x) & Q.integer(y) & Q.odd(z)) is False + + +def test_oddness_in_ternary_integer_product_with_even(): + assert ask(Q.odd(x*y*(y + z)), Q.integer(x) & Q.integer(y) & Q.even(z)) is None + + +def test_prime(): + assert ask(Q.prime(x), Q.prime(x)) is True + assert ask(Q.prime(x), ~Q.prime(x)) is False + assert ask(Q.prime(x), Q.integer(x)) is None + assert ask(Q.prime(x), ~Q.integer(x)) is False + + assert ask(Q.prime(2*x), Q.integer(x)) is None + assert ask(Q.prime(x*y)) is None + assert ask(Q.prime(x*y), Q.prime(x)) is None + assert ask(Q.prime(x*y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.prime(4*x), Q.integer(x)) is False + assert ask(Q.prime(4*x)) is None + + assert ask(Q.prime(x**2), Q.integer(x)) is False + assert ask(Q.prime(x**2), Q.prime(x)) is False + + # https://github.com/sympy/sympy/issues/27446 + assert ask(Q.prime(4**x), Q.integer(x)) is False + assert ask(Q.prime(p**x), Q.prime(p) & Q.integer(x) & Q.ne(x, 1)) is False + assert ask(Q.prime(n**x), Q.integer(x) & Q.composite(n)) is False + assert ask(Q.prime(x**y), Q.integer(x) & Q.integer(y)) is None + assert ask(Q.prime(2**x), Q.integer(x)) is None + assert ask(Q.prime(p**x), Q.prime(p) & Q.integer(x)) is None + + # Ideally, these should return True since the base is prime and the exponent is one, + # but currently, they return None. + assert ask(Q.prime(x**y), Q.prime(x) & Q.eq(y,1)) is None + assert ask(Q.prime(x**y), Q.prime(x) & Q.integer(y) & Q.gt(y,0) & Q.lt(y,2)) is None + + assert ask(Q.prime(Pow(x,1, evaluate=False)), Q.prime(x)) is True + + +@_both_exp_pow +def test_positive(): + assert ask(Q.positive(cos(I) ** 2 + sin(I) ** 2 - 1)) is None + assert ask(Q.positive(x), Q.positive(x)) is True + assert ask(Q.positive(x), Q.negative(x)) is False + assert ask(Q.positive(x), Q.nonzero(x)) is None + + assert ask(Q.positive(-x), Q.positive(x)) is False + assert ask(Q.positive(-x), Q.negative(x)) is True + + assert ask(Q.positive(x + y), Q.positive(x) & Q.positive(y)) is True + assert ask(Q.positive(x + y), Q.positive(x) & Q.nonnegative(y)) is True + assert ask(Q.positive(x + y), Q.positive(x) & Q.negative(y)) is None + assert ask(Q.positive(x + y), Q.positive(x) & Q.imaginary(y)) is False + + assert ask(Q.positive(2*x), Q.positive(x)) is True + assumptions = Q.positive(x) & Q.negative(y) & Q.negative(z) & Q.positive(w) + assert ask(Q.positive(x*y*z)) is None + assert ask(Q.positive(x*y*z), assumptions) is True + assert ask(Q.positive(-x*y*z), assumptions) is False + + assert ask(Q.positive(x**I), Q.positive(x)) is None + + assert ask(Q.positive(x**2), Q.positive(x)) is True + assert ask(Q.positive(x**2), Q.negative(x)) is True + assert ask(Q.positive(x**3), Q.negative(x)) is False + assert ask(Q.positive(1/(1 + x**2)), Q.real(x)) is True + assert ask(Q.positive(2**I)) is False + assert ask(Q.positive(2 + I)) is False + # although this could be False, it is representative of expressions + # that don't evaluate to a zero with precision + assert ask(Q.positive(cos(I)**2 + sin(I)**2 - 1)) is None + assert ask(Q.positive(-I + I*(cos(2)**2 + sin(2)**2))) is None + + #exponential + assert ask(Q.positive(exp(x)), Q.real(x)) is True + assert ask(~Q.negative(exp(x)), Q.real(x)) is True + assert ask(Q.positive(x + exp(x)), Q.real(x)) is None + assert ask(Q.positive(exp(x)), Q.imaginary(x)) is None + assert ask(Q.positive(exp(2*pi*I, evaluate=False)), Q.imaginary(x)) is True + assert ask(Q.negative(exp(pi*I, evaluate=False)), Q.imaginary(x)) is True + assert ask(Q.positive(exp(x*pi*I)), Q.even(x)) is True + assert ask(Q.positive(exp(x*pi*I)), Q.odd(x)) is False + assert ask(Q.positive(exp(x*pi*I)), Q.real(x)) is None + + # logarithm + assert ask(Q.positive(log(x)), Q.imaginary(x)) is False + assert ask(Q.positive(log(x)), Q.negative(x)) is False + assert ask(Q.positive(log(x)), Q.positive(x)) is None + assert ask(Q.positive(log(x + 2)), Q.positive(x)) is True + + # factorial + assert ask(Q.positive(factorial(x)), Q.integer(x) & Q.positive(x)) + assert ask(Q.positive(factorial(x)), Q.integer(x)) is None + + #absolute value + assert ask(Q.positive(Abs(x))) is None # Abs(0) = 0 + assert ask(Q.positive(Abs(x)), Q.positive(x)) is True + + +def test_nonpositive(): + assert ask(Q.nonpositive(-1)) + assert ask(Q.nonpositive(0)) + assert ask(Q.nonpositive(1)) is False + assert ask(~Q.positive(x), Q.nonpositive(x)) + assert ask(Q.nonpositive(x), Q.positive(x)) is False + assert ask(Q.nonpositive(sqrt(-1))) is False + assert ask(Q.nonpositive(x), Q.imaginary(x)) is False + + +def test_nonnegative(): + assert ask(Q.nonnegative(-1)) is False + assert ask(Q.nonnegative(0)) + assert ask(Q.nonnegative(1)) + assert ask(~Q.negative(x), Q.nonnegative(x)) + assert ask(Q.nonnegative(x), Q.negative(x)) is False + assert ask(Q.nonnegative(sqrt(-1))) is False + assert ask(Q.nonnegative(x), Q.imaginary(x)) is False + +def test_real_basic(): + assert ask(Q.real(x)) is None + assert ask(Q.real(x), Q.real(x)) is True + assert ask(Q.real(x), Q.nonzero(x)) is True + assert ask(Q.real(x), Q.positive(x)) is True + assert ask(Q.real(x), Q.negative(x)) is True + assert ask(Q.real(x), Q.integer(x)) is True + assert ask(Q.real(x), Q.even(x)) is True + assert ask(Q.real(x), Q.prime(x)) is True + + assert ask(Q.real(x/sqrt(2)), Q.real(x)) is True + assert ask(Q.real(x/sqrt(-2)), Q.real(x)) is False + + assert ask(Q.real(x + 1), Q.real(x)) is True + assert ask(Q.real(x + I), Q.real(x)) is False + assert ask(Q.real(x + I), Q.complex(x)) is None + + assert ask(Q.real(2*x), Q.real(x)) is True + assert ask(Q.real(I*x), Q.real(x)) is False + assert ask(Q.real(I*x), Q.imaginary(x)) is True + assert ask(Q.real(I*x), Q.complex(x)) is None + + +def test_real_pow(): + assert ask(Q.real(x**2), Q.real(x)) is True + assert ask(Q.real(sqrt(x)), Q.negative(x)) is False + assert ask(Q.real(x**y), Q.real(x) & Q.integer(y)) is None + assert ask(Q.real(x**y), Q.real(x) & Q.real(y)) is None + assert ask(Q.real(x**y), Q.positive(x) & Q.real(y)) is True + assert ask(Q.real(x**y), Q.imaginary(x) & Q.imaginary(y)) is None # I**I or (2*I)**I + assert ask(Q.real(x**y), Q.imaginary(x) & Q.real(y)) is None # I**1 or I**0 + assert ask(Q.real(x**y), Q.real(x) & Q.imaginary(y)) is None # could be exp(2*pi*I) or 2**I + assert ask(Q.real(x**0), Q.imaginary(x)) is True + assert ask(Q.real(x**y), Q.positive(x) & Q.real(y)) is True + assert ask(Q.real(x**y), Q.real(x) & Q.rational(y)) is None + assert ask(Q.real(x**y), Q.imaginary(x) & Q.integer(y)) is None + assert ask(Q.real(x**y), Q.imaginary(x) & Q.odd(y)) is False + assert ask(Q.real(x**y), Q.imaginary(x) & Q.even(y)) is True + assert ask(Q.real(x**(y/z)), Q.real(x) & Q.real(y/z) & Q.rational(y/z) & Q.even(z) & Q.positive(x)) is True + assert ask(Q.real(x**(y/z)), Q.real(x) & Q.rational(y/z) & Q.even(z) & Q.negative(x)) is None + assert ask(Q.real(x**(y/z)), Q.real(x) & Q.integer(y/z)) is None + assert ask(Q.real(x**(y/z)), Q.real(x) & Q.real(y/z) & Q.positive(x)) is True + assert ask(Q.real(x**(y/z)), Q.real(x) & Q.real(y/z) & Q.negative(x)) is None + assert ask(Q.real((-I)**i), Q.imaginary(i)) is True + assert ask(Q.real(I**i), Q.imaginary(i)) is True + assert ask(Q.real(i**i), Q.imaginary(i)) is None # i might be 2*I + assert ask(Q.real(x**i), Q.imaginary(i)) is None # x could be 0 + assert ask(Q.real(x**(I*pi/log(x))), Q.real(x)) is True + + # https://github.com/sympy/sympy/issues/27485 + assert ask(Q.real(n**p), Q.negative(n) & Q.positive(p)) is None + + # https://github.com/sympy/sympy/issues/16530 + assert ask(Q.real(1/Abs(x))) is None + assert ask(Q.real(x**y), Q.zero(x) & Q.real(y)) is None + assert ask(Q.real(x**y), Q.zero(x) & Q.positive(y)) is True + + +@_both_exp_pow +def test_real_functions(): + # trigonometric functions + assert ask(Q.real(sin(x))) is None + assert ask(Q.real(cos(x))) is None + assert ask(Q.real(sin(x)), Q.real(x)) is True + assert ask(Q.real(cos(x)), Q.real(x)) is True + + # exponential function + assert ask(Q.real(exp(x))) is None + assert ask(Q.real(exp(x)), Q.real(x)) is True + assert ask(Q.real(x + exp(x)), Q.real(x)) is True + assert ask(Q.real(exp(2*pi*I, evaluate=False))) is True + assert ask(Q.real(exp(pi*I, evaluate=False))) is True + assert ask(Q.real(exp(pi*I/2, evaluate=False))) is False + + # logarithm + assert ask(Q.real(log(I))) is False + assert ask(Q.real(log(2*I))) is False + assert ask(Q.real(log(I + 1))) is False + assert ask(Q.real(log(x)), Q.complex(x)) is None + assert ask(Q.real(log(x)), Q.imaginary(x)) is False + assert ask(Q.real(log(exp(x))), Q.imaginary(x)) is None # exp(2*pi*I) is 1, log(exp(pi*I)) is pi*I (disregarding periodicity) + assert ask(Q.real(log(exp(x))), Q.complex(x)) is None + eq = Pow(exp(2*pi*I*x, evaluate=False), x, evaluate=False) + assert ask(Q.real(eq), Q.integer(x)) is True + assert ask(Q.real(exp(x)**x), Q.imaginary(x)) is True + assert ask(Q.real(exp(x)**x), Q.complex(x)) is None + + # Q.complexes + assert ask(Q.real(re(x))) is True + assert ask(Q.real(im(x))) is True + + +def test_matrix(): + + # hermitian + assert ask(Q.hermitian(Matrix([[2, 2 + I, 4], [2 - I, 3, I], [4, -I, 1]]))) == True + assert ask(Q.hermitian(Matrix([[2, 2 + I, 4], [2 + I, 3, I], [4, -I, 1]]))) == False + z = symbols('z', complex=True) + assert ask(Q.hermitian(Matrix([[2, 2 + I, z], [2 - I, 3, I], [4, -I, 1]]))) == None + assert ask(Q.hermitian(SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))))) == True + assert ask(Q.hermitian(SparseMatrix(((25, 15, -5), (15, I, 0), (-5, 0, 11))))) == False + assert ask(Q.hermitian(SparseMatrix(((25, 15, -5), (15, z, 0), (-5, 0, 11))))) == None + + # antihermitian + A = Matrix([[0, -2 - I, 0], [2 - I, 0, -I], [0, -I, 0]]) + B = Matrix([[-I, 2 + I, 0], [-2 + I, 0, 2 + I], [0, -2 + I, -I]]) + assert ask(Q.antihermitian(A)) is True + assert ask(Q.antihermitian(B)) is True + assert ask(Q.antihermitian(A**2)) is False + C = (B**3) + C.simplify() + assert ask(Q.antihermitian(C)) is True + _A = Matrix([[0, -2 - I, 0], [z, 0, -I], [0, -I, 0]]) + assert ask(Q.antihermitian(_A)) is None + + +@_both_exp_pow +def test_algebraic(): + assert ask(Q.algebraic(x)) is None + + assert ask(Q.algebraic(I)) is True + assert ask(Q.algebraic(2*I)) is True + assert ask(Q.algebraic(I/3)) is True + + assert ask(Q.algebraic(sqrt(7))) is True + assert ask(Q.algebraic(2*sqrt(7))) is True + assert ask(Q.algebraic(sqrt(7)/3)) is True + + assert ask(Q.algebraic(I*sqrt(3))) is True + assert ask(Q.algebraic(sqrt(1 + I*sqrt(3)))) is True + + assert ask(Q.algebraic(1 + I*sqrt(3)**Rational(17, 31))) is True + assert ask(Q.algebraic(1 + I*sqrt(3)**(17/pi))) is None + + for f in [exp, sin, tan, asin, atan, cos]: + assert ask(Q.algebraic(f(7))) is False + assert ask(Q.algebraic(f(7, evaluate=False))) is False + assert ask(Q.algebraic(f(0, evaluate=False))) is True + assert ask(Q.algebraic(f(x)), Q.algebraic(x)) is None + assert ask(Q.algebraic(f(x)), Q.algebraic(x) & Q.nonzero(x)) is False + + for g in [log, acos]: + assert ask(Q.algebraic(g(7))) is False + assert ask(Q.algebraic(g(7, evaluate=False))) is False + assert ask(Q.algebraic(g(1, evaluate=False))) is True + assert ask(Q.algebraic(g(x)), Q.algebraic(x)) is None + assert ask(Q.algebraic(g(x)), Q.algebraic(x) & Q.nonzero(x - 1)) is False + + for h in [cot, acot]: + assert ask(Q.algebraic(h(7))) is False + assert ask(Q.algebraic(h(7, evaluate=False))) is False + assert ask(Q.algebraic(h(x)), Q.algebraic(x)) is False + + assert ask(Q.algebraic(sqrt(sin(7)))) is None + assert ask(Q.algebraic(sqrt(y + I*sqrt(7)))) is None + + assert ask(Q.algebraic(2.47)) is True + + assert ask(Q.algebraic(x), Q.transcendental(x)) is False + assert ask(Q.transcendental(x), Q.algebraic(x)) is False + + #https://github.com/sympy/sympy/issues/27445 + assert ask(Q.algebraic(Pow(1, x, evaluate=False)), Q.algebraic(x)) is None + assert ask(Q.algebraic(Pow(x, y))) is None + assert ask(Q.algebraic(Pow(1, x, evaluate=False))) is None + assert ask(Q.algebraic(x**(pi*I))) is None + assert ask(Q.algebraic(pi**n),Q.integer(n) & Q.positive(n)) is False + assert ask(Q.algebraic(x**y),Q.algebraic(x) & Q.rational(y)) is True + + +def test_global(): + """Test ask with global assumptions""" + assert ask(Q.integer(x)) is None + global_assumptions.add(Q.integer(x)) + assert ask(Q.integer(x)) is True + global_assumptions.clear() + assert ask(Q.integer(x)) is None + + +def test_custom_context(): + """Test ask with custom assumptions context""" + assert ask(Q.integer(x)) is None + local_context = AssumptionsContext() + local_context.add(Q.integer(x)) + assert ask(Q.integer(x), context=local_context) is True + assert ask(Q.integer(x)) is None + + +def test_functions_in_assumptions(): + assert ask(Q.negative(x), Q.real(x) >> Q.positive(x)) is False + assert ask(Q.negative(x), Equivalent(Q.real(x), Q.positive(x))) is False + assert ask(Q.negative(x), Xor(Q.real(x), Q.negative(x))) is False + + +def test_composite_ask(): + assert ask(Q.negative(x) & Q.integer(x), + assumptions=Q.real(x) >> Q.positive(x)) is False + + +def test_composite_proposition(): + assert ask(True) is True + assert ask(False) is False + assert ask(~Q.negative(x), Q.positive(x)) is True + assert ask(~Q.real(x), Q.commutative(x)) is None + assert ask(Q.negative(x) & Q.integer(x), Q.positive(x)) is False + assert ask(Q.negative(x) & Q.integer(x)) is None + assert ask(Q.real(x) | Q.integer(x), Q.positive(x)) is True + assert ask(Q.real(x) | Q.integer(x)) is None + assert ask(Q.real(x) >> Q.positive(x), Q.negative(x)) is False + assert ask(Implies( + Q.real(x), Q.positive(x), evaluate=False), Q.negative(x)) is False + assert ask(Implies(Q.real(x), Q.positive(x), evaluate=False)) is None + assert ask(Equivalent(Q.integer(x), Q.even(x)), Q.even(x)) is True + assert ask(Equivalent(Q.integer(x), Q.even(x))) is None + assert ask(Equivalent(Q.positive(x), Q.integer(x)), Q.integer(x)) is None + assert ask(Q.real(x) | Q.integer(x), Q.real(x) | Q.integer(x)) is True + +def test_tautology(): + assert ask(Q.real(x) | ~Q.real(x)) is True + assert ask(Q.real(x) & ~Q.real(x)) is False + +def test_composite_assumptions(): + assert ask(Q.real(x), Q.real(x) & Q.real(y)) is True + assert ask(Q.positive(x), Q.positive(x) | Q.positive(y)) is None + assert ask(Q.positive(x), Q.real(x) >> Q.positive(y)) is None + assert ask(Q.real(x), ~(Q.real(x) >> Q.real(y))) is True + +def test_key_extensibility(): + """test that you can add keys to the ask system at runtime""" + # make sure the key is not defined + raises(AttributeError, lambda: ask(Q.my_key(x))) + + # Old handler system + class MyAskHandler(AskHandler): + @staticmethod + def Symbol(expr, assumptions): + return True + try: + with warns_deprecated_sympy(): + register_handler('my_key', MyAskHandler) + with warns_deprecated_sympy(): + assert ask(Q.my_key(x)) is True + with warns_deprecated_sympy(): + assert ask(Q.my_key(x + 1)) is None + finally: + # We have to disable the stacklevel testing here because this raises + # the warning twice from two different places + with warns_deprecated_sympy(): + remove_handler('my_key', MyAskHandler) + del Q.my_key + raises(AttributeError, lambda: ask(Q.my_key(x))) + + # New handler system + class MyPredicate(Predicate): + pass + try: + Q.my_key = MyPredicate() + @Q.my_key.register(Symbol) + def _(expr, assumptions): + return True + assert ask(Q.my_key(x)) is True + assert ask(Q.my_key(x+1)) is None + finally: + del Q.my_key + raises(AttributeError, lambda: ask(Q.my_key(x))) + + +def test_type_extensibility(): + """test that new types can be added to the ask system at runtime + """ + from sympy.core import Basic + + class MyType(Basic): + pass + + @Q.prime.register(MyType) + def _(expr, assumptions): + return True + + assert ask(Q.prime(MyType())) is True + + +def test_single_fact_lookup(): + known_facts = And(Implies(Q.integer, Q.rational), + Implies(Q.rational, Q.real), + Implies(Q.real, Q.complex)) + known_facts_keys = {Q.integer, Q.rational, Q.real, Q.complex} + + known_facts_cnf = to_cnf(known_facts) + mapping = single_fact_lookup(known_facts_keys, known_facts_cnf) + + assert mapping[Q.rational] == {Q.real, Q.rational, Q.complex} + + +def test_generate_known_facts_dict(): + known_facts = And(Implies(Q.integer(x), Q.rational(x)), + Implies(Q.rational(x), Q.real(x)), + Implies(Q.real(x), Q.complex(x))) + known_facts_keys = {Q.integer(x), Q.rational(x), Q.real(x), Q.complex(x)} + + assert generate_known_facts_dict(known_facts_keys, known_facts) == \ + {Q.complex: ({Q.complex}, set()), + Q.integer: ({Q.complex, Q.integer, Q.rational, Q.real}, set()), + Q.rational: ({Q.complex, Q.rational, Q.real}, set()), + Q.real: ({Q.complex, Q.real}, set())} + + +@slow +def test_known_facts_consistent(): + """"Test that ask_generated.py is up-to-date""" + x = Symbol('x') + fact = get_known_facts(x) + # test cnf clauses of fact between unary predicates + cnf = CNF.to_CNF(fact) + clauses = set() + clauses.update(frozenset(Literal(lit.arg.function, lit.is_Not) for lit in sorted(cl, key=str)) for cl in cnf.clauses) + assert get_all_known_facts() == clauses + # test dictionary of fact between unary predicates + keys = [pred(x) for pred in get_known_facts_keys()] + mapping = generate_known_facts_dict(keys, fact) + assert get_known_facts_dict() == mapping + + +def test_Add_queries(): + assert ask(Q.prime(12345678901234567890 + (cos(1)**2 + sin(1)**2))) is True + assert ask(Q.even(Add(S(2), S(2), evaluate=False))) is True + assert ask(Q.prime(Add(S(2), S(2), evaluate=False))) is False + assert ask(Q.integer(Add(S(2), S(2), evaluate=False))) is True + + +def test_positive_assuming(): + with assuming(Q.positive(x + 1)): + assert not ask(Q.positive(x)) + + +def test_issue_5421(): + raises(TypeError, lambda: ask(pi/log(x), Q.real)) + + +def test_issue_3906(): + raises(TypeError, lambda: ask(Q.positive)) + + +def test_issue_5833(): + assert ask(Q.positive(log(x)**2), Q.positive(x)) is None + assert ask(~Q.negative(log(x)**2), Q.positive(x)) is True + + +def test_issue_6732(): + raises(ValueError, lambda: ask(Q.positive(x), Q.positive(x) & Q.negative(x))) + raises(ValueError, lambda: ask(Q.negative(x), Q.positive(x) & Q.negative(x))) + + +def test_issue_7246(): + assert ask(Q.positive(atan(p)), Q.positive(p)) is True + assert ask(Q.positive(atan(p)), Q.negative(p)) is False + assert ask(Q.positive(atan(p)), Q.zero(p)) is False + assert ask(Q.positive(atan(x))) is None + + assert ask(Q.positive(asin(p)), Q.positive(p)) is None + assert ask(Q.positive(asin(p)), Q.zero(p)) is None + assert ask(Q.positive(asin(Rational(1, 7)))) is True + assert ask(Q.positive(asin(x)), Q.positive(x) & Q.nonpositive(x - 1)) is True + assert ask(Q.positive(asin(x)), Q.negative(x) & Q.nonnegative(x + 1)) is False + + assert ask(Q.positive(acos(p)), Q.positive(p)) is None + assert ask(Q.positive(acos(Rational(1, 7)))) is True + assert ask(Q.positive(acos(x)), Q.nonnegative(x + 1) & Q.nonpositive(x - 1)) is True + assert ask(Q.positive(acos(x)), Q.nonnegative(x - 1)) is None + + assert ask(Q.positive(acot(x)), Q.positive(x)) is True + assert ask(Q.positive(acot(x)), Q.real(x)) is True + assert ask(Q.positive(acot(x)), Q.imaginary(x)) is False + assert ask(Q.positive(acot(x))) is None + + +@XFAIL +def test_issue_7246_failing(): + #Move this test to test_issue_7246 once + #the new assumptions module is improved. + assert ask(Q.positive(acos(x)), Q.zero(x)) is True + + +def test_check_old_assumption(): + x = symbols('x', real=True) + assert ask(Q.real(x)) is True + assert ask(Q.imaginary(x)) is False + assert ask(Q.complex(x)) is True + + x = symbols('x', imaginary=True) + assert ask(Q.real(x)) is False + assert ask(Q.imaginary(x)) is True + assert ask(Q.complex(x)) is True + + x = symbols('x', complex=True) + assert ask(Q.real(x)) is None + assert ask(Q.complex(x)) is True + + x = symbols('x', positive=True) + assert ask(Q.positive(x)) is True + assert ask(Q.negative(x)) is False + assert ask(Q.real(x)) is True + + x = symbols('x', commutative=False) + assert ask(Q.commutative(x)) is False + + x = symbols('x', negative=True) + assert ask(Q.positive(x)) is False + assert ask(Q.negative(x)) is True + + x = symbols('x', nonnegative=True) + assert ask(Q.negative(x)) is False + assert ask(Q.positive(x)) is None + assert ask(Q.zero(x)) is None + + x = symbols('x', finite=True) + assert ask(Q.finite(x)) is True + + x = symbols('x', prime=True) + assert ask(Q.prime(x)) is True + assert ask(Q.composite(x)) is False + + x = symbols('x', composite=True) + assert ask(Q.prime(x)) is False + assert ask(Q.composite(x)) is True + + x = symbols('x', even=True) + assert ask(Q.even(x)) is True + assert ask(Q.odd(x)) is False + + x = symbols('x', odd=True) + assert ask(Q.even(x)) is False + assert ask(Q.odd(x)) is True + + x = symbols('x', nonzero=True) + assert ask(Q.nonzero(x)) is True + assert ask(Q.zero(x)) is False + + x = symbols('x', zero=True) + assert ask(Q.zero(x)) is True + + x = symbols('x', integer=True) + assert ask(Q.integer(x)) is True + + x = symbols('x', rational=True) + assert ask(Q.rational(x)) is True + assert ask(Q.irrational(x)) is False + + x = symbols('x', irrational=True) + assert ask(Q.irrational(x)) is True + assert ask(Q.rational(x)) is False + + +def test_issue_9636(): + assert ask(Q.integer(1.0)) is None + assert ask(Q.prime(3.0)) is None + assert ask(Q.composite(4.0)) is None + assert ask(Q.even(2.0)) is None + assert ask(Q.odd(3.0)) is None + + +def test_autosimp_used_to_fail(): + # See issue #9807 + assert ask(Q.imaginary(0**I)) is None + assert ask(Q.imaginary(0**(-I))) is None + assert ask(Q.real(0**I)) is None + assert ask(Q.real(0**(-I))) is None + + +def test_custom_AskHandler(): + from sympy.logic.boolalg import conjuncts + + # Old handler system + class MersenneHandler(AskHandler): + @staticmethod + def Integer(expr, assumptions): + if ask(Q.integer(log(expr + 1, 2))): + return True + @staticmethod + def Symbol(expr, assumptions): + if expr in conjuncts(assumptions): + return True + try: + with warns_deprecated_sympy(): + register_handler('mersenne', MersenneHandler) + n = Symbol('n', integer=True) + with warns_deprecated_sympy(): + assert ask(Q.mersenne(7)) + with warns_deprecated_sympy(): + assert ask(Q.mersenne(n), Q.mersenne(n)) + finally: + del Q.mersenne + + # New handler system + class MersennePredicate(Predicate): + pass + try: + Q.mersenne = MersennePredicate() + @Q.mersenne.register(Integer) + def _(expr, assumptions): + if ask(Q.integer(log(expr + 1, 2))): + return True + @Q.mersenne.register(Symbol) + def _(expr, assumptions): + if expr in conjuncts(assumptions): + return True + assert ask(Q.mersenne(7)) + assert ask(Q.mersenne(n), Q.mersenne(n)) + finally: + del Q.mersenne + + +def test_polyadic_predicate(): + + class SexyPredicate(Predicate): + pass + try: + Q.sexyprime = SexyPredicate() + + @Q.sexyprime.register(Integer, Integer) + def _(int1, int2, assumptions): + args = sorted([int1, int2]) + if not all(ask(Q.prime(a), assumptions) for a in args): + return False + return args[1] - args[0] == 6 + + @Q.sexyprime.register(Integer, Integer, Integer) + def _(int1, int2, int3, assumptions): + args = sorted([int1, int2, int3]) + if not all(ask(Q.prime(a), assumptions) for a in args): + return False + return args[2] - args[1] == 6 and args[1] - args[0] == 6 + + assert ask(Q.sexyprime(5, 11)) + assert ask(Q.sexyprime(7, 13, 19)) + finally: + del Q.sexyprime + + +def test_Predicate_handler_is_unique(): + + # Undefined predicate does not have a handler + assert Predicate('mypredicate').handler is None + + # Handler of defined predicate is unique to the class + class MyPredicate(Predicate): + pass + mp1 = MyPredicate(Str('mp1')) + mp2 = MyPredicate(Str('mp2')) + assert mp1.handler is mp2.handler + + +def test_relational(): + assert ask(Q.eq(x, 0), Q.zero(x)) + assert not ask(Q.eq(x, 0), Q.nonzero(x)) + assert not ask(Q.ne(x, 0), Q.zero(x)) + assert ask(Q.ne(x, 0), Q.nonzero(x)) + + +def test_issue_25221(): + assert ask(Q.transcendental(x), Q.algebraic(x) | Q.positive(y,y)) is None + assert ask(Q.transcendental(x), Q.algebraic(x) | (0 > y)) is None + assert ask(Q.transcendental(x), Q.algebraic(x) | Q.gt(0,y)) is None + + +def test_issue_27440(): + nan = S.NaN + assert ask(Q.negative(nan)) is None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_refine.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..81533a88b232cd5c3cfb9be17d09dad404d679dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_refine.py @@ -0,0 +1,227 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.expr import Expr +from sympy.core.numbers import (I, Rational, nan, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import (Abs, arg, im, re, sign) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, atan2) +from sympy.abc import w, x, y, z +from sympy.core.relational import Eq, Ne +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.expressions.matexpr import MatrixSymbol + + +def test_Abs(): + assert refine(Abs(x), Q.positive(x)) == x + assert refine(1 + Abs(x), Q.positive(x)) == 1 + x + assert refine(Abs(x), Q.negative(x)) == -x + assert refine(1 + Abs(x), Q.negative(x)) == 1 - x + + assert refine(Abs(x**2)) != x**2 + assert refine(Abs(x**2), Q.real(x)) == x**2 + + +def test_pow1(): + assert refine((-1)**x, Q.even(x)) == 1 + assert refine((-1)**x, Q.odd(x)) == -1 + assert refine((-2)**x, Q.even(x)) == 2**x + + # nested powers + assert refine(sqrt(x**2)) != Abs(x) + assert refine(sqrt(x**2), Q.complex(x)) != Abs(x) + assert refine(sqrt(x**2), Q.real(x)) == Abs(x) + assert refine(sqrt(x**2), Q.positive(x)) == x + assert refine((x**3)**Rational(1, 3)) != x + + assert refine((x**3)**Rational(1, 3), Q.real(x)) != x + assert refine((x**3)**Rational(1, 3), Q.positive(x)) == x + + assert refine(sqrt(1/x), Q.real(x)) != 1/sqrt(x) + assert refine(sqrt(1/x), Q.positive(x)) == 1/sqrt(x) + + # powers of (-1) + assert refine((-1)**(x + y), Q.even(x)) == (-1)**y + assert refine((-1)**(x + y + z), Q.odd(x) & Q.odd(z)) == (-1)**y + assert refine((-1)**(x + y + 1), Q.odd(x)) == (-1)**y + assert refine((-1)**(x + y + 2), Q.odd(x)) == (-1)**(y + 1) + assert refine((-1)**(x + 3)) == (-1)**(x + 1) + + # continuation + assert refine((-1)**((-1)**x/2 - S.Half), Q.integer(x)) == (-1)**x + assert refine((-1)**((-1)**x/2 + S.Half), Q.integer(x)) == (-1)**(x + 1) + assert refine((-1)**((-1)**x/2 + 5*S.Half), Q.integer(x)) == (-1)**(x + 1) + + +def test_pow2(): + assert refine((-1)**((-1)**x/2 - 7*S.Half), Q.integer(x)) == (-1)**(x + 1) + assert refine((-1)**((-1)**x/2 - 9*S.Half), Q.integer(x)) == (-1)**x + + # powers of Abs + assert refine(Abs(x)**2, Q.real(x)) == x**2 + assert refine(Abs(x)**3, Q.real(x)) == Abs(x)**3 + assert refine(Abs(x)**2) == Abs(x)**2 + + +def test_exp(): + x = Symbol('x', integer=True) + assert refine(exp(pi*I*2*x)) == 1 + assert refine(exp(pi*I*2*(x + S.Half))) == -1 + assert refine(exp(pi*I*2*(x + Rational(1, 4)))) == I + assert refine(exp(pi*I*2*(x + Rational(3, 4)))) == -I + + +def test_Piecewise(): + assert refine(Piecewise((1, x < 0), (3, True)), (x < 0)) == 1 + assert refine(Piecewise((1, x < 0), (3, True)), ~(x < 0)) == 3 + assert refine(Piecewise((1, x < 0), (3, True)), (y < 0)) == \ + Piecewise((1, x < 0), (3, True)) + assert refine(Piecewise((1, x > 0), (3, True)), (x > 0)) == 1 + assert refine(Piecewise((1, x > 0), (3, True)), ~(x > 0)) == 3 + assert refine(Piecewise((1, x > 0), (3, True)), (y > 0)) == \ + Piecewise((1, x > 0), (3, True)) + assert refine(Piecewise((1, x <= 0), (3, True)), (x <= 0)) == 1 + assert refine(Piecewise((1, x <= 0), (3, True)), ~(x <= 0)) == 3 + assert refine(Piecewise((1, x <= 0), (3, True)), (y <= 0)) == \ + Piecewise((1, x <= 0), (3, True)) + assert refine(Piecewise((1, x >= 0), (3, True)), (x >= 0)) == 1 + assert refine(Piecewise((1, x >= 0), (3, True)), ~(x >= 0)) == 3 + assert refine(Piecewise((1, x >= 0), (3, True)), (y >= 0)) == \ + Piecewise((1, x >= 0), (3, True)) + assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(x, 0)))\ + == 1 + assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(0, x)))\ + == 1 + assert refine(Piecewise((1, Eq(x, 0)), (3, True)), ~(Eq(x, 0)))\ + == 3 + assert refine(Piecewise((1, Eq(x, 0)), (3, True)), ~(Eq(0, x)))\ + == 3 + assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(y, 0)))\ + == Piecewise((1, Eq(x, 0)), (3, True)) + assert refine(Piecewise((1, Ne(x, 0)), (3, True)), (Ne(x, 0)))\ + == 1 + assert refine(Piecewise((1, Ne(x, 0)), (3, True)), ~(Ne(x, 0)))\ + == 3 + assert refine(Piecewise((1, Ne(x, 0)), (3, True)), (Ne(y, 0)))\ + == Piecewise((1, Ne(x, 0)), (3, True)) + + +def test_atan2(): + assert refine(atan2(y, x), Q.real(y) & Q.positive(x)) == atan(y/x) + assert refine(atan2(y, x), Q.negative(y) & Q.positive(x)) == atan(y/x) + assert refine(atan2(y, x), Q.negative(y) & Q.negative(x)) == atan(y/x) - pi + assert refine(atan2(y, x), Q.positive(y) & Q.negative(x)) == atan(y/x) + pi + assert refine(atan2(y, x), Q.zero(y) & Q.negative(x)) == pi + assert refine(atan2(y, x), Q.positive(y) & Q.zero(x)) == pi/2 + assert refine(atan2(y, x), Q.negative(y) & Q.zero(x)) == -pi/2 + assert refine(atan2(y, x), Q.zero(y) & Q.zero(x)) is nan + + +def test_re(): + assert refine(re(x), Q.real(x)) == x + assert refine(re(x), Q.imaginary(x)) is S.Zero + assert refine(re(x+y), Q.real(x) & Q.real(y)) == x + y + assert refine(re(x+y), Q.real(x) & Q.imaginary(y)) == x + assert refine(re(x*y), Q.real(x) & Q.real(y)) == x * y + assert refine(re(x*y), Q.real(x) & Q.imaginary(y)) == 0 + assert refine(re(x*y*z), Q.real(x) & Q.real(y) & Q.real(z)) == x * y * z + + +def test_im(): + assert refine(im(x), Q.imaginary(x)) == -I*x + assert refine(im(x), Q.real(x)) is S.Zero + assert refine(im(x+y), Q.imaginary(x) & Q.imaginary(y)) == -I*x - I*y + assert refine(im(x+y), Q.real(x) & Q.imaginary(y)) == -I*y + assert refine(im(x*y), Q.imaginary(x) & Q.real(y)) == -I*x*y + assert refine(im(x*y), Q.imaginary(x) & Q.imaginary(y)) == 0 + assert refine(im(1/x), Q.imaginary(x)) == -I/x + assert refine(im(x*y*z), Q.imaginary(x) & Q.imaginary(y) + & Q.imaginary(z)) == -I*x*y*z + + +def test_complex(): + assert refine(re(1/(x + I*y)), Q.real(x) & Q.real(y)) == \ + x/(x**2 + y**2) + assert refine(im(1/(x + I*y)), Q.real(x) & Q.real(y)) == \ + -y/(x**2 + y**2) + assert refine(re((w + I*x) * (y + I*z)), Q.real(w) & Q.real(x) & Q.real(y) + & Q.real(z)) == w*y - x*z + assert refine(im((w + I*x) * (y + I*z)), Q.real(w) & Q.real(x) & Q.real(y) + & Q.real(z)) == w*z + x*y + + +def test_sign(): + x = Symbol('x', real = True) + assert refine(sign(x), Q.positive(x)) == 1 + assert refine(sign(x), Q.negative(x)) == -1 + assert refine(sign(x), Q.zero(x)) == 0 + assert refine(sign(x), True) == sign(x) + assert refine(sign(Abs(x)), Q.nonzero(x)) == 1 + + x = Symbol('x', imaginary=True) + assert refine(sign(x), Q.positive(im(x))) == S.ImaginaryUnit + assert refine(sign(x), Q.negative(im(x))) == -S.ImaginaryUnit + assert refine(sign(x), True) == sign(x) + + x = Symbol('x', complex=True) + assert refine(sign(x), Q.zero(x)) == 0 + +def test_arg(): + x = Symbol('x', complex = True) + assert refine(arg(x), Q.positive(x)) == 0 + assert refine(arg(x), Q.negative(x)) == pi + +def test_func_args(): + class MyClass(Expr): + # A class with nontrivial .func + + def __init__(self, *args): + self.my_member = "" + + @property + def func(self): + def my_func(*args): + obj = MyClass(*args) + obj.my_member = self.my_member + return obj + return my_func + + x = MyClass() + x.my_member = "A very important value" + assert x.my_member == refine(x).my_member + +def test_issue_refine_9384(): + assert refine(Piecewise((1, x < 0), (0, True)), Q.positive(x)) == 0 + assert refine(Piecewise((1, x < 0), (0, True)), Q.negative(x)) == 1 + assert refine(Piecewise((1, x > 0), (0, True)), Q.positive(x)) == 1 + assert refine(Piecewise((1, x > 0), (0, True)), Q.negative(x)) == 0 + + +def test_eval_refine(): + class MockExpr(Expr): + def _eval_refine(self, assumptions): + return True + + mock_obj = MockExpr() + assert refine(mock_obj) + +def test_refine_issue_12724(): + expr1 = refine(Abs(x * y), Q.positive(x)) + expr2 = refine(Abs(x * y * z), Q.positive(x)) + assert expr1 == x * Abs(y) + assert expr2 == x * Abs(y * z) + y1 = Symbol('y1', real = True) + expr3 = refine(Abs(x * y1**2 * z), Q.positive(x)) + assert expr3 == x * y1**2 * Abs(z) + + +def test_matrixelement(): + x = MatrixSymbol('x', 3, 3) + i = Symbol('i', positive = True) + j = Symbol('j', positive = True) + assert refine(x[0, 1], Q.symmetric(x)) == x[0, 1] + assert refine(x[1, 0], Q.symmetric(x)) == x[0, 1] + assert refine(x[i, j], Q.symmetric(x)) == x[j, i] + assert refine(x[j, i], Q.symmetric(x)) == x[j, i] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_rel_queries.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_rel_queries.py new file mode 100644 index 0000000000000000000000000000000000000000..46fe3a35dc1adb23668e88d5794fe1c0ab22f33a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_rel_queries.py @@ -0,0 +1,172 @@ +from sympy.assumptions.lra_satask import lra_satask +from sympy.logic.algorithms.lra_theory import UnhandledInput +from sympy.assumptions.ask import Q, ask + +from sympy.core import symbols, Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.core.numbers import I + +from sympy.testing.pytest import raises, XFAIL +x, y, z = symbols("x y z", real=True) + +def test_lra_satask(): + im = Symbol('im', imaginary=True) + + # test preprocessing of unequalities is working correctly + assert lra_satask(Q.eq(x, 1), ~Q.ne(x, 0)) is False + assert lra_satask(Q.eq(x, 0), ~Q.ne(x, 0)) is True + assert lra_satask(~Q.ne(x, 0), Q.eq(x, 0)) is True + assert lra_satask(~Q.eq(x, 0), Q.eq(x, 0)) is False + assert lra_satask(Q.ne(x, 0), Q.eq(x, 0)) is False + + # basic tests + assert lra_satask(Q.ne(x, x)) is False + assert lra_satask(Q.eq(x, x)) is True + assert lra_satask(Q.gt(x, 0), Q.gt(x, 1)) is True + + # check that True/False are handled + assert lra_satask(Q.gt(x, 0), True) is None + assert raises(ValueError, lambda: lra_satask(Q.gt(x, 0), False)) + + # check imaginary numbers are correctly handled + # (im * I).is_real returns True so this is an edge case + raises(UnhandledInput, lambda: lra_satask(Q.gt(im * I, 0), Q.gt(im * I, 0))) + + # check matrix inputs + X = MatrixSymbol("X", 2, 2) + raises(UnhandledInput, lambda: lra_satask(Q.lt(X, 2) & Q.gt(X, 3))) + + +def test_old_assumptions(): + # test unhandled old assumptions + w = symbols("w") + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", rational=False, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", odd=True, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", even=True, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", prime=True, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", composite=True, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", integer=True, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + w = symbols("w", integer=False, real=True) + raises(UnhandledInput, lambda: lra_satask(Q.lt(w, 2) & Q.gt(w, 3))) + + # test handled + w = symbols("w", positive=True, real=True) + assert lra_satask(Q.le(w, 0)) is False + assert lra_satask(Q.gt(w, 0)) is True + w = symbols("w", negative=True, real=True) + assert lra_satask(Q.lt(w, 0)) is True + assert lra_satask(Q.ge(w, 0)) is False + w = symbols("w", zero=True, real=True) + assert lra_satask(Q.eq(w, 0)) is True + assert lra_satask(Q.ne(w, 0)) is False + w = symbols("w", nonzero=True, real=True) + assert lra_satask(Q.ne(w, 0)) is True + assert lra_satask(Q.eq(w, 1)) is None + w = symbols("w", nonpositive=True, real=True) + assert lra_satask(Q.le(w, 0)) is True + assert lra_satask(Q.gt(w, 0)) is False + w = symbols("w", nonnegative=True, real=True) + assert lra_satask(Q.ge(w, 0)) is True + assert lra_satask(Q.lt(w, 0)) is False + + +def test_rel_queries(): + assert ask(Q.lt(x, 2) & Q.gt(x, 3)) is False + assert ask(Q.positive(x - z), (x > y) & (y > z)) is True + assert ask(x + y > 2, (x < 0) & (y <0)) is False + assert ask(x > z, (x > y) & (y > z)) is True + + +def test_unhandled_queries(): + X = MatrixSymbol("X", 2, 2) + assert ask(Q.lt(X, 2) & Q.gt(X, 3)) is None + + +def test_all_pred(): + # test usable pred + assert lra_satask(Q.extended_positive(x), (x > 2)) is True + assert lra_satask(Q.positive_infinite(x)) is False + assert lra_satask(Q.negative_infinite(x)) is False + + # test disallowed pred + raises(UnhandledInput, lambda: lra_satask((x > 0), (x > 2) & Q.prime(x))) + raises(UnhandledInput, lambda: lra_satask((x > 0), (x > 2) & Q.composite(x))) + raises(UnhandledInput, lambda: lra_satask((x > 0), (x > 2) & Q.odd(x))) + raises(UnhandledInput, lambda: lra_satask((x > 0), (x > 2) & Q.even(x))) + raises(UnhandledInput, lambda: lra_satask((x > 0), (x > 2) & Q.integer(x))) + + +def test_number_line_properties(): + # From: + # https://en.wikipedia.org/wiki/Inequality_(mathematics)#Properties_on_the_number_line + + a, b, c = symbols("a b c", real=True) + + # Transitivity + # If a <= b and b <= c, then a <= c. + assert ask(a <= c, (a <= b) & (b <= c)) is True + # If a <= b and b < c, then a < c. + assert ask(a < c, (a <= b) & (b < c)) is True + # If a < b and b <= c, then a < c. + assert ask(a < c, (a < b) & (b <= c)) is True + + # Addition and subtraction + # If a <= b, then a + c <= b + c and a - c <= b - c. + assert ask(a + c <= b + c, a <= b) is True + assert ask(a - c <= b - c, a <= b) is True + + +@XFAIL +def test_failing_number_line_properties(): + # From: + # https://en.wikipedia.org/wiki/Inequality_(mathematics)#Properties_on_the_number_line + + a, b, c = symbols("a b c", real=True) + + # Multiplication and division + # If a <= b and c > 0, then ac <= bc and a/c <= b/c. (True for non-zero c) + assert ask(a*c <= b*c, (a <= b) & (c > 0) & ~ Q.zero(c)) is True + assert ask(a/c <= b/c, (a <= b) & (c > 0) & ~ Q.zero(c)) is True + # If a <= b and c < 0, then ac >= bc and a/c >= b/c. (True for non-zero c) + assert ask(a*c >= b*c, (a <= b) & (c < 0) & ~ Q.zero(c)) is True + assert ask(a/c >= b/c, (a <= b) & (c < 0) & ~ Q.zero(c)) is True + + # Additive inverse + # If a <= b, then -a >= -b. + assert ask(-a >= -b, a <= b) is True + + # Multiplicative inverse + # For a, b that are both negative or both positive: + # If a <= b, then 1/a >= 1/b . + assert ask(1/a >= 1/b, (a <= b) & Q.positive(x) & Q.positive(b)) is True + assert ask(1/a >= 1/b, (a <= b) & Q.negative(x) & Q.negative(b)) is True + + +def test_equality(): + # test symmetry and reflexivity + assert ask(Q.eq(x, x)) is True + assert ask(Q.eq(y, x), Q.eq(x, y)) is True + assert ask(Q.eq(y, x), ~Q.eq(z, z) | Q.eq(x, y)) is True + + # test transitivity + assert ask(Q.eq(x,z), Q.eq(x,y) & Q.eq(y,z)) is True + + +@XFAIL +def test_equality_failing(): + # Note that implementing the substitution property of equality + # most likely requires a redesign of the new assumptions. + # See issue #25485 for why this is the case and general ideas + # about how things could be redesigned. + + # test substitution property + assert ask(Q.prime(x), Q.eq(x, y) & Q.prime(y)) is True + assert ask(Q.real(x), Q.eq(x, y) & Q.real(y)) is True + assert ask(Q.imaginary(x), Q.eq(x, y) & Q.imaginary(y)) is True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_satask.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_satask.py new file mode 100644 index 0000000000000000000000000000000000000000..5831b69e3e6bf2b1a906d1140967510c2ea8b630 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_satask.py @@ -0,0 +1,378 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.assume import assuming +from sympy.core.numbers import (I, pi) +from sympy.core.relational import (Eq, Gt) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import Abs +from sympy.logic.boolalg import Implies +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.assumptions.cnf import CNF, Literal +from sympy.assumptions.satask import (satask, extract_predargs, + get_relevant_clsfacts) + +from sympy.testing.pytest import raises, XFAIL + + +x, y, z = symbols('x y z') + + +def test_satask(): + # No relevant facts + assert satask(Q.real(x), Q.real(x)) is True + assert satask(Q.real(x), ~Q.real(x)) is False + assert satask(Q.real(x)) is None + + assert satask(Q.real(x), Q.positive(x)) is True + assert satask(Q.positive(x), Q.real(x)) is None + assert satask(Q.real(x), ~Q.positive(x)) is None + assert satask(Q.positive(x), ~Q.real(x)) is False + + raises(ValueError, lambda: satask(Q.real(x), Q.real(x) & ~Q.real(x))) + + with assuming(Q.positive(x)): + assert satask(Q.real(x)) is True + assert satask(~Q.positive(x)) is False + raises(ValueError, lambda: satask(Q.real(x), ~Q.positive(x))) + + assert satask(Q.zero(x), Q.nonzero(x)) is False + assert satask(Q.positive(x), Q.zero(x)) is False + assert satask(Q.real(x), Q.zero(x)) is True + assert satask(Q.zero(x), Q.zero(x*y)) is None + assert satask(Q.zero(x*y), Q.zero(x)) + + +def test_zero(): + """ + Everything in this test doesn't work with the ask handlers, and most + things would be very difficult or impossible to make work under that + model. + + """ + assert satask(Q.zero(x) | Q.zero(y), Q.zero(x*y)) is True + assert satask(Q.zero(x*y), Q.zero(x) | Q.zero(y)) is True + + assert satask(Implies(Q.zero(x), Q.zero(x*y))) is True + + # This one in particular requires computing the fixed-point of the + # relevant facts, because going from Q.nonzero(x*y) -> ~Q.zero(x*y) and + # Q.zero(x*y) -> Equivalent(Q.zero(x*y), Q.zero(x) | Q.zero(y)) takes two + # steps. + assert satask(Q.zero(x) | Q.zero(y), Q.nonzero(x*y)) is False + + assert satask(Q.zero(x), Q.zero(x**2)) is True + + +def test_zero_positive(): + assert satask(Q.zero(x + y), Q.positive(x) & Q.positive(y)) is False + assert satask(Q.positive(x) & Q.positive(y), Q.zero(x + y)) is False + assert satask(Q.nonzero(x + y), Q.positive(x) & Q.positive(y)) is True + assert satask(Q.positive(x) & Q.positive(y), Q.nonzero(x + y)) is None + + # This one requires several levels of forward chaining + assert satask(Q.zero(x*(x + y)), Q.positive(x) & Q.positive(y)) is False + + assert satask(Q.positive(pi*x*y + 1), Q.positive(x) & Q.positive(y)) is True + assert satask(Q.positive(pi*x*y - 5), Q.positive(x) & Q.positive(y)) is None + + +def test_zero_pow(): + assert satask(Q.zero(x**y), Q.zero(x) & Q.positive(y)) is True + assert satask(Q.zero(x**y), Q.nonzero(x) & Q.zero(y)) is False + + assert satask(Q.zero(x), Q.zero(x**y)) is True + + assert satask(Q.zero(x**y), Q.zero(x)) is None + + +@XFAIL +# Requires correct Q.square calculation first +def test_invertible(): + A = MatrixSymbol('A', 5, 5) + B = MatrixSymbol('B', 5, 5) + assert satask(Q.invertible(A*B), Q.invertible(A) & Q.invertible(B)) is True + assert satask(Q.invertible(A), Q.invertible(A*B)) is True + assert satask(Q.invertible(A) & Q.invertible(B), Q.invertible(A*B)) is True + + +def test_prime(): + assert satask(Q.prime(5)) is True + assert satask(Q.prime(6)) is False + assert satask(Q.prime(-5)) is False + + assert satask(Q.prime(x*y), Q.integer(x) & Q.integer(y)) is None + assert satask(Q.prime(x*y), Q.prime(x) & Q.prime(y)) is False + + +def test_old_assump(): + assert satask(Q.positive(1)) is True + assert satask(Q.positive(-1)) is False + assert satask(Q.positive(0)) is False + assert satask(Q.positive(I)) is False + assert satask(Q.positive(pi)) is True + + assert satask(Q.negative(1)) is False + assert satask(Q.negative(-1)) is True + assert satask(Q.negative(0)) is False + assert satask(Q.negative(I)) is False + assert satask(Q.negative(pi)) is False + + assert satask(Q.zero(1)) is False + assert satask(Q.zero(-1)) is False + assert satask(Q.zero(0)) is True + assert satask(Q.zero(I)) is False + assert satask(Q.zero(pi)) is False + + assert satask(Q.nonzero(1)) is True + assert satask(Q.nonzero(-1)) is True + assert satask(Q.nonzero(0)) is False + assert satask(Q.nonzero(I)) is False + assert satask(Q.nonzero(pi)) is True + + assert satask(Q.nonpositive(1)) is False + assert satask(Q.nonpositive(-1)) is True + assert satask(Q.nonpositive(0)) is True + assert satask(Q.nonpositive(I)) is False + assert satask(Q.nonpositive(pi)) is False + + assert satask(Q.nonnegative(1)) is True + assert satask(Q.nonnegative(-1)) is False + assert satask(Q.nonnegative(0)) is True + assert satask(Q.nonnegative(I)) is False + assert satask(Q.nonnegative(pi)) is True + + +def test_rational_irrational(): + assert satask(Q.irrational(2)) is False + assert satask(Q.rational(2)) is True + assert satask(Q.irrational(pi)) is True + assert satask(Q.rational(pi)) is False + assert satask(Q.irrational(I)) is False + assert satask(Q.rational(I)) is False + + assert satask(Q.irrational(x*y*z), Q.irrational(x) & Q.irrational(y) & + Q.rational(z)) is None + assert satask(Q.irrational(x*y*z), Q.irrational(x) & Q.rational(y) & + Q.rational(z)) is True + assert satask(Q.irrational(pi*x*y), Q.rational(x) & Q.rational(y)) is True + + assert satask(Q.irrational(x + y + z), Q.irrational(x) & Q.irrational(y) & + Q.rational(z)) is None + assert satask(Q.irrational(x + y + z), Q.irrational(x) & Q.rational(y) & + Q.rational(z)) is True + assert satask(Q.irrational(pi + x + y), Q.rational(x) & Q.rational(y)) is True + + assert satask(Q.irrational(x*y*z), Q.rational(x) & Q.rational(y) & + Q.rational(z)) is False + assert satask(Q.rational(x*y*z), Q.rational(x) & Q.rational(y) & + Q.rational(z)) is True + + assert satask(Q.irrational(x + y + z), Q.rational(x) & Q.rational(y) & + Q.rational(z)) is False + assert satask(Q.rational(x + y + z), Q.rational(x) & Q.rational(y) & + Q.rational(z)) is True + + +def test_even_satask(): + assert satask(Q.even(2)) is True + assert satask(Q.even(3)) is False + + assert satask(Q.even(x*y), Q.even(x) & Q.odd(y)) is True + assert satask(Q.even(x*y), Q.even(x) & Q.integer(y)) is True + assert satask(Q.even(x*y), Q.even(x) & Q.even(y)) is True + assert satask(Q.even(x*y), Q.odd(x) & Q.odd(y)) is False + assert satask(Q.even(x*y), Q.even(x)) is None + assert satask(Q.even(x*y), Q.odd(x) & Q.integer(y)) is None + assert satask(Q.even(x*y), Q.odd(x) & Q.odd(y)) is False + + assert satask(Q.even(abs(x)), Q.even(x)) is True + assert satask(Q.even(abs(x)), Q.odd(x)) is False + assert satask(Q.even(x), Q.even(abs(x))) is None # x could be complex + + +def test_odd_satask(): + assert satask(Q.odd(2)) is False + assert satask(Q.odd(3)) is True + + assert satask(Q.odd(x*y), Q.even(x) & Q.odd(y)) is False + assert satask(Q.odd(x*y), Q.even(x) & Q.integer(y)) is False + assert satask(Q.odd(x*y), Q.even(x) & Q.even(y)) is False + assert satask(Q.odd(x*y), Q.odd(x) & Q.odd(y)) is True + assert satask(Q.odd(x*y), Q.even(x)) is None + assert satask(Q.odd(x*y), Q.odd(x) & Q.integer(y)) is None + assert satask(Q.odd(x*y), Q.odd(x) & Q.odd(y)) is True + + assert satask(Q.odd(abs(x)), Q.even(x)) is False + assert satask(Q.odd(abs(x)), Q.odd(x)) is True + assert satask(Q.odd(x), Q.odd(abs(x))) is None # x could be complex + + +def test_integer(): + assert satask(Q.integer(1)) is True + assert satask(Q.integer(S.Half)) is False + + assert satask(Q.integer(x + y), Q.integer(x) & Q.integer(y)) is True + assert satask(Q.integer(x + y), Q.integer(x)) is None + + assert satask(Q.integer(x + y), Q.integer(x) & ~Q.integer(y)) is False + assert satask(Q.integer(x + y + z), Q.integer(x) & Q.integer(y) & + ~Q.integer(z)) is False + assert satask(Q.integer(x + y + z), Q.integer(x) & ~Q.integer(y) & + ~Q.integer(z)) is None + assert satask(Q.integer(x + y + z), Q.integer(x) & ~Q.integer(y)) is None + assert satask(Q.integer(x + y), Q.integer(x) & Q.irrational(y)) is False + + assert satask(Q.integer(x*y), Q.integer(x) & Q.integer(y)) is True + assert satask(Q.integer(x*y), Q.integer(x)) is None + + assert satask(Q.integer(x*y), Q.integer(x) & ~Q.integer(y)) is None + assert satask(Q.integer(x*y), Q.integer(x) & ~Q.rational(y)) is False + assert satask(Q.integer(x*y*z), Q.integer(x) & Q.integer(y) & + ~Q.rational(z)) is False + assert satask(Q.integer(x*y*z), Q.integer(x) & ~Q.rational(y) & + ~Q.rational(z)) is None + assert satask(Q.integer(x*y*z), Q.integer(x) & ~Q.rational(y)) is None + assert satask(Q.integer(x*y), Q.integer(x) & Q.irrational(y)) is False + + +def test_abs(): + assert satask(Q.nonnegative(abs(x))) is True + assert satask(Q.positive(abs(x)), ~Q.zero(x)) is True + assert satask(Q.zero(x), ~Q.zero(abs(x))) is False + assert satask(Q.zero(x), Q.zero(abs(x))) is True + assert satask(Q.nonzero(x), ~Q.zero(abs(x))) is None # x could be complex + assert satask(Q.zero(abs(x)), Q.zero(x)) is True + + +def test_imaginary(): + assert satask(Q.imaginary(2*I)) is True + assert satask(Q.imaginary(x*y), Q.imaginary(x)) is None + assert satask(Q.imaginary(x*y), Q.imaginary(x) & Q.real(y)) is True + assert satask(Q.imaginary(x), Q.real(x)) is False + assert satask(Q.imaginary(1)) is False + assert satask(Q.imaginary(x*y), Q.real(x) & Q.real(y)) is False + assert satask(Q.imaginary(x + y), Q.real(x) & Q.real(y)) is False + + +def test_real(): + assert satask(Q.real(x*y), Q.real(x) & Q.real(y)) is True + assert satask(Q.real(x + y), Q.real(x) & Q.real(y)) is True + assert satask(Q.real(x*y*z), Q.real(x) & Q.real(y) & Q.real(z)) is True + assert satask(Q.real(x*y*z), Q.real(x) & Q.real(y)) is None + assert satask(Q.real(x*y*z), Q.real(x) & Q.real(y) & Q.imaginary(z)) is False + assert satask(Q.real(x + y + z), Q.real(x) & Q.real(y) & Q.real(z)) is True + assert satask(Q.real(x + y + z), Q.real(x) & Q.real(y)) is None + + +def test_pos_neg(): + assert satask(~Q.positive(x), Q.negative(x)) is True + assert satask(~Q.negative(x), Q.positive(x)) is True + assert satask(Q.positive(x + y), Q.positive(x) & Q.positive(y)) is True + assert satask(Q.negative(x + y), Q.negative(x) & Q.negative(y)) is True + assert satask(Q.positive(x + y), Q.negative(x) & Q.negative(y)) is False + assert satask(Q.negative(x + y), Q.positive(x) & Q.positive(y)) is False + + +def test_pow_pos_neg(): + assert satask(Q.nonnegative(x**2), Q.positive(x)) is True + assert satask(Q.nonpositive(x**2), Q.positive(x)) is False + assert satask(Q.positive(x**2), Q.positive(x)) is True + assert satask(Q.negative(x**2), Q.positive(x)) is False + assert satask(Q.real(x**2), Q.positive(x)) is True + + assert satask(Q.nonnegative(x**2), Q.negative(x)) is True + assert satask(Q.nonpositive(x**2), Q.negative(x)) is False + assert satask(Q.positive(x**2), Q.negative(x)) is True + assert satask(Q.negative(x**2), Q.negative(x)) is False + assert satask(Q.real(x**2), Q.negative(x)) is True + + assert satask(Q.nonnegative(x**2), Q.nonnegative(x)) is True + assert satask(Q.nonpositive(x**2), Q.nonnegative(x)) is None + assert satask(Q.positive(x**2), Q.nonnegative(x)) is None + assert satask(Q.negative(x**2), Q.nonnegative(x)) is False + assert satask(Q.real(x**2), Q.nonnegative(x)) is True + + assert satask(Q.nonnegative(x**2), Q.nonpositive(x)) is True + assert satask(Q.nonpositive(x**2), Q.nonpositive(x)) is None + assert satask(Q.positive(x**2), Q.nonpositive(x)) is None + assert satask(Q.negative(x**2), Q.nonpositive(x)) is False + assert satask(Q.real(x**2), Q.nonpositive(x)) is True + + assert satask(Q.nonnegative(x**3), Q.positive(x)) is True + assert satask(Q.nonpositive(x**3), Q.positive(x)) is False + assert satask(Q.positive(x**3), Q.positive(x)) is True + assert satask(Q.negative(x**3), Q.positive(x)) is False + assert satask(Q.real(x**3), Q.positive(x)) is True + + assert satask(Q.nonnegative(x**3), Q.negative(x)) is False + assert satask(Q.nonpositive(x**3), Q.negative(x)) is True + assert satask(Q.positive(x**3), Q.negative(x)) is False + assert satask(Q.negative(x**3), Q.negative(x)) is True + assert satask(Q.real(x**3), Q.negative(x)) is True + + assert satask(Q.nonnegative(x**3), Q.nonnegative(x)) is True + assert satask(Q.nonpositive(x**3), Q.nonnegative(x)) is None + assert satask(Q.positive(x**3), Q.nonnegative(x)) is None + assert satask(Q.negative(x**3), Q.nonnegative(x)) is False + assert satask(Q.real(x**3), Q.nonnegative(x)) is True + + assert satask(Q.nonnegative(x**3), Q.nonpositive(x)) is None + assert satask(Q.nonpositive(x**3), Q.nonpositive(x)) is True + assert satask(Q.positive(x**3), Q.nonpositive(x)) is False + assert satask(Q.negative(x**3), Q.nonpositive(x)) is None + assert satask(Q.real(x**3), Q.nonpositive(x)) is True + + # If x is zero, x**negative is not real. + assert satask(Q.nonnegative(x**-2), Q.nonpositive(x)) is None + assert satask(Q.nonpositive(x**-2), Q.nonpositive(x)) is None + assert satask(Q.positive(x**-2), Q.nonpositive(x)) is None + assert satask(Q.negative(x**-2), Q.nonpositive(x)) is None + assert satask(Q.real(x**-2), Q.nonpositive(x)) is None + + # We could deduce things for negative powers if x is nonzero, but it + # isn't implemented yet. + + +def test_prime_composite(): + assert satask(Q.prime(x), Q.composite(x)) is False + assert satask(Q.composite(x), Q.prime(x)) is False + assert satask(Q.composite(x), ~Q.prime(x)) is None + assert satask(Q.prime(x), ~Q.composite(x)) is None + # since 1 is neither prime nor composite the following should hold + assert satask(Q.prime(x), Q.integer(x) & Q.positive(x) & ~Q.composite(x)) is None + assert satask(Q.prime(2)) is True + assert satask(Q.prime(4)) is False + assert satask(Q.prime(1)) is False + assert satask(Q.composite(1)) is False + + +def test_extract_predargs(): + props = CNF.from_prop(Q.zero(Abs(x*y)) & Q.zero(x*y)) + assump = CNF.from_prop(Q.zero(x)) + context = CNF.from_prop(Q.zero(y)) + assert extract_predargs(props) == {Abs(x*y), x*y} + assert extract_predargs(props, assump) == {Abs(x*y), x*y, x} + assert extract_predargs(props, assump, context) == {Abs(x*y), x*y, x, y} + + props = CNF.from_prop(Eq(x, y)) + assump = CNF.from_prop(Gt(y, z)) + assert extract_predargs(props, assump) == {x, y, z} + + +def test_get_relevant_clsfacts(): + exprs = {Abs(x*y)} + exprs, facts = get_relevant_clsfacts(exprs) + assert exprs == {x*y} + assert facts.clauses == \ + {frozenset({Literal(Q.odd(Abs(x*y)), False), Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), False), Literal(Q.zero(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), Literal(Q.even(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), True), Literal(Q.zero(x*y), False)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.odd(Abs(x*y)), False), + Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.even(x*y), True), + Literal(Q.odd(Abs(x*y)), False)}), + frozenset({Literal(Q.positive(Abs(x*y)), False), + Literal(Q.zero(Abs(x*y)), False)})} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_sathandlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_sathandlers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d568ad8efe6ba7cf7f5eb03879ad6764c16e729 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_sathandlers.py @@ -0,0 +1,50 @@ +from sympy.assumptions.ask import Q +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.symbol import symbols +from sympy.logic.boolalg import (And, Or) + +from sympy.assumptions.sathandlers import (ClassFactRegistry, allargs, + anyarg, exactlyonearg,) + +x, y, z = symbols('x y z') + + +def test_class_handler_registry(): + my_handler_registry = ClassFactRegistry() + + # The predicate doesn't matter here, so just pass + @my_handler_registry.register(Mul) + def fact1(expr): + pass + @my_handler_registry.multiregister(Expr) + def fact2(expr): + pass + + assert my_handler_registry[Basic] == (frozenset(), frozenset()) + assert my_handler_registry[Expr] == (frozenset(), frozenset({fact2})) + assert my_handler_registry[Mul] == (frozenset({fact1}), frozenset({fact2})) + + +def test_allargs(): + assert allargs(x, Q.zero(x), x*y) == And(Q.zero(x), Q.zero(y)) + assert allargs(x, Q.positive(x) | Q.negative(x), x*y) == And(Q.positive(x) | Q.negative(x), Q.positive(y) | Q.negative(y)) + + +def test_anyarg(): + assert anyarg(x, Q.zero(x), x*y) == Or(Q.zero(x), Q.zero(y)) + assert anyarg(x, Q.positive(x) & Q.negative(x), x*y) == \ + Or(Q.positive(x) & Q.negative(x), Q.positive(y) & Q.negative(y)) + + +def test_exactlyonearg(): + assert exactlyonearg(x, Q.zero(x), x*y) == \ + Or(Q.zero(x) & ~Q.zero(y), Q.zero(y) & ~Q.zero(x)) + assert exactlyonearg(x, Q.zero(x), x*y*z) == \ + Or(Q.zero(x) & ~Q.zero(y) & ~Q.zero(z), Q.zero(y) + & ~Q.zero(x) & ~Q.zero(z), Q.zero(z) & ~Q.zero(x) & ~Q.zero(y)) + assert exactlyonearg(x, Q.positive(x) | Q.negative(x), x*y) == \ + Or((Q.positive(x) | Q.negative(x)) & + ~(Q.positive(y) | Q.negative(y)), (Q.positive(y) | Q.negative(y)) & + ~(Q.positive(x) | Q.negative(x))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..af9afd5d51fb1341e0b08149dc842b78a39c329b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/tests/test_wrapper.py @@ -0,0 +1,39 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.wrapper import (AssumptionsWrapper, is_infinite, + is_extended_real) +from sympy.core.symbol import Symbol +from sympy.core.assumptions import _assume_defined + + +def test_all_predicates(): + for fact in _assume_defined: + method_name = f'_eval_is_{fact}' + assert hasattr(AssumptionsWrapper, method_name) + + +def test_AssumptionsWrapper(): + x = Symbol('x', positive=True) + y = Symbol('y') + assert AssumptionsWrapper(x).is_positive + assert AssumptionsWrapper(y).is_positive is None + assert AssumptionsWrapper(y, Q.positive(y)).is_positive + + +def test_is_infinite(): + x = Symbol('x', infinite=True) + y = Symbol('y', infinite=False) + z = Symbol('z') + assert is_infinite(x) + assert not is_infinite(y) + assert is_infinite(z) is None + assert is_infinite(z, Q.infinite(z)) + + +def test_is_extended_real(): + x = Symbol('x', extended_real=True) + y = Symbol('y', extended_real=False) + z = Symbol('z') + assert is_extended_real(x) + assert not is_extended_real(y) + assert is_extended_real(z) is None + assert is_extended_real(z, Q.extended_real(z)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cb06e9de770ed41a2b3d6fe63381ad1cb59acacc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/assumptions/wrapper.py @@ -0,0 +1,164 @@ +""" +Functions and wrapper object to call assumption property and predicate +query with same syntax. + +In SymPy, there are two assumption systems. Old assumption system is +defined in sympy/core/assumptions, and it can be accessed by attribute +such as ``x.is_even``. New assumption system is defined in +sympy/assumptions, and it can be accessed by predicates such as +``Q.even(x)``. + +Old assumption is fast, while new assumptions can freely take local facts. +In general, old assumption is used in evaluation method and new assumption +is used in refinement method. + +In most cases, both evaluation and refinement follow the same process, and +the only difference is which assumption system is used. This module provides +``is_[...]()`` functions and ``AssumptionsWrapper()`` class which allows +using two systems with same syntax so that parallel code implementation can be +avoided. + +Examples +======== + +For multiple use, use ``AssumptionsWrapper()``. + +>>> from sympy import Q, Symbol +>>> from sympy.assumptions.wrapper import AssumptionsWrapper +>>> x = Symbol('x') +>>> _x = AssumptionsWrapper(x, Q.even(x)) +>>> _x.is_integer +True +>>> _x.is_odd +False + +For single use, use ``is_[...]()`` functions. + +>>> from sympy.assumptions.wrapper import is_infinite +>>> a = Symbol('a') +>>> print(is_infinite(a)) +None +>>> is_infinite(a, Q.finite(a)) +False + +""" + +from sympy.assumptions import ask, Q +from sympy.core.basic import Basic +from sympy.core.sympify import _sympify + + +def make_eval_method(fact): + def getit(self): + pred = getattr(Q, fact) + ret = ask(pred(self.expr), self.assumptions) + return ret + return getit + + +# we subclass Basic to use the fact deduction and caching +class AssumptionsWrapper(Basic): + """ + Wrapper over ``Basic`` instances to call predicate query by + ``.is_[...]`` property + + Parameters + ========== + + expr : Basic + + assumptions : Boolean, optional + + Examples + ======== + + >>> from sympy import Q, Symbol + >>> from sympy.assumptions.wrapper import AssumptionsWrapper + >>> x = Symbol('x', even=True) + >>> AssumptionsWrapper(x).is_integer + True + >>> y = Symbol('y') + >>> AssumptionsWrapper(y, Q.even(y)).is_integer + True + + With ``AssumptionsWrapper``, both evaluation and refinement can be supported + by single implementation. + + >>> from sympy import Function + >>> class MyAbs(Function): + ... @classmethod + ... def eval(cls, x, assumptions=True): + ... _x = AssumptionsWrapper(x, assumptions) + ... if _x.is_nonnegative: + ... return x + ... if _x.is_negative: + ... return -x + ... def _eval_refine(self, assumptions): + ... return MyAbs.eval(self.args[0], assumptions) + >>> MyAbs(x) + MyAbs(x) + >>> MyAbs(x).refine(Q.positive(x)) + x + >>> MyAbs(Symbol('y', negative=True)) + -y + + """ + def __new__(cls, expr, assumptions=None): + if assumptions is None: + return expr + obj = super().__new__(cls, expr, _sympify(assumptions)) + obj.expr = expr + obj.assumptions = assumptions + return obj + + _eval_is_algebraic = make_eval_method("algebraic") + _eval_is_antihermitian = make_eval_method("antihermitian") + _eval_is_commutative = make_eval_method("commutative") + _eval_is_complex = make_eval_method("complex") + _eval_is_composite = make_eval_method("composite") + _eval_is_even = make_eval_method("even") + _eval_is_extended_negative = make_eval_method("extended_negative") + _eval_is_extended_nonnegative = make_eval_method("extended_nonnegative") + _eval_is_extended_nonpositive = make_eval_method("extended_nonpositive") + _eval_is_extended_nonzero = make_eval_method("extended_nonzero") + _eval_is_extended_positive = make_eval_method("extended_positive") + _eval_is_extended_real = make_eval_method("extended_real") + _eval_is_finite = make_eval_method("finite") + _eval_is_hermitian = make_eval_method("hermitian") + _eval_is_imaginary = make_eval_method("imaginary") + _eval_is_infinite = make_eval_method("infinite") + _eval_is_integer = make_eval_method("integer") + _eval_is_irrational = make_eval_method("irrational") + _eval_is_negative = make_eval_method("negative") + _eval_is_noninteger = make_eval_method("noninteger") + _eval_is_nonnegative = make_eval_method("nonnegative") + _eval_is_nonpositive = make_eval_method("nonpositive") + _eval_is_nonzero = make_eval_method("nonzero") + _eval_is_odd = make_eval_method("odd") + _eval_is_polar = make_eval_method("polar") + _eval_is_positive = make_eval_method("positive") + _eval_is_prime = make_eval_method("prime") + _eval_is_rational = make_eval_method("rational") + _eval_is_real = make_eval_method("real") + _eval_is_transcendental = make_eval_method("transcendental") + _eval_is_zero = make_eval_method("zero") + + +# one shot functions which are faster than AssumptionsWrapper + +def is_infinite(obj, assumptions=None): + if assumptions is None: + return obj.is_infinite + return ask(Q.infinite(obj), assumptions) + + +def is_extended_real(obj, assumptions=None): + if assumptions is None: + return obj.is_extended_real + return ask(Q.extended_real(obj), assumptions) + + +def is_extended_nonnegative(obj, assumptions=None): + if assumptions is None: + return obj.is_extended_nonnegative + return ask(Q.extended_nonnegative(obj), assumptions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e880e82c4788c9d175304a0ab2ece06d9098d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_discrete_log.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_discrete_log.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ea1ff35d5d6ea54ba41606f5191a4fcebd538d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_discrete_log.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_meijerint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_meijerint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3c22bbe8b2a7dbb9856b24c5c00cc021b686d60 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_meijerint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_symbench.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_symbench.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da60aa5aa4b94539344b6528bc33a00841e5161 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/__pycache__/bench_symbench.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_discrete_log.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_discrete_log.py new file mode 100644 index 0000000000000000000000000000000000000000..76b273909e415318a7d3bace00ffff2a0bc53762 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_discrete_log.py @@ -0,0 +1,83 @@ +import sys +from time import time +from sympy.ntheory.residue_ntheory import (discrete_log, + _discrete_log_trial_mul, _discrete_log_shanks_steps, + _discrete_log_pollard_rho, _discrete_log_pohlig_hellman) + + +# Cyclic group (Z/pZ)* with p prime, order p - 1 and generator g +data_set_1 = [ + # p, p - 1, g + [191, 190, 19], + [46639, 46638, 6], + [14789363, 14789362, 2], + [4254225211, 4254225210, 2], + [432751500361, 432751500360, 7], + [158505390797053, 158505390797052, 2], + [6575202655312007, 6575202655312006, 5], + [8430573471995353769, 8430573471995353768, 3], + [3938471339744997827267, 3938471339744997827266, 2], + [875260951364705563393093, 875260951364705563393092, 5], + ] + + +# Cyclic sub-groups of (Z/nZ)* with prime order p and generator g +# (n, p are primes and n = 2 * p + 1) +data_set_2 = [ + # n, p, g + [227, 113, 3], + [2447, 1223, 2], + [24527, 12263, 2], + [245639, 122819, 2], + [2456747, 1228373, 3], + [24567899, 12283949, 3], + [245679023, 122839511, 2], + [2456791307, 1228395653, 3], + [24567913439, 12283956719, 2], + [245679135407, 122839567703, 2], + [2456791354763, 1228395677381, 3], + [24567913550903, 12283956775451, 2], + [245679135509519, 122839567754759, 2], + ] + + +# Cyclic sub-groups of (Z/nZ)* with smooth order o and generator g +data_set_3 = [ + # n, o, g + [2**118, 2**116, 3], + ] + + +def bench_discrete_log(data_set, algo=None): + if algo is None: + f = discrete_log + elif algo == 'trial': + f = _discrete_log_trial_mul + elif algo == 'shanks': + f = _discrete_log_shanks_steps + elif algo == 'rho': + f = _discrete_log_pollard_rho + elif algo == 'ph': + f = _discrete_log_pohlig_hellman + else: + raise ValueError("Argument 'algo' should be one" + " of ('trial', 'shanks', 'rho' or 'ph')") + + for i, data in enumerate(data_set): + for j, (n, p, g) in enumerate(data): + t = time() + l = f(n, pow(g, p - 1, n), g, p) + t = time() - t + print('[%02d-%03d] %15.10f' % (i, j, t)) + assert l == p - 1 + + +if __name__ == '__main__': + algo = sys.argv[1] \ + if len(sys.argv) > 1 else None + data_set = [ + data_set_1, + data_set_2, + data_set_3, + ] + bench_discrete_log(data_set, algo) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_meijerint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_meijerint.py new file mode 100644 index 0000000000000000000000000000000000000000..d648c3e02463d5a7ee1dcbe3b22af5cc22fef43d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_meijerint.py @@ -0,0 +1,261 @@ +# conceal the implicit import from the code quality tester +from sympy.core.numbers import (oo, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.bessel import besseli +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import integrate +from sympy.integrals.transforms import (mellin_transform, + inverse_fourier_transform, inverse_mellin_transform, + laplace_transform, inverse_laplace_transform, fourier_transform) + +LT = laplace_transform +FT = fourier_transform +MT = mellin_transform +IFT = inverse_fourier_transform +ILT = inverse_laplace_transform +IMT = inverse_mellin_transform + +from sympy.abc import x, y +nu, beta, rho = symbols('nu beta rho') + +apos, bpos, cpos, dpos, posk, p = symbols('a b c d k p', positive=True) +k = Symbol('k', real=True) +negk = Symbol('k', negative=True) + +mu1, mu2 = symbols('mu1 mu2', real=True, nonzero=True, finite=True) +sigma1, sigma2 = symbols('sigma1 sigma2', real=True, nonzero=True, + finite=True, positive=True) +rate = Symbol('lambda', positive=True) + + +def normal(x, mu, sigma): + return 1/sqrt(2*pi*sigma**2)*exp(-(x - mu)**2/2/sigma**2) + + +def exponential(x, rate): + return rate*exp(-rate*x) +alpha, beta = symbols('alpha beta', positive=True) +betadist = x**(alpha - 1)*(1 + x)**(-alpha - beta)*gamma(alpha + beta) \ + /gamma(alpha)/gamma(beta) +kint = Symbol('k', integer=True, positive=True) +chi = 2**(1 - kint/2)*x**(kint - 1)*exp(-x**2/2)/gamma(kint/2) +chisquared = 2**(-k/2)/gamma(k/2)*x**(k/2 - 1)*exp(-x/2) +dagum = apos*p/x*(x/bpos)**(apos*p)/(1 + x**apos/bpos**apos)**(p + 1) +d1, d2 = symbols('d1 d2', positive=True) +f = sqrt(((d1*x)**d1 * d2**d2)/(d1*x + d2)**(d1 + d2))/x \ + /gamma(d1/2)/gamma(d2/2)*gamma((d1 + d2)/2) +nupos, sigmapos = symbols('nu sigma', positive=True) +rice = x/sigmapos**2*exp(-(x**2 + nupos**2)/2/sigmapos**2)*besseli(0, x* + nupos/sigmapos**2) +mu = Symbol('mu', real=True) +laplace = exp(-abs(x - mu)/bpos)/2/bpos + +u = Symbol('u', polar=True) +tpos = Symbol('t', positive=True) + + +def E(expr): + integrate(expr*exponential(x, rate)*normal(y, mu1, sigma1), + (x, 0, oo), (y, -oo, oo), meijerg=True) + integrate(expr*exponential(x, rate)*normal(y, mu1, sigma1), + (y, -oo, oo), (x, 0, oo), meijerg=True) + +bench = [ + 'MT(x**nu*Heaviside(x - 1), x, s)', + 'MT(x**nu*Heaviside(1 - x), x, s)', + 'MT((1-x)**(beta - 1)*Heaviside(1-x), x, s)', + 'MT((x-1)**(beta - 1)*Heaviside(x-1), x, s)', + 'MT((1+x)**(-rho), x, s)', + 'MT(abs(1-x)**(-rho), x, s)', + 'MT((1-x)**(beta-1)*Heaviside(1-x) + a*(x-1)**(beta-1)*Heaviside(x-1), x, s)', + 'MT((x**a-b**a)/(x-b), x, s)', + 'MT((x**a-bpos**a)/(x-bpos), x, s)', + 'MT(exp(-x), x, s)', + 'MT(exp(-1/x), x, s)', + 'MT(log(x)**4*Heaviside(1-x), x, s)', + 'MT(log(x)**3*Heaviside(x-1), x, s)', + 'MT(log(x + 1), x, s)', + 'MT(log(1/x + 1), x, s)', + 'MT(log(abs(1 - x)), x, s)', + 'MT(log(abs(1 - 1/x)), x, s)', + 'MT(log(x)/(x+1), x, s)', + 'MT(log(x)**2/(x+1), x, s)', + 'MT(log(x)/(x+1)**2, x, s)', + 'MT(erf(sqrt(x)), x, s)', + + 'MT(besselj(a, 2*sqrt(x)), x, s)', + 'MT(sin(sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(cos(sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))**2, x, s)', + 'MT(besselj(a, sqrt(x))*besselj(-a, sqrt(x)), x, s)', + 'MT(besselj(a - 1, sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*besselj(b, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))**2 + besselj(-a, sqrt(x))**2, x, s)', + 'MT(bessely(a, 2*sqrt(x)), x, s)', + 'MT(sin(sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(cos(sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*bessely(b, sqrt(x)), x, s)', + 'MT(bessely(a, sqrt(x))**2, x, s)', + + 'MT(besselk(a, 2*sqrt(x)), x, s)', + 'MT(besselj(a, 2*sqrt(2*sqrt(x)))*besselk(a, 2*sqrt(2*sqrt(x))), x, s)', + 'MT(besseli(a, sqrt(x))*besselk(a, sqrt(x)), x, s)', + 'MT(besseli(b, sqrt(x))*besselk(a, sqrt(x)), x, s)', + 'MT(exp(-x/2)*besselk(a, x/2), x, s)', + + # later: ILT, IMT + + 'LT((t-apos)**bpos*exp(-cpos*(t-apos))*Heaviside(t-apos), t, s)', + 'LT(t**apos, t, s)', + 'LT(Heaviside(t), t, s)', + 'LT(Heaviside(t - apos), t, s)', + 'LT(1 - exp(-apos*t), t, s)', + 'LT((exp(2*t)-1)*exp(-bpos - t)*Heaviside(t)/2, t, s, noconds=True)', + 'LT(exp(t), t, s)', + 'LT(exp(2*t), t, s)', + 'LT(exp(apos*t), t, s)', + 'LT(log(t/apos), t, s)', + 'LT(erf(t), t, s)', + 'LT(sin(apos*t), t, s)', + 'LT(cos(apos*t), t, s)', + 'LT(exp(-apos*t)*sin(bpos*t), t, s)', + 'LT(exp(-apos*t)*cos(bpos*t), t, s)', + 'LT(besselj(0, t), t, s, noconds=True)', + 'LT(besselj(1, t), t, s, noconds=True)', + + 'FT(Heaviside(1 - abs(2*apos*x)), x, k)', + 'FT(Heaviside(1-abs(apos*x))*(1-abs(apos*x)), x, k)', + 'FT(exp(-apos*x)*Heaviside(x), x, k)', + 'IFT(1/(apos + 2*pi*I*x), x, posk, noconds=False)', + 'IFT(1/(apos + 2*pi*I*x), x, -posk, noconds=False)', + 'IFT(1/(apos + 2*pi*I*x), x, negk)', + 'FT(x*exp(-apos*x)*Heaviside(x), x, k)', + 'FT(exp(-apos*x)*sin(bpos*x)*Heaviside(x), x, k)', + 'FT(exp(-apos*x**2), x, k)', + 'IFT(sqrt(pi/apos)*exp(-(pi*k)**2/apos), k, x)', + 'FT(exp(-apos*abs(x)), x, k)', + + 'integrate(normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x**2*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x**3*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(y*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x*y*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate((x+y+1)*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate((x+y-1)*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x**2*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(y**2*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(exponential(x, rate), (x, 0, oo), meijerg=True)', + 'integrate(x*exponential(x, rate), (x, 0, oo), meijerg=True)', + 'integrate(x**2*exponential(x, rate), (x, 0, oo), meijerg=True)', + 'E(1)', + 'E(x*y)', + 'E(x*y**2)', + 'E((x+y+1)**2)', + 'E(x+y+1)', + 'E((x+y-1)**2)', + 'integrate(betadist, (x, 0, oo), meijerg=True)', + 'integrate(x*betadist, (x, 0, oo), meijerg=True)', + 'integrate(x**2*betadist, (x, 0, oo), meijerg=True)', + 'integrate(chi, (x, 0, oo), meijerg=True)', + 'integrate(x*chi, (x, 0, oo), meijerg=True)', + 'integrate(x**2*chi, (x, 0, oo), meijerg=True)', + 'integrate(chisquared, (x, 0, oo), meijerg=True)', + 'integrate(x*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(x**2*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(((x-k)/sqrt(2*k))**3*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(dagum, (x, 0, oo), meijerg=True)', + 'integrate(x*dagum, (x, 0, oo), meijerg=True)', + 'integrate(x**2*dagum, (x, 0, oo), meijerg=True)', + 'integrate(f, (x, 0, oo), meijerg=True)', + 'integrate(x*f, (x, 0, oo), meijerg=True)', + 'integrate(x**2*f, (x, 0, oo), meijerg=True)', + 'integrate(rice, (x, 0, oo), meijerg=True)', + 'integrate(laplace, (x, -oo, oo), meijerg=True)', + 'integrate(x*laplace, (x, -oo, oo), meijerg=True)', + 'integrate(x**2*laplace, (x, -oo, oo), meijerg=True)', + 'integrate(log(x) * x**(k-1) * exp(-x) / gamma(k), (x, 0, oo))', + + 'integrate(sin(z*x)*(x**2-1)**(-(y+S(1)/2)), (x, 1, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*exp(-x**2), (x, 0, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*besselk(0,x), (x, 0, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*exp(-x**2), (x, 0, oo), meijerg=True)', + 'integrate(besselj(a,x)*besselj(b,x)/x, (x,0,oo), meijerg=True)', + + 'hyperexpand(meijerg((-s - a/2 + 1, -s + a/2 + 1), (-a/2 - S(1)/2, -s + a/2 + S(3)/2), (a/2, -a/2), (-a/2 - S(1)/2, -s + a/2 + S(3)/2), 1))', + "gammasimp(S('2**(2*s)*(-pi*gamma(-a + 1)*gamma(a + 1)*gamma(-a - s + 1)*gamma(-a + s - 1/2)*gamma(a - s + 3/2)*gamma(a + s + 1)/(a*(a + s)) - gamma(-a - 1/2)*gamma(-a + 1)*gamma(a + 1)*gamma(a + 3/2)*gamma(-s + 3/2)*gamma(s - 1/2)*gamma(-a + s + 1)*gamma(a - s + 1)/(a*(-a + s)))*gamma(-2*s + 1)*gamma(s + 1)/(pi*s*gamma(-a - 1/2)*gamma(a + 3/2)*gamma(-s + 1)*gamma(-s + 3/2)*gamma(s - 1/2)*gamma(-a - s + 1)*gamma(-a + s - 1/2)*gamma(a - s + 1)*gamma(a - s + 3/2))'))", + + 'mellin_transform(E1(x), x, s)', + 'inverse_mellin_transform(gamma(s)/s, s, x, (0, oo))', + 'mellin_transform(expint(a, x), x, s)', + 'mellin_transform(Si(x), x, s)', + 'inverse_mellin_transform(-2**s*sqrt(pi)*gamma((s + 1)/2)/(2*s*gamma(-s/2 + 1)), s, x, (-1, 0))', + 'mellin_transform(Ci(sqrt(x)), x, s)', + 'inverse_mellin_transform(-4**s*sqrt(pi)*gamma(s)/(2*s*gamma(-s + S(1)/2)),s, u, (0, 1))', + 'laplace_transform(Ci(x), x, s)', + 'laplace_transform(expint(a, x), x, s)', + 'laplace_transform(expint(1, x), x, s)', + 'laplace_transform(expint(2, x), x, s)', + 'inverse_laplace_transform(-log(1 + s**2)/2/s, s, u)', + 'inverse_laplace_transform(log(s + 1)/s, s, x)', + 'inverse_laplace_transform((s - log(s + 1))/s**2, s, x)', + 'laplace_transform(Chi(x), x, s)', + 'laplace_transform(Shi(x), x, s)', + + 'integrate(exp(-z*x)/x, (x, 1, oo), meijerg=True, conds="none")', + 'integrate(exp(-z*x)/x**2, (x, 1, oo), meijerg=True, conds="none")', + 'integrate(exp(-z*x)/x**3, (x, 1, oo), meijerg=True,conds="none")', + 'integrate(-cos(x)/x, (x, tpos, oo), meijerg=True)', + 'integrate(-sin(x)/x, (x, tpos, oo), meijerg=True)', + 'integrate(sin(x)/x, (x, 0, z), meijerg=True)', + 'integrate(sinh(x)/x, (x, 0, z), meijerg=True)', + 'integrate(exp(-x)/x, x, meijerg=True)', + 'integrate(exp(-x)/x**2, x, meijerg=True)', + 'integrate(cos(u)/u, u, meijerg=True)', + 'integrate(cosh(u)/u, u, meijerg=True)', + 'integrate(expint(1, x), x, meijerg=True)', + 'integrate(expint(2, x), x, meijerg=True)', + 'integrate(Si(x), x, meijerg=True)', + 'integrate(Ci(u), u, meijerg=True)', + 'integrate(Shi(x), x, meijerg=True)', + 'integrate(Chi(u), u, meijerg=True)', + 'integrate(Si(x)*exp(-x), (x, 0, oo), meijerg=True)', + 'integrate(expint(1, x)*sin(x), (x, 0, oo), meijerg=True)' +] + +from time import time +from sympy.core.cache import clear_cache +import sys + +timings = [] + +if __name__ == '__main__': + for n, string in enumerate(bench): + clear_cache() + _t = time() + exec(string) + _t = time() - _t + timings += [(_t, string)] + sys.stdout.write('.') + sys.stdout.flush() + if n % (len(bench) // 10) == 0: + sys.stdout.write('%s' % (10*n // len(bench))) + print() + + timings.sort(key=lambda x: -x[0]) + + for ti, string in timings: + print('%.2fs %s' % (ti, string)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_symbench.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_symbench.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea700b44b677107f5345196a8895e8ed5a9d56d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/benchmarks/bench_symbench.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +from sympy.core.random import random +from sympy.core.numbers import (I, Integer, pi) +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import factor +from sympy.simplify.simplify import simplify +from sympy.abc import x, y, z +from timeit import default_timer as clock + + +def bench_R1(): + "real(f(f(f(f(f(f(f(f(f(f(i/2)))))))))))" + def f(z): + return sqrt(Integer(1)/3)*z**2 + I/3 + f(f(f(f(f(f(f(f(f(f(I/2)))))))))).as_real_imag()[0] + + +def bench_R2(): + "Hermite polynomial hermite(15, y)" + def hermite(n, y): + if n == 1: + return 2*y + if n == 0: + return 1 + return (2*y*hermite(n - 1, y) - 2*(n - 1)*hermite(n - 2, y)).expand() + + hermite(15, y) + + +def bench_R3(): + "a = [bool(f==f) for _ in range(10)]" + f = x + y + z + [bool(f == f) for _ in range(10)] + + +def bench_R4(): + # we don't have Tuples + pass + + +def bench_R5(): + "blowup(L, 8); L=uniq(L)" + def blowup(L, n): + for i in range(n): + L.append( (L[i] + L[i + 1]) * L[i + 2] ) + + def uniq(x): + v = set(x) + return v + L = [x, y, z] + blowup(L, 8) + L = uniq(L) + + +def bench_R6(): + "sum(simplify((x+sin(i))/x+(x-sin(i))/x) for i in range(100))" + sum(simplify((x + sin(i))/x + (x - sin(i))/x) for i in range(100)) + + +def bench_R7(): + "[f.subs(x, random()) for _ in range(10**4)]" + f = x**24 + 34*x**12 + 45*x**3 + 9*x**18 + 34*x**10 + 32*x**21 + [f.subs(x, random()) for _ in range(10**4)] + + +def bench_R8(): + "right(x^2,0,5,10^4)" + def right(f, a, b, n): + a = sympify(a) + b = sympify(b) + n = sympify(n) + x = f.atoms(Symbol).pop() + Deltax = (b - a)/n + c = a + est = 0 + for i in range(n): + c += Deltax + est += f.subs(x, c) + return est*Deltax + + right(x**2, 0, 5, 10**4) + + +def _bench_R9(): + "factor(x^20 - pi^5*y^20)" + factor(x**20 - pi**5*y**20) + + +def bench_R10(): + "v = [-pi,-pi+1/10..,pi]" + def srange(min, max, step): + v = [min] + while (max - v[-1]).evalf() > 0: + v.append(v[-1] + step) + return v[:-1] + srange(-pi, pi, sympify(1)/10) + + +def bench_R11(): + "a = [random() + random()*I for w in [0..1000]]" + [random() + random()*I for w in range(1000)] + + +def bench_S1(): + "e=(x+y+z+1)**7;f=e*(e+1);f.expand()" + e = (x + y + z + 1)**7 + f = e*(e + 1) + f.expand() + + +if __name__ == '__main__': + benchmarks = [ + bench_R1, + bench_R2, + bench_R3, + bench_R5, + bench_R6, + bench_R7, + bench_R8, + #_bench_R9, + bench_R10, + bench_R11, + #bench_S1, + ] + + report = [] + for b in benchmarks: + t = clock() + b() + t = clock() - t + print("%s%65s: %f" % (b.__name__, b.__doc__, t)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5007308a1b232e57f9ed164276862df0c5f265 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__init__.py @@ -0,0 +1,33 @@ +""" +Category Theory module. + +Provides some of the fundamental category-theory-related classes, +including categories, morphisms, diagrams. Functors are not +implemented yet. + +The general reference work this module tries to follow is + + [JoyOfCats] J. Adamek, H. Herrlich. G. E. Strecker: Abstract and + Concrete Categories. The Joy of Cats. + +The latest version of this book should be available for free download +from + + katmat.math.uni-bremen.de/acc/acc.pdf + +""" + +from .baseclasses import (Object, Morphism, IdentityMorphism, + NamedMorphism, CompositeMorphism, Category, + Diagram) + +from .diagram_drawing import (DiagramGrid, XypicDiagramDrawer, + xypic_draw_diagram, preview_diagram) + +__all__ = [ + 'Object', 'Morphism', 'IdentityMorphism', 'NamedMorphism', + 'CompositeMorphism', 'Category', 'Diagram', + + 'DiagramGrid', 'XypicDiagramDrawer', 'xypic_draw_diagram', + 'preview_diagram', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba6e4caff98c2d4fdb2cfa8640f56d09a3b160cf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/baseclasses.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/baseclasses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c748ff284a40f72506ee3ca5da0dd3b6d0068c05 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/baseclasses.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835988a6563d494f739131811561f70a1e14cd95 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/baseclasses.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/baseclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ab5153ae4e95f193030864c8f32a52254f2458 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/baseclasses.py @@ -0,0 +1,978 @@ +from sympy.core import S, Basic, Dict, Symbol, Tuple, sympify +from sympy.core.symbol import Str +from sympy.sets import Set, FiniteSet, EmptySet +from sympy.utilities.iterables import iterable + + +class Class(Set): + r""" + The base class for any kind of class in the set-theoretic sense. + + Explanation + =========== + + In axiomatic set theories, everything is a class. A class which + can be a member of another class is a set. A class which is not a + member of another class is a proper class. The class `\{1, 2\}` + is a set; the class of all sets is a proper class. + + This class is essentially a synonym for :class:`sympy.core.Set`. + The goal of this class is to assure easier migration to the + eventual proper implementation of set theory. + """ + is_proper = False + + +class Object(Symbol): + """ + The base class for any kind of object in an abstract category. + + Explanation + =========== + + While technically any instance of :class:`~.Basic` will do, this + class is the recommended way to create abstract objects in + abstract categories. + """ + + +class Morphism(Basic): + """ + The base class for any morphism in an abstract category. + + Explanation + =========== + + In abstract categories, a morphism is an arrow between two + category objects. The object where the arrow starts is called the + domain, while the object where the arrow ends is called the + codomain. + + Two morphisms between the same pair of objects are considered to + be the same morphisms. To distinguish between morphisms between + the same objects use :class:`NamedMorphism`. + + It is prohibited to instantiate this class. Use one of the + derived classes instead. + + See Also + ======== + + IdentityMorphism, NamedMorphism, CompositeMorphism + """ + def __new__(cls, domain, codomain): + raise(NotImplementedError( + "Cannot instantiate Morphism. Use derived classes instead.")) + + @property + def domain(self): + """ + Returns the domain of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.domain + Object("A") + + """ + return self.args[0] + + @property + def codomain(self): + """ + Returns the codomain of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.codomain + Object("B") + + """ + return self.args[1] + + def compose(self, other): + r""" + Composes self with the supplied morphism. + + The order of elements in the composition is the usual order, + i.e., to construct `g\circ f` use ``g.compose(f)``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> g * f + CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g"))) + >>> (g * f).domain + Object("A") + >>> (g * f).codomain + Object("C") + + """ + return CompositeMorphism(other, self) + + def __mul__(self, other): + r""" + Composes self with the supplied morphism. + + The semantics of this operation is given by the following + equation: ``g * f == g.compose(f)`` for composable morphisms + ``g`` and ``f``. + + See Also + ======== + + compose + """ + return self.compose(other) + + +class IdentityMorphism(Morphism): + """ + Represents an identity morphism. + + Explanation + =========== + + An identity morphism is a morphism with equal domain and codomain, + which acts as an identity with respect to composition. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, IdentityMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> id_A = IdentityMorphism(A) + >>> id_B = IdentityMorphism(B) + >>> f * id_A == f + True + >>> id_B * f == f + True + + See Also + ======== + + Morphism + """ + def __new__(cls, domain): + return Basic.__new__(cls, domain) + + @property + def codomain(self): + return self.domain + + +class NamedMorphism(Morphism): + """ + Represents a morphism which has a name. + + Explanation + =========== + + Names are used to distinguish between morphisms which have the + same domain and codomain: two named morphisms are equal if they + have the same domains, codomains, and names. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f + NamedMorphism(Object("A"), Object("B"), "f") + >>> f.name + 'f' + + See Also + ======== + + Morphism + """ + def __new__(cls, domain, codomain, name): + if not name: + raise ValueError("Empty morphism names not allowed.") + + if not isinstance(name, Str): + name = Str(name) + + return Basic.__new__(cls, domain, codomain, name) + + @property + def name(self): + """ + Returns the name of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.name + 'f' + + """ + return self.args[2].name + + +class CompositeMorphism(Morphism): + r""" + Represents a morphism which is a composition of other morphisms. + + Explanation + =========== + + Two composite morphisms are equal if the morphisms they were + obtained from (components) are the same and were listed in the + same order. + + The arguments to the constructor for this class should be listed + in diagram order: to obtain the composition `g\circ f` from the + instances of :class:`Morphism` ``g`` and ``f`` use + ``CompositeMorphism(f, g)``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, CompositeMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> g * f + CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g"))) + >>> CompositeMorphism(f, g) == g * f + True + + """ + @staticmethod + def _add_morphism(t, morphism): + """ + Intelligently adds ``morphism`` to tuple ``t``. + + Explanation + =========== + + If ``morphism`` is a composite morphism, its components are + added to the tuple. If ``morphism`` is an identity, nothing + is added to the tuple. + + No composability checks are performed. + """ + if isinstance(morphism, CompositeMorphism): + # ``morphism`` is a composite morphism; we have to + # denest its components. + return t + morphism.components + elif isinstance(morphism, IdentityMorphism): + # ``morphism`` is an identity. Nothing happens. + return t + else: + return t + Tuple(morphism) + + def __new__(cls, *components): + if components and not isinstance(components[0], Morphism): + # Maybe the user has explicitly supplied a list of + # morphisms. + return CompositeMorphism.__new__(cls, *components[0]) + + normalised_components = Tuple() + + for current, following in zip(components, components[1:]): + if not isinstance(current, Morphism) or \ + not isinstance(following, Morphism): + raise TypeError("All components must be morphisms.") + + if current.codomain != following.domain: + raise ValueError("Uncomposable morphisms.") + + normalised_components = CompositeMorphism._add_morphism( + normalised_components, current) + + # We haven't added the last morphism to the list of normalised + # components. Add it now. + normalised_components = CompositeMorphism._add_morphism( + normalised_components, components[-1]) + + if not normalised_components: + # If ``normalised_components`` is empty, only identities + # were supplied. Since they all were composable, they are + # all the same identities. + return components[0] + elif len(normalised_components) == 1: + # No sense to construct a whole CompositeMorphism. + return normalised_components[0] + + return Basic.__new__(cls, normalised_components) + + @property + def components(self): + """ + Returns the components of this composite morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).components + (NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g")) + + """ + return self.args[0] + + @property + def domain(self): + """ + Returns the domain of this composite morphism. + + The domain of the composite morphism is the domain of its + first component. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).domain + Object("A") + + """ + return self.components[0].domain + + @property + def codomain(self): + """ + Returns the codomain of this composite morphism. + + The codomain of the composite morphism is the codomain of its + last component. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).codomain + Object("C") + + """ + return self.components[-1].codomain + + def flatten(self, new_name): + """ + Forgets the composite structure of this morphism. + + Explanation + =========== + + If ``new_name`` is not empty, returns a :class:`NamedMorphism` + with the supplied name, otherwise returns a :class:`Morphism`. + In both cases the domain of the new morphism is the domain of + this composite morphism and the codomain of the new morphism + is the codomain of this composite morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).flatten("h") + NamedMorphism(Object("A"), Object("C"), "h") + + """ + return NamedMorphism(self.domain, self.codomain, new_name) + + +class Category(Basic): + r""" + An (abstract) category. + + Explanation + =========== + + A category [JoyOfCats] is a quadruple `\mbox{K} = (O, \hom, id, + \circ)` consisting of + + * a (set-theoretical) class `O`, whose members are called + `K`-objects, + + * for each pair `(A, B)` of `K`-objects, a set `\hom(A, B)` whose + members are called `K`-morphisms from `A` to `B`, + + * for a each `K`-object `A`, a morphism `id:A\rightarrow A`, + called the `K`-identity of `A`, + + * a composition law `\circ` associating with every `K`-morphisms + `f:A\rightarrow B` and `g:B\rightarrow C` a `K`-morphism `g\circ + f:A\rightarrow C`, called the composite of `f` and `g`. + + Composition is associative, `K`-identities are identities with + respect to composition, and the sets `\hom(A, B)` are pairwise + disjoint. + + This class knows nothing about its objects and morphisms. + Concrete cases of (abstract) categories should be implemented as + classes derived from this one. + + Certain instances of :class:`Diagram` can be asserted to be + commutative in a :class:`Category` by supplying the argument + ``commutative_diagrams`` in the constructor. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> K = Category("K", commutative_diagrams=[d]) + >>> K.commutative_diagrams == FiniteSet(d) + True + + See Also + ======== + + Diagram + """ + def __new__(cls, name, objects=EmptySet, commutative_diagrams=EmptySet): + if not name: + raise ValueError("A Category cannot have an empty name.") + + if not isinstance(name, Str): + name = Str(name) + + if not isinstance(objects, Class): + objects = Class(objects) + + new_category = Basic.__new__(cls, name, objects, + FiniteSet(*commutative_diagrams)) + return new_category + + @property + def name(self): + """ + Returns the name of this category. + + Examples + ======== + + >>> from sympy.categories import Category + >>> K = Category("K") + >>> K.name + 'K' + + """ + return self.args[0].name + + @property + def objects(self): + """ + Returns the class of objects of this category. + + Examples + ======== + + >>> from sympy.categories import Object, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> K = Category("K", FiniteSet(A, B)) + >>> K.objects + Class({Object("A"), Object("B")}) + + """ + return self.args[1] + + @property + def commutative_diagrams(self): + """ + Returns the :class:`~.FiniteSet` of diagrams which are known to + be commutative in this category. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> K = Category("K", commutative_diagrams=[d]) + >>> K.commutative_diagrams == FiniteSet(d) + True + + """ + return self.args[2] + + def hom(self, A, B): + raise NotImplementedError( + "hom-sets are not implemented in Category.") + + def all_morphisms(self): + raise NotImplementedError( + "Obtaining the class of morphisms is not implemented in Category.") + + +class Diagram(Basic): + r""" + Represents a diagram in a certain category. + + Explanation + =========== + + Informally, a diagram is a collection of objects of a category and + certain morphisms between them. A diagram is still a monoid with + respect to morphism composition; i.e., identity morphisms, as well + as all composites of morphisms included in the diagram belong to + the diagram. For a more formal approach to this notion see + [Pare1970]. + + The components of composite morphisms are also added to the + diagram. No properties are assigned to such morphisms by default. + + A commutative diagram is often accompanied by a statement of the + following kind: "if such morphisms with such properties exist, + then such morphisms which such properties exist and the diagram is + commutative". To represent this, an instance of :class:`Diagram` + includes a collection of morphisms which are the premises and + another collection of conclusions. ``premises`` and + ``conclusions`` associate morphisms belonging to the corresponding + categories with the :class:`~.FiniteSet`'s of their properties. + + The set of properties of a composite morphism is the intersection + of the sets of properties of its components. The domain and + codomain of a conclusion morphism should be among the domains and + codomains of the morphisms listed as the premises of a diagram. + + No checks are carried out of whether the supplied object and + morphisms do belong to one and the same category. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import pprint, default_sort_key + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> premises_keys = sorted(d.premises.keys(), key=default_sort_key) + >>> pprint(premises_keys, use_unicode=False) + [g*f:A-->C, id:A-->A, id:B-->B, id:C-->C, f:A-->B, g:B-->C] + >>> pprint(d.premises, use_unicode=False) + {g*f:A-->C: EmptySet, id:A-->A: EmptySet, id:B-->B: EmptySet, + id:C-->C: EmptySet, f:A-->B: EmptySet, g:B-->C: EmptySet} + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> pprint(d.conclusions,use_unicode=False) + {g*f:A-->C: {unique}} + + References + ========== + + [Pare1970] B. Pareigis: Categories and functors. Academic Press, 1970. + + """ + @staticmethod + def _set_dict_union(dictionary, key, value): + """ + If ``key`` is in ``dictionary``, set the new value of ``key`` + to be the union between the old value and ``value``. + Otherwise, set the value of ``key`` to ``value. + + Returns ``True`` if the key already was in the dictionary and + ``False`` otherwise. + """ + if key in dictionary: + dictionary[key] = dictionary[key] | value + return True + else: + dictionary[key] = value + return False + + @staticmethod + def _add_morphism_closure(morphisms, morphism, props, add_identities=True, + recurse_composites=True): + """ + Adds a morphism and its attributes to the supplied dictionary + ``morphisms``. If ``add_identities`` is True, also adds the + identity morphisms for the domain and the codomain of + ``morphism``. + """ + if not Diagram._set_dict_union(morphisms, morphism, props): + # We have just added a new morphism. + + if isinstance(morphism, IdentityMorphism): + if props: + # Properties for identity morphisms don't really + # make sense, because very much is known about + # identity morphisms already, so much that they + # are trivial. Having properties for identity + # morphisms would only be confusing. + raise ValueError( + "Instances of IdentityMorphism cannot have properties.") + return + + if add_identities: + empty = EmptySet + + id_dom = IdentityMorphism(morphism.domain) + id_cod = IdentityMorphism(morphism.codomain) + + Diagram._set_dict_union(morphisms, id_dom, empty) + Diagram._set_dict_union(morphisms, id_cod, empty) + + for existing_morphism, existing_props in list(morphisms.items()): + new_props = existing_props & props + if morphism.domain == existing_morphism.codomain: + left = morphism * existing_morphism + Diagram._set_dict_union(morphisms, left, new_props) + if morphism.codomain == existing_morphism.domain: + right = existing_morphism * morphism + Diagram._set_dict_union(morphisms, right, new_props) + + if isinstance(morphism, CompositeMorphism) and recurse_composites: + # This is a composite morphism, add its components as + # well. + empty = EmptySet + for component in morphism.components: + Diagram._add_morphism_closure(morphisms, component, empty, + add_identities) + + def __new__(cls, *args): + """ + Construct a new instance of Diagram. + + Explanation + =========== + + If no arguments are supplied, an empty diagram is created. + + If at least an argument is supplied, ``args[0]`` is + interpreted as the premises of the diagram. If ``args[0]`` is + a list, it is interpreted as a list of :class:`Morphism`'s, in + which each :class:`Morphism` has an empty set of properties. + If ``args[0]`` is a Python dictionary or a :class:`Dict`, it + is interpreted as a dictionary associating to some + :class:`Morphism`'s some properties. + + If at least two arguments are supplied ``args[1]`` is + interpreted as the conclusions of the diagram. The type of + ``args[1]`` is interpreted in exactly the same way as the type + of ``args[0]``. If only one argument is supplied, the diagram + has no conclusions. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> IdentityMorphism(A) in d.premises.keys() + True + >>> g * f in d.premises.keys() + True + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d.conclusions[g * f] + {unique} + + """ + premises = {} + conclusions = {} + + # Here we will keep track of the objects which appear in the + # premises. + objects = EmptySet + + if len(args) >= 1: + # We've got some premises in the arguments. + premises_arg = args[0] + + if isinstance(premises_arg, list): + # The user has supplied a list of morphisms, none of + # which have any attributes. + empty = EmptySet + + for morphism in premises_arg: + objects |= FiniteSet(morphism.domain, morphism.codomain) + Diagram._add_morphism_closure(premises, morphism, empty) + elif isinstance(premises_arg, (dict, Dict)): + # The user has supplied a dictionary of morphisms and + # their properties. + for morphism, props in premises_arg.items(): + objects |= FiniteSet(morphism.domain, morphism.codomain) + Diagram._add_morphism_closure( + premises, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props)) + + if len(args) >= 2: + # We also have some conclusions. + conclusions_arg = args[1] + + if isinstance(conclusions_arg, list): + # The user has supplied a list of morphisms, none of + # which have any attributes. + empty = EmptySet + + for morphism in conclusions_arg: + # Check that no new objects appear in conclusions. + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + # No need to add identities and recurse + # composites this time. + Diagram._add_morphism_closure( + conclusions, morphism, empty, add_identities=False, + recurse_composites=False) + elif isinstance(conclusions_arg, (dict, Dict)): + # The user has supplied a dictionary of morphisms and + # their properties. + for morphism, props in conclusions_arg.items(): + # Check that no new objects appear in conclusions. + if (morphism.domain in objects) and \ + (morphism.codomain in objects): + # No need to add identities and recurse + # composites this time. + Diagram._add_morphism_closure( + conclusions, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props), + add_identities=False, recurse_composites=False) + + return Basic.__new__(cls, Dict(premises), Dict(conclusions), objects) + + @property + def premises(self): + """ + Returns the premises of this diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> from sympy import pretty + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> id_A = IdentityMorphism(A) + >>> id_B = IdentityMorphism(B) + >>> d = Diagram([f]) + >>> print(pretty(d.premises, use_unicode=False)) + {id:A-->A: EmptySet, id:B-->B: EmptySet, f:A-->B: EmptySet} + + """ + return self.args[0] + + @property + def conclusions(self): + """ + Returns the conclusions of this diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> IdentityMorphism(A) in d.premises.keys() + True + >>> g * f in d.premises.keys() + True + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d.conclusions[g * f] == FiniteSet("unique") + True + + """ + return self.args[1] + + @property + def objects(self): + """ + Returns the :class:`~.FiniteSet` of objects that appear in this + diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> d.objects + {Object("A"), Object("B"), Object("C")} + + """ + return self.args[2] + + def hom(self, A, B): + """ + Returns a 2-tuple of sets of morphisms between objects ``A`` and + ``B``: one set of morphisms listed as premises, and the other set + of morphisms listed as conclusions. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import pretty + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> print(pretty(d.hom(A, C), use_unicode=False)) + ({g*f:A-->C}, {g*f:A-->C}) + + See Also + ======== + Object, Morphism + """ + premises = EmptySet + conclusions = EmptySet + + for morphism in self.premises.keys(): + if (morphism.domain == A) and (morphism.codomain == B): + premises |= FiniteSet(morphism) + for morphism in self.conclusions.keys(): + if (morphism.domain == A) and (morphism.codomain == B): + conclusions |= FiniteSet(morphism) + + return (premises, conclusions) + + def is_subdiagram(self, diagram): + """ + Checks whether ``diagram`` is a subdiagram of ``self``. + Diagram `D'` is a subdiagram of `D` if all premises + (conclusions) of `D'` are contained in the premises + (conclusions) of `D`. The morphisms contained + both in `D'` and `D` should have the same properties for `D'` + to be a subdiagram of `D`. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d1 = Diagram([f]) + >>> d.is_subdiagram(d1) + True + >>> d1.is_subdiagram(d) + False + """ + premises = all((m in self.premises) and + (diagram.premises[m] == self.premises[m]) + for m in diagram.premises) + if not premises: + return False + + conclusions = all((m in self.conclusions) and + (diagram.conclusions[m] == self.conclusions[m]) + for m in diagram.conclusions) + + # Premises is surely ``True`` here. + return conclusions + + def subdiagram_from_objects(self, objects): + """ + If ``objects`` is a subset of the objects of ``self``, returns + a diagram which has as premises all those premises of ``self`` + which have a domains and codomains in ``objects``, likewise + for conclusions. Properties are preserved. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {f: "unique", g*f: "veryunique"}) + >>> d1 = d.subdiagram_from_objects(FiniteSet(A, B)) + >>> d1 == Diagram([f], {f: "unique"}) + True + """ + if not objects.is_subset(self.objects): + raise ValueError( + "Supplied objects should all belong to the diagram.") + + new_premises = {} + for morphism, props in self.premises.items(): + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + new_premises[morphism] = props + + new_conclusions = {} + for morphism, props in self.conclusions.items(): + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + new_conclusions[morphism] = props + + return Diagram(new_premises, new_conclusions) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/diagram_drawing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/diagram_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9b507cd86cf0e633b5abf7a0c9a353740af334 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/diagram_drawing.py @@ -0,0 +1,2580 @@ +r""" +This module contains the functionality to arrange the nodes of a +diagram on an abstract grid, and then to produce a graphical +representation of the grid. + +The currently supported back-ends are Xy-pic [Xypic]. + +Layout Algorithm +================ + +This section provides an overview of the algorithms implemented in +:class:`DiagramGrid` to lay out diagrams. + +The first step of the algorithm is the removal composite and identity +morphisms which do not have properties in the supplied diagram. The +premises and conclusions of the diagram are then merged. + +The generic layout algorithm begins with the construction of the +"skeleton" of the diagram. The skeleton is an undirected graph which +has the objects of the diagram as vertices and has an (undirected) +edge between each pair of objects between which there exist morphisms. +The direction of the morphisms does not matter at this stage. The +skeleton also includes an edge between each pair of vertices `A` and +`C` such that there exists an object `B` which is connected via +a morphism to `A`, and via a morphism to `C`. + +The skeleton constructed in this way has the property that every +object is a vertex of a triangle formed by three edges of the +skeleton. This property lies at the base of the generic layout +algorithm. + +After the skeleton has been constructed, the algorithm lists all +triangles which can be formed. Note that some triangles will not have +all edges corresponding to morphisms which will actually be drawn. +Triangles which have only one edge or less which will actually be +drawn are immediately discarded. + +The list of triangles is sorted according to the number of edges which +correspond to morphisms, then the triangle with the least number of such +edges is selected. One of such edges is picked and the corresponding +objects are placed horizontally, on a grid. This edge is recorded to +be in the fringe. The algorithm then finds a "welding" of a triangle +to the fringe. A welding is an edge in the fringe where a triangle +could be attached. If the algorithm succeeds in finding such a +welding, it adds to the grid that vertex of the triangle which was not +yet included in any edge in the fringe and records the two new edges in +the fringe. This process continues iteratively until all objects of +the diagram has been placed or until no more weldings can be found. + +An edge is only removed from the fringe when a welding to this edge +has been found, and there is no room around this edge to place +another vertex. + +When no more weldings can be found, but there are still triangles +left, the algorithm searches for a possibility of attaching one of the +remaining triangles to the existing structure by a vertex. If such a +possibility is found, the corresponding edge of the found triangle is +placed in the found space and the iterative process of welding +triangles restarts. + +When logical groups are supplied, each of these groups is laid out +independently. Then a diagram is constructed in which groups are +objects and any two logical groups between which there exist morphisms +are connected via a morphism. This diagram is laid out. Finally, +the grid which includes all objects of the initial diagram is +constructed by replacing the cells which contain logical groups with +the corresponding laid out grids, and by correspondingly expanding the +rows and columns. + +The sequential layout algorithm begins by constructing the +underlying undirected graph defined by the morphisms obtained after +simplifying premises and conclusions and merging them (see above). +The vertex with the minimal degree is then picked up and depth-first +search is started from it. All objects which are located at distance +`n` from the root in the depth-first search tree, are positioned in +the `n`-th column of the resulting grid. The sequential layout will +therefore attempt to lay the objects out along a line. + +References +========== + +.. [Xypic] https://xy-pic.sourceforge.net/ + +""" +from sympy.categories import (CompositeMorphism, IdentityMorphism, + NamedMorphism, Diagram) +from sympy.core import Dict, Symbol, default_sort_key +from sympy.printing.latex import latex +from sympy.sets import FiniteSet +from sympy.utilities.iterables import iterable +from sympy.utilities.decorator import doctest_depends_on + +from itertools import chain + + +__doctest_requires__ = {('preview_diagram',): 'pyglet'} + + +class _GrowableGrid: + """ + Holds a growable grid of objects. + + Explanation + =========== + + It is possible to append or prepend a row or a column to the grid + using the corresponding methods. Prepending rows or columns has + the effect of changing the coordinates of the already existing + elements. + + This class currently represents a naive implementation of the + functionality with little attempt at optimisation. + """ + def __init__(self, width, height): + self._width = width + self._height = height + + self._array = [[None for j in range(width)] for i in range(height)] + + @property + def width(self): + return self._width + + @property + def height(self): + return self._height + + def __getitem__(self, i_j): + """ + Returns the element located at in the i-th line and j-th + column. + """ + i, j = i_j + return self._array[i][j] + + def __setitem__(self, i_j, newvalue): + """ + Sets the element located at in the i-th line and j-th + column. + """ + i, j = i_j + self._array[i][j] = newvalue + + def append_row(self): + """ + Appends an empty row to the grid. + """ + self._height += 1 + self._array.append([None for j in range(self._width)]) + + def append_column(self): + """ + Appends an empty column to the grid. + """ + self._width += 1 + for i in range(self._height): + self._array[i].append(None) + + def prepend_row(self): + """ + Prepends the grid with an empty row. + """ + self._height += 1 + self._array.insert(0, [None for j in range(self._width)]) + + def prepend_column(self): + """ + Prepends the grid with an empty column. + """ + self._width += 1 + for i in range(self._height): + self._array[i].insert(0, None) + + +class DiagramGrid: + r""" + Constructs and holds the fitting of the diagram into a grid. + + Explanation + =========== + + The mission of this class is to analyse the structure of the + supplied diagram and to place its objects on a grid such that, + when the objects and the morphisms are actually drawn, the diagram + would be "readable", in the sense that there will not be many + intersections of moprhisms. This class does not perform any + actual drawing. It does strive nevertheless to offer sufficient + metadata to draw a diagram. + + Consider the following simple diagram. + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> from sympy import pprint + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + + The simplest way to have a diagram laid out is the following: + + >>> grid = DiagramGrid(diagram) + >>> (grid.width, grid.height) + (2, 2) + >>> pprint(grid) + A B + + C + + Sometimes one sees the diagram as consisting of logical groups. + One can advise ``DiagramGrid`` as to such groups by employing the + ``groups`` keyword argument. + + Consider the following diagram: + + >>> D = Object("D") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> h = NamedMorphism(D, A, "h") + >>> k = NamedMorphism(D, B, "k") + >>> diagram = Diagram([f, g, h, k]) + + Lay it out with generic layout: + + >>> grid = DiagramGrid(diagram) + >>> pprint(grid) + A B D + + C + + Now, we can group the objects `A` and `D` to have them near one + another: + + >>> grid = DiagramGrid(diagram, groups=[[A, D], B, C]) + >>> pprint(grid) + B C + + A D + + Note how the positioning of the other objects changes. + + Further indications can be supplied to the constructor of + :class:`DiagramGrid` using keyword arguments. The currently + supported hints are explained in the following paragraphs. + + :class:`DiagramGrid` does not automatically guess which layout + would suit the supplied diagram better. Consider, for example, + the following linear diagram: + + >>> E = Object("E") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> h = NamedMorphism(C, D, "h") + >>> i = NamedMorphism(D, E, "i") + >>> diagram = Diagram([f, g, h, i]) + + When laid out with the generic layout, it does not get to look + linear: + + >>> grid = DiagramGrid(diagram) + >>> pprint(grid) + A B + + C D + + E + + To get it laid out in a line, use ``layout="sequential"``: + + >>> grid = DiagramGrid(diagram, layout="sequential") + >>> pprint(grid) + A B C D E + + One may sometimes need to transpose the resulting layout. While + this can always be done by hand, :class:`DiagramGrid` provides a + hint for that purpose: + + >>> grid = DiagramGrid(diagram, layout="sequential", transpose=True) + >>> pprint(grid) + A + + B + + C + + D + + E + + Separate hints can also be provided for each group. For an + example, refer to ``tests/test_drawing.py``, and see the different + ways in which the five lemma [FiveLemma] can be laid out. + + See Also + ======== + + Diagram + + References + ========== + + .. [FiveLemma] https://en.wikipedia.org/wiki/Five_lemma + """ + @staticmethod + def _simplify_morphisms(morphisms): + """ + Given a dictionary mapping morphisms to their properties, + returns a new dictionary in which there are no morphisms which + do not have properties, and which are compositions of other + morphisms included in the dictionary. Identities are dropped + as well. + """ + newmorphisms = {} + for morphism, props in morphisms.items(): + if isinstance(morphism, CompositeMorphism) and not props: + continue + elif isinstance(morphism, IdentityMorphism): + continue + else: + newmorphisms[morphism] = props + return newmorphisms + + @staticmethod + def _merge_premises_conclusions(premises, conclusions): + """ + Given two dictionaries of morphisms and their properties, + produces a single dictionary which includes elements from both + dictionaries. If a morphism has some properties in premises + and also in conclusions, the properties in conclusions take + priority. + """ + return dict(chain(premises.items(), conclusions.items())) + + @staticmethod + def _juxtapose_edges(edge1, edge2): + """ + If ``edge1`` and ``edge2`` have precisely one common endpoint, + returns an edge which would form a triangle with ``edge1`` and + ``edge2``. + + If ``edge1`` and ``edge2`` do not have a common endpoint, + returns ``None``. + + If ``edge1`` and ``edge`` are the same edge, returns ``None``. + """ + intersection = edge1 & edge2 + if len(intersection) != 1: + # The edges either have no common points or are equal. + return None + + # The edges have a common endpoint. Extract the different + # endpoints and set up the new edge. + return (edge1 - intersection) | (edge2 - intersection) + + @staticmethod + def _add_edge_append(dictionary, edge, elem): + """ + If ``edge`` is not in ``dictionary``, adds ``edge`` to the + dictionary and sets its value to ``[elem]``. Otherwise + appends ``elem`` to the value of existing entry. + + Note that edges are undirected, thus `(A, B) = (B, A)`. + """ + if edge in dictionary: + dictionary[edge].append(elem) + else: + dictionary[edge] = [elem] + + @staticmethod + def _build_skeleton(morphisms): + """ + Creates a dictionary which maps edges to corresponding + morphisms. Thus for a morphism `f:A\rightarrow B`, the edge + `(A, B)` will be associated with `f`. This function also adds + to the list those edges which are formed by juxtaposition of + two edges already in the list. These new edges are not + associated with any morphism and are only added to assure that + the diagram can be decomposed into triangles. + """ + edges = {} + # Create edges for morphisms. + for morphism in morphisms: + DiagramGrid._add_edge_append( + edges, frozenset([morphism.domain, morphism.codomain]), morphism) + + # Create new edges by juxtaposing existing edges. + edges1 = dict(edges) + for w in edges1: + for v in edges1: + wv = DiagramGrid._juxtapose_edges(w, v) + if wv and wv not in edges: + edges[wv] = [] + + return edges + + @staticmethod + def _list_triangles(edges): + """ + Builds the set of triangles formed by the supplied edges. The + triangles are arbitrary and need not be commutative. A + triangle is a set that contains all three of its sides. + """ + triangles = set() + + for w in edges: + for v in edges: + wv = DiagramGrid._juxtapose_edges(w, v) + if wv and wv in edges: + triangles.add(frozenset([w, v, wv])) + + return triangles + + @staticmethod + def _drop_redundant_triangles(triangles, skeleton): + """ + Returns a list which contains only those triangles who have + morphisms associated with at least two edges. + """ + return [tri for tri in triangles + if len([e for e in tri if skeleton[e]]) >= 2] + + @staticmethod + def _morphism_length(morphism): + """ + Returns the length of a morphism. The length of a morphism is + the number of components it consists of. A non-composite + morphism is of length 1. + """ + if isinstance(morphism, CompositeMorphism): + return len(morphism.components) + else: + return 1 + + @staticmethod + def _compute_triangle_min_sizes(triangles, edges): + r""" + Returns a dictionary mapping triangles to their minimal sizes. + The minimal size of a triangle is the sum of maximal lengths + of morphisms associated to the sides of the triangle. The + length of a morphism is the number of components it consists + of. A non-composite morphism is of length 1. + + Sorting triangles by this metric attempts to address two + aspects of layout. For triangles with only simple morphisms + in the edge, this assures that triangles with all three edges + visible will get typeset after triangles with less visible + edges, which sometimes minimizes the necessity in diagonal + arrows. For triangles with composite morphisms in the edges, + this assures that objects connected with shorter morphisms + will be laid out first, resulting the visual proximity of + those objects which are connected by shorter morphisms. + """ + triangle_sizes = {} + for triangle in triangles: + size = 0 + for e in triangle: + morphisms = edges[e] + if morphisms: + size += max(DiagramGrid._morphism_length(m) + for m in morphisms) + triangle_sizes[triangle] = size + return triangle_sizes + + @staticmethod + def _triangle_objects(triangle): + """ + Given a triangle, returns the objects included in it. + """ + # A triangle is a frozenset of three two-element frozensets + # (the edges). This chains the three edges together and + # creates a frozenset from the iterator, thus producing a + # frozenset of objects of the triangle. + return frozenset(chain(*tuple(triangle))) + + @staticmethod + def _other_vertex(triangle, edge): + """ + Given a triangle and an edge of it, returns the vertex which + opposes the edge. + """ + # This gets the set of objects of the triangle and then + # subtracts the set of objects employed in ``edge`` to get the + # vertex opposite to ``edge``. + return list(DiagramGrid._triangle_objects(triangle) - set(edge))[0] + + @staticmethod + def _empty_point(pt, grid): + """ + Checks if the cell at coordinates ``pt`` is either empty or + out of the bounds of the grid. + """ + if (pt[0] < 0) or (pt[1] < 0) or \ + (pt[0] >= grid.height) or (pt[1] >= grid.width): + return True + return grid[pt] is None + + @staticmethod + def _put_object(coords, obj, grid, fringe): + """ + Places an object at the coordinate ``cords`` in ``grid``, + growing the grid and updating ``fringe``, if necessary. + Returns (0, 0) if no row or column has been prepended, (1, 0) + if a row was prepended, (0, 1) if a column was prepended and + (1, 1) if both a column and a row were prepended. + """ + (i, j) = coords + offset = (0, 0) + if i == -1: + grid.prepend_row() + i = 0 + offset = (1, 0) + for k in range(len(fringe)): + ((i1, j1), (i2, j2)) = fringe[k] + fringe[k] = ((i1 + 1, j1), (i2 + 1, j2)) + elif i == grid.height: + grid.append_row() + + if j == -1: + j = 0 + offset = (offset[0], 1) + grid.prepend_column() + for k in range(len(fringe)): + ((i1, j1), (i2, j2)) = fringe[k] + fringe[k] = ((i1, j1 + 1), (i2, j2 + 1)) + elif j == grid.width: + grid.append_column() + + grid[i, j] = obj + return offset + + @staticmethod + def _choose_target_cell(pt1, pt2, edge, obj, skeleton, grid): + """ + Given two points, ``pt1`` and ``pt2``, and the welding edge + ``edge``, chooses one of the two points to place the opposing + vertex ``obj`` of the triangle. If neither of this points + fits, returns ``None``. + """ + pt1_empty = DiagramGrid._empty_point(pt1, grid) + pt2_empty = DiagramGrid._empty_point(pt2, grid) + + if pt1_empty and pt2_empty: + # Both cells are empty. Of these two, choose that cell + # which will assure that a visible edge of the triangle + # will be drawn perpendicularly to the current welding + # edge. + + A = grid[edge[0]] + + if skeleton.get(frozenset([A, obj])): + return pt1 + else: + return pt2 + if pt1_empty: + return pt1 + elif pt2_empty: + return pt2 + else: + return None + + @staticmethod + def _find_triangle_to_weld(triangles, fringe, grid): + """ + Finds, if possible, a triangle and an edge in the ``fringe`` to + which the triangle could be attached. Returns the tuple + containing the triangle and the index of the corresponding + edge in the ``fringe``. + + This function relies on the fact that objects are unique in + the diagram. + """ + for triangle in triangles: + for (a, b) in fringe: + if frozenset([grid[a], grid[b]]) in triangle: + return (triangle, (a, b)) + return None + + @staticmethod + def _weld_triangle(tri, welding_edge, fringe, grid, skeleton): + """ + If possible, welds the triangle ``tri`` to ``fringe`` and + returns ``False``. If this method encounters a degenerate + situation in the fringe and corrects it such that a restart of + the search is required, it returns ``True`` (which means that + a restart in finding triangle weldings is required). + + A degenerate situation is a situation when an edge listed in + the fringe does not belong to the visual boundary of the + diagram. + """ + a, b = welding_edge + target_cell = None + + obj = DiagramGrid._other_vertex(tri, (grid[a], grid[b])) + + # We now have a triangle and an edge where it can be welded to + # the fringe. Decide where to place the other vertex of the + # triangle and check for degenerate situations en route. + + if (abs(a[0] - b[0]) == 1) and (abs(a[1] - b[1]) == 1): + # A diagonal edge. + target_cell = (a[0], b[1]) + if grid[target_cell]: + # That cell is already occupied. + target_cell = (b[0], a[1]) + + if grid[target_cell]: + # Degenerate situation, this edge is not + # on the actual fringe. Correct the + # fringe and go on. + fringe.remove((a, b)) + return True + elif a[0] == b[0]: + # A horizontal edge. We first attempt to build the + # triangle in the downward direction. + + down_left = a[0] + 1, a[1] + down_right = a[0] + 1, b[1] + + target_cell = DiagramGrid._choose_target_cell( + down_left, down_right, (a, b), obj, skeleton, grid) + + if not target_cell: + # No room below this edge. Check above. + up_left = a[0] - 1, a[1] + up_right = a[0] - 1, b[1] + + target_cell = DiagramGrid._choose_target_cell( + up_left, up_right, (a, b), obj, skeleton, grid) + + if not target_cell: + # This edge is not in the fringe, remove it + # and restart. + fringe.remove((a, b)) + return True + elif a[1] == b[1]: + # A vertical edge. We will attempt to place the other + # vertex of the triangle to the right of this edge. + right_up = a[0], a[1] + 1 + right_down = b[0], a[1] + 1 + + target_cell = DiagramGrid._choose_target_cell( + right_up, right_down, (a, b), obj, skeleton, grid) + + if not target_cell: + # No room to the left. See what's to the right. + left_up = a[0], a[1] - 1 + left_down = b[0], a[1] - 1 + + target_cell = DiagramGrid._choose_target_cell( + left_up, left_down, (a, b), obj, skeleton, grid) + + if not target_cell: + # This edge is not in the fringe, remove it + # and restart. + fringe.remove((a, b)) + return True + + # We now know where to place the other vertex of the + # triangle. + offset = DiagramGrid._put_object(target_cell, obj, grid, fringe) + + # Take care of the displacement of coordinates if a row or + # a column was prepended. + target_cell = (target_cell[0] + offset[0], + target_cell[1] + offset[1]) + a = (a[0] + offset[0], a[1] + offset[1]) + b = (b[0] + offset[0], b[1] + offset[1]) + + fringe.extend([(a, target_cell), (b, target_cell)]) + + # No restart is required. + return False + + @staticmethod + def _triangle_key(tri, triangle_sizes): + """ + Returns a key for the supplied triangle. It should be the + same independently of the hash randomisation. + """ + objects = sorted( + DiagramGrid._triangle_objects(tri), key=default_sort_key) + return (triangle_sizes[tri], default_sort_key(objects)) + + @staticmethod + def _pick_root_edge(tri, skeleton): + """ + For a given triangle always picks the same root edge. The + root edge is the edge that will be placed first on the grid. + """ + candidates = [sorted(e, key=default_sort_key) + for e in tri if skeleton[e]] + sorted_candidates = sorted(candidates, key=default_sort_key) + # Don't forget to assure the proper ordering of the vertices + # in this edge. + return tuple(sorted(sorted_candidates[0], key=default_sort_key)) + + @staticmethod + def _drop_irrelevant_triangles(triangles, placed_objects): + """ + Returns only those triangles whose set of objects is not + completely included in ``placed_objects``. + """ + return [tri for tri in triangles if not placed_objects.issuperset( + DiagramGrid._triangle_objects(tri))] + + @staticmethod + def _grow_pseudopod(triangles, fringe, grid, skeleton, placed_objects): + """ + Starting from an object in the existing structure on the ``grid``, + adds an edge to which a triangle from ``triangles`` could be + welded. If this method has found a way to do so, it returns + the object it has just added. + + This method should be applied when ``_weld_triangle`` cannot + find weldings any more. + """ + for i in range(grid.height): + for j in range(grid.width): + obj = grid[i, j] + if not obj: + continue + + # Here we need to choose a triangle which has only + # ``obj`` in common with the existing structure. The + # situations when this is not possible should be + # handled elsewhere. + + def good_triangle(tri): + objs = DiagramGrid._triangle_objects(tri) + return obj in objs and \ + placed_objects & (objs - {obj}) == set() + + tris = [tri for tri in triangles if good_triangle(tri)] + if not tris: + # This object is not interesting. + continue + + # Pick the "simplest" of the triangles which could be + # attached. Remember that the list of triangles is + # sorted according to their "simplicity" (see + # _compute_triangle_min_sizes for the metric). + # + # Note that ``tris`` are sequentially built from + # ``triangles``, so we don't have to worry about hash + # randomisation. + tri = tris[0] + + # We have found a triangle which could be attached to + # the existing structure by a vertex. + + candidates = sorted([e for e in tri if skeleton[e]], + key=lambda e: FiniteSet(*e).sort_key()) + edges = [e for e in candidates if obj in e] + + # Note that a meaningful edge (i.e., and edge that is + # associated with a morphism) containing ``obj`` + # always exists. That's because all triangles are + # guaranteed to have at least two meaningful edges. + # See _drop_redundant_triangles. + + # Get the object at the other end of the edge. + edge = edges[0] + other_obj = tuple(edge - frozenset([obj]))[0] + + # Now check for free directions. When checking for + # free directions, prefer the horizontal and vertical + # directions. + neighbours = [(i - 1, j), (i, j + 1), (i + 1, j), (i, j - 1), + (i - 1, j - 1), (i - 1, j + 1), (i + 1, j - 1), (i + 1, j + 1)] + + for pt in neighbours: + if DiagramGrid._empty_point(pt, grid): + # We have a found a place to grow the + # pseudopod into. + offset = DiagramGrid._put_object( + pt, other_obj, grid, fringe) + + i += offset[0] + j += offset[1] + pt = (pt[0] + offset[0], pt[1] + offset[1]) + fringe.append(((i, j), pt)) + + return other_obj + + # This diagram is actually cooler that I can handle. Fail cowardly. + return None + + @staticmethod + def _handle_groups(diagram, groups, merged_morphisms, hints): + """ + Given the slightly preprocessed morphisms of the diagram, + produces a grid laid out according to ``groups``. + + If a group has hints, it is laid out with those hints only, + without any influence from ``hints``. Otherwise, it is laid + out with ``hints``. + """ + def lay_out_group(group, local_hints): + """ + If ``group`` is a set of objects, uses a ``DiagramGrid`` + to lay it out and returns the grid. Otherwise returns the + object (i.e., ``group``). If ``local_hints`` is not + empty, it is supplied to ``DiagramGrid`` as the dictionary + of hints. Otherwise, the ``hints`` argument of + ``_handle_groups`` is used. + """ + if isinstance(group, FiniteSet): + # Set up the corresponding object-to-group + # mappings. + for obj in group: + obj_groups[obj] = group + + # Lay out the current group. + if local_hints: + groups_grids[group] = DiagramGrid( + diagram.subdiagram_from_objects(group), **local_hints) + else: + groups_grids[group] = DiagramGrid( + diagram.subdiagram_from_objects(group), **hints) + else: + obj_groups[group] = group + + def group_to_finiteset(group): + """ + Converts ``group`` to a :class:``FiniteSet`` if it is an + iterable. + """ + if iterable(group): + return FiniteSet(*group) + else: + return group + + obj_groups = {} + groups_grids = {} + + # We would like to support various containers to represent + # groups. To achieve that, before laying each group out, it + # should be converted to a FiniteSet, because that is what the + # following code expects. + + if isinstance(groups, (dict, Dict)): + finiteset_groups = {} + for group, local_hints in groups.items(): + finiteset_group = group_to_finiteset(group) + finiteset_groups[finiteset_group] = local_hints + lay_out_group(group, local_hints) + groups = finiteset_groups + else: + finiteset_groups = [] + for group in groups: + finiteset_group = group_to_finiteset(group) + finiteset_groups.append(finiteset_group) + lay_out_group(finiteset_group, None) + groups = finiteset_groups + + new_morphisms = [] + for morphism in merged_morphisms: + dom = obj_groups[morphism.domain] + cod = obj_groups[morphism.codomain] + # Note that we are not really interested in morphisms + # which do not employ two different groups, because + # these do not influence the layout. + if dom != cod: + # These are essentially unnamed morphisms; they are + # not going to mess in the final layout. By giving + # them the same names, we avoid unnecessary + # duplicates. + new_morphisms.append(NamedMorphism(dom, cod, "dummy")) + + # Lay out the new diagram. Since these are dummy morphisms, + # properties and conclusions are irrelevant. + top_grid = DiagramGrid(Diagram(new_morphisms)) + + # We now have to substitute the groups with the corresponding + # grids, laid out at the beginning of this function. Compute + # the size of each row and column in the grid, so that all + # nested grids fit. + + def group_size(group): + """ + For the supplied group (or object, eventually), returns + the size of the cell that will hold this group (object). + """ + if group in groups_grids: + grid = groups_grids[group] + return (grid.height, grid.width) + else: + return (1, 1) + + row_heights = [max(group_size(top_grid[i, j])[0] + for j in range(top_grid.width)) + for i in range(top_grid.height)] + + column_widths = [max(group_size(top_grid[i, j])[1] + for i in range(top_grid.height)) + for j in range(top_grid.width)] + + grid = _GrowableGrid(sum(column_widths), sum(row_heights)) + + real_row = 0 + real_column = 0 + for logical_row in range(top_grid.height): + for logical_column in range(top_grid.width): + obj = top_grid[logical_row, logical_column] + + if obj in groups_grids: + # This is a group. Copy the corresponding grid in + # place. + local_grid = groups_grids[obj] + for i in range(local_grid.height): + for j in range(local_grid.width): + grid[real_row + i, + real_column + j] = local_grid[i, j] + else: + # This is an object. Just put it there. + grid[real_row, real_column] = obj + + real_column += column_widths[logical_column] + real_column = 0 + real_row += row_heights[logical_row] + + return grid + + @staticmethod + def _generic_layout(diagram, merged_morphisms): + """ + Produces the generic layout for the supplied diagram. + """ + all_objects = set(diagram.objects) + if len(all_objects) == 1: + # There only one object in the diagram, just put in on 1x1 + # grid. + grid = _GrowableGrid(1, 1) + grid[0, 0] = tuple(all_objects)[0] + return grid + + skeleton = DiagramGrid._build_skeleton(merged_morphisms) + + grid = _GrowableGrid(2, 1) + + if len(skeleton) == 1: + # This diagram contains only one morphism. Draw it + # horizontally. + objects = sorted(all_objects, key=default_sort_key) + grid[0, 0] = objects[0] + grid[0, 1] = objects[1] + + return grid + + triangles = DiagramGrid._list_triangles(skeleton) + triangles = DiagramGrid._drop_redundant_triangles(triangles, skeleton) + triangle_sizes = DiagramGrid._compute_triangle_min_sizes( + triangles, skeleton) + + triangles = sorted(triangles, key=lambda tri: + DiagramGrid._triangle_key(tri, triangle_sizes)) + + # Place the first edge on the grid. + root_edge = DiagramGrid._pick_root_edge(triangles[0], skeleton) + grid[0, 0], grid[0, 1] = root_edge + fringe = [((0, 0), (0, 1))] + + # Record which objects we now have on the grid. + placed_objects = set(root_edge) + + while placed_objects != all_objects: + welding = DiagramGrid._find_triangle_to_weld( + triangles, fringe, grid) + + if welding: + (triangle, welding_edge) = welding + + restart_required = DiagramGrid._weld_triangle( + triangle, welding_edge, fringe, grid, skeleton) + if restart_required: + continue + + placed_objects.update( + DiagramGrid._triangle_objects(triangle)) + else: + # No more weldings found. Try to attach triangles by + # vertices. + new_obj = DiagramGrid._grow_pseudopod( + triangles, fringe, grid, skeleton, placed_objects) + + if not new_obj: + # No more triangles can be attached, not even by + # the edge. We will set up a new diagram out of + # what has been left, laid it out independently, + # and then attach it to this one. + + remaining_objects = all_objects - placed_objects + + remaining_diagram = diagram.subdiagram_from_objects( + FiniteSet(*remaining_objects)) + remaining_grid = DiagramGrid(remaining_diagram) + + # Now, let's glue ``remaining_grid`` to ``grid``. + final_width = grid.width + remaining_grid.width + final_height = max(grid.height, remaining_grid.height) + final_grid = _GrowableGrid(final_width, final_height) + + for i in range(grid.width): + for j in range(grid.height): + final_grid[i, j] = grid[i, j] + + start_j = grid.width + for i in range(remaining_grid.height): + for j in range(remaining_grid.width): + final_grid[i, start_j + j] = remaining_grid[i, j] + + return final_grid + + placed_objects.add(new_obj) + + triangles = DiagramGrid._drop_irrelevant_triangles( + triangles, placed_objects) + + return grid + + @staticmethod + def _get_undirected_graph(objects, merged_morphisms): + """ + Given the objects and the relevant morphisms of a diagram, + returns the adjacency lists of the underlying undirected + graph. + """ + adjlists = {obj: [] for obj in objects} + + for morphism in merged_morphisms: + adjlists[morphism.domain].append(morphism.codomain) + adjlists[morphism.codomain].append(morphism.domain) + + # Assure that the objects in the adjacency list are always in + # the same order. + for obj in adjlists.keys(): + adjlists[obj].sort(key=default_sort_key) + + return adjlists + + @staticmethod + def _sequential_layout(diagram, merged_morphisms): + r""" + Lays out the diagram in "sequential" layout. This method + will attempt to produce a result as close to a line as + possible. For linear diagrams, the result will actually be a + line. + """ + objects = diagram.objects + sorted_objects = sorted(objects, key=default_sort_key) + + # Set up the adjacency lists of the underlying undirected + # graph of ``merged_morphisms``. + adjlists = DiagramGrid._get_undirected_graph(objects, merged_morphisms) + + root = min(sorted_objects, key=lambda x: len(adjlists[x])) + grid = _GrowableGrid(1, 1) + grid[0, 0] = root + + placed_objects = {root} + + def place_objects(pt, placed_objects): + """ + Does depth-first search in the underlying graph of the + diagram and places the objects en route. + """ + # We will start placing new objects from here. + new_pt = (pt[0], pt[1] + 1) + + for adjacent_obj in adjlists[grid[pt]]: + if adjacent_obj in placed_objects: + # This object has already been placed. + continue + + DiagramGrid._put_object(new_pt, adjacent_obj, grid, []) + placed_objects.add(adjacent_obj) + placed_objects.update(place_objects(new_pt, placed_objects)) + + new_pt = (new_pt[0] + 1, new_pt[1]) + + return placed_objects + + place_objects((0, 0), placed_objects) + + return grid + + @staticmethod + def _drop_inessential_morphisms(merged_morphisms): + r""" + Removes those morphisms which should appear in the diagram, + but which have no relevance to object layout. + + Currently this removes "loop" morphisms: the non-identity + morphisms with the same domains and codomains. + """ + morphisms = [m for m in merged_morphisms if m.domain != m.codomain] + return morphisms + + @staticmethod + def _get_connected_components(objects, merged_morphisms): + """ + Given a container of morphisms, returns a list of connected + components formed by these morphisms. A connected component + is represented by a diagram consisting of the corresponding + morphisms. + """ + component_index = {} + for o in objects: + component_index[o] = None + + # Get the underlying undirected graph of the diagram. + adjlist = DiagramGrid._get_undirected_graph(objects, merged_morphisms) + + def traverse_component(object, current_index): + """ + Does a depth-first search traversal of the component + containing ``object``. + """ + component_index[object] = current_index + for o in adjlist[object]: + if component_index[o] is None: + traverse_component(o, current_index) + + # Traverse all components. + current_index = 0 + for o in adjlist: + if component_index[o] is None: + traverse_component(o, current_index) + current_index += 1 + + # List the objects of the components. + component_objects = [[] for i in range(current_index)] + for o, idx in component_index.items(): + component_objects[idx].append(o) + + # Finally, list the morphisms belonging to each component. + # + # Note: If some objects are isolated, they will not get any + # morphisms at this stage, and since the layout algorithm + # relies, we are essentially going to lose this object. + # Therefore, check if there are isolated objects and, for each + # of them, provide the trivial identity morphism. It will get + # discarded later, but the object will be there. + + component_morphisms = [] + for component in component_objects: + current_morphisms = {} + for m in merged_morphisms: + if (m.domain in component) and (m.codomain in component): + current_morphisms[m] = merged_morphisms[m] + + if len(component) == 1: + # Let's add an identity morphism, for the sake of + # surely having morphisms in this component. + current_morphisms[IdentityMorphism(component[0])] = FiniteSet() + + component_morphisms.append(Diagram(current_morphisms)) + + return component_morphisms + + def __init__(self, diagram, groups=None, **hints): + premises = DiagramGrid._simplify_morphisms(diagram.premises) + conclusions = DiagramGrid._simplify_morphisms(diagram.conclusions) + all_merged_morphisms = DiagramGrid._merge_premises_conclusions( + premises, conclusions) + merged_morphisms = DiagramGrid._drop_inessential_morphisms( + all_merged_morphisms) + + # Store the merged morphisms for later use. + self._morphisms = all_merged_morphisms + + components = DiagramGrid._get_connected_components( + diagram.objects, all_merged_morphisms) + + if groups and (groups != diagram.objects): + # Lay out the diagram according to the groups. + self._grid = DiagramGrid._handle_groups( + diagram, groups, merged_morphisms, hints) + elif len(components) > 1: + # Note that we check for connectedness _before_ checking + # the layout hints because the layout strategies don't + # know how to deal with disconnected diagrams. + + # The diagram is disconnected. Lay out the components + # independently. + grids = [] + + # Sort the components to eventually get the grids arranged + # in a fixed, hash-independent order. + components = sorted(components, key=default_sort_key) + + for component in components: + grid = DiagramGrid(component, **hints) + grids.append(grid) + + # Throw the grids together, in a line. + total_width = sum(g.width for g in grids) + total_height = max(g.height for g in grids) + + grid = _GrowableGrid(total_width, total_height) + start_j = 0 + for g in grids: + for i in range(g.height): + for j in range(g.width): + grid[i, start_j + j] = g[i, j] + + start_j += g.width + + self._grid = grid + elif "layout" in hints: + if hints["layout"] == "sequential": + self._grid = DiagramGrid._sequential_layout( + diagram, merged_morphisms) + else: + self._grid = DiagramGrid._generic_layout(diagram, merged_morphisms) + + if hints.get("transpose"): + # Transpose the resulting grid. + grid = _GrowableGrid(self._grid.height, self._grid.width) + for i in range(self._grid.height): + for j in range(self._grid.width): + grid[j, i] = self._grid[i, j] + self._grid = grid + + @property + def width(self): + """ + Returns the number of columns in this diagram layout. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.width + 2 + + """ + return self._grid.width + + @property + def height(self): + """ + Returns the number of rows in this diagram layout. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.height + 2 + + """ + return self._grid.height + + def __getitem__(self, i_j): + """ + Returns the object placed in the row ``i`` and column ``j``. + The indices are 0-based. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> (grid[0, 0], grid[0, 1]) + (Object("A"), Object("B")) + >>> (grid[1, 0], grid[1, 1]) + (None, Object("C")) + + """ + i, j = i_j + return self._grid[i, j] + + @property + def morphisms(self): + """ + Returns those morphisms (and their properties) which are + sufficiently meaningful to be drawn. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.morphisms + {NamedMorphism(Object("A"), Object("B"), "f"): EmptySet, + NamedMorphism(Object("B"), Object("C"), "g"): EmptySet} + + """ + return self._morphisms + + def __str__(self): + """ + Produces a string representation of this class. + + This method returns a string representation of the underlying + list of lists of objects. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> print(grid) + [[Object("A"), Object("B")], + [None, Object("C")]] + + """ + return repr(self._grid._array) + + +class ArrowStringDescription: + r""" + Stores the information necessary for producing an Xy-pic + description of an arrow. + + The principal goal of this class is to abstract away the string + representation of an arrow and to also provide the functionality + to produce the actual Xy-pic string. + + ``unit`` sets the unit which will be used to specify the amount of + curving and other distances. ``horizontal_direction`` should be a + string of ``"r"`` or ``"l"`` specifying the horizontal offset of the + target cell of the arrow relatively to the current one. + ``vertical_direction`` should specify the vertical offset using a + series of either ``"d"`` or ``"u"``. ``label_position`` should be + either ``"^"``, ``"_"``, or ``"|"`` to specify that the label should + be positioned above the arrow, below the arrow or just over the arrow, + in a break. Note that the notions "above" and "below" are relative + to arrow direction. ``label`` stores the morphism label. + + This works as follows (disregard the yet unexplained arguments): + + >>> from sympy.categories.diagram_drawing import ArrowStringDescription + >>> astr = ArrowStringDescription( + ... unit="mm", curving=None, curving_amount=None, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> print(str(astr)) + \ar[dr]_{f} + + ``curving`` should be one of ``"^"``, ``"_"`` to specify in which + direction the arrow is going to curve. ``curving_amount`` is a number + describing how many ``unit``'s the morphism is going to curve: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> print(str(astr)) + \ar@/^12mm/[dr]_{f} + + ``looping_start`` and ``looping_end`` are currently only used for + loop morphisms, those which have the same domain and codomain. + These two attributes should store a valid Xy-pic direction and + specify, correspondingly, the direction the arrow gets out into + and the direction the arrow gets back from: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving=None, curving_amount=None, + ... looping_start="u", looping_end="l", horizontal_direction="", + ... vertical_direction="", label_position="_", label="f") + >>> print(str(astr)) + \ar@(u,l)[]_{f} + + ``label_displacement`` controls how far the arrow label is from + the ends of the arrow. For example, to position the arrow label + near the arrow head, use ">": + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> astr.label_displacement = ">" + >>> print(str(astr)) + \ar@/^12mm/[dr]_>{f} + + Finally, ``arrow_style`` is used to specify the arrow style. To + get a dashed arrow, for example, use "{-->}" as arrow style: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> astr.arrow_style = "{-->}" + >>> print(str(astr)) + \ar@/^12mm/@{-->}[dr]_{f} + + Notes + ===== + + Instances of :class:`ArrowStringDescription` will be constructed + by :class:`XypicDiagramDrawer` and provided for further use in + formatters. The user is not expected to construct instances of + :class:`ArrowStringDescription` themselves. + + To be able to properly utilise this class, the reader is encouraged + to checkout the Xy-pic user guide, available at [Xypic]. + + See Also + ======== + + XypicDiagramDrawer + + References + ========== + + .. [Xypic] https://xy-pic.sourceforge.net/ + """ + def __init__(self, unit, curving, curving_amount, looping_start, + looping_end, horizontal_direction, vertical_direction, + label_position, label): + self.unit = unit + self.curving = curving + self.curving_amount = curving_amount + self.looping_start = looping_start + self.looping_end = looping_end + self.horizontal_direction = horizontal_direction + self.vertical_direction = vertical_direction + self.label_position = label_position + self.label = label + + self.label_displacement = "" + self.arrow_style = "" + + # This flag shows that the position of the label of this + # morphism was set while typesetting a curved morphism and + # should not be modified later. + self.forced_label_position = False + + def __str__(self): + if self.curving: + curving_str = "@/%s%d%s/" % (self.curving, self.curving_amount, + self.unit) + else: + curving_str = "" + + if self.looping_start and self.looping_end: + looping_str = "@(%s,%s)" % (self.looping_start, self.looping_end) + else: + looping_str = "" + + if self.arrow_style: + + style_str = "@" + self.arrow_style + else: + style_str = "" + + return "\\ar%s%s%s[%s%s]%s%s{%s}" % \ + (curving_str, looping_str, style_str, self.horizontal_direction, + self.vertical_direction, self.label_position, + self.label_displacement, self.label) + + +class XypicDiagramDrawer: + r""" + Given a :class:`~.Diagram` and the corresponding + :class:`DiagramGrid`, produces the Xy-pic representation of the + diagram. + + The most important method in this class is ``draw``. Consider the + following triangle diagram: + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + + To draw this diagram, its objects need to be laid out with a + :class:`DiagramGrid`:: + + >>> grid = DiagramGrid(diagram) + + Finally, the drawing: + + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + For further details see the docstring of this method. + + To control the appearance of the arrows, formatters are used. The + dictionary ``arrow_formatters`` maps morphisms to formatter + functions. A formatter is accepts an + :class:`ArrowStringDescription` and is allowed to modify any of + the arrow properties exposed thereby. For example, to have all + morphisms with the property ``unique`` appear as dashed arrows, + and to have their names prepended with `\exists !`, the following + should be done: + + >>> def formatter(astr): + ... astr.label = r"\exists !" + astr.label + ... astr.arrow_style = "{-->}" + >>> drawer.arrow_formatters["unique"] = formatter + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar@{-->}[d]_{\exists !g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + To modify the appearance of all arrows in the diagram, set + ``default_arrow_formatter``. For example, to place all morphism + labels a little bit farther from the arrow head so that they look + more centred, do as follows: + + >>> def default_formatter(astr): + ... astr.label_displacement = "(0.45)" + >>> drawer.default_arrow_formatter = default_formatter + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar@{-->}[d]_(0.45){\exists !g\circ f} \ar[r]^(0.45){f} & B \ar[ld]^(0.45){g} \\ + C & + } + + In some diagrams some morphisms are drawn as curved arrows. + Consider the following diagram: + + >>> D = Object("D") + >>> E = Object("E") + >>> h = NamedMorphism(D, A, "h") + >>> k = NamedMorphism(D, B, "k") + >>> diagram = Diagram([f, g, h, k]) + >>> grid = DiagramGrid(diagram) + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_3mm/[ll]_{h} \\ + & C & + } + + To control how far the morphisms are curved by default, one can + use the ``unit`` and ``default_curving_amount`` attributes: + + >>> drawer.unit = "cm" + >>> drawer.default_curving_amount = 1 + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_1cm/[ll]_{h} \\ + & C & + } + + In some diagrams, there are multiple curved morphisms between the + same two objects. To control by how much the curving changes + between two such successive morphisms, use + ``default_curving_step``: + + >>> drawer.default_curving_step = 1 + >>> h1 = NamedMorphism(A, D, "h1") + >>> diagram = Diagram([f, g, h, k, h1]) + >>> grid = DiagramGrid(diagram) + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} \ar@/^1cm/[rr]^{h_{1}} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_2cm/[ll]_{h} \\ + & C & + } + + The default value of ``default_curving_step`` is 4 units. + + See Also + ======== + + draw, ArrowStringDescription + """ + def __init__(self): + self.unit = "mm" + self.default_curving_amount = 3 + self.default_curving_step = 4 + + # This dictionary maps properties to the corresponding arrow + # formatters. + self.arrow_formatters = {} + + # This is the default arrow formatter which will be applied to + # each arrow independently of its properties. + self.default_arrow_formatter = None + + @staticmethod + def _process_loop_morphism(i, j, grid, morphisms_str_info, object_coords): + """ + Produces the information required for constructing the string + representation of a loop morphism. This function is invoked + from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + curving = "" + label_pos = "^" + looping_start = "" + looping_end = "" + + # This is a loop morphism. Count how many morphisms stick + # in each of the four quadrants. Note that straight + # vertical and horizontal morphisms count in two quadrants + # at the same time (i.e., a morphism going up counts both + # in the first and the second quadrants). + + # The usual numbering (counterclockwise) of quadrants + # applies. + quadrant = [0, 0, 0, 0] + + obj = grid[i, j] + + for m, m_str_info in morphisms_str_info.items(): + if (m.domain == obj) and (m.codomain == obj): + # That's another loop morphism. Check how it + # loops and mark the corresponding quadrants as + # busy. + (l_s, l_e) = (m_str_info.looping_start, m_str_info.looping_end) + + if (l_s, l_e) == ("r", "u"): + quadrant[0] += 1 + elif (l_s, l_e) == ("u", "l"): + quadrant[1] += 1 + elif (l_s, l_e) == ("l", "d"): + quadrant[2] += 1 + elif (l_s, l_e) == ("d", "r"): + quadrant[3] += 1 + + continue + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + goes_out = True + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + goes_out = False + else: + continue + + d_i = end_i - i + d_j = end_j - j + m_curving = m_str_info.curving + + if (d_i != 0) and (d_j != 0): + # This is really a diagonal morphism. Detect the + # quadrant. + if (d_i > 0) and (d_j > 0): + quadrant[0] += 1 + elif (d_i > 0) and (d_j < 0): + quadrant[1] += 1 + elif (d_i < 0) and (d_j < 0): + quadrant[2] += 1 + elif (d_i < 0) and (d_j > 0): + quadrant[3] += 1 + elif d_i == 0: + # Knowing where the other end of the morphism is + # and which way it goes, we now have to decide + # which quadrant is now the upper one and which is + # the lower one. + if d_j > 0: + if goes_out: + upper_quadrant = 0 + lower_quadrant = 3 + else: + upper_quadrant = 3 + lower_quadrant = 0 + else: + if goes_out: + upper_quadrant = 2 + lower_quadrant = 1 + else: + upper_quadrant = 1 + lower_quadrant = 2 + + if m_curving: + if m_curving == "^": + quadrant[upper_quadrant] += 1 + elif m_curving == "_": + quadrant[lower_quadrant] += 1 + else: + # This morphism counts in both upper and lower + # quadrants. + quadrant[upper_quadrant] += 1 + quadrant[lower_quadrant] += 1 + elif d_j == 0: + # Knowing where the other end of the morphism is + # and which way it goes, we now have to decide + # which quadrant is now the left one and which is + # the right one. + if d_i < 0: + if goes_out: + left_quadrant = 1 + right_quadrant = 0 + else: + left_quadrant = 0 + right_quadrant = 1 + else: + if goes_out: + left_quadrant = 3 + right_quadrant = 2 + else: + left_quadrant = 2 + right_quadrant = 3 + + if m_curving: + if m_curving == "^": + quadrant[left_quadrant] += 1 + elif m_curving == "_": + quadrant[right_quadrant] += 1 + else: + # This morphism counts in both upper and lower + # quadrants. + quadrant[left_quadrant] += 1 + quadrant[right_quadrant] += 1 + + # Pick the freest quadrant to curve our morphism into. + freest_quadrant = 0 + for i in range(4): + if quadrant[i] < quadrant[freest_quadrant]: + freest_quadrant = i + + # Now set up proper looping. + (looping_start, looping_end) = [("r", "u"), ("u", "l"), ("l", "d"), + ("d", "r")][freest_quadrant] + + return (curving, label_pos, looping_start, looping_end) + + @staticmethod + def _process_horizontal_morphism(i, j, target_j, grid, morphisms_str_info, + object_coords): + """ + Produces the information required for constructing the string + representation of a horizontal morphism. This function is + invoked from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + # The arrow is horizontal. Check if it goes from left to + # right (``backwards == False``) or from right to left + # (``backwards == True``). + backwards = False + start = j + end = target_j + if end < start: + (start, end) = (end, start) + backwards = True + + # Let's see which objects are there between ``start`` and + # ``end``, and then count how many morphisms stick out + # upwards, and how many stick out downwards. + # + # For example, consider the situation: + # + # B1 C1 + # | | + # A--B--C--D + # | + # B2 + # + # Between the objects `A` and `D` there are two objects: + # `B` and `C`. Further, there are two morphisms which + # stick out upward (the ones between `B1` and `B` and + # between `C` and `C1`) and one morphism which sticks out + # downward (the one between `B and `B2`). + # + # We need this information to decide how to curve the + # arrow between `A` and `D`. First of all, since there + # are two objects between `A` and `D``, we must curve the + # arrow. Then, we will have it curve downward, because + # there is more space (less morphisms stick out downward + # than upward). + up = [] + down = [] + straight_horizontal = [] + for k in range(start + 1, end): + obj = grid[i, k] + if not obj: + continue + + for m in morphisms_str_info: + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + else: + continue + + if end_i > i: + down.append(m) + elif end_i < i: + up.append(m) + elif not morphisms_str_info[m].curving: + # This is a straight horizontal morphism, + # because it has no curving. + straight_horizontal.append(m) + + if len(up) < len(down): + # More morphisms stick out downward than upward, let's + # curve the morphism up. + if backwards: + curving = "_" + label_pos = "_" + else: + curving = "^" + label_pos = "^" + + # Assure that the straight horizontal morphisms have + # their labels on the lower side of the arrow. + for m in straight_horizontal: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if j1 < j2: + m_str_info.label_position = "_" + else: + m_str_info.label_position = "^" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + else: + # More morphisms stick out downward than upward, let's + # curve the morphism up. + if backwards: + curving = "^" + label_pos = "^" + else: + curving = "_" + label_pos = "_" + + # Assure that the straight horizontal morphisms have + # their labels on the upper side of the arrow. + for m in straight_horizontal: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if j1 < j2: + m_str_info.label_position = "^" + else: + m_str_info.label_position = "_" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + + return (curving, label_pos) + + @staticmethod + def _process_vertical_morphism(i, j, target_i, grid, morphisms_str_info, + object_coords): + """ + Produces the information required for constructing the string + representation of a vertical morphism. This function is + invoked from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + # This arrow is vertical. Check if it goes from top to + # bottom (``backwards == False``) or from bottom to top + # (``backwards == True``). + backwards = False + start = i + end = target_i + if end < start: + (start, end) = (end, start) + backwards = True + + # Let's see which objects are there between ``start`` and + # ``end``, and then count how many morphisms stick out to + # the left, and how many stick out to the right. + # + # See the corresponding comment in the previous branch of + # this if-statement for more details. + left = [] + right = [] + straight_vertical = [] + for k in range(start + 1, end): + obj = grid[k, j] + if not obj: + continue + + for m in morphisms_str_info: + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + else: + continue + + if end_j > j: + right.append(m) + elif end_j < j: + left.append(m) + elif not morphisms_str_info[m].curving: + # This is a straight vertical morphism, + # because it has no curving. + straight_vertical.append(m) + + if len(left) < len(right): + # More morphisms stick out to the left than to the + # right, let's curve the morphism to the right. + if backwards: + curving = "^" + label_pos = "^" + else: + curving = "_" + label_pos = "_" + + # Assure that the straight vertical morphisms have + # their labels on the left side of the arrow. + for m in straight_vertical: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if i1 < i2: + m_str_info.label_position = "^" + else: + m_str_info.label_position = "_" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + else: + # More morphisms stick out to the right than to the + # left, let's curve the morphism to the left. + if backwards: + curving = "_" + label_pos = "_" + else: + curving = "^" + label_pos = "^" + + # Assure that the straight vertical morphisms have + # their labels on the right side of the arrow. + for m in straight_vertical: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if i1 < i2: + m_str_info.label_position = "_" + else: + m_str_info.label_position = "^" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + + return (curving, label_pos) + + def _process_morphism(self, diagram, grid, morphism, object_coords, + morphisms, morphisms_str_info): + """ + Given the required information, produces the string + representation of ``morphism``. + """ + def repeat_string_cond(times, str_gt, str_lt): + """ + If ``times > 0``, repeats ``str_gt`` ``times`` times. + Otherwise, repeats ``str_lt`` ``-times`` times. + """ + if times > 0: + return str_gt * times + else: + return str_lt * (-times) + + def count_morphisms_undirected(A, B): + """ + Counts how many processed morphisms there are between the + two supplied objects. + """ + return len([m for m in morphisms_str_info + if {m.domain, m.codomain} == {A, B}]) + + def count_morphisms_filtered(dom, cod, curving): + """ + Counts the processed morphisms which go out of ``dom`` + into ``cod`` with curving ``curving``. + """ + return len([m for m, m_str_info in morphisms_str_info.items() + if (m.domain, m.codomain) == (dom, cod) and + (m_str_info.curving == curving)]) + + (i, j) = object_coords[morphism.domain] + (target_i, target_j) = object_coords[morphism.codomain] + + # We now need to determine the direction of + # the arrow. + delta_i = target_i - i + delta_j = target_j - j + vertical_direction = repeat_string_cond(delta_i, + "d", "u") + horizontal_direction = repeat_string_cond(delta_j, + "r", "l") + + curving = "" + label_pos = "^" + looping_start = "" + looping_end = "" + + if (delta_i == 0) and (delta_j == 0): + # This is a loop morphism. + (curving, label_pos, looping_start, + looping_end) = XypicDiagramDrawer._process_loop_morphism( + i, j, grid, morphisms_str_info, object_coords) + elif (delta_i == 0) and (abs(j - target_j) > 1): + # This is a horizontal morphism. + (curving, label_pos) = XypicDiagramDrawer._process_horizontal_morphism( + i, j, target_j, grid, morphisms_str_info, object_coords) + elif (delta_j == 0) and (abs(i - target_i) > 1): + # This is a vertical morphism. + (curving, label_pos) = XypicDiagramDrawer._process_vertical_morphism( + i, j, target_i, grid, morphisms_str_info, object_coords) + + count = count_morphisms_undirected(morphism.domain, morphism.codomain) + curving_amount = "" + if curving: + # This morphisms should be curved anyway. + curving_amount = self.default_curving_amount + count * \ + self.default_curving_step + elif count: + # There are no objects between the domain and codomain of + # the current morphism, but this is not there already are + # some morphisms with the same domain and codomain, so we + # have to curve this one. + curving = "^" + filtered_morphisms = count_morphisms_filtered( + morphism.domain, morphism.codomain, curving) + curving_amount = self.default_curving_amount + \ + filtered_morphisms * \ + self.default_curving_step + + # Let's now get the name of the morphism. + morphism_name = "" + if isinstance(morphism, IdentityMorphism): + morphism_name = "id_{%s}" + latex(grid[i, j]) + elif isinstance(morphism, CompositeMorphism): + component_names = [latex(Symbol(component.name)) for + component in morphism.components] + component_names.reverse() + morphism_name = "\\circ ".join(component_names) + elif isinstance(morphism, NamedMorphism): + morphism_name = latex(Symbol(morphism.name)) + + return ArrowStringDescription( + self.unit, curving, curving_amount, looping_start, + looping_end, horizontal_direction, vertical_direction, + label_pos, morphism_name) + + @staticmethod + def _check_free_space_horizontal(dom_i, dom_j, cod_j, grid): + """ + For a horizontal morphism, checks whether there is free space + (i.e., space not occupied by any objects) above the morphism + or below it. + """ + if dom_j < cod_j: + (start, end) = (dom_j, cod_j) + backwards = False + else: + (start, end) = (cod_j, dom_j) + backwards = True + + # Check for free space above. + if dom_i == 0: + free_up = True + else: + free_up = all(grid[dom_i - 1, j] for j in + range(start, end + 1)) + + # Check for free space below. + if dom_i == grid.height - 1: + free_down = True + else: + free_down = not any(grid[dom_i + 1, j] for j in + range(start, end + 1)) + + return (free_up, free_down, backwards) + + @staticmethod + def _check_free_space_vertical(dom_i, cod_i, dom_j, grid): + """ + For a vertical morphism, checks whether there is free space + (i.e., space not occupied by any objects) to the left of the + morphism or to the right of it. + """ + if dom_i < cod_i: + (start, end) = (dom_i, cod_i) + backwards = False + else: + (start, end) = (cod_i, dom_i) + backwards = True + + # Check if there's space to the left. + if dom_j == 0: + free_left = True + else: + free_left = not any(grid[i, dom_j - 1] for i in + range(start, end + 1)) + + if dom_j == grid.width - 1: + free_right = True + else: + free_right = not any(grid[i, dom_j + 1] for i in + range(start, end + 1)) + + return (free_left, free_right, backwards) + + @staticmethod + def _check_free_space_diagonal(dom_i, cod_i, dom_j, cod_j, grid): + """ + For a diagonal morphism, checks whether there is free space + (i.e., space not occupied by any objects) above the morphism + or below it. + """ + def abs_xrange(start, end): + if start < end: + return range(start, end + 1) + else: + return range(end, start + 1) + + if dom_i < cod_i and dom_j < cod_j: + # This morphism goes from top-left to + # bottom-right. + (start_i, start_j) = (dom_i, dom_j) + (end_i, end_j) = (cod_i, cod_j) + backwards = False + elif dom_i > cod_i and dom_j > cod_j: + # This morphism goes from bottom-right to + # top-left. + (start_i, start_j) = (cod_i, cod_j) + (end_i, end_j) = (dom_i, dom_j) + backwards = True + if dom_i < cod_i and dom_j > cod_j: + # This morphism goes from top-right to + # bottom-left. + (start_i, start_j) = (dom_i, dom_j) + (end_i, end_j) = (cod_i, cod_j) + backwards = True + elif dom_i > cod_i and dom_j < cod_j: + # This morphism goes from bottom-left to + # top-right. + (start_i, start_j) = (cod_i, cod_j) + (end_i, end_j) = (dom_i, dom_j) + backwards = False + + # This is an attempt at a fast and furious strategy to + # decide where there is free space on the two sides of + # a diagonal morphism. For a diagonal morphism + # starting at ``(start_i, start_j)`` and ending at + # ``(end_i, end_j)`` the rectangle defined by these + # two points is considered. The slope of the diagonal + # ``alpha`` is then computed. Then, for every cell + # ``(i, j)`` within the rectangle, the slope + # ``alpha1`` of the line through ``(start_i, + # start_j)`` and ``(i, j)`` is considered. If + # ``alpha1`` is between 0 and ``alpha``, the point + # ``(i, j)`` is above the diagonal, if ``alpha1`` is + # between ``alpha`` and infinity, the point is below + # the diagonal. Also note that, with some beforehand + # precautions, this trick works for both the main and + # the secondary diagonals of the rectangle. + + # I have considered the possibility to only follow the + # shorter diagonals immediately above and below the + # main (or secondary) diagonal. This, however, + # wouldn't have resulted in much performance gain or + # better detection of outer edges, because of + # relatively small sizes of diagram grids, while the + # code would have become harder to understand. + + alpha = float(end_i - start_i)/(end_j - start_j) + free_up = True + free_down = True + for i in abs_xrange(start_i, end_i): + if not free_up and not free_down: + break + + for j in abs_xrange(start_j, end_j): + if not free_up and not free_down: + break + + if (i, j) == (start_i, start_j): + continue + + if j == start_j: + alpha1 = "inf" + else: + alpha1 = float(i - start_i)/(j - start_j) + + if grid[i, j]: + if (alpha1 == "inf") or (abs(alpha1) > abs(alpha)): + free_down = False + elif abs(alpha1) < abs(alpha): + free_up = False + + return (free_up, free_down, backwards) + + def _push_labels_out(self, morphisms_str_info, grid, object_coords): + """ + For all straight morphisms which form the visual boundary of + the laid out diagram, puts their labels on their outer sides. + """ + def set_label_position(free1, free2, pos1, pos2, backwards, m_str_info): + """ + Given the information about room available to one side and + to the other side of a morphism (``free1`` and ``free2``), + sets the position of the morphism label in such a way that + it is on the freer side. This latter operations involves + choice between ``pos1`` and ``pos2``, taking ``backwards`` + in consideration. + + Thus this function will do nothing if either both ``free1 + == True`` and ``free2 == True`` or both ``free1 == False`` + and ``free2 == False``. In either case, choosing one side + over the other presents no advantage. + """ + if backwards: + (pos1, pos2) = (pos2, pos1) + + if free1 and not free2: + m_str_info.label_position = pos1 + elif free2 and not free1: + m_str_info.label_position = pos2 + + for m, m_str_info in morphisms_str_info.items(): + if m_str_info.curving or m_str_info.forced_label_position: + # This is either a curved morphism, and curved + # morphisms have other magic, or the position of this + # label has already been fixed. + continue + + if m.domain == m.codomain: + # This is a loop morphism, their labels, again have a + # different magic. + continue + + (dom_i, dom_j) = object_coords[m.domain] + (cod_i, cod_j) = object_coords[m.codomain] + + if dom_i == cod_i: + # Horizontal morphism. + (free_up, free_down, + backwards) = XypicDiagramDrawer._check_free_space_horizontal( + dom_i, dom_j, cod_j, grid) + + set_label_position(free_up, free_down, "^", "_", + backwards, m_str_info) + elif dom_j == cod_j: + # Vertical morphism. + (free_left, free_right, + backwards) = XypicDiagramDrawer._check_free_space_vertical( + dom_i, cod_i, dom_j, grid) + + set_label_position(free_left, free_right, "_", "^", + backwards, m_str_info) + else: + # A diagonal morphism. + (free_up, free_down, + backwards) = XypicDiagramDrawer._check_free_space_diagonal( + dom_i, cod_i, dom_j, cod_j, grid) + + set_label_position(free_up, free_down, "^", "_", + backwards, m_str_info) + + @staticmethod + def _morphism_sort_key(morphism, object_coords): + """ + Provides a morphism sorting key such that horizontal or + vertical morphisms between neighbouring objects come + first, then horizontal or vertical morphisms between more + far away objects, and finally, all other morphisms. + """ + (i, j) = object_coords[morphism.domain] + (target_i, target_j) = object_coords[morphism.codomain] + + if morphism.domain == morphism.codomain: + # Loop morphisms should get after diagonal morphisms + # so that the proper direction in which to curve the + # loop can be determined. + return (3, 0, default_sort_key(morphism)) + + if target_i == i: + return (1, abs(target_j - j), default_sort_key(morphism)) + + if target_j == j: + return (1, abs(target_i - i), default_sort_key(morphism)) + + # Diagonal morphism. + return (2, 0, default_sort_key(morphism)) + + @staticmethod + def _build_xypic_string(diagram, grid, morphisms, + morphisms_str_info, diagram_format): + """ + Given a collection of :class:`ArrowStringDescription` + describing the morphisms of a diagram and the object layout + information of a diagram, produces the final Xy-pic picture. + """ + # Build the mapping between objects and morphisms which have + # them as domains. + object_morphisms = {} + for obj in diagram.objects: + object_morphisms[obj] = [] + for morphism in morphisms: + object_morphisms[morphism.domain].append(morphism) + + result = "\\xymatrix%s{\n" % diagram_format + + for i in range(grid.height): + for j in range(grid.width): + obj = grid[i, j] + if obj: + result += latex(obj) + " " + + morphisms_to_draw = object_morphisms[obj] + for morphism in morphisms_to_draw: + result += str(morphisms_str_info[morphism]) + " " + + # Don't put the & after the last column. + if j < grid.width - 1: + result += "& " + + # Don't put the line break after the last row. + if i < grid.height - 1: + result += "\\\\" + result += "\n" + + result += "}\n" + + return result + + def draw(self, diagram, grid, masked=None, diagram_format=""): + r""" + Returns the Xy-pic representation of ``diagram`` laid out in + ``grid``. + + Consider the following simple triangle diagram. + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + + To draw this diagram, its objects need to be laid out with a + :class:`DiagramGrid`:: + + >>> grid = DiagramGrid(diagram) + + Finally, the drawing: + + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + The argument ``masked`` can be used to skip morphisms in the + presentation of the diagram: + + >>> print(drawer.draw(diagram, grid, masked=[g * f])) + \xymatrix{ + A \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + Finally, the ``diagram_format`` argument can be used to + specify the format string of the diagram. For example, to + increase the spacing by 1 cm, proceeding as follows: + + >>> print(drawer.draw(diagram, grid, diagram_format="@+1cm")) + \xymatrix@+1cm{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + """ + # This method works in several steps. It starts by removing + # the masked morphisms, if necessary, and then maps objects to + # their positions in the grid (coordinate tuples). Remember + # that objects are unique in ``Diagram`` and in the layout + # produced by ``DiagramGrid``, so every object is mapped to a + # single coordinate pair. + # + # The next step is the central step and is concerned with + # analysing the morphisms of the diagram and deciding how to + # draw them. For example, how to curve the arrows is decided + # at this step. The bulk of the analysis is implemented in + # ``_process_morphism``, to the result of which the + # appropriate formatters are applied. + # + # The result of the previous step is a list of + # ``ArrowStringDescription``. After the analysis and + # application of formatters, some extra logic tries to assure + # better positioning of morphism labels (for example, an + # attempt is made to avoid the situations when arrows cross + # labels). This functionality constitutes the next step and + # is implemented in ``_push_labels_out``. Note that label + # positions which have been set via a formatter are not + # affected in this step. + # + # Finally, at the closing step, the array of + # ``ArrowStringDescription`` and the layout information + # incorporated in ``DiagramGrid`` are combined to produce the + # resulting Xy-pic picture. This part of code lies in + # ``_build_xypic_string``. + + if not masked: + morphisms_props = grid.morphisms + else: + morphisms_props = {} + for m, props in grid.morphisms.items(): + if m in masked: + continue + morphisms_props[m] = props + + # Build the mapping between objects and their position in the + # grid. + object_coords = {} + for i in range(grid.height): + for j in range(grid.width): + if grid[i, j]: + object_coords[grid[i, j]] = (i, j) + + morphisms = sorted(morphisms_props, + key=lambda m: XypicDiagramDrawer._morphism_sort_key( + m, object_coords)) + + # Build the tuples defining the string representations of + # morphisms. + morphisms_str_info = {} + for morphism in morphisms: + string_description = self._process_morphism( + diagram, grid, morphism, object_coords, morphisms, + morphisms_str_info) + + if self.default_arrow_formatter: + self.default_arrow_formatter(string_description) + + for prop in morphisms_props[morphism]: + # prop is a Symbol. TODO: Find out why. + if prop.name in self.arrow_formatters: + formatter = self.arrow_formatters[prop.name] + formatter(string_description) + + morphisms_str_info[morphism] = string_description + + # Reposition the labels a bit. + self._push_labels_out(morphisms_str_info, grid, object_coords) + + return XypicDiagramDrawer._build_xypic_string( + diagram, grid, morphisms, morphisms_str_info, diagram_format) + + +def xypic_draw_diagram(diagram, masked=None, diagram_format="", + groups=None, **hints): + r""" + Provides a shortcut combining :class:`DiagramGrid` and + :class:`XypicDiagramDrawer`. Returns an Xy-pic presentation of + ``diagram``. The argument ``masked`` is a list of morphisms which + will be not be drawn. The argument ``diagram_format`` is the + format string inserted after "\xymatrix". ``groups`` should be a + set of logical groups. The ``hints`` will be passed directly to + the constructor of :class:`DiagramGrid`. + + For more information about the arguments, see the docstrings of + :class:`DiagramGrid` and ``XypicDiagramDrawer.draw``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import xypic_draw_diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + >>> print(xypic_draw_diagram(diagram)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + See Also + ======== + + XypicDiagramDrawer, DiagramGrid + """ + grid = DiagramGrid(diagram, groups, **hints) + drawer = XypicDiagramDrawer() + return drawer.draw(diagram, grid, masked, diagram_format) + + +@doctest_depends_on(exe=('latex', 'dvipng'), modules=('pyglet',)) +def preview_diagram(diagram, masked=None, diagram_format="", groups=None, + output='png', viewer=None, euler=True, **hints): + """ + Combines the functionality of ``xypic_draw_diagram`` and + ``sympy.printing.preview``. The arguments ``masked``, + ``diagram_format``, ``groups``, and ``hints`` are passed to + ``xypic_draw_diagram``, while ``output``, ``viewer, and ``euler`` + are passed to ``preview``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import preview_diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> preview_diagram(d) + + See Also + ======== + + XypicDiagramDrawer + """ + from sympy.printing import preview + latex_output = xypic_draw_diagram(diagram, masked, diagram_format, + groups, **hints) + preview(latex_output, output, viewer, euler, ("xypic",)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24dc8e42e06e9417f423e8649bf9afc908047826 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_baseclasses.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_baseclasses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b49d7d926edf43ac8d2cd217532ac7f4b2cecb4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_baseclasses.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_drawing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_drawing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eaa52ff46b6b0915a7286183318c1a2ce5abb12 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/__pycache__/test_drawing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_baseclasses.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_baseclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..cfac32229768fb5903b23b11ffb236912c0b931e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_baseclasses.py @@ -0,0 +1,209 @@ +from sympy.categories import (Object, Morphism, IdentityMorphism, + NamedMorphism, CompositeMorphism, + Diagram, Category) +from sympy.categories.baseclasses import Class +from sympy.testing.pytest import raises +from sympy.core.containers import (Dict, Tuple) +from sympy.sets import EmptySet +from sympy.sets.sets import FiniteSet + + +def test_morphisms(): + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + # Test the base morphism. + f = NamedMorphism(A, B, "f") + assert f.domain == A + assert f.codomain == B + assert f == NamedMorphism(A, B, "f") + + # Test identities. + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + assert id_A.domain == A + assert id_A.codomain == A + assert id_A == IdentityMorphism(A) + assert id_A != id_B + + # Test named morphisms. + g = NamedMorphism(B, C, "g") + assert g.name == "g" + assert g != f + assert g == NamedMorphism(B, C, "g") + assert g != NamedMorphism(B, C, "f") + + # Test composite morphisms. + assert f == CompositeMorphism(f) + + k = g.compose(f) + assert k.domain == A + assert k.codomain == C + assert k.components == Tuple(f, g) + assert g * f == k + assert CompositeMorphism(f, g) == k + + assert CompositeMorphism(g * f) == g * f + + # Test the associativity of composition. + h = NamedMorphism(C, D, "h") + + p = h * g + u = h * g * f + + assert h * k == u + assert p * f == u + assert CompositeMorphism(f, g, h) == u + + # Test flattening. + u2 = u.flatten("u") + assert isinstance(u2, NamedMorphism) + assert u2.name == "u" + assert u2.domain == A + assert u2.codomain == D + + # Test identities. + assert f * id_A == f + assert id_B * f == f + assert id_A * id_A == id_A + assert CompositeMorphism(id_A) == id_A + + # Test bad compositions. + raises(ValueError, lambda: f * g) + + raises(TypeError, lambda: f.compose(None)) + raises(TypeError, lambda: id_A.compose(None)) + raises(TypeError, lambda: f * None) + raises(TypeError, lambda: id_A * None) + + raises(TypeError, lambda: CompositeMorphism(f, None, 1)) + + raises(ValueError, lambda: NamedMorphism(A, B, "")) + raises(NotImplementedError, lambda: Morphism(A, B)) + + +def test_diagram(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + + empty = EmptySet + + # Test the addition of identities. + d1 = Diagram([f]) + + assert d1.objects == FiniteSet(A, B) + assert d1.hom(A, B) == (FiniteSet(f), empty) + assert d1.hom(A, A) == (FiniteSet(id_A), empty) + assert d1.hom(B, B) == (FiniteSet(id_B), empty) + + assert d1 == Diagram([id_A, f]) + assert d1 == Diagram([f, f]) + + # Test the addition of composites. + d2 = Diagram([f, g]) + homAC = d2.hom(A, C)[0] + + assert d2.objects == FiniteSet(A, B, C) + assert g * f in d2.premises.keys() + assert homAC == FiniteSet(g * f) + + # Test equality, inequality and hash. + d11 = Diagram([f]) + + assert d1 == d11 + assert d1 != d2 + assert hash(d1) == hash(d11) + + d11 = Diagram({f: "unique"}) + assert d1 != d11 + + # Make sure that (re-)adding composites (with new properties) + # works as expected. + d = Diagram([f, g], {g * f: "unique"}) + assert d.conclusions == Dict({g * f: FiniteSet("unique")}) + + # Check the hom-sets when there are premises and conclusions. + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + d = Diagram([f, g], [g * f]) + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + + # Check how the properties of composite morphisms are computed. + d = Diagram({f: ["unique", "isomorphism"], g: "unique"}) + assert d.premises[g * f] == FiniteSet("unique") + + # Check that conclusion morphisms with new objects are not allowed. + d = Diagram([f], [g]) + assert d.conclusions == Dict({}) + + # Test an empty diagram. + d = Diagram() + assert d.premises == Dict({}) + assert d.conclusions == Dict({}) + assert d.objects == empty + + # Check a SymPy Dict object. + d = Diagram(Dict({f: FiniteSet("unique", "isomorphism"), g: "unique"})) + assert d.premises[g * f] == FiniteSet("unique") + + # Check the addition of components of composite morphisms. + d = Diagram([g * f]) + assert f in d.premises + assert g in d.premises + + # Check subdiagrams. + d = Diagram([f, g], {g * f: "unique"}) + + d1 = Diagram([f]) + assert d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([NamedMorphism(B, A, "f'")]) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d1 = Diagram([f, g], {g * f: ["unique", "something"]}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram({f: "blooh"}) + d1 = Diagram({f: "bleeh"}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([f, g], {f: "unique", g * f: "veryunique"}) + d1 = d.subdiagram_from_objects(FiniteSet(A, B)) + assert d1 == Diagram([f], {f: "unique"}) + raises(ValueError, lambda: d.subdiagram_from_objects(FiniteSet(A, + Object("D")))) + + raises(ValueError, lambda: Diagram({IdentityMorphism(A): "unique"})) + + +def test_category(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d1 = Diagram([f, g]) + d2 = Diagram([f]) + + objects = d1.objects | d2.objects + + K = Category("K", objects, commutative_diagrams=[d1, d2]) + + assert K.name == "K" + assert K.objects == Class(objects) + assert K.commutative_diagrams == FiniteSet(d1, d2) + + raises(ValueError, lambda: Category("")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_drawing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..63a13266cd6b58f6a85aad4af0813b395acbb5e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/categories/tests/test_drawing.py @@ -0,0 +1,919 @@ +from sympy.categories.diagram_drawing import _GrowableGrid, ArrowStringDescription +from sympy.categories import (DiagramGrid, Object, NamedMorphism, + Diagram, XypicDiagramDrawer, xypic_draw_diagram) +from sympy.sets.sets import FiniteSet + + +def test_GrowableGrid(): + grid = _GrowableGrid(1, 2) + + # Check dimensions. + assert grid.width == 1 + assert grid.height == 2 + + # Check initialization of elements. + assert grid[0, 0] is None + assert grid[1, 0] is None + + # Check assignment to elements. + grid[0, 0] = 1 + grid[1, 0] = "two" + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + + # Check appending a row. + grid.append_row() + + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + # Check appending a column. + grid.append_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] is None + assert grid[2, 1] is None + + grid = _GrowableGrid(1, 2) + grid[0, 0] = 1 + grid[1, 0] = "two" + + # Check prepending a row. + grid.prepend_row() + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] == 1 + assert grid[2, 0] == "two" + + # Check prepending a column. + grid.prepend_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] is None + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] == 1 + assert grid[2, 1] == "two" + + +def test_DiagramGrid(): + # Set up some objects and morphisms. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + + # A one-morphism diagram. + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid.morphisms == {f: FiniteSet()} + + # A triangle. + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), + g * f: FiniteSet("unique")} + + # A triangle with a "loop" morphism. + l_A = NamedMorphism(A, A, "l_A") + d = Diagram([f, g, l_A]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), l_A: FiniteSet()} + + # A simple diagram. + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == D + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + assert str(grid) == '[[Object("A"), Object("B"), Object("D")], ' \ + '[None, Object("C"), None]]' + + # A chain of morphisms. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + k = NamedMorphism(D, E, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A square. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, D, "g") + h = NamedMorphism(A, C, "h") + k = NamedMorphism(C, D, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A strange diagram which resulted from a typo when creating a + # test for five lemma, but which allowed to stop one extra problem + # in the algorithm. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + # These 4 morphisms should be between primed objects. + j = NamedMorphism(A, B, "j") + k = NamedMorphism(B, C, "k") + l = NamedMorphism(C, D, "l") + m = NamedMorphism(D, E, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 4 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[2, 0] == C_ + assert grid[2, 1] == D + assert grid[2, 2] == D_ + assert grid[3, 0] is None + assert grid[3, 1] == E + assert grid[3, 2] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A cube. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A5 + assert grid[0, 2] == A6 + assert grid[0, 3] is None + assert grid[1, 0] is None + assert grid[1, 1] == A1 + assert grid[1, 2] == A2 + assert grid[1, 3] is None + assert grid[2, 0] == A7 + assert grid[2, 1] == A3 + assert grid[2, 2] == A4 + assert grid[2, 3] == A8 + + morphisms = {} + for m in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A line diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # Test the transposed version. + grid = DiagramGrid(d, layout="sequential", transpose=True) + + assert grid.width == 1 + assert grid.height == 5 + assert grid[0, 0] == A + assert grid[1, 0] == B + assert grid[2, 0] == C + assert grid[3, 0] == D + assert grid[4, 0] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # A pullback. + m1 = NamedMorphism(A, B, "m1") + m2 = NamedMorphism(A, C, "m2") + s1 = NamedMorphism(B, D, "s1") + s2 = NamedMorphism(C, D, "s2") + f1 = NamedMorphism(E, B, "f1") + f2 = NamedMorphism(E, C, "f2") + g = NamedMorphism(E, A, "g") + + d = Diagram([m1, m2, s1, s2, f1, f2], {g: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == E + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid[1, 2] is None + + morphisms = {g: FiniteSet("unique")} + for m in [m1, m2, s1, s2, f1, f2]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the pullback with sequential layout, just for stress + # testing. + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == D + assert grid[0, 1] == B + assert grid[0, 2] == A + assert grid[0, 3] == C + assert grid[0, 4] == E + assert grid.morphisms == morphisms + + # Test a pullback with object grouping. + grid = DiagramGrid(d, groups=FiniteSet(E, FiniteSet(A, B, C, D))) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == E + assert grid[0, 1] == A + assert grid[0, 2] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid.morphisms == morphisms + + # Five lemma, actually. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + j = NamedMorphism(A_, B_, "j") + k = NamedMorphism(B_, C_, "k") + l = NamedMorphism(C_, D_, "l") + m = NamedMorphism(D_, E_, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[0, 3] is None + assert grid[0, 4] is None + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[1, 3] == C_ + assert grid[1, 4] is None + assert grid[2, 0] == D + assert grid[2, 1] == E + assert grid[2, 2] is None + assert grid[2, 3] == D_ + assert grid[2, 4] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping. + grid = DiagramGrid(d, FiniteSet( + FiniteSet(A, B, C, D, E), FiniteSet(A_, B_, C_, D_, E_))) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping, but mixing containers + # to represent groups. + grid = DiagramGrid(d, [(A, B, C, D, E), {A_, B_, C_, D_, E_}]) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping and hints. + grid = DiagramGrid(d, { + FiniteSet(A, B, C, D, E): {"layout": "sequential", + "transpose": True}, + FiniteSet(A_, B_, C_, D_, E_): {"layout": "sequential", + "transpose": True}}, + transpose=True) + + assert grid.width == 5 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid[1, 0] == A_ + assert grid[1, 1] == B_ + assert grid[1, 2] == C_ + assert grid[1, 3] == D_ + assert grid[1, 4] == E_ + assert grid.morphisms == morphisms + + # A two-triangle disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + f_ = NamedMorphism(A_, B_, "f") + g_ = NamedMorphism(B_, C_, "g") + d = Diagram([f, g, f_, g_], {g * f: "unique", g_ * f_: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == A_ + assert grid[0, 3] == B_ + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid[1, 2] == C_ + assert grid[1, 3] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), f_: FiniteSet(), + g_: FiniteSet(), g * f: FiniteSet("unique"), + g_ * f_: FiniteSet("unique")} + + # A two-morphism disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(C, D, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet()} + + # Test a one-object diagram. + f = NamedMorphism(A, A, "f") + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 1 + assert grid.height == 1 + assert grid[0, 0] == A + + # Test a two-object disconnected diagram. + g = NamedMorphism(B, B, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + + +def test_DiagramGrid_pseudopod(): + # Test a diagram in which even growing a pseudopod does not + # eventually help. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + F = Object("F") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f1 = NamedMorphism(A, B, "f1") + f2 = NamedMorphism(A, C, "f2") + f3 = NamedMorphism(A, D, "f3") + f4 = NamedMorphism(A, E, "f4") + f5 = NamedMorphism(A, A_, "f5") + f6 = NamedMorphism(A, B_, "f6") + f7 = NamedMorphism(A, C_, "f7") + f8 = NamedMorphism(A, D_, "f8") + f9 = NamedMorphism(A, E_, "f9") + f10 = NamedMorphism(A, F, "f10") + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] == E + assert grid[0, 1] == C + assert grid[0, 2] == C_ + assert grid[0, 3] == E_ + assert grid[0, 4] == F + assert grid[1, 0] == D + assert grid[1, 1] == A + assert grid[1, 2] == A_ + assert grid[1, 3] is None + assert grid[1, 4] is None + assert grid[2, 0] == D_ + assert grid[2, 1] == B + assert grid[2, 2] == B_ + assert grid[2, 3] is None + assert grid[2, 4] is None + + morphisms = {} + for f in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]: + morphisms[f] = FiniteSet() + assert grid.morphisms == morphisms + + +def test_ArrowStringDescription(): + astr = ArrowStringDescription("cm", "", None, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "^", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar@/^12cm/[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@(r,u)@{-->}[dr]_{f}" + + astr = ArrowStringDescription("cm", "_", 12, "", "", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@/_12cm/@{-->}[dr]_{f}" + + +def test_XypicDiagramDrawer_line(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{f} & B \\ar[r]^{g} & C \\ar[r]^{h} & D \\ar[r]^{i} & E \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, layout="sequential", transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\\\\n" \ + "B \\ar[d]^{g} \\\\\n" \ + "C \\ar[d]^{h} \\\\\n" \ + "D \\ar[d]^{i} \\\\\n" \ + "E \n" \ + "}\n" + + +def test_XypicDiagramDrawer_triangle(): + # A triangle diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]_{g\\circ f} \\ar[r]^{f} & B \\ar[ld]^{g} \\\\\n" \ + "C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram, with a masked morphism. + assert drawer.draw(d, grid, masked=[g]) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B & \n" \ + "}\n" + + # The same diagram with a formatter for "unique". + def formatter(astr): + astr.label = "\\exists !" + astr.label + astr.arrow_style = "{-->}" + + drawer.arrow_formatters["unique"] = formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^{\\exists !g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram with a default formatter. + def default_formatter(astr): + astr.label_displacement = "(0.45)" + + drawer.default_arrow_formatter = default_formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^(0.45){\\exists !g\\circ f} \\ar[d]_(0.45){f} & C \\\\\n" \ + "B \\ar[ru]_(0.45){g} & \n" \ + "}\n" + + # A triangle diagram with a lot of morphisms between the same + # objects. + f1 = NamedMorphism(B, A, "f1") + f2 = NamedMorphism(A, B, "f2") + g1 = NamedMorphism(C, B, "g1") + g2 = NamedMorphism(B, C, "g2") + d = Diagram([f, f1, f2, g, g1, g2], {f1 * g1: "unique", g2 * f2: "unique"}) + + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid, masked=[f1*g1*g2*f2, g2*f2*f1*g1]) == \ + "\\xymatrix{\n" \ + "A \\ar[r]^{g_{2}\\circ f_{2}} \\ar[d]_{f} \\ar@/^3mm/[d]^{f_{2}} " \ + "& C \\ar@/^3mm/[l]^{f_{1}\\circ g_{1}} \\ar@/^3mm/[ld]^{g_{1}} \\\\\n" \ + "B \\ar@/^3mm/[u]^{f_{1}} \\ar[ru]_{g} \\ar@/^3mm/[ru]^{g_{2}} & \n" \ + "}\n" + + +def test_XypicDiagramDrawer_cube(): + # A cube diagram. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& A_{5} \\ar[r]^{f_{5}} \\ar[ldd]_{f_{6}} & A_{6} \\ar[rdd]^{f_{7}} " \ + "& \\\\\n" \ + "& A_{1} \\ar[r]^{f_{1}} \\ar[d]^{f_{2}} \\ar[u]^{f_{9}} & A_{2} " \ + "\\ar[d]^{f_{3}} \\ar[u]_{f_{10}} & \\\\\n" \ + "A_{7} \\ar@/_3mm/[rrr]_{f_{8}} & A_{3} \\ar[r]^{f_{3}} \\ar[l]_{f_{11}} " \ + "& A_{4} \\ar[r]^{f_{11}} & A_{8} \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& & A_{7} \\ar@/^3mm/[ddd]^{f_{8}} \\\\\n" \ + "A_{5} \\ar[d]_{f_{5}} \\ar[rru]^{f_{6}} & A_{1} \\ar[d]^{f_{1}} " \ + "\\ar[r]^{f_{2}} \\ar[l]^{f_{9}} & A_{3} \\ar[d]_{f_{3}} " \ + "\\ar[u]^{f_{11}} \\\\\n" \ + "A_{6} \\ar[rrd]_{f_{7}} & A_{2} \\ar[r]^{f_{3}} \\ar[l]^{f_{10}} " \ + "& A_{4} \\ar[d]_{f_{11}} \\\\\n" \ + "& & A_{8} \n" \ + "}\n" + + +def test_XypicDiagramDrawer_curved_and_loops(): + # A simple diagram, with a curved arrow. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} & B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_3mm/[ll]_{h} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # The same diagram, larger and rotated. + assert drawer.draw(d, grid, diagram_format="@+1cm@dr") == \ + "\\xymatrix@+1cm@dr{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # A simple diagram with three curved arrows. + h1 = NamedMorphism(D, A, "h1") + h2 = NamedMorphism(A, D, "h2") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k, h1, h2]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} & \n" \ + "}\n" + + # The same diagram, with "loop" morphisms. + l_A = NamedMorphism(A, A, "l_A") + l_D = NamedMorphism(D, D, "l_D") + l_C = NamedMorphism(C, C, "l_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "& B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_7mm/[ll]_{h} " \ + "\\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} & \n" \ + "}\n" + + # The same diagram with "loop" morphisms, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object. + l_A_ = NamedMorphism(A, A, "n_A") + l_D_ = NamedMorphism(D, D, "n_D") + l_C_ = NamedMorphism(C, C, "n_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C, l_A_, l_D_, l_C_]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "\\ar@/^3mm/@(l,d)[]^{n_{A}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} " \ + "\\ar@/^3mm/@(d,r)[]^{n_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} " \ + "\\ar@/^3mm/@(u,l)[]^{n_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} \\ar@/^3mm/@(d,r)[]^{n_{D}} & \n" \ + "}\n" + + +def test_xypic_draw_diagram(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == xypic_draw_diagram(d, layout="sequential") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_convolutions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_convolutions.py new file mode 100644 index 0000000000000000000000000000000000000000..96e5fc801ac63f95c01eb18d48143ae3a1ac6222 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_convolutions.py @@ -0,0 +1,392 @@ +from sympy.core.numbers import (E, Rational, pi) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core import S, symbols, I +from sympy.discrete.convolutions import ( + convolution, convolution_fft, convolution_ntt, convolution_fwht, + convolution_subset, covering_product, intersecting_product, + convolution_int) +from sympy.testing.pytest import raises +from sympy.abc import x, y + +def test_convolution(): + # fft + a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)] + b = [9, 5, 5, 4, 3, 2] + c = [3, 5, 3, 7, 8] + d = [1422, 6572, 3213, 5552] + e = [-1, Rational(5, 3), Rational(7, 5)] + + assert convolution(a, b) == convolution_fft(a, b) + assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9) + assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7) + assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3) + + # prime moduli of the form (m*2**k + 1), sequence length + # should be a divisor of 2**k + p = 7*17*2**23 + 1 + q = 19*2**10 + 1 + + # ntt + assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q) + assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p) + assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p) + raises(TypeError, lambda: convolution(b, d, dps=5, prime=q)) + raises(TypeError, lambda: convolution(b, d, dps=6, prime=q)) + + # fwht + assert convolution(a, b, dyadic=True) == convolution_fwht(a, b) + assert convolution(a, b, dyadic=False) == convolution(a, b) + raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True)) + raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True)) + raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True)) + raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True)) + + # subset + assert convolution(a, b, subset=True) == convolution_subset(a, b) == \ + convolution(a, b, subset=True, dyadic=False) == \ + convolution(a, b, subset=True) + assert convolution(a, b, subset=False) == convolution(a, b) + raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True)) + raises(TypeError, lambda: convolution(c, d, subset=True, dps=6)) + raises(TypeError, lambda: convolution(a, c, subset=True, prime=q)) + + # integer + assert convolution([0], [0]) == convolution_int([0], [0]) + assert convolution(b, c) == convolution_int(b, c) + + # rational + assert convolution([Rational(1,2)], [Rational(1,2)]) == [Rational(1, 4)] + assert convolution(b, e) == [-9, 10, Rational(239, 15), Rational(34, 3), + Rational(32, 3), Rational(43, 5), Rational(113, 15), + Rational(14, 5)] + + +def test_cyclic_convolution(): + # fft + a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)] + b = [9, 5, 5, 4, 3, 2] + + assert convolution([1, 2, 3], [4, 5, 6], cycle=0) == \ + convolution([1, 2, 3], [4, 5, 6], cycle=5) == \ + convolution([1, 2, 3], [4, 5, 6]) + + assert convolution([1, 2, 3], [4, 5, 6], cycle=3) == [31, 31, 28] + + a = [Rational(1, 3), Rational(7, 3), Rational(5, 9), Rational(2, 7), Rational(5, 8)] + b = [Rational(3, 5), Rational(4, 7), Rational(7, 8), Rational(8, 9)] + + assert convolution(a, b, cycle=0) == \ + convolution(a, b, cycle=len(a) + len(b) - 1) + + assert convolution(a, b, cycle=4) == [Rational(87277, 26460), Rational(30521, 11340), + Rational(11125, 4032), Rational(3653, 1080)] + + assert convolution(a, b, cycle=6) == [Rational(20177, 20160), Rational(676, 315), Rational(47, 24), + Rational(3053, 1080), Rational(16397, 5292), Rational(2497, 2268)] + + assert convolution(a, b, cycle=9) == \ + convolution(a, b, cycle=0) + [S.Zero] + + # ntt + a = [2313, 5323532, S(3232), 42142, 42242421] + b = [S(33456), 56757, 45754, 432423] + + assert convolution(a, b, prime=19*2**10 + 1, cycle=0) == \ + convolution(a, b, prime=19*2**10 + 1, cycle=8) == \ + convolution(a, b, prime=19*2**10 + 1) + + assert convolution(a, b, prime=19*2**10 + 1, cycle=5) == [96, 17146, 2664, + 15534, 3517] + + assert convolution(a, b, prime=19*2**10 + 1, cycle=7) == [4643, 3458, 1260, + 15534, 3517, 16314, 13688] + + assert convolution(a, b, prime=19*2**10 + 1, cycle=9) == \ + convolution(a, b, prime=19*2**10 + 1) + [0] + + # fwht + u, v, w, x, y = symbols('u v w x y') + p, q, r, s, t = symbols('p q r s t') + c = [u, v, w, x, y] + d = [p, q, r, s, t] + + assert convolution(a, b, dyadic=True, cycle=3) == \ + [2499522285783, 19861417974796, 4702176579021] + + assert convolution(a, b, dyadic=True, cycle=5) == [2718149225143, + 2114320852171, 20571217906407, 246166418903, 1413262436976] + + assert convolution(c, d, dyadic=True, cycle=4) == \ + [p*u + p*y + q*v + r*w + s*x + t*u + t*y, + p*v + q*u + q*y + r*x + s*w + t*v, + p*w + q*x + r*u + r*y + s*v + t*w, + p*x + q*w + r*v + s*u + s*y + t*x] + + assert convolution(c, d, dyadic=True, cycle=6) == \ + [p*u + q*v + r*w + r*y + s*x + t*w + t*y, + p*v + q*u + r*x + s*w + s*y + t*x, + p*w + q*x + r*u + s*v, + p*x + q*w + r*v + s*u, + p*y + t*u, + q*y + t*v] + + # subset + assert convolution(a, b, subset=True, cycle=7) == [18266671799811, + 178235365533, 213958794, 246166418903, 1413262436976, + 2397553088697, 1932759730434] + + assert convolution(a[1:], b, subset=True, cycle=4) == \ + [178104086592, 302255835516, 244982785880, 3717819845434] + + assert convolution(a, b[:-1], subset=True, cycle=6) == [1932837114162, + 178235365533, 213958794, 245166224504, 1413262436976, 2397553088697] + + assert convolution(c, d, subset=True, cycle=3) == \ + [p*u + p*x + q*w + r*v + r*y + s*u + t*w, + p*v + p*y + q*u + s*y + t*u + t*x, + p*w + q*y + r*u + t*v] + + assert convolution(c, d, subset=True, cycle=5) == \ + [p*u + q*y + t*v, + p*v + q*u + r*y + t*w, + p*w + r*u + s*y + t*x, + p*x + q*w + r*v + s*u, + p*y + t*u] + + raises(ValueError, lambda: convolution([1, 2, 3], [4, 5, 6], cycle=-1)) + + +def test_convolution_fft(): + assert all(convolution_fft([], x, dps=y) == [] for x in ([], [1]) for y in (None, 3)) + assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18] + assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7] + assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21] + + assert convolution_fft([1 + 2*I], [2 + 3*I]) == [-4 + 7*I] + assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + 3*I/5], [Rational(2, 5) + 4*I/7]) == \ + [Rational(-26, 35) + I*48/35, Rational(-38, 35) + I*116/35, Rational(58, 35) + I*542/175] + + assert convolution_fft([Rational(3, 4), Rational(5, 6)], [Rational(7, 8), Rational(1, 3), Rational(2, 5)]) == \ + [Rational(21, 32), Rational(47, 48), Rational(26, 45), Rational(1, 3)] + + assert convolution_fft([Rational(1, 9), Rational(2, 3), Rational(3, 5)], [Rational(2, 5), Rational(3, 7), Rational(4, 9)]) == \ + [Rational(2, 45), Rational(11, 35), Rational(8152, 14175), Rational(523, 945), Rational(4, 15)] + + assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == \ + [sqrt(3)*pi, 1 + sqrt(3)*E, E/pi + pi*exp(-1) + sqrt(6), + sqrt(2)/pi + 1, sqrt(2)*exp(-1)] + + assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == \ + [12350041, 190918524, 374911166, 2362431729] + + assert convolution_fft([312313, 31278232], [32139631, 319631]) == \ + [10037624576503, 1005370659728895, 9997492572392] + + raises(TypeError, lambda: convolution_fft(x, y)) + raises(ValueError, lambda: convolution_fft([x, y], [y, x])) + + +def test_convolution_ntt(): + # prime moduli of the form (m*2**k + 1), sequence length + # should be a divisor of 2**k + p = 7*17*2**23 + 1 + q = 19*2**10 + 1 + r = 2*500000003 + 1 # only for sequences of length 1 or 2 + # s = 2*3*5*7 # composite modulus + + assert all(convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r)) + assert convolution_ntt([2], [3], r) == [6] + assert convolution_ntt([2, 3], [4], r) == [8, 12] + + assert convolution_ntt([32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619, + 459741727, 79180879, 831885249, 381344700, 369993322] + assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \ + [8158, 3065, 3682, 7090, 1239, 2232, 3744] + + assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \ + convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q) + assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \ + convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q) + + raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r)) + raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q)) + raises(TypeError, lambda: convolution_ntt(x, y, p)) + + +def test_convolution_fwht(): + assert convolution_fwht([], []) == [] + assert convolution_fwht([], [1]) == [] + assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27] + + assert convolution_fwht([Rational(5, 7), Rational(6, 8), Rational(7, 3)], [2, 4, Rational(6, 7)]) == \ + [Rational(45, 7), Rational(61, 14), Rational(776, 147), Rational(419, 42)] + + a = [1, Rational(5, 3), sqrt(3), Rational(7, 5), 4 + 5*I] + b = [94, 51, 53, 45, 31, 27, 13] + c = [3 + 4*I, 5 + 7*I, 3, Rational(7, 6), 8] + + assert convolution_fwht(a, b) == [53*sqrt(3) + 366 + 155*I, + 45*sqrt(3) + Rational(5848, 15) + 135*I, + 94*sqrt(3) + Rational(1257, 5) + 65*I, + 51*sqrt(3) + Rational(3974, 15), + 13*sqrt(3) + 452 + 470*I, + Rational(4513, 15) + 255*I, + 31*sqrt(3) + Rational(1314, 5) + 265*I, + 27*sqrt(3) + Rational(3676, 15) + 225*I] + + assert convolution_fwht(b, c) == [Rational(1993, 2) + 733*I, Rational(6215, 6) + 862*I, + Rational(1659, 2) + 527*I, Rational(1988, 3) + 551*I, 1019 + 313*I, Rational(3955, 6) + 325*I, + Rational(1175, 2) + 52*I, Rational(3253, 6) + 91*I] + + assert convolution_fwht(a[3:], c) == [Rational(-54, 5) + I*293/5, -1 + I*204/5, + Rational(133, 15) + I*35/6, Rational(409, 30) + 15*I, Rational(56, 5), 32 + 40*I, 0, 0] + + u, v, w, x, y, z = symbols('u v w x y z') + + assert convolution_fwht([u, v], [x, y]) == [u*x + v*y, u*y + v*x] + + assert convolution_fwht([u, v, w], [x, y]) == \ + [u*x + v*y, u*y + v*x, w*x, w*y] + + assert convolution_fwht([u, v, w], [x, y, z]) == \ + [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y] + + raises(TypeError, lambda: convolution_fwht(x, y)) + raises(TypeError, lambda: convolution_fwht(x*y, u + v)) + + +def test_convolution_subset(): + assert convolution_subset([], []) == [] + assert convolution_subset([], [Rational(1, 3)]) == [] + assert convolution_subset([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7] + + a = [1, Rational(5, 3), sqrt(3), 4 + 5*I] + b = [64, 71, 55, 47, 33, 29, 15] + c = [3 + I*2/3, 5 + 7*I, 7, Rational(7, 5), 9] + + assert convolution_subset(a, b) == [64, Rational(533, 3), 55 + 64*sqrt(3), + 71*sqrt(3) + Rational(1184, 3) + 320*I, 33, 84, + 15 + 33*sqrt(3), 29*sqrt(3) + 157 + 165*I] + + assert convolution_subset(b, c) == [192 + I*128/3, 533 + I*1486/3, + 613 + I*110/3, Rational(5013, 5) + I*1249/3, + 675 + 22*I, 891 + I*751/3, + 771 + 10*I, Rational(3736, 5) + 105*I] + + assert convolution_subset(a, c) == convolution_subset(c, a) + assert convolution_subset(a[:2], b) == \ + [64, Rational(533, 3), 55, Rational(416, 3), 33, 84, 15, 25] + + assert convolution_subset(a[:2], c) == \ + [3 + I*2/3, 10 + I*73/9, 7, Rational(196, 15), 9, 15, 0, 0] + + u, v, w, x, y, z = symbols('u v w x y z') + + assert convolution_subset([u, v, w], [x, y]) == [u*x, u*y + v*x, w*x, w*y] + assert convolution_subset([u, v, w, x], [y, z]) == \ + [u*y, u*z + v*y, w*y, w*z + x*y] + + assert convolution_subset([u, v], [x, y, z]) == \ + convolution_subset([x, y, z], [u, v]) + + raises(TypeError, lambda: convolution_subset(x, z)) + raises(TypeError, lambda: convolution_subset(Rational(7, 3), u)) + + +def test_covering_product(): + assert covering_product([], []) == [] + assert covering_product([], [Rational(1, 3)]) == [] + assert covering_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7] + + a = [1, Rational(5, 8), sqrt(7), 4 + 9*I] + b = [66, 81, 95, 49, 37, 89, 17] + c = [3 + I*2/3, 51 + 72*I, 7, Rational(7, 15), 91] + + assert covering_product(a, b) == [66, Rational(1383, 8), 95 + 161*sqrt(7), + 130*sqrt(7) + 1303 + 2619*I, 37, + Rational(671, 4), 17 + 54*sqrt(7), + 89*sqrt(7) + Rational(4661, 8) + 1287*I] + + assert covering_product(b, c) == [198 + 44*I, 7740 + 10638*I, + 1412 + I*190/3, Rational(42684, 5) + I*31202/3, + 9484 + I*74/3, 22163 + I*27394/3, + 10621 + I*34/3, Rational(90236, 15) + 1224*I] + + assert covering_product(a, c) == covering_product(c, a) + assert covering_product(b, c[:-1]) == [198 + 44*I, 7740 + 10638*I, + 1412 + I*190/3, Rational(42684, 5) + I*31202/3, + 111 + I*74/3, 6693 + I*27394/3, + 429 + I*34/3, Rational(23351, 15) + 1224*I] + + assert covering_product(a, c[:-1]) == [3 + I*2/3, + Rational(339, 4) + I*1409/12, 7 + 10*sqrt(7) + 2*sqrt(7)*I/3, + -403 + 772*sqrt(7)/15 + 72*sqrt(7)*I + I*12658/15] + + u, v, w, x, y, z = symbols('u v w x y z') + + assert covering_product([u, v, w], [x, y]) == \ + [u*x, u*y + v*x + v*y, w*x, w*y] + + assert covering_product([u, v, w, x], [y, z]) == \ + [u*y, u*z + v*y + v*z, w*y, w*z + x*y + x*z] + + assert covering_product([u, v], [x, y, z]) == \ + covering_product([x, y, z], [u, v]) + + raises(TypeError, lambda: covering_product(x, z)) + raises(TypeError, lambda: covering_product(Rational(7, 3), u)) + + +def test_intersecting_product(): + assert intersecting_product([], []) == [] + assert intersecting_product([], [Rational(1, 3)]) == [] + assert intersecting_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7] + + a = [1, sqrt(5), Rational(3, 8) + 5*I, 4 + 7*I] + b = [67, 51, 65, 48, 36, 79, 27] + c = [3 + I*2/5, 5 + 9*I, 7, Rational(7, 19), 13] + + assert intersecting_product(a, b) == [195*sqrt(5) + Rational(6979, 8) + 1886*I, + 178*sqrt(5) + 520 + 910*I, Rational(841, 2) + 1344*I, + 192 + 336*I, 0, 0, 0, 0] + + assert intersecting_product(b, c) == [Rational(128553, 19) + I*9521/5, + Rational(17820, 19) + 1602*I, Rational(19264, 19), Rational(336, 19), 1846, 0, 0, 0] + + assert intersecting_product(a, c) == intersecting_product(c, a) + assert intersecting_product(b[1:], c[:-1]) == [Rational(64788, 19) + I*8622/5, + Rational(12804, 19) + 1152*I, Rational(11508, 19), Rational(252, 19), 0, 0, 0, 0] + + assert intersecting_product(a, c[:-2]) == \ + [Rational(-99, 5) + 10*sqrt(5) + 2*sqrt(5)*I/5 + I*3021/40, + -43 + 5*sqrt(5) + 9*sqrt(5)*I + 71*I, Rational(245, 8) + 84*I, 0] + + u, v, w, x, y, z = symbols('u v w x y z') + + assert intersecting_product([u, v, w], [x, y]) == \ + [u*x + u*y + v*x + w*x + w*y, v*y, 0, 0] + + assert intersecting_product([u, v, w, x], [y, z]) == \ + [u*y + u*z + v*y + w*y + w*z + x*y, v*z + x*z, 0, 0] + + assert intersecting_product([u, v], [x, y, z]) == \ + intersecting_product([x, y, z], [u, v]) + + raises(TypeError, lambda: intersecting_product(x, z)) + raises(TypeError, lambda: intersecting_product(u, Rational(8, 3))) + + +def test_convolution_int(): + assert convolution_int([1], [1]) == [1] + assert convolution_int([1, 1], [0]) == [0] + assert convolution_int([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18] + assert convolution_int([1], [5, 6, 7]) == [5, 6, 7] + assert convolution_int([1, 3], [5, 6, 7]) == [5, 21, 25, 21] + assert convolution_int([10, -5, 1, 3], [-5, 6, 7]) == [-50, 85, 35, -44, 25, 21] + assert convolution_int([0, 1, 0, -1], [1, 0, -1, 0]) == [0, 1, 0, -2, 0, 1] + assert convolution_int( + [-341, -5, 1, 3, -71, -99, 43, 87], + [5, 6, 7, 12, 345, 21, -78, -7, -89] + ) == [-1705, -2071, -2412, -4106, -118035, -9774, 25998, 2981, 5509, + -34317, 19228, 38870, 5485, 1724, -4436, -7743] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_transforms.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..385514be4cdec2f19cf3a750bdbe0f4f6e21cc6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/discrete/tests/test_transforms.py @@ -0,0 +1,154 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core import S, Symbol, symbols, I, Rational +from sympy.discrete import (fft, ifft, ntt, intt, fwht, ifwht, + mobius_transform, inverse_mobius_transform) +from sympy.testing.pytest import raises + + +def test_fft_ifft(): + assert all(tf(ls) == ls for tf in (fft, ifft) + for ls in ([], [Rational(5, 3)])) + + ls = list(range(6)) + fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I, + -4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3, + -4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I, + 2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2] + + assert fft(ls) == fls + assert ifft(fls) == ls + [S.Zero]*2 + + ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] + ifls = [Rational(9, 4) + 3*I, I*Rational(-7, 4), Rational(3, 4) + I, -2 - I/4] + + assert ifft(ls) == ifls + assert fft(ifls) == ls + [S.Zero] + + x = Symbol('x', real=True) + raises(TypeError, lambda: fft(x)) + raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3])) + + +def test_ntt_intt(): + # prime moduli of the form (m*2**k + 1), sequence length + # should be a divisor of 2**k + p = 7*17*2**23 + 1 + q = 2*500000003 + 1 # only for sequences of length 1 or 2 + r = 2*3*5*7 # composite modulus + + assert all(tf(ls, p) == ls for tf in (ntt, intt) + for ls in ([], [5])) + + ls = list(range(6)) + nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224, + 259751156, 12232587] + + assert ntt(ls, p) == nls + assert intt(nls, p) == ls + [0]*2 + + ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] + x = Symbol('x', integer=True) + + raises(TypeError, lambda: ntt(x, p)) + raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p)) + raises(ValueError, lambda: intt(ls, p)) + raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p)) + raises(ValueError, lambda: ntt([3, 5, 6], q)) + raises(ValueError, lambda: ntt([4, 5, 7], r)) + raises(ValueError, lambda: ntt([1.0, 2.0, 3.0], p)) + + +def test_fwht_ifwht(): + assert all(tf(ls) == ls for tf in (fwht, ifwht) \ + for ls in ([], [Rational(7, 4)])) + + ls = [213, 321, 43235, 5325, 312, 53] + fls = [49459, 38061, -47661, -37759, 48729, 37543, -48391, -38277] + + assert fwht(ls) == fls + assert ifwht(fls) == ls + [S.Zero]*2 + + ls = [S.Half + 2*I, Rational(3, 7) + 4*I, Rational(5, 6) + 6*I, Rational(7, 3), Rational(9, 4)] + ifls = [Rational(533, 672) + I*3/2, Rational(23, 224) + I/2, Rational(1, 672), Rational(107, 224) - I, + Rational(155, 672) + I*3/2, Rational(-103, 224) + I/2, Rational(-377, 672), Rational(-19, 224) - I] + + assert ifwht(ls) == ifls + assert fwht(ifls) == ls + [S.Zero]*3 + + x, y = symbols('x y') + + raises(TypeError, lambda: fwht(x)) + + ls = [x, 2*x, 3*x**2, 4*x**3] + ifls = [x**3 + 3*x**2/4 + x*Rational(3, 4), + -x**3 + 3*x**2/4 - x/4, + -x**3 - 3*x**2/4 + x*Rational(3, 4), + x**3 - 3*x**2/4 - x/4] + + assert ifwht(ls) == ifls + assert fwht(ifls) == ls + + ls = [x, y, x**2, y**2, x*y] + fls = [x**2 + x*y + x + y**2 + y, + x**2 + x*y + x - y**2 - y, + -x**2 + x*y + x - y**2 + y, + -x**2 + x*y + x + y**2 - y, + x**2 - x*y + x + y**2 + y, + x**2 - x*y + x - y**2 - y, + -x**2 - x*y + x - y**2 + y, + -x**2 - x*y + x + y**2 - y] + + assert fwht(ls) == fls + assert ifwht(fls) == ls + [S.Zero]*3 + + ls = list(range(6)) + + assert fwht(ls) == [x*8 for x in ifwht(ls)] + + +def test_mobius_transform(): + assert all(tf(ls, subset=subset) == ls + for ls in ([], [Rational(7, 4)]) for subset in (True, False) + for tf in (mobius_transform, inverse_mobius_transform)) + + w, x, y, z = symbols('w x y z') + + assert mobius_transform([x, y]) == [x, x + y] + assert inverse_mobius_transform([x, x + y]) == [x, y] + assert mobius_transform([x, y], subset=False) == [x + y, y] + assert inverse_mobius_transform([x + y, y], subset=False) == [x, y] + + assert mobius_transform([w, x, y, z]) == [w, w + x, w + y, w + x + y + z] + assert inverse_mobius_transform([w, w + x, w + y, w + x + y + z]) == \ + [w, x, y, z] + assert mobius_transform([w, x, y, z], subset=False) == \ + [w + x + y + z, x + z, y + z, z] + assert inverse_mobius_transform([w + x + y + z, x + z, y + z, z], subset=False) == \ + [w, x, y, z] + + ls = [Rational(2, 3), Rational(6, 7), Rational(5, 8), 9, Rational(5, 3) + 7*I] + mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168), + Rational(7, 3) + 7*I, Rational(67, 21) + 7*I, Rational(71, 24) + 7*I, + Rational(2153, 168) + 7*I] + + assert mobius_transform(ls) == mls + assert inverse_mobius_transform(mls) == ls + [S.Zero]*3 + + mls = [Rational(2153, 168) + 7*I, Rational(69, 7), Rational(77, 8), 9, Rational(5, 3) + 7*I, 0, 0, 0] + + assert mobius_transform(ls, subset=False) == mls + assert inverse_mobius_transform(mls, subset=False) == ls + [S.Zero]*3 + + ls = ls[:-1] + mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168)] + + assert mobius_transform(ls) == mls + assert inverse_mobius_transform(mls) == ls + + mls = [Rational(1873, 168), Rational(69, 7), Rational(77, 8), 9] + + assert mobius_transform(ls, subset=False) == mls + assert inverse_mobius_transform(mls, subset=False) == ls + + raises(TypeError, lambda: mobius_transform(x, subset=True)) + raises(TypeError, lambda: inverse_mobius_transform(y, subset=False)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb54b72011173f3ccec764f41ab02794c098f36 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/factorials.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/factorials.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6172c570ab3bd88a17508c3a8016aa679d87d77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/__pycache__/factorials.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8da0fa718a4dba49c95eb2c65e4274582904d7cf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/test_comb_factorials.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/test_comb_factorials.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a51bb173c39598193ed87b888686de26032b7f28 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/__pycache__/test_comb_factorials.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3986c56736cccec0b3370007e047a1f38f06d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py @@ -0,0 +1,653 @@ +from sympy.concrete.products import Product +from sympy.core.function import expand_func +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core import EulerGamma +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import (ff, rf, binomial, factorial, factorial2) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import (gamma, polygamma) +from sympy.polys.polytools import Poly +from sympy.series.order import O +from sympy.simplify.simplify import simplify +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.functions.combinatorial.factorials import subfactorial +from sympy.functions.special.gamma_functions import uppergamma +from sympy.testing.pytest import XFAIL, raises, slow + +#Solves and Fixes Issue #10388 - This is the updated test for the same solved issue + +def test_rf_eval_apply(): + x, y = symbols('x,y') + n, k = symbols('n k', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + + assert rf(nan, y) is nan + assert rf(x, nan) is nan + + assert unchanged(rf, x, y) + + assert rf(oo, 0) == 1 + assert rf(-oo, 0) == 1 + + assert rf(oo, 6) is oo + assert rf(-oo, 7) is -oo + assert rf(-oo, 6) is oo + + assert rf(oo, -6) is oo + assert rf(-oo, -7) is oo + + assert rf(-1, pi) == 0 + assert rf(-5, 1 + I) == 0 + + assert unchanged(rf, -3, k) + assert unchanged(rf, x, Symbol('k', integer=False)) + assert rf(-3, Symbol('k', integer=False)) == 0 + assert rf(Symbol('x', negative=True, integer=True), Symbol('k', integer=False)) == 0 + + assert rf(x, 0) == 1 + assert rf(x, 1) == x + assert rf(x, 2) == x*(x + 1) + assert rf(x, 3) == x*(x + 1)*(x + 2) + assert rf(x, 5) == x*(x + 1)*(x + 2)*(x + 3)*(x + 4) + + assert rf(x, -1) == 1/(x - 1) + assert rf(x, -2) == 1/((x - 1)*(x - 2)) + assert rf(x, -3) == 1/((x - 1)*(x - 2)*(x - 3)) + + assert rf(1, 100) == factorial(100) + + assert rf(x**2 + 3*x, 2) == (x**2 + 3*x)*(x**2 + 3*x + 1) + assert isinstance(rf(x**2 + 3*x, 2), Mul) + assert rf(x**3 + x, -2) == 1/((x**3 + x - 1)*(x**3 + x - 2)) + + assert rf(Poly(x**2 + 3*x, x), 2) == Poly(x**4 + 8*x**3 + 19*x**2 + 12*x, x) + assert isinstance(rf(Poly(x**2 + 3*x, x), 2), Poly) + raises(ValueError, lambda: rf(Poly(x**2 + 3*x, x, y), 2)) + assert rf(Poly(x**3 + x, x), -2) == 1/(x**6 - 9*x**5 + 35*x**4 - 75*x**3 + 94*x**2 - 66*x + 20) + raises(ValueError, lambda: rf(Poly(x**3 + x, x, y), -2)) + + assert rf(x, m).is_integer is None + assert rf(n, k).is_integer is None + assert rf(n, m).is_integer is True + assert rf(n, k + pi).is_integer is False + assert rf(n, m + pi).is_integer is False + assert rf(pi, m).is_integer is False + + def check(x, k, o, n): + a, b = Dummy(), Dummy() + r = lambda x, k: o(a, b).rewrite(n).subs({a:x,b:k}) + for i in range(-5,5): + for j in range(-5,5): + assert o(i, j) == r(i, j), (o, n, i, j) + check(x, k, rf, ff) + check(x, k, rf, binomial) + check(n, k, rf, factorial) + check(x, y, rf, factorial) + check(x, y, rf, binomial) + + assert rf(x, k).rewrite(ff) == ff(x + k - 1, k) + assert rf(x, k).rewrite(gamma) == Piecewise( + (gamma(k + x)/gamma(x), x > 0), + ((-1)**k*gamma(1 - x)/gamma(-k - x + 1), True)) + assert rf(5, k).rewrite(gamma) == gamma(k + 5)/24 + assert rf(x, k).rewrite(binomial) == factorial(k)*binomial(x + k - 1, k) + assert rf(n, k).rewrite(factorial) == Piecewise( + (factorial(k + n - 1)/factorial(n - 1), n > 0), + ((-1)**k*factorial(-n)/factorial(-k - n), True)) + assert rf(5, k).rewrite(factorial) == factorial(k + 4)/24 + assert rf(x, y).rewrite(factorial) == rf(x, y) + assert rf(x, y).rewrite(binomial) == rf(x, y) + + import random + from mpmath import rf as mpmath_rf + for i in range(100): + x = -500 + 500 * random.random() + k = -500 + 500 * random.random() + assert (abs(mpmath_rf(x, k) - rf(x, k)) < 10**(-15)) + + +def test_ff_eval_apply(): + x, y = symbols('x,y') + n, k = symbols('n k', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + + assert ff(nan, y) is nan + assert ff(x, nan) is nan + + assert unchanged(ff, x, y) + + assert ff(oo, 0) == 1 + assert ff(-oo, 0) == 1 + + assert ff(oo, 6) is oo + assert ff(-oo, 7) is -oo + assert ff(-oo, 6) is oo + + assert ff(oo, -6) is oo + assert ff(-oo, -7) is oo + + assert ff(x, 0) == 1 + assert ff(x, 1) == x + assert ff(x, 2) == x*(x - 1) + assert ff(x, 3) == x*(x - 1)*(x - 2) + assert ff(x, 5) == x*(x - 1)*(x - 2)*(x - 3)*(x - 4) + + assert ff(x, -1) == 1/(x + 1) + assert ff(x, -2) == 1/((x + 1)*(x + 2)) + assert ff(x, -3) == 1/((x + 1)*(x + 2)*(x + 3)) + + assert ff(100, 100) == factorial(100) + + assert ff(2*x**2 - 5*x, 2) == (2*x**2 - 5*x)*(2*x**2 - 5*x - 1) + assert isinstance(ff(2*x**2 - 5*x, 2), Mul) + assert ff(x**2 + 3*x, -2) == 1/((x**2 + 3*x + 1)*(x**2 + 3*x + 2)) + + assert ff(Poly(2*x**2 - 5*x, x), 2) == Poly(4*x**4 - 28*x**3 + 59*x**2 - 35*x, x) + assert isinstance(ff(Poly(2*x**2 - 5*x, x), 2), Poly) + raises(ValueError, lambda: ff(Poly(2*x**2 - 5*x, x, y), 2)) + assert ff(Poly(x**2 + 3*x, x), -2) == 1/(x**4 + 12*x**3 + 49*x**2 + 78*x + 40) + raises(ValueError, lambda: ff(Poly(x**2 + 3*x, x, y), -2)) + + + assert ff(x, m).is_integer is None + assert ff(n, k).is_integer is None + assert ff(n, m).is_integer is True + assert ff(n, k + pi).is_integer is False + assert ff(n, m + pi).is_integer is False + assert ff(pi, m).is_integer is False + + assert isinstance(ff(x, x), ff) + assert ff(n, n) == factorial(n) + + def check(x, k, o, n): + a, b = Dummy(), Dummy() + r = lambda x, k: o(a, b).rewrite(n).subs({a:x,b:k}) + for i in range(-5,5): + for j in range(-5,5): + assert o(i, j) == r(i, j), (o, n) + check(x, k, ff, rf) + check(x, k, ff, gamma) + check(n, k, ff, factorial) + check(x, k, ff, binomial) + check(x, y, ff, factorial) + check(x, y, ff, binomial) + + assert ff(x, k).rewrite(rf) == rf(x - k + 1, k) + assert ff(x, k).rewrite(gamma) == Piecewise( + (gamma(x + 1)/gamma(-k + x + 1), x >= 0), + ((-1)**k*gamma(k - x)/gamma(-x), True)) + assert ff(5, k).rewrite(gamma) == 120/gamma(6 - k) + assert ff(n, k).rewrite(factorial) == Piecewise( + (factorial(n)/factorial(-k + n), n >= 0), + ((-1)**k*factorial(k - n - 1)/factorial(-n - 1), True)) + assert ff(5, k).rewrite(factorial) == 120/factorial(5 - k) + assert ff(x, k).rewrite(binomial) == factorial(k) * binomial(x, k) + assert ff(x, y).rewrite(factorial) == ff(x, y) + assert ff(x, y).rewrite(binomial) == ff(x, y) + + import random + from mpmath import ff as mpmath_ff + for i in range(100): + x = -500 + 500 * random.random() + k = -500 + 500 * random.random() + a = mpmath_ff(x, k) + b = ff(x, k) + assert (abs(a - b) < abs(a) * 10**(-15)) + + +def test_rf_ff_eval_hiprec(): + maple = Float('6.9109401292234329956525265438452') + us = ff(18, Rational(2, 3)).evalf(32) + assert abs(us - maple)/us < 1e-31 + + maple = Float('6.8261540131125511557924466355367') + us = rf(18, Rational(2, 3)).evalf(32) + assert abs(us - maple)/us < 1e-31 + + maple = Float('34.007346127440197150854651814225') + us = rf(Float('4.4', 32), Float('2.2', 32)) + assert abs(us - maple)/us < 1e-31 + + +def test_rf_lambdify_mpmath(): + from sympy.utilities.lambdify import lambdify + x, y = symbols('x,y') + f = lambdify((x,y), rf(x, y), 'mpmath') + maple = Float('34.007346127440197') + us = f(4.4, 2.2) + assert abs(us - maple)/us < 1e-15 + + +def test_factorial(): + x = Symbol('x') + n = Symbol('n', integer=True) + k = Symbol('k', integer=True, nonnegative=True) + r = Symbol('r', integer=False) + s = Symbol('s', integer=False, negative=True) + t = Symbol('t', nonnegative=True) + u = Symbol('u', noninteger=True) + + assert factorial(-2) is zoo + assert factorial(0) == 1 + assert factorial(7) == 5040 + assert factorial(19) == 121645100408832000 + assert factorial(31) == 8222838654177922817725562880000000 + assert factorial(n).func == factorial + assert factorial(2*n).func == factorial + + assert factorial(x).is_integer is None + assert factorial(n).is_integer is None + assert factorial(k).is_integer + assert factorial(r).is_integer is None + + assert factorial(n).is_positive is None + assert factorial(k).is_positive + + assert factorial(x).is_real is None + assert factorial(n).is_real is None + assert factorial(k).is_real is True + assert factorial(r).is_real is None + assert factorial(s).is_real is True + assert factorial(t).is_real is True + assert factorial(u).is_real is True + + assert factorial(x).is_composite is None + assert factorial(n).is_composite is None + assert factorial(k).is_composite is None + assert factorial(k + 3).is_composite is True + assert factorial(r).is_composite is None + assert factorial(s).is_composite is None + assert factorial(t).is_composite is None + assert factorial(u).is_composite is None + + assert factorial(oo) is oo + + +def test_factorial_Mod(): + pr = Symbol('pr', prime=True) + p, q = 10**9 + 9, 10**9 + 33 # prime modulo + r, s = 10**7 + 5, 33333333 # composite modulo + assert Mod(factorial(pr - 1), pr) == pr - 1 + assert Mod(factorial(pr - 1), -pr) == -1 + assert Mod(factorial(r - 1, evaluate=False), r) == 0 + assert Mod(factorial(s - 1, evaluate=False), s) == 0 + assert Mod(factorial(p - 1, evaluate=False), p) == p - 1 + assert Mod(factorial(q - 1, evaluate=False), q) == q - 1 + assert Mod(factorial(p - 50, evaluate=False), p) == 854928834 + assert Mod(factorial(q - 1800, evaluate=False), q) == 905504050 + assert Mod(factorial(153, evaluate=False), r) == Mod(factorial(153), r) + assert Mod(factorial(255, evaluate=False), s) == Mod(factorial(255), s) + assert Mod(factorial(4, evaluate=False), 3) == S.Zero + assert Mod(factorial(5, evaluate=False), 6) == S.Zero + + +def test_factorial_diff(): + n = Symbol('n', integer=True) + + assert factorial(n).diff(n) == \ + gamma(1 + n)*polygamma(0, 1 + n) + assert factorial(n**2).diff(n) == \ + 2*n*gamma(1 + n**2)*polygamma(0, 1 + n**2) + raises(ArgumentIndexError, lambda: factorial(n**2).fdiff(2)) + + +def test_factorial_series(): + n = Symbol('n', integer=True) + + assert factorial(n).series(n, 0, 3) == \ + 1 - n*EulerGamma + n**2*(EulerGamma**2/2 + pi**2/12) + O(n**3) + + +def test_factorial_rewrite(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True, nonnegative=True) + + assert factorial(n).rewrite(gamma) == gamma(n + 1) + _i = Dummy('i') + assert factorial(k).rewrite(Product).dummy_eq(Product(_i, (_i, 1, k))) + assert factorial(n).rewrite(Product) == factorial(n) + + +def test_factorial2(): + n = Symbol('n', integer=True) + + assert factorial2(-1) == 1 + assert factorial2(0) == 1 + assert factorial2(7) == 105 + assert factorial2(8) == 384 + + # The following is exhaustive + tt = Symbol('tt', integer=True, nonnegative=True) + tte = Symbol('tte', even=True, nonnegative=True) + tpe = Symbol('tpe', even=True, positive=True) + tto = Symbol('tto', odd=True, nonnegative=True) + tf = Symbol('tf', integer=True, nonnegative=False) + tfe = Symbol('tfe', even=True, nonnegative=False) + tfo = Symbol('tfo', odd=True, nonnegative=False) + ft = Symbol('ft', integer=False, nonnegative=True) + ff = Symbol('ff', integer=False, nonnegative=False) + fn = Symbol('fn', integer=False) + nt = Symbol('nt', nonnegative=True) + nf = Symbol('nf', nonnegative=False) + nn = Symbol('nn') + z = Symbol('z', zero=True) + #Solves and Fixes Issue #10388 - This is the updated test for the same solved issue + raises(ValueError, lambda: factorial2(oo)) + raises(ValueError, lambda: factorial2(Rational(5, 2))) + raises(ValueError, lambda: factorial2(-4)) + assert factorial2(n).is_integer is None + assert factorial2(tt - 1).is_integer + assert factorial2(tte - 1).is_integer + assert factorial2(tpe - 3).is_integer + assert factorial2(tto - 4).is_integer + assert factorial2(tto - 2).is_integer + assert factorial2(tf).is_integer is None + assert factorial2(tfe).is_integer is None + assert factorial2(tfo).is_integer is None + assert factorial2(ft).is_integer is None + assert factorial2(ff).is_integer is None + assert factorial2(fn).is_integer is None + assert factorial2(nt).is_integer is None + assert factorial2(nf).is_integer is None + assert factorial2(nn).is_integer is None + + assert factorial2(n).is_positive is None + assert factorial2(tt - 1).is_positive is True + assert factorial2(tte - 1).is_positive is True + assert factorial2(tpe - 3).is_positive is True + assert factorial2(tpe - 1).is_positive is True + assert factorial2(tto - 2).is_positive is True + assert factorial2(tto - 1).is_positive is True + assert factorial2(tf).is_positive is None + assert factorial2(tfe).is_positive is None + assert factorial2(tfo).is_positive is None + assert factorial2(ft).is_positive is None + assert factorial2(ff).is_positive is None + assert factorial2(fn).is_positive is None + assert factorial2(nt).is_positive is None + assert factorial2(nf).is_positive is None + assert factorial2(nn).is_positive is None + + assert factorial2(tt).is_even is None + assert factorial2(tt).is_odd is None + assert factorial2(tte).is_even is None + assert factorial2(tte).is_odd is None + assert factorial2(tte + 2).is_even is True + assert factorial2(tpe).is_even is True + assert factorial2(tpe).is_odd is False + assert factorial2(tto).is_odd is True + assert factorial2(tf).is_even is None + assert factorial2(tf).is_odd is None + assert factorial2(tfe).is_even is None + assert factorial2(tfe).is_odd is None + assert factorial2(tfo).is_even is False + assert factorial2(tfo).is_odd is None + assert factorial2(z).is_even is False + assert factorial2(z).is_odd is True + + +def test_factorial2_rewrite(): + n = Symbol('n', integer=True) + assert factorial2(n).rewrite(gamma) == \ + 2**(n/2)*Piecewise((1, Eq(Mod(n, 2), 0)), (sqrt(2)/sqrt(pi), Eq(Mod(n, 2), 1)))*gamma(n/2 + 1) + assert factorial2(2*n).rewrite(gamma) == 2**n*gamma(n + 1) + assert factorial2(2*n + 1).rewrite(gamma) == \ + sqrt(2)*2**(n + S.Half)*gamma(n + Rational(3, 2))/sqrt(pi) + + +def test_binomial(): + x = Symbol('x') + n = Symbol('n', integer=True) + nz = Symbol('nz', integer=True, nonzero=True) + k = Symbol('k', integer=True) + kp = Symbol('kp', integer=True, positive=True) + kn = Symbol('kn', integer=True, negative=True) + u = Symbol('u', negative=True) + v = Symbol('v', nonnegative=True) + p = Symbol('p', positive=True) + z = Symbol('z', zero=True) + nt = Symbol('nt', integer=False) + kt = Symbol('kt', integer=False) + a = Symbol('a', integer=True, nonnegative=True) + b = Symbol('b', integer=True, nonnegative=True) + + assert binomial(0, 0) == 1 + assert binomial(1, 1) == 1 + assert binomial(10, 10) == 1 + assert binomial(n, z) == 1 + assert binomial(1, 2) == 0 + assert binomial(-1, 2) == 1 + assert binomial(1, -1) == 0 + assert binomial(-1, 1) == -1 + assert binomial(-1, -1) == 0 + assert binomial(S.Half, S.Half) == 1 + assert binomial(-10, 1) == -10 + assert binomial(-10, 7) == -11440 + assert binomial(n, -1) == 0 # holds for all integers (negative, zero, positive) + assert binomial(kp, -1) == 0 + assert binomial(nz, 0) == 1 + assert expand_func(binomial(n, 1)) == n + assert expand_func(binomial(n, 2)) == n*(n - 1)/2 + assert expand_func(binomial(n, n - 2)) == n*(n - 1)/2 + assert expand_func(binomial(n, n - 1)) == n + assert binomial(n, 3).func == binomial + assert binomial(n, 3).expand(func=True) == n**3/6 - n**2/2 + n/3 + assert expand_func(binomial(n, 3)) == n*(n - 2)*(n - 1)/6 + assert binomial(n, n).func == binomial # e.g. (-1, -1) == 0, (2, 2) == 1 + assert binomial(n, n + 1).func == binomial # e.g. (-1, 0) == 1 + assert binomial(kp, kp + 1) == 0 + assert binomial(kn, kn) == 0 # issue #14529 + assert binomial(n, u).func == binomial + assert binomial(kp, u).func == binomial + assert binomial(n, p).func == binomial + assert binomial(n, k).func == binomial + assert binomial(n, n + p).func == binomial + assert binomial(kp, kp + p).func == binomial + + assert expand_func(binomial(n, n - 3)) == n*(n - 2)*(n - 1)/6 + + assert binomial(n, k).is_integer + assert binomial(nt, k).is_integer is None + assert binomial(x, nt).is_integer is False + + assert binomial(gamma(25), 6) == 79232165267303928292058750056084441948572511312165380965440075720159859792344339983120618959044048198214221915637090855535036339620413440000 + assert binomial(1324, 47) == 906266255662694632984994480774946083064699457235920708992926525848438478406790323869952 + assert binomial(1735, 43) == 190910140420204130794758005450919715396159959034348676124678207874195064798202216379800 + assert binomial(2512, 53) == 213894469313832631145798303740098720367984955243020898718979538096223399813295457822575338958939834177325304000 + assert binomial(3383, 52) == 27922807788818096863529701501764372757272890613101645521813434902890007725667814813832027795881839396839287659777235 + assert binomial(4321, 51) == 124595639629264868916081001263541480185227731958274383287107643816863897851139048158022599533438936036467601690983780576 + + assert binomial(a, b).is_nonnegative is True + assert binomial(-1, 2, evaluate=False).is_nonnegative is True + assert binomial(10, 5, evaluate=False).is_nonnegative is True + assert binomial(10, -3, evaluate=False).is_nonnegative is True + assert binomial(-10, -3, evaluate=False).is_nonnegative is True + assert binomial(-10, 2, evaluate=False).is_nonnegative is True + assert binomial(-10, 1, evaluate=False).is_nonnegative is False + assert binomial(-10, 7, evaluate=False).is_nonnegative is False + + # issue #14625 + for _ in (pi, -pi, nt, v, a): + assert binomial(_, _) == 1 + assert binomial(_, _ - 1) == _ + assert isinstance(binomial(u, u), binomial) + assert isinstance(binomial(u, u - 1), binomial) + assert isinstance(binomial(x, x), binomial) + assert isinstance(binomial(x, x - 1), binomial) + + #issue #18802 + assert expand_func(binomial(x + 1, x)) == x + 1 + assert expand_func(binomial(x, x - 1)) == x + assert expand_func(binomial(x + 1, x - 1)) == x*(x + 1)/2 + assert expand_func(binomial(x**2 + 1, x**2)) == x**2 + 1 + + # issue #13980 and #13981 + assert binomial(-7, -5) == 0 + assert binomial(-23, -12) == 0 + assert binomial(Rational(13, 2), -10) == 0 + assert binomial(-49, -51) == 0 + + assert binomial(19, Rational(-7, 2)) == S(-68719476736)/(911337863661225*pi) + assert binomial(0, Rational(3, 2)) == S(-2)/(3*pi) + assert binomial(-3, Rational(-7, 2)) is zoo + assert binomial(kn, kt) is zoo + + assert binomial(nt, kt).func == binomial + assert binomial(nt, Rational(15, 6)) == 8*gamma(nt + 1)/(15*sqrt(pi)*gamma(nt - Rational(3, 2))) + assert binomial(Rational(20, 3), Rational(-10, 8)) == gamma(Rational(23, 3))/(gamma(Rational(-1, 4))*gamma(Rational(107, 12))) + assert binomial(Rational(19, 2), Rational(-7, 2)) == Rational(-1615, 8388608) + assert binomial(Rational(-13, 5), Rational(-7, 8)) == gamma(Rational(-8, 5))/(gamma(Rational(-29, 40))*gamma(Rational(1, 8))) + assert binomial(Rational(-19, 8), Rational(-13, 5)) == gamma(Rational(-11, 8))/(gamma(Rational(-8, 5))*gamma(Rational(49, 40))) + + # binomial for complexes + assert binomial(I, Rational(-89, 8)) == gamma(1 + I)/(gamma(Rational(-81, 8))*gamma(Rational(97, 8) + I)) + assert binomial(I, 2*I) == gamma(1 + I)/(gamma(1 - I)*gamma(1 + 2*I)) + assert binomial(-7, I) is zoo + assert binomial(Rational(-7, 6), I) == gamma(Rational(-1, 6))/(gamma(Rational(-1, 6) - I)*gamma(1 + I)) + assert binomial((1+2*I), (1+3*I)) == gamma(2 + 2*I)/(gamma(1 - I)*gamma(2 + 3*I)) + assert binomial(I, 5) == Rational(1, 3) - I/S(12) + assert binomial((2*I + 3), 7) == -13*I/S(63) + assert isinstance(binomial(I, n), binomial) + assert expand_func(binomial(3, 2, evaluate=False)) == 3 + assert expand_func(binomial(n, 0, evaluate=False)) == 1 + assert expand_func(binomial(n, -2, evaluate=False)) == 0 + assert expand_func(binomial(n, k)) == binomial(n, k) + + +def test_binomial_Mod(): + p, q = 10**5 + 3, 10**9 + 33 # prime modulo + r = 10**7 + 5 # composite modulo + + # A few tests to get coverage + # Lucas Theorem + assert Mod(binomial(156675, 4433, evaluate=False), p) == Mod(binomial(156675, 4433), p) + + # factorial Mod + assert Mod(binomial(1234, 432, evaluate=False), q) == Mod(binomial(1234, 432), q) + + # binomial factorize + assert Mod(binomial(253, 113, evaluate=False), r) == Mod(binomial(253, 113), r) + + # using Granville's generalisation of Lucas' Theorem + assert Mod(binomial(10**18, 10**12, evaluate=False), p*p) == 3744312326 + + +@slow +def test_binomial_Mod_slow(): + p, q = 10**5 + 3, 10**9 + 33 # prime modulo + r, s = 10**7 + 5, 33333333 # composite modulo + + n, k, m = symbols('n k m') + assert (binomial(n, k) % q).subs({n: s, k: p}) == Mod(binomial(s, p), q) + assert (binomial(n, k) % m).subs({n: 8, k: 5, m: 13}) == 4 + assert (binomial(9, k) % 7).subs(k, 2) == 1 + + # Lucas Theorem + assert Mod(binomial(123456, 43253, evaluate=False), p) == Mod(binomial(123456, 43253), p) + assert Mod(binomial(-178911, 237, evaluate=False), p) == Mod(-binomial(178911 + 237 - 1, 237), p) + assert Mod(binomial(-178911, 238, evaluate=False), p) == Mod(binomial(178911 + 238 - 1, 238), p) + + # factorial Mod + assert Mod(binomial(9734, 451, evaluate=False), q) == Mod(binomial(9734, 451), q) + assert Mod(binomial(-10733, 4459, evaluate=False), q) == Mod(binomial(-10733, 4459), q) + assert Mod(binomial(-15733, 4458, evaluate=False), q) == Mod(binomial(-15733, 4458), q) + assert Mod(binomial(23, -38, evaluate=False), q) is S.Zero + assert Mod(binomial(23, 38, evaluate=False), q) is S.Zero + + # binomial factorize + assert Mod(binomial(753, 119, evaluate=False), r) == Mod(binomial(753, 119), r) + assert Mod(binomial(3781, 948, evaluate=False), s) == Mod(binomial(3781, 948), s) + assert Mod(binomial(25773, 1793, evaluate=False), s) == Mod(binomial(25773, 1793), s) + assert Mod(binomial(-753, 118, evaluate=False), r) == Mod(binomial(-753, 118), r) + assert Mod(binomial(-25773, 1793, evaluate=False), s) == Mod(binomial(-25773, 1793), s) + + +def test_binomial_diff(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True) + + assert binomial(n, k).diff(n) == \ + (-polygamma(0, 1 + n - k) + polygamma(0, 1 + n))*binomial(n, k) + assert binomial(n**2, k**3).diff(n) == \ + 2*n*(-polygamma( + 0, 1 + n**2 - k**3) + polygamma(0, 1 + n**2))*binomial(n**2, k**3) + + assert binomial(n, k).diff(k) == \ + (-polygamma(0, 1 + k) + polygamma(0, 1 + n - k))*binomial(n, k) + assert binomial(n**2, k**3).diff(k) == \ + 3*k**2*(-polygamma( + 0, 1 + k**3) + polygamma(0, 1 + n**2 - k**3))*binomial(n**2, k**3) + raises(ArgumentIndexError, lambda: binomial(n, k).fdiff(3)) + + +def test_binomial_rewrite(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True) + x = Symbol('x') + + assert binomial(n, k).rewrite( + factorial) == factorial(n)/(factorial(k)*factorial(n - k)) + assert binomial( + n, k).rewrite(gamma) == gamma(n + 1)/(gamma(k + 1)*gamma(n - k + 1)) + assert binomial(n, k).rewrite(ff) == ff(n, k) / factorial(k) + assert binomial(n, x).rewrite(ff) == binomial(n, x) + + +@XFAIL +def test_factorial_simplify_fail(): + # simplify(factorial(x + 1).diff(x) - ((x + 1)*factorial(x)).diff(x))) == 0 + from sympy.abc import x + assert simplify(x*polygamma(0, x + 1) - x*polygamma(0, x + 2) + + polygamma(0, x + 1) - polygamma(0, x + 2) + 1) == 0 + + +def test_subfactorial(): + assert all(subfactorial(i) == ans for i, ans in enumerate( + [1, 0, 1, 2, 9, 44, 265, 1854, 14833, 133496])) + assert subfactorial(oo) is oo + assert subfactorial(nan) is nan + assert subfactorial(23) == 9510425471055777937262 + assert unchanged(subfactorial, 2.2) + + x = Symbol('x') + assert subfactorial(x).rewrite(uppergamma) == uppergamma(x + 1, -1)/S.Exp1 + + tt = Symbol('tt', integer=True, nonnegative=True) + tf = Symbol('tf', integer=True, nonnegative=False) + tn = Symbol('tf', integer=True) + ft = Symbol('ft', integer=False, nonnegative=True) + ff = Symbol('ff', integer=False, nonnegative=False) + fn = Symbol('ff', integer=False) + nt = Symbol('nt', nonnegative=True) + nf = Symbol('nf', nonnegative=False) + nn = Symbol('nf') + te = Symbol('te', even=True, nonnegative=True) + to = Symbol('to', odd=True, nonnegative=True) + assert subfactorial(tt).is_integer + assert subfactorial(tf).is_integer is None + assert subfactorial(tn).is_integer is None + assert subfactorial(ft).is_integer is None + assert subfactorial(ff).is_integer is None + assert subfactorial(fn).is_integer is None + assert subfactorial(nt).is_integer is None + assert subfactorial(nf).is_integer is None + assert subfactorial(nn).is_integer is None + assert subfactorial(tt).is_nonnegative + assert subfactorial(tf).is_nonnegative is None + assert subfactorial(tn).is_nonnegative is None + assert subfactorial(ft).is_nonnegative is None + assert subfactorial(ff).is_nonnegative is None + assert subfactorial(fn).is_nonnegative is None + assert subfactorial(nt).is_nonnegative is None + assert subfactorial(nf).is_nonnegative is None + assert subfactorial(nn).is_nonnegative is None + assert subfactorial(tt).is_even is None + assert subfactorial(tt).is_odd is None + assert subfactorial(te).is_odd is True + assert subfactorial(to).is_even is True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..83a7de89ed8e4fcc433d29f41fc87b9d0d397539 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py @@ -0,0 +1,1250 @@ +import string + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import (diff, expand_func) +from sympy.core import (EulerGamma, TribonacciConstant) +from sympy.core.numbers import (Float, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.numbers import carmichael +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.integers import floor +from sympy.polys.polytools import cancel +from sympy.series.limits import limit, Limit +from sympy.series.order import O +from sympy.functions import ( + bernoulli, harmonic, bell, fibonacci, tribonacci, lucas, euler, catalan, + genocchi, andre, partition, divisor_sigma, udivisor_sigma, legendre_symbol, + jacobi_symbol, kronecker_symbol, mobius, + primenu, primeomega, totient, reduced_totient, primepi, + motzkin, binomial, gamma, sqrt, cbrt, hyper, log, digamma, + trigamma, polygamma, factorial, sin, cos, cot, polylog, zeta, dirichlet_eta) +from sympy.functions.combinatorial.numbers import _nT +from sympy.ntheory.factor_ import factorint + +from sympy.core.expr import unchanged +from sympy.core.numbers import GoldenRatio, Integer + +from sympy.testing.pytest import raises, nocache_fail, warns_deprecated_sympy +from sympy.abc import x + + +def test_carmichael(): + with warns_deprecated_sympy(): + assert carmichael.is_prime(2821) == False + + +def test_bernoulli(): + assert bernoulli(0) == 1 + assert bernoulli(1) == Rational(1, 2) + assert bernoulli(2) == Rational(1, 6) + assert bernoulli(3) == 0 + assert bernoulli(4) == Rational(-1, 30) + assert bernoulli(5) == 0 + assert bernoulli(6) == Rational(1, 42) + assert bernoulli(7) == 0 + assert bernoulli(8) == Rational(-1, 30) + assert bernoulli(10) == Rational(5, 66) + assert bernoulli(1000001) == 0 + + assert bernoulli(0, x) == 1 + assert bernoulli(1, x) == x - S.Half + assert bernoulli(2, x) == x**2 - x + Rational(1, 6) + assert bernoulli(3, x) == x**3 - (3*x**2)/2 + x/2 + + # Should be fast; computed with mpmath + b = bernoulli(1000) + assert b.p % 10**10 == 7950421099 + assert b.q == 342999030 + + b = bernoulli(10**6, evaluate=False).evalf() + assert str(b) == '-2.23799235765713e+4767529' + + # Issue #8527 + l = Symbol('l', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + n = Symbol('n', integer=True, positive=True) + assert isinstance(bernoulli(2 * l + 1), bernoulli) + assert isinstance(bernoulli(2 * m + 1), bernoulli) + assert bernoulli(2 * n + 1) == 0 + + assert bernoulli(x, 1) == bernoulli(x) + + assert str(bernoulli(0.0, 2.3).evalf(n=10)) == '1.000000000' + assert str(bernoulli(1.0).evalf(n=10)) == '0.5000000000' + assert str(bernoulli(1.2).evalf(n=10)) == '0.4195995367' + assert str(bernoulli(1.2, 0.8).evalf(n=10)) == '0.2144830348' + assert str(bernoulli(1.2, -0.8).evalf(n=10)) == '-1.158865646 - 0.6745558744*I' + assert str(bernoulli(3.0, 1j).evalf(n=10)) == '1.5 - 0.5*I' + assert str(bernoulli(I).evalf(n=10)) == '0.9268485643 - 0.5821580598*I' + assert str(bernoulli(I, I).evalf(n=10)) == '0.1267792071 + 0.01947413152*I' + assert bernoulli(x).evalf() == bernoulli(x) + + +def test_bernoulli_rewrite(): + from sympy.functions.elementary.piecewise import Piecewise + n = Symbol('n', integer=True, nonnegative=True) + + assert bernoulli(-1).rewrite(zeta) == pi**2/6 + assert bernoulli(-2).rewrite(zeta) == 2*zeta(3) + assert not bernoulli(n, -3).rewrite(zeta).has(harmonic) + assert bernoulli(-4, x).rewrite(zeta) == 4*zeta(5, x) + assert isinstance(bernoulli(n, x).rewrite(zeta), Piecewise) + assert bernoulli(n+1, x).rewrite(zeta) == -(n+1) * zeta(-n, x) + + +def test_fibonacci(): + assert [fibonacci(n) for n in range(-3, 5)] == [2, -1, 1, 0, 1, 1, 2, 3] + assert fibonacci(100) == 354224848179261915075 + assert [lucas(n) for n in range(-3, 5)] == [-4, 3, -1, 2, 1, 3, 4, 7] + assert lucas(100) == 792070839848372253127 + + assert fibonacci(1, x) == 1 + assert fibonacci(2, x) == x + assert fibonacci(3, x) == x**2 + 1 + assert fibonacci(4, x) == x**3 + 2*x + + # issue #8800 + n = Dummy('n') + assert fibonacci(n).limit(n, S.Infinity) is S.Infinity + assert lucas(n).limit(n, S.Infinity) is S.Infinity + + assert fibonacci(n).rewrite(sqrt) == \ + 2**(-n)*sqrt(5)*((1 + sqrt(5))**n - (-sqrt(5) + 1)**n) / 5 + assert fibonacci(n).rewrite(sqrt).subs(n, 10).expand() == fibonacci(10) + assert fibonacci(n).rewrite(GoldenRatio).subs(n,10).evalf() == \ + Float(fibonacci(10)) + assert lucas(n).rewrite(sqrt) == \ + (fibonacci(n-1).rewrite(sqrt) + fibonacci(n+1).rewrite(sqrt)).simplify() + assert lucas(n).rewrite(sqrt).subs(n, 10).expand() == lucas(10) + raises(ValueError, lambda: fibonacci(-3, x)) + + +def test_tribonacci(): + assert [tribonacci(n) for n in range(8)] == [0, 1, 1, 2, 4, 7, 13, 24] + assert tribonacci(100) == 98079530178586034536500564 + + assert tribonacci(0, x) == 0 + assert tribonacci(1, x) == 1 + assert tribonacci(2, x) == x**2 + assert tribonacci(3, x) == x**4 + x + assert tribonacci(4, x) == x**6 + 2*x**3 + 1 + assert tribonacci(5, x) == x**8 + 3*x**5 + 3*x**2 + + n = Dummy('n') + assert tribonacci(n).limit(n, S.Infinity) is S.Infinity + + w = (-1 + S.ImaginaryUnit * sqrt(3)) / 2 + a = (1 + cbrt(19 + 3*sqrt(33)) + cbrt(19 - 3*sqrt(33))) / 3 + b = (1 + w*cbrt(19 + 3*sqrt(33)) + w**2*cbrt(19 - 3*sqrt(33))) / 3 + c = (1 + w**2*cbrt(19 + 3*sqrt(33)) + w*cbrt(19 - 3*sqrt(33))) / 3 + assert tribonacci(n).rewrite(sqrt) == \ + (a**(n + 1)/((a - b)*(a - c)) + + b**(n + 1)/((b - a)*(b - c)) + + c**(n + 1)/((c - a)*(c - b))) + assert tribonacci(n).rewrite(sqrt).subs(n, 4).simplify() == tribonacci(4) + assert tribonacci(n).rewrite(GoldenRatio).subs(n,10).evalf() == \ + Float(tribonacci(10)) + assert tribonacci(n).rewrite(TribonacciConstant) == floor( + 3*TribonacciConstant**n*(102*sqrt(33) + 586)**Rational(1, 3)/ + (-2*(102*sqrt(33) + 586)**Rational(1, 3) + 4 + (102*sqrt(33) + + 586)**Rational(2, 3)) + S.Half) + raises(ValueError, lambda: tribonacci(-1, x)) + + +@nocache_fail +def test_bell(): + assert [bell(n) for n in range(8)] == [1, 1, 2, 5, 15, 52, 203, 877] + + assert bell(0, x) == 1 + assert bell(1, x) == x + assert bell(2, x) == x**2 + x + assert bell(5, x) == x**5 + 10*x**4 + 25*x**3 + 15*x**2 + x + assert bell(oo) is S.Infinity + raises(ValueError, lambda: bell(oo, x)) + + raises(ValueError, lambda: bell(-1)) + raises(ValueError, lambda: bell(S.Half)) + + X = symbols('x:6') + # X = (x0, x1, .. x5) + # at the same time: X[1] = x1, X[2] = x2 for standard readablity. + # but we must supply zero-based indexed object X[1:] = (x1, .. x5) + + assert bell(6, 2, X[1:]) == 6*X[5]*X[1] + 15*X[4]*X[2] + 10*X[3]**2 + assert bell( + 6, 3, X[1:]) == 15*X[4]*X[1]**2 + 60*X[3]*X[2]*X[1] + 15*X[2]**3 + + X = (1, 10, 100, 1000, 10000) + assert bell(6, 2, X) == (6 + 15 + 10)*10000 + + X = (1, 2, 3, 3, 5) + assert bell(6, 2, X) == 6*5 + 15*3*2 + 10*3**2 + + X = (1, 2, 3, 5) + assert bell(6, 3, X) == 15*5 + 60*3*2 + 15*2**3 + + # Dobinski's formula + n = Symbol('n', integer=True, nonnegative=True) + # For large numbers, this is too slow + # For nonintegers, there are significant precision errors + for i in [0, 2, 3, 7, 13, 42, 55]: + # Running without the cache this is either very slow or goes into an + # infinite loop. + assert bell(i).evalf() == bell(n).rewrite(Sum).evalf(subs={n: i}) + + m = Symbol("m") + assert bell(m).rewrite(Sum) == bell(m) + assert bell(n, m).rewrite(Sum) == bell(n, m) + # issue 9184 + n = Dummy('n') + assert bell(n).limit(n, S.Infinity) is S.Infinity + + +def test_harmonic(): + n = Symbol("n") + m = Symbol("m") + + assert harmonic(n, 0) == n + assert harmonic(n).evalf() == harmonic(n) + assert harmonic(n, 1) == harmonic(n) + assert harmonic(1, n) == 1 + + assert harmonic(0, 1) == 0 + assert harmonic(1, 1) == 1 + assert harmonic(2, 1) == Rational(3, 2) + assert harmonic(3, 1) == Rational(11, 6) + assert harmonic(4, 1) == Rational(25, 12) + assert harmonic(0, 2) == 0 + assert harmonic(1, 2) == 1 + assert harmonic(2, 2) == Rational(5, 4) + assert harmonic(3, 2) == Rational(49, 36) + assert harmonic(4, 2) == Rational(205, 144) + assert harmonic(0, 3) == 0 + assert harmonic(1, 3) == 1 + assert harmonic(2, 3) == Rational(9, 8) + assert harmonic(3, 3) == Rational(251, 216) + assert harmonic(4, 3) == Rational(2035, 1728) + + assert harmonic(oo, -1) is S.NaN + assert harmonic(oo, 0) is oo + assert harmonic(oo, S.Half) is oo + assert harmonic(oo, 1) is oo + assert harmonic(oo, 2) == (pi**2)/6 + assert harmonic(oo, 3) == zeta(3) + assert harmonic(oo, Dummy(negative=True)) is S.NaN + ip = Dummy(integer=True, positive=True) + if (1/ip <= 1) is True: #---------------------------------+ + assert None, 'delete this if-block and the next line' #| + ip = Dummy(even=True, positive=True) #--------------------+ + assert harmonic(oo, 1/ip) is oo + assert harmonic(oo, 1 + ip) is zeta(1 + ip) + + assert harmonic(0, m) == 0 + assert harmonic(-1, -1) == 0 + assert harmonic(-1, 0) == -1 + assert harmonic(-1, 1) is S.ComplexInfinity + assert harmonic(-1, 2) is S.NaN + assert harmonic(-3, -2) == -5 + assert harmonic(-3, -3) == 9 + + +def test_harmonic_rational(): + ne = S(6) + no = S(5) + pe = S(8) + po = S(9) + qe = S(10) + qo = S(13) + + Heee = harmonic(ne + pe/qe) + Aeee = (-log(10) + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + pi*sqrt(2*sqrt(5)/5 + 1)/2 + Rational(13944145, 4720968)) + + Heeo = harmonic(ne + pe/qo) + Aeeo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(4, 13)) + 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(32, 13)) + + 2*log(sin(pi*Rational(5, 13)))*cos(pi*Rational(80, 13)) - 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(5, 13)) + - 2*log(sin(pi*Rational(4, 13)))*cos(pi/13) + pi*cot(pi*Rational(5, 13))/2 - 2*log(sin(pi/13))*cos(pi*Rational(3, 13)) + + Rational(2422020029, 702257080)) + + Heoe = harmonic(ne + po/qe) + Aeoe = (-log(20) + 2*(Rational(1, 4) + sqrt(5)/4)*log(Rational(-1, 4) + sqrt(5)/4) + + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 + Rational(1, 4))*log(Rational(1, 4) + sqrt(5)/4) + + Rational(11818877030, 4286604231) + pi*sqrt(2*sqrt(5) + 5)/2) + + Heoo = harmonic(ne + po/qo) + Aeoo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(54, 13)) + 2*log(sin(pi*Rational(4, 13)))*cos(pi*Rational(6, 13)) + + 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(108, 13)) - 2*log(sin(pi*Rational(5, 13)))*cos(pi/13) + - 2*log(sin(pi/13))*cos(pi*Rational(5, 13)) + pi*cot(pi*Rational(4, 13))/2 + - 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(3, 13)) + Rational(11669332571, 3628714320)) + + Hoee = harmonic(no + pe/qe) + Aoee = (-log(10) + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + pi*sqrt(2*sqrt(5)/5 + 1)/2 + Rational(779405, 277704)) + + Hoeo = harmonic(no + pe/qo) + Aoeo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(4, 13)) + 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(32, 13)) + + 2*log(sin(pi*Rational(5, 13)))*cos(pi*Rational(80, 13)) - 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(5, 13)) + - 2*log(sin(pi*Rational(4, 13)))*cos(pi/13) + pi*cot(pi*Rational(5, 13))/2 + - 2*log(sin(pi/13))*cos(pi*Rational(3, 13)) + Rational(53857323, 16331560)) + + Hooe = harmonic(no + po/qe) + Aooe = (-log(20) + 2*(Rational(1, 4) + sqrt(5)/4)*log(Rational(-1, 4) + sqrt(5)/4) + + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 + Rational(1, 4))*log(Rational(1, 4) + sqrt(5)/4) + + Rational(486853480, 186374097) + pi*sqrt(2*sqrt(5) + 5)/2) + + Hooo = harmonic(no + po/qo) + Aooo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(54, 13)) + 2*log(sin(pi*Rational(4, 13)))*cos(pi*Rational(6, 13)) + + 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(108, 13)) - 2*log(sin(pi*Rational(5, 13)))*cos(pi/13) + - 2*log(sin(pi/13))*cos(pi*Rational(5, 13)) + pi*cot(pi*Rational(4, 13))/2 + - 2*log(sin(pi*Rational(2, 13)))*cos(3*pi/13) + Rational(383693479, 125128080)) + + H = [Heee, Heeo, Heoe, Heoo, Hoee, Hoeo, Hooe, Hooo] + A = [Aeee, Aeeo, Aeoe, Aeoo, Aoee, Aoeo, Aooe, Aooo] + for h, a in zip(H, A): + e = expand_func(h).doit() + assert cancel(e/a) == 1 + assert abs(h.n() - a.n()) < 1e-12 + + +def test_harmonic_evalf(): + assert str(harmonic(1.5).evalf(n=10)) == '1.280372306' + assert str(harmonic(1.5, 2).evalf(n=10)) == '1.154576311' # issue 7443 + assert str(harmonic(4.0, -3).evalf(n=10)) == '100.0000000' + assert str(harmonic(7.0, 1.0).evalf(n=10)) == '2.592857143' + assert str(harmonic(1, pi).evalf(n=10)) == '1.000000000' + assert str(harmonic(2, pi).evalf(n=10)) == '1.113314732' + assert str(harmonic(1000.0, pi).evalf(n=10)) == '1.176241563' + assert str(harmonic(I).evalf(n=10)) == '0.6718659855 + 1.076674047*I' + assert str(harmonic(I, I).evalf(n=10)) == '-0.3970915266 + 1.9629689*I' + + assert harmonic(-1.0, 1).evalf() is S.NaN + assert harmonic(-2.0, 2.0).evalf() is S.NaN + +def test_harmonic_rewrite(): + from sympy.functions.elementary.piecewise import Piecewise + n = Symbol("n") + m = Symbol("m", integer=True, positive=True) + x1 = Symbol("x1", positive=True) + x2 = Symbol("x2", negative=True) + + assert harmonic(n).rewrite(digamma) == polygamma(0, n + 1) + EulerGamma + assert harmonic(n).rewrite(trigamma) == polygamma(0, n + 1) + EulerGamma + assert harmonic(n).rewrite(polygamma) == polygamma(0, n + 1) + EulerGamma + + assert harmonic(n,3).rewrite(polygamma) == polygamma(2, n + 1)/2 - polygamma(2, 1)/2 + assert isinstance(harmonic(n,m).rewrite(polygamma), Piecewise) + + assert expand_func(harmonic(n+4)) == harmonic(n) + 1/(n + 4) + 1/(n + 3) + 1/(n + 2) + 1/(n + 1) + assert expand_func(harmonic(n-4)) == harmonic(n) - 1/(n - 1) - 1/(n - 2) - 1/(n - 3) - 1/n + + assert harmonic(n, m).rewrite("tractable") == harmonic(n, m).rewrite(polygamma) + assert harmonic(n, x1).rewrite("tractable") == harmonic(n, x1) + assert harmonic(n, x1 + 1).rewrite("tractable") == zeta(x1 + 1) - zeta(x1 + 1, n + 1) + assert harmonic(n, x2).rewrite("tractable") == zeta(x2) - zeta(x2, n + 1) + + _k = Dummy("k") + assert harmonic(n).rewrite(Sum).dummy_eq(Sum(1/_k, (_k, 1, n))) + assert harmonic(n, m).rewrite(Sum).dummy_eq(Sum(_k**(-m), (_k, 1, n))) + + +def test_harmonic_calculus(): + y = Symbol("y", positive=True) + z = Symbol("z", negative=True) + assert harmonic(x, 1).limit(x, 0) == 0 + assert harmonic(x, y).limit(x, 0) == 0 + assert harmonic(x, 1).series(x, y, 2) == \ + harmonic(y) + (x - y)*zeta(2, y + 1) + O((x - y)**2, (x, y)) + assert limit(harmonic(x, y), x, oo) == harmonic(oo, y) + assert limit(harmonic(x, y + 1), x, oo) == zeta(y + 1) + assert limit(harmonic(x, y - 1), x, oo) == harmonic(oo, y - 1) + assert limit(harmonic(x, z), x, oo) == Limit(harmonic(x, z), x, oo, dir='-') + assert limit(harmonic(x, z + 1), x, oo) == oo + assert limit(harmonic(x, z + 2), x, oo) == harmonic(oo, z + 2) + assert limit(harmonic(x, z - 1), x, oo) == Limit(harmonic(x, z - 1), x, oo, dir='-') + + +def test_euler(): + assert euler(0) == 1 + assert euler(1) == 0 + assert euler(2) == -1 + assert euler(3) == 0 + assert euler(4) == 5 + assert euler(6) == -61 + assert euler(8) == 1385 + + assert euler(20, evaluate=False) != 370371188237525 + + n = Symbol('n', integer=True) + assert euler(n) != -1 + assert euler(n).subs(n, 2) == -1 + + assert euler(-1) == S.Pi / 2 + assert euler(-1, 1) == 2*log(2) + assert euler(-2).evalf() == (2*S.Catalan).evalf() + assert euler(-3).evalf() == (S.Pi**3 / 16).evalf() + assert str(euler(2.3).evalf(n=10)) == '-1.052850274' + assert str(euler(1.2, 3.4).evalf(n=10)) == '3.575613489' + assert str(euler(I).evalf(n=10)) == '1.248446443 - 0.7675445124*I' + assert str(euler(I, I).evalf(n=10)) == '0.04812930469 + 0.01052411008*I' + + assert euler(20).evalf() == 370371188237525.0 + assert euler(20, evaluate=False).evalf() == 370371188237525.0 + + assert euler(n).rewrite(Sum) == euler(n) + n = Symbol('n', integer=True, nonnegative=True) + assert euler(2*n + 1).rewrite(Sum) == 0 + _j = Dummy('j') + _k = Dummy('k') + assert euler(2*n).rewrite(Sum).dummy_eq( + I*Sum((-1)**_j*2**(-_k)*I**(-_k)*(-2*_j + _k)**(2*n + 1)* + binomial(_k, _j)/_k, (_j, 0, _k), (_k, 1, 2*n + 1))) + + +def test_euler_odd(): + n = Symbol('n', odd=True, positive=True) + assert euler(n) == 0 + n = Symbol('n', odd=True) + assert euler(n) != 0 + + +def test_euler_polynomials(): + assert euler(0, x) == 1 + assert euler(1, x) == x - S.Half + assert euler(2, x) == x**2 - x + assert euler(3, x) == x**3 - (3*x**2)/2 + Rational(1, 4) + m = Symbol('m') + assert isinstance(euler(m, x), euler) + from sympy.core.numbers import Float + A = Float('-0.46237208575048694923364757452876131e8') # from Maple + B = euler(19, S.Pi).evalf(32) + assert abs((A - B)/A) < 1e-31 + z = Float(0.1) + Float(0.2)*I + expected = Float(-3126.54721663773 ) + Float(565.736261497056) * I + assert abs(euler(13, z) - expected) < 1e-10 + + +def test_euler_polynomial_rewrite(): + m = Symbol('m') + A = euler(m, x).rewrite('Sum') + assert A.subs({m:3, x:5}).doit() == euler(3, 5) + + +def test_catalan(): + n = Symbol('n', integer=True) + m = Symbol('m', integer=True, positive=True) + k = Symbol('k', integer=True, nonnegative=True) + p = Symbol('p', nonnegative=True) + + catalans = [1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862, 16796, 58786] + for i, c in enumerate(catalans): + assert catalan(i) == c + assert catalan(n).rewrite(factorial).subs(n, i) == c + assert catalan(n).rewrite(Product).subs(n, i).doit() == c + + assert unchanged(catalan, x) + assert catalan(2*x).rewrite(binomial) == binomial(4*x, 2*x)/(2*x + 1) + assert catalan(S.Half).rewrite(gamma) == 8/(3*pi) + assert catalan(S.Half).rewrite(factorial).rewrite(gamma) ==\ + 8 / (3 * pi) + assert catalan(3*x).rewrite(gamma) == 4**( + 3*x)*gamma(3*x + S.Half)/(sqrt(pi)*gamma(3*x + 2)) + assert catalan(x).rewrite(hyper) == hyper((-x + 1, -x), (2,), 1) + + assert catalan(n).rewrite(factorial) == factorial(2*n) / (factorial(n + 1) + * factorial(n)) + assert isinstance(catalan(n).rewrite(Product), catalan) + assert isinstance(catalan(m).rewrite(Product), Product) + + assert diff(catalan(x), x) == (polygamma( + 0, x + S.Half) - polygamma(0, x + 2) + log(4))*catalan(x) + + assert catalan(x).evalf() == catalan(x) + c = catalan(S.Half).evalf() + assert str(c) == '0.848826363156775' + c = catalan(I).evalf(3) + assert str((re(c), im(c))) == '(0.398, -0.0209)' + + # Assumptions + assert catalan(p).is_positive is True + assert catalan(k).is_integer is True + assert catalan(m+3).is_composite is True + + +def test_genocchi(): + genocchis = [0, -1, -1, 0, 1, 0, -3, 0, 17] + for n, g in enumerate(genocchis): + assert genocchi(n) == g + + m = Symbol('m', integer=True) + n = Symbol('n', integer=True, positive=True) + assert unchanged(genocchi, m) + assert genocchi(2*n + 1) == 0 + gn = 2 * (1 - 2**n) * bernoulli(n) + assert genocchi(n).rewrite(bernoulli).factor() == gn.factor() + gnx = 2 * (bernoulli(n, x) - 2**n * bernoulli(n, (x+1) / 2)) + assert genocchi(n, x).rewrite(bernoulli).factor() == gnx.factor() + assert genocchi(2 * n).is_odd + assert genocchi(2 * n).is_even is False + assert genocchi(2 * n + 1).is_even + assert genocchi(n).is_integer + assert genocchi(4 * n).is_positive + # these are the only 2 prime Genocchi numbers + assert genocchi(6, evaluate=False).is_prime == S(-3).is_prime + assert genocchi(8, evaluate=False).is_prime + assert genocchi(4 * n + 2).is_negative + assert genocchi(4 * n + 1).is_negative is False + assert genocchi(4 * n - 2).is_negative + + g0 = genocchi(0, evaluate=False) + assert g0.is_positive is False + assert g0.is_negative is False + assert g0.is_even is True + assert g0.is_odd is False + + assert genocchi(0, x) == 0 + assert genocchi(1, x) == -1 + assert genocchi(2, x) == 1 - 2*x + assert genocchi(3, x) == 3*x - 3*x**2 + assert genocchi(4, x) == -1 + 6*x**2 - 4*x**3 + y = Symbol("y") + assert genocchi(5, (x+y)**100) == -5*(x+y)**400 + 10*(x+y)**300 - 5*(x+y)**100 + + assert str(genocchi(5.0, 4.0).evalf(n=10)) == '-660.0000000' + assert str(genocchi(Rational(5, 4)).evalf(n=10)) == '-1.104286457' + assert str(genocchi(-2).evalf(n=10)) == '3.606170709' + assert str(genocchi(1.3, 3.7).evalf(n=10)) == '-1.847375373' + assert str(genocchi(I, 1.0).evalf(n=10)) == '-0.3161917278 - 1.45311955*I' + + n = Symbol('n') + assert genocchi(n, x).rewrite(dirichlet_eta) == -2*n * dirichlet_eta(1-n, x) + + +def test_andre(): + nums = [1, 1, 1, 2, 5, 16, 61, 272, 1385, 7936, 50521] + for n, a in enumerate(nums): + assert andre(n) == a + assert andre(S.Infinity) == S.Infinity + assert andre(-1) == -log(2) + assert andre(-2) == -2*S.Catalan + assert andre(-3) == 3*zeta(3)/16 + assert andre(-5) == -15*zeta(5)/256 + # In fact andre(-2*n) is related to the Dirichlet *beta* function + # at 2*n, but SymPy doesn't implement that (or general L-functions) + assert unchanged(andre, -4) + + n = Symbol('n', integer=True, nonnegative=True) + assert unchanged(andre, n) + assert andre(n).is_integer is True + assert andre(n).is_positive is True + + assert str(andre(10, evaluate=False).evalf(n=10)) == '50521.00000' + assert str(andre(-1, evaluate=False).evalf(n=10)) == '-0.6931471806' + assert str(andre(-2, evaluate=False).evalf(n=10)) == '-1.831931188' + assert str(andre(-4, evaluate=False).evalf(n=10)) == '1.977889103' + assert str(andre(I, evaluate=False).evalf(n=10)) == '2.378417833 + 0.6343322845*I' + + assert andre(x).rewrite(polylog) == \ + (-I)**(x+1) * polylog(-x, I) + I**(x+1) * polylog(-x, -I) + assert andre(x).rewrite(zeta) == \ + 2 * gamma(x+1) / (2*pi)**(x+1) * \ + (zeta(x+1, Rational(1,4)) - cos(pi*x) * zeta(x+1, Rational(3,4))) + + +@nocache_fail +def test_partition(): + partition_nums = [1, 1, 2, 3, 5, 7, 11, 15, 22] + for n, p in enumerate(partition_nums): + assert partition(n) == p + + x = Symbol('x') + y = Symbol('y', real=True) + m = Symbol('m', integer=True) + n = Symbol('n', integer=True, negative=True) + p = Symbol('p', integer=True, nonnegative=True) + assert partition(m).is_integer + assert not partition(m).is_negative + assert partition(m).is_nonnegative + assert partition(n).is_zero + assert partition(p).is_positive + assert partition(x).subs(x, 7) == 15 + assert partition(y).subs(y, 8) == 22 + raises(TypeError, lambda: partition(Rational(5, 4))) + assert partition(9, evaluate=False) % 5 == 0 + assert partition(5*m + 4) % 5 == 0 + assert partition(47, evaluate=False) % 7 == 0 + assert partition(7*m + 5) % 7 == 0 + assert partition(50, evaluate=False) % 11 == 0 + assert partition(11*m + 6) % 11 == 0 + + +def test_divisor_sigma(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: divisor_sigma(m)) + raises(TypeError, lambda: divisor_sigma(4.5)) + raises(TypeError, lambda: divisor_sigma(1, m)) + raises(TypeError, lambda: divisor_sigma(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: divisor_sigma(m)) + raises(ValueError, lambda: divisor_sigma(0)) + m = Symbol('m', negative=True) + raises(ValueError, lambda: divisor_sigma(1, m)) + raises(ValueError, lambda: divisor_sigma(1, -1)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert divisor_sigma(p, 1) == p + 1 + assert divisor_sigma(p, k) == p**k + 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert divisor_sigma(n).is_integer is True + assert divisor_sigma(n).is_positive is True + + # symbolic + k = Symbol('k', integer=True, zero=False) + assert divisor_sigma(4, k) == 2**(2*k) + 2**k + 1 + assert divisor_sigma(6, k) == (2**k + 1) * (3**k + 1) + + # Integer + assert divisor_sigma(23450) == 50592 + assert divisor_sigma(23450, 0) == 24 + assert divisor_sigma(23450, 1) == 50592 + assert divisor_sigma(23450, 2) == 730747500 + assert divisor_sigma(23450, 3) == 14666785333344 + + +def test_udivisor_sigma(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: udivisor_sigma(m)) + raises(TypeError, lambda: udivisor_sigma(4.5)) + raises(TypeError, lambda: udivisor_sigma(1, m)) + raises(TypeError, lambda: udivisor_sigma(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: udivisor_sigma(m)) + raises(ValueError, lambda: udivisor_sigma(0)) + m = Symbol('m', negative=True) + raises(ValueError, lambda: udivisor_sigma(1, m)) + raises(ValueError, lambda: udivisor_sigma(1, -1)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert udivisor_sigma(p, 1) == p + 1 + assert udivisor_sigma(p, k) == p**k + 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert udivisor_sigma(n).is_integer is True + assert udivisor_sigma(n).is_positive is True + + # Integer + A034444 = [1, 2, 2, 2, 2, 4, 2, 2, 2, 4, 2, 4, 2, 4, 4, 2, 2, 4, 2, 4, + 4, 4, 2, 4, 2, 4, 2, 4, 2, 8, 2, 2, 4, 4, 4, 4, 2, 4, 4, 4, + 2, 8, 2, 4, 4, 4, 2, 4, 2, 4, 4, 4, 2, 4, 4, 4, 4, 4, 2, 8] + for n, val in enumerate(A034444, 1): + assert udivisor_sigma(n, 0) == val + A034448 = [1, 3, 4, 5, 6, 12, 8, 9, 10, 18, 12, 20, 14, 24, 24, 17, 18, + 30, 20, 30, 32, 36, 24, 36, 26, 42, 28, 40, 30, 72, 32, 33, + 48, 54, 48, 50, 38, 60, 56, 54, 42, 96, 44, 60, 60, 72, 48] + for n, val in enumerate(A034448, 1): + assert udivisor_sigma(n, 1) == val + A034676 = [1, 5, 10, 17, 26, 50, 50, 65, 82, 130, 122, 170, 170, 250, + 260, 257, 290, 410, 362, 442, 500, 610, 530, 650, 626, 850, + 730, 850, 842, 1300, 962, 1025, 1220, 1450, 1300, 1394, 1370] + for n, val in enumerate(A034676, 1): + assert udivisor_sigma(n, 2) == val + + +def test_legendre_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: legendre_symbol(m, 3)) + raises(TypeError, lambda: legendre_symbol(4.5, 3)) + raises(TypeError, lambda: legendre_symbol(1, m)) + raises(TypeError, lambda: legendre_symbol(1, 4.5)) + m = Symbol('m', prime=False) + raises(ValueError, lambda: legendre_symbol(1, m)) + raises(ValueError, lambda: legendre_symbol(1, 6)) + m = Symbol('m', odd=False) + raises(ValueError, lambda: legendre_symbol(1, m)) + raises(ValueError, lambda: legendre_symbol(1, 2)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert legendre_symbol(p*k, p) == 0 + assert legendre_symbol(1, p) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert legendre_symbol(m, n).is_integer is True + assert legendre_symbol(m, n).is_prime is False + + # Integer + assert legendre_symbol(5, 11) == 1 + assert legendre_symbol(25, 41) == 1 + assert legendre_symbol(67, 101) == -1 + assert legendre_symbol(0, 13) == 0 + assert legendre_symbol(9, 3) == 0 + + +def test_jacobi_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: jacobi_symbol(m, 3)) + raises(TypeError, lambda: jacobi_symbol(4.5, 3)) + raises(TypeError, lambda: jacobi_symbol(1, m)) + raises(TypeError, lambda: jacobi_symbol(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: jacobi_symbol(1, m)) + raises(ValueError, lambda: jacobi_symbol(1, -6)) + m = Symbol('m', odd=False) + raises(ValueError, lambda: jacobi_symbol(1, m)) + raises(ValueError, lambda: jacobi_symbol(1, 2)) + + # special case + p = Symbol('p', integer=True) + k = Symbol('k', integer=True) + assert jacobi_symbol(p*k, p) == 0 + assert jacobi_symbol(1, p) == 1 + assert jacobi_symbol(1, 1) == 1 + assert jacobi_symbol(0, 1) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert jacobi_symbol(m, n).is_integer is True + assert jacobi_symbol(m, n).is_prime is False + + # Integer + assert jacobi_symbol(25, 41) == 1 + assert jacobi_symbol(-23, 83) == -1 + assert jacobi_symbol(3, 9) == 0 + assert jacobi_symbol(42, 97) == -1 + assert jacobi_symbol(3, 5) == -1 + assert jacobi_symbol(7, 9) == 1 + assert jacobi_symbol(0, 3) == 0 + assert jacobi_symbol(0, 1) == 1 + assert jacobi_symbol(2, 1) == 1 + assert jacobi_symbol(1, 3) == 1 + + +def test_kronecker_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: kronecker_symbol(m, 3)) + raises(TypeError, lambda: kronecker_symbol(4.5, 3)) + raises(TypeError, lambda: kronecker_symbol(1, m)) + raises(TypeError, lambda: kronecker_symbol(1, 4.5)) + + # special case + p = Symbol('p', integer=True) + assert kronecker_symbol(1, p) == 1 + assert kronecker_symbol(1, 1) == 1 + assert kronecker_symbol(0, 1) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert kronecker_symbol(m, n).is_integer is True + assert kronecker_symbol(m, n).is_prime is False + + # Integer + for n in range(3, 10, 2): + for a in range(-n, n): + val = kronecker_symbol(a, n) + assert val == jacobi_symbol(a, n) + minus = kronecker_symbol(a, -n) + if a < 0: + assert -minus == val + else: + assert minus == val + even = kronecker_symbol(a, 2 * n) + if a % 2 == 0: + assert even == 0 + elif a % 8 in [1, 7]: + assert even == val + else: + assert -even == val + assert kronecker_symbol(1, 0) == kronecker_symbol(-1, 0) == 1 + assert kronecker_symbol(0, 0) == 0 + + +def test_mobius(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: mobius(m)) + raises(TypeError, lambda: mobius(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: mobius(m)) + raises(ValueError, lambda: mobius(-3)) + + # special case + p = Symbol('p', prime=True) + assert mobius(p) == -1 + + # property + n = Symbol('n', integer=True, positive=True) + assert mobius(n).is_integer is True + assert mobius(n).is_prime is False + + # symbolic + n = Symbol('n', integer=True, positive=True) + k = Symbol('k', integer=True, positive=True) + assert mobius(n**2) == 0 + assert mobius(4*n) == 0 + assert isinstance(mobius(n**k), mobius) + assert mobius(n**(k+1)) == 0 + assert isinstance(mobius(3**k), mobius) + assert mobius(3**(k+1)) == 0 + m = Symbol('m') + assert isinstance(mobius(4*m), mobius) + + # Integer + assert mobius(13*7) == 1 + assert mobius(1) == 1 + assert mobius(13*7*5) == -1 + assert mobius(13**2) == 0 + A008683 = [1, -1, -1, 0, -1, 1, -1, 0, 0, 1, -1, 0, -1, 1, 1, 0, -1, 0, + -1, 0, 1, 1, -1, 0, 0, 1, 0, 0, -1, -1, -1, 0, 1, 1, 1, 0, -1, + 1, 1, 0, -1, -1, -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, -1, 0, 1, 0] + for n, val in enumerate(A008683, 1): + assert mobius(n) == val + + +def test_primenu(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: primenu(m)) + raises(TypeError, lambda: primenu(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: primenu(m)) + raises(ValueError, lambda: primenu(0)) + + # special case + p = Symbol('p', prime=True) + assert primenu(p) == 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert primenu(n).is_integer is True + assert primenu(n).is_nonnegative is True + + # Integer + assert primenu(7*13) == 2 + assert primenu(2*17*19) == 3 + assert primenu(2**3 * 17 * 19**2) == 3 + A001221 = [0, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, + 1, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 3, 1, 1, 2, 2, 2, 2] + for n, val in enumerate(A001221, 1): + assert primenu(n) == val + + +def test_primeomega(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: primeomega(m)) + raises(TypeError, lambda: primeomega(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: primeomega(m)) + raises(ValueError, lambda: primeomega(0)) + + # special case + p = Symbol('p', prime=True) + assert primeomega(p) == 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert primeomega(n).is_integer is True + assert primeomega(n).is_nonnegative is True + + # Integer + assert primeomega(7*13) == 2 + assert primeomega(2*17*19) == 3 + assert primeomega(2**3 * 17 * 19**2) == 6 + A001222 = [0, 1, 1, 2, 1, 2, 1, 3, 2, 2, 1, 3, 1, 2, 2, 4, 1, 3, + 1, 3, 2, 2, 1, 4, 2, 2, 3, 3, 1, 3, 1, 5, 2, 2, 2, 4] + for n, val in enumerate(A001222, 1): + assert primeomega(n) == val + + +def test_totient(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: totient(m)) + raises(TypeError, lambda: totient(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: totient(m)) + raises(ValueError, lambda: totient(0)) + + # special case + p = Symbol('p', prime=True) + assert totient(p) == p - 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert totient(n).is_integer is True + assert totient(n).is_positive is True + + # Integer + assert totient(7*13) == totient(factorint(7*13)) == (7-1)*(13-1) + assert totient(2*17*19) == totient(factorint(2*17*19)) == (17-1)*(19-1) + assert totient(2**3 * 17 * 19**2) == totient({2: 3, 17: 1, 19: 2}) == 2**2 * (17-1) * 19*(19-1) + A000010 = [1, 1, 2, 2, 4, 2, 6, 4, 6, 4, 10, 4, 12, 6, 8, 8, 16, + 6, 18, 8, 12, 10, 22, 8, 20, 12, 18, 12, 28, 8, 30, 16, + 20, 16, 24, 12, 36, 18, 24, 16, 40, 12, 42, 20, 24, 22] + for n, val in enumerate(A000010, 1): + assert totient(n) == val + + +def test_reduced_totient(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: reduced_totient(m)) + raises(TypeError, lambda: reduced_totient(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: reduced_totient(m)) + raises(ValueError, lambda: reduced_totient(0)) + + # special case + p = Symbol('p', prime=True) + assert reduced_totient(p) == p - 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert reduced_totient(n).is_integer is True + assert reduced_totient(n).is_positive is True + + # Integer + assert reduced_totient(7*13) == reduced_totient(factorint(7*13)) == 12 + assert reduced_totient(2*17*19) == reduced_totient(factorint(2*17*19)) == 144 + assert reduced_totient(2**2 * 11) == reduced_totient({2: 2, 11: 1}) == 10 + assert reduced_totient(2**3 * 17 * 19**2) == reduced_totient({2: 3, 17: 1, 19: 2}) == 2736 + A002322 = [1, 1, 2, 2, 4, 2, 6, 2, 6, 4, 10, 2, 12, 6, 4, 4, 16, 6, + 18, 4, 6, 10, 22, 2, 20, 12, 18, 6, 28, 4, 30, 8, 10, 16, + 12, 6, 36, 18, 12, 4, 40, 6, 42, 10, 12, 22, 46, 4, 42] + for n, val in enumerate(A002322, 1): + assert reduced_totient(n) == val + + +def test_primepi(): + # error + z = Symbol('z', real=False) + raises(TypeError, lambda: primepi(z)) + raises(TypeError, lambda: primepi(I)) + + # property + n = Symbol('n', integer=True, positive=True) + assert primepi(n).is_integer is True + assert primepi(n).is_nonnegative is True + + # infinity + assert primepi(oo) == oo + assert primepi(-oo) == 0 + + # symbol + x = Symbol('x') + assert isinstance(primepi(x), primepi) + + # Integer + assert primepi(0) == 0 + A000720 = [0, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 8, + 8, 8, 8, 9, 9, 9, 9, 9, 9, 10, 10, 11, 11, 11, 11, 11, 11, + 12, 12, 12, 12, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15] + for n, val in enumerate(A000720, 1): + assert primepi(n) == primepi(n + 0.5) == val + + +def test__nT(): + assert [_nT(i, j) for i in range(5) for j in range(i + 2)] == [ + 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 2, 1, 1, 0] + check = [_nT(10, i) for i in range(11)] + assert check == [0, 1, 5, 8, 9, 7, 5, 3, 2, 1, 1] + assert all(type(i) is int for i in check) + assert _nT(10, 5) == 7 + assert _nT(100, 98) == 2 + assert _nT(100, 100) == 1 + assert _nT(10, 3) == 8 + + +def test_nC_nP_nT(): + from sympy.utilities.iterables import ( + multiset_permutations, multiset_combinations, multiset_partitions, + partitions, subsets, permutations) + from sympy.functions.combinatorial.numbers import ( + nP, nC, nT, stirling, _stirling1, _stirling2, _multiset_histogram, _AOP_product) + + from sympy.combinatorics.permutations import Permutation + from sympy.core.random import choice + + c = string.ascii_lowercase + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(8): + check = nP(s, i) + tot += check + assert len(list(multiset_permutations(s, i))) == check + if u: + assert nP(len(s), i) == check + assert nP(s) == tot + except AssertionError: + print(s, i, 'failed perm test') + raise ValueError() + + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(8): + check = nC(s, i) + tot += check + assert len(list(multiset_combinations(s, i))) == check + if u: + assert nC(len(s), i) == check + assert nC(s) == tot + if u: + assert nC(len(s)) == tot + except AssertionError: + print(s, i, 'failed combo test') + raise ValueError() + + for i in range(1, 10): + tot = 0 + for j in range(1, i + 2): + check = nT(i, j) + assert check.is_Integer + tot += check + assert sum(1 for p in partitions(i, j, size=True) if p[0] == j) == check + assert nT(i) == tot + + for i in range(1, 10): + tot = 0 + for j in range(1, i + 2): + check = nT(range(i), j) + tot += check + assert len(list(multiset_partitions(list(range(i)), j))) == check + assert nT(range(i)) == tot + + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(1, 8): + check = nT(s, i) + tot += check + assert len(list(multiset_partitions(s, i))) == check + if u: + assert nT(range(len(s)), i) == check + if u: + assert nT(range(len(s))) == tot + assert nT(s) == tot + except AssertionError: + print(s, i, 'failed partition test') + raise ValueError() + + # tests for Stirling numbers of the first kind that are not tested in the + # above + assert [stirling(9, i, kind=1) for i in range(11)] == [ + 0, 40320, 109584, 118124, 67284, 22449, 4536, 546, 36, 1, 0] + perms = list(permutations(range(4))) + assert [sum(1 for p in perms if Permutation(p).cycles == i) + for i in range(5)] == [0, 6, 11, 6, 1] == [ + stirling(4, i, kind=1) for i in range(5)] + # http://oeis.org/A008275 + assert [stirling(n, k, signed=1) + for n in range(10) for k in range(1, n + 1)] == [ + 1, -1, + 1, 2, -3, + 1, -6, 11, -6, + 1, 24, -50, 35, -10, + 1, -120, 274, -225, 85, -15, + 1, 720, -1764, 1624, -735, 175, -21, + 1, -5040, 13068, -13132, 6769, -1960, 322, -28, + 1, 40320, -109584, 118124, -67284, 22449, -4536, 546, -36, 1] + # https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind + assert [stirling(n, k, kind=1) + for n in range(10) for k in range(n+1)] == [ + 1, + 0, 1, + 0, 1, 1, + 0, 2, 3, 1, + 0, 6, 11, 6, 1, + 0, 24, 50, 35, 10, 1, + 0, 120, 274, 225, 85, 15, 1, + 0, 720, 1764, 1624, 735, 175, 21, 1, + 0, 5040, 13068, 13132, 6769, 1960, 322, 28, 1, + 0, 40320, 109584, 118124, 67284, 22449, 4536, 546, 36, 1] + # https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind + assert [stirling(n, k, kind=2) + for n in range(10) for k in range(n+1)] == [ + 1, + 0, 1, + 0, 1, 1, + 0, 1, 3, 1, + 0, 1, 7, 6, 1, + 0, 1, 15, 25, 10, 1, + 0, 1, 31, 90, 65, 15, 1, + 0, 1, 63, 301, 350, 140, 21, 1, + 0, 1, 127, 966, 1701, 1050, 266, 28, 1, + 0, 1, 255, 3025, 7770, 6951, 2646, 462, 36, 1] + assert stirling(3, 4, kind=1) == stirling(3, 4, kind=1) == 0 + raises(ValueError, lambda: stirling(-2, 2)) + + # Assertion that the return type is SymPy Integer. + assert isinstance(_stirling1(6, 3), Integer) + assert isinstance(_stirling2(6, 3), Integer) + + def delta(p): + if len(p) == 1: + return oo + return min(abs(i[0] - i[1]) for i in subsets(p, 2)) + parts = multiset_partitions(range(5), 3) + d = 2 + assert (sum(1 for p in parts if all(delta(i) >= d for i in p)) == + stirling(5, 3, d=d) == 7) + + # other coverage tests + assert nC('abb', 2) == nC('aab', 2) == 2 + assert nP(3, 3, replacement=True) == nP('aabc', 3, replacement=True) == 27 + assert nP(3, 4) == 0 + assert nP('aabc', 5) == 0 + assert nC(4, 2, replacement=True) == nC('abcdd', 2, replacement=True) == \ + len(list(multiset_combinations('aabbccdd', 2))) == 10 + assert nC('abcdd') == sum(nC('abcdd', i) for i in range(6)) == 24 + assert nC(list('abcdd'), 4) == 4 + assert nT('aaaa') == nT(4) == len(list(partitions(4))) == 5 + assert nT('aaab') == len(list(multiset_partitions('aaab'))) == 7 + assert nC('aabb'*3, 3) == 4 # aaa, bbb, abb, baa + assert dict(_AOP_product((4,1,1,1))) == { + 0: 1, 1: 4, 2: 7, 3: 8, 4: 8, 5: 7, 6: 4, 7: 1} + # the following was the first t that showed a problem in a previous form of + # the function, so it's not as random as it may appear + t = (3, 9, 4, 6, 6, 5, 5, 2, 10, 4) + assert sum(_AOP_product(t)[i] for i in range(55)) == 58212000 + raises(ValueError, lambda: _multiset_histogram({1:'a'})) + + +def test_PR_14617(): + from sympy.functions.combinatorial.numbers import nT + for n in (0, []): + for k in (-1, 0, 1): + if k == 0: + assert nT(n, k) == 1 + else: + assert nT(n, k) == 0 + + +def test_issue_8496(): + n = Symbol("n") + k = Symbol("k") + + raises(TypeError, lambda: catalan(n, k)) + + +def test_issue_8601(): + n = Symbol('n', integer=True, negative=True) + + assert catalan(n - 1) is S.Zero + assert catalan(Rational(-1, 2)) is S.ComplexInfinity + assert catalan(-S.One) == Rational(-1, 2) + c1 = catalan(-5.6).evalf() + assert str(c1) == '6.93334070531408e-5' + c2 = catalan(-35.4).evalf() + assert str(c2) == '-4.14189164517449e-24' + + +def test_motzkin(): + assert motzkin.is_motzkin(4) == True + assert motzkin.is_motzkin(9) == True + assert motzkin.is_motzkin(10) == False + assert motzkin.find_motzkin_numbers_in_range(10,200) == [21, 51, 127] + assert motzkin.find_motzkin_numbers_in_range(10,400) == [21, 51, 127, 323] + assert motzkin.find_motzkin_numbers_in_range(10,1600) == [21, 51, 127, 323, 835] + assert motzkin.find_first_n_motzkins(5) == [1, 1, 2, 4, 9] + assert motzkin.find_first_n_motzkins(7) == [1, 1, 2, 4, 9, 21, 51] + assert motzkin.find_first_n_motzkins(10) == [1, 1, 2, 4, 9, 21, 51, 127, 323, 835] + raises(ValueError, lambda: motzkin.eval(77.58)) + raises(ValueError, lambda: motzkin.eval(-8)) + raises(ValueError, lambda: motzkin.find_motzkin_numbers_in_range(-2,7)) + raises(ValueError, lambda: motzkin.find_motzkin_numbers_in_range(13,7)) + raises(ValueError, lambda: motzkin.find_first_n_motzkins(112.8)) + + +def test_nD_derangements(): + from sympy.utilities.iterables import (partitions, multiset, + multiset_derangements, multiset_permutations) + from sympy.functions.combinatorial.numbers import nD + + got = [] + for i in partitions(8, k=4): + s = [] + it = 0 + for k, v in i.items(): + for i in range(v): + s.extend([it]*k) + it += 1 + ms = multiset(s) + c1 = sum(1 for i in multiset_permutations(s) if + all(i != j for i, j in zip(i, s))) + assert c1 == nD(ms) == nD(ms, 0) == nD(ms, 1) + v = [tuple(i) for i in multiset_derangements(s)] + c2 = len(v) + assert c2 == len(set(v)) + assert c1 == c2 + got.append(c1) + assert got == [1, 4, 6, 12, 24, 24, 61, 126, 315, 780, 297, 772, + 2033, 5430, 14833] + + assert nD('1112233456', brute=True) == nD('1112233456') == 16356 + assert nD('') == nD([]) == nD({}) == 0 + assert nD({1: 0}) == 0 + raises(ValueError, lambda: nD({1: -1})) + assert nD('112') == 0 + assert nD(i='112') == 0 + assert [nD(n=i) for i in range(6)] == [0, 0, 1, 2, 9, 44] + assert nD((i for i in range(4))) == nD('0123') == 9 + assert nD(m=(i for i in range(4))) == 3 + assert nD(m={0: 1, 1: 1, 2: 1, 3: 1}) == 3 + assert nD(m=[0, 1, 2, 3]) == 3 + raises(TypeError, lambda: nD(m=0)) + raises(TypeError, lambda: nD(-1)) + assert nD({-1: 1, -2: 1}) == 1 + assert nD(m={0: 3}) == 0 + raises(ValueError, lambda: nD(i='123', n=3)) + raises(ValueError, lambda: nD(i='123', m=(1,2))) + raises(ValueError, lambda: nD(n=0, m=(1,2))) + raises(ValueError, lambda: nD({1: -1})) + raises(ValueError, lambda: nD(m={-1: 1, 2: 1})) + raises(ValueError, lambda: nD(m={1: -1, 2: 1})) + raises(ValueError, lambda: nD(m=[-1, 2])) + raises(TypeError, lambda: nD({1: x})) + raises(TypeError, lambda: nD(m={1: x})) + raises(TypeError, lambda: nD(m={x: 1})) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert not carmichael.is_carmichael(3) + with warns_deprecated_sympy(): + assert carmichael.find_carmichael_numbers_in_range(10, 20) == [] + with warns_deprecated_sympy(): + assert carmichael.find_first_n_carmichaels(1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a862bb68c3870db8bd81b74d3a4f1f73c687f010 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/beta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/beta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..754f95dbecb3a5ce18932a4f4c1f897faa4f61fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/beta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/bsplines.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/bsplines.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f3aad24dd2e522f4ad5910734238a8d5e86cf0a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/bsplines.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/delta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/delta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73898ce7fa392f6cf3a548afe6c1c5965db58d7e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/delta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/elliptic_integrals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/elliptic_integrals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2467033126ed6a58d84d1211a12bb26e4d168b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/elliptic_integrals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/gamma_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/gamma_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aacf64ee9d87bdbde79c297b8b96ae54bac2014 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/gamma_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/hyper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/hyper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c2b83c02fc4d6df07c60eef70dcf75b4ba82863 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/hyper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/mathieu_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/mathieu_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a5188ea5c4b4f5e63ca282a8e30d21dcf184d47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/mathieu_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/polynomials.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/polynomials.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d81afd5b71d6c211f2f23f9e52e1ed3ecd85acc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/polynomials.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/singularity_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/singularity_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18c8920b6017dc0d99ce9167f4560cd3270068b7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/singularity_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/spherical_harmonics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/spherical_harmonics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1cea51f32088ec117026607f8e19b638095556 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/spherical_harmonics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/tensor_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/tensor_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d1e78be322758cd75b84fc86098a0e19e4b2cdb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/tensor_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/zeta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/zeta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3586896f990888fe81da9853460bc02422b0337 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/__pycache__/zeta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7894a8deb3a4d9392d04c7425261b16d82017d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/bench_special.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/bench_special.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e361292a18da018726134897d969e995d1dddca0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/__pycache__/bench_special.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/bench_special.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/bench_special.py new file mode 100644 index 0000000000000000000000000000000000000000..25d7280c2cf31dcbff08065a78847ed03e0ebb05 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/benchmarks/bench_special.py @@ -0,0 +1,8 @@ +from sympy.core.symbol import symbols +from sympy.functions.special.spherical_harmonics import Ynm + +x, y = symbols('x,y') + + +def timeit_Ynm_xy(): + Ynm(1, 1, x, y) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60567b803edcb76a24bae16597128800a62959ed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bessel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bessel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516250a64161ce2386fa247f885a336fbd9c5da2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bessel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_beta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_beta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f430b83ec5b4ea85cc5adbed719f2f455ebe3e69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_beta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bsplines.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bsplines.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e56ac5109ef0a65c928eb3eadb57653e417a1b5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_bsplines.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_delta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_delta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f66e0ce0766f54e39f2925ab1e352646573e671 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_delta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_elliptic_integrals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_elliptic_integrals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e480d23feb9c5073e8e7a995fee63bfec21eb66a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_elliptic_integrals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_gamma_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_gamma_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31e128578ee2171cbfad1b44558d7e33eabcc35b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_gamma_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_hyper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_hyper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac0ae97eee392c485edc09f2f1933c3e8c2e651 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_hyper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_mathieu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_mathieu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72c76405511f8f1af2018048ff50fc6076f5d625 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_mathieu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_singularity_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_singularity_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ea0c845f744671b7a0dd79d396ad4ef9fef99b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_singularity_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spec_polynomials.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spec_polynomials.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950b4f3454dcdccf6aed57c35531323c372079d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spec_polynomials.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spherical_harmonics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spherical_harmonics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c64a8ffbaa8bc28b4278d8fd7734f4ffff8f27 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_spherical_harmonics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_tensor_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_tensor_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e868e0c07370e384726b780b020afd81ccadceea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_tensor_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_zeta_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_zeta_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..958f3bf531365448efccede5ccdb823e660c1777 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/__pycache__/test_zeta_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_beta_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_beta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b34cb2febf9e2746d869cd878525d2794535aea5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_beta_functions.py @@ -0,0 +1,89 @@ +from sympy.core.function import (diff, expand_func) +from sympy.core.numbers import I, Rational, pi +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.numbers import catalan +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized) +from sympy.functions.special.gamma_functions import gamma, polygamma +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import Integral +from sympy.core.function import ArgumentIndexError +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises + + +def test_beta(): + x, y = symbols('x y') + t = Dummy('t') + + assert unchanged(beta, x, y) + assert unchanged(beta, x, x) + + assert beta(5, -3).is_real == True + assert beta(3, y).is_real is None + + assert expand_func(beta(x, y)) == gamma(x)*gamma(y)/gamma(x + y) + assert expand_func(beta(x, y) - beta(y, x)) == 0 # Symmetric + assert expand_func(beta(x, y)) == expand_func(beta(x, y + 1) + beta(x + 1, y)).simplify() + + assert diff(beta(x, y), x) == beta(x, y)*(polygamma(0, x) - polygamma(0, x + y)) + assert diff(beta(x, y), y) == beta(x, y)*(polygamma(0, y) - polygamma(0, x + y)) + + assert conjugate(beta(x, y)) == beta(conjugate(x), conjugate(y)) + + raises(ArgumentIndexError, lambda: beta(x, y).fdiff(3)) + + assert beta(x, y).rewrite(gamma) == gamma(x)*gamma(y)/gamma(x + y) + assert beta(x).rewrite(gamma) == gamma(x)**2/gamma(2*x) + assert beta(x, y).rewrite(Integral).dummy_eq(Integral(t**(x - 1) * (1 - t)**(y - 1), (t, 0, 1))) + assert beta(Rational(-19, 10), Rational(-1, 10)) == S.Zero + assert beta(Rational(-19, 10), Rational(-9, 10)) == \ + 800*2**(S(4)/5)*sqrt(pi)*gamma(S.One/10)/(171*gamma(-S(7)/5)) + assert beta(Rational(19, 10), Rational(29, 10)) == 100/(551*catalan(Rational(19, 10))) + assert beta(1, 0) == S.ComplexInfinity + assert beta(0, 1) == S.ComplexInfinity + assert beta(2, 3) == S.One/12 + assert unchanged(beta, x, x + 1) + assert unchanged(beta, x, 1) + assert unchanged(beta, 1, y) + assert beta(x, x + 1).doit() == 1/(x*(x+1)*catalan(x)) + assert beta(1, y).doit() == 1/y + assert beta(x, 1).doit() == 1/x + assert beta(Rational(-19, 10), Rational(-1, 10), evaluate=False).doit() == S.Zero + assert beta(2) == beta(2, 2) + assert beta(x, evaluate=False) != beta(x, x) + assert beta(x, evaluate=False).doit() == beta(x, x) + + +def test_betainc(): + a, b, x1, x2 = symbols('a b x1 x2') + + assert unchanged(betainc, a, b, x1, x2) + assert unchanged(betainc, a, b, 0, x1) + + assert betainc(1, 2, 0, -5).is_real == True + assert betainc(1, 2, 0, x2).is_real is None + assert conjugate(betainc(I, 2, 3 - I, 1 + 4*I)) == betainc(-I, 2, 3 + I, 1 - 4*I) + + assert betainc(a, b, 0, 1).rewrite(Integral).dummy_eq(beta(a, b).rewrite(Integral)) + assert betainc(1, 2, 0, x2).rewrite(hyper) == x2*hyper((1, -1), (2,), x2) + + assert betainc(1, 2, 3, 3).evalf() == 0 + + +def test_betainc_regularized(): + a, b, x1, x2 = symbols('a b x1 x2') + + assert unchanged(betainc_regularized, a, b, x1, x2) + assert unchanged(betainc_regularized, a, b, 0, x1) + + assert betainc_regularized(3, 5, 0, -1).is_real == True + assert betainc_regularized(3, 5, 0, x2).is_real is None + assert conjugate(betainc_regularized(3*I, 1, 2 + I, 1 + 2*I)) == betainc_regularized(-3*I, 1, 2 - I, 1 - 2*I) + + assert betainc_regularized(a, b, 0, 1).rewrite(Integral) == 1 + assert betainc_regularized(1, 2, x1, x2).rewrite(hyper) == 2*x2*hyper((1, -1), (2,), x2) - 2*x1*hyper((1, -1), (2,), x1) + + assert betainc_regularized(4, 1, 5, 5).evalf() == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_bsplines.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_bsplines.py new file mode 100644 index 0000000000000000000000000000000000000000..136831b96ba16c95edba12ecd47b6f1566b68427 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_bsplines.py @@ -0,0 +1,167 @@ +from sympy.functions import bspline_basis_set, interpolating_spline +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import And +from sympy.sets.sets import Interval +from sympy.testing.pytest import slow + +x, y = symbols('x,y') + + +def test_basic_degree_0(): + d = 0 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + for i in range(len(splines)): + assert splines[i] == Piecewise((1, Interval(i, i + 1).contains(x)), + (0, True)) + + +def test_basic_degree_1(): + d = 1 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + assert splines[0] == Piecewise((x, Interval(0, 1).contains(x)), + (2 - x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[1] == Piecewise((-1 + x, Interval(1, 2).contains(x)), + (3 - x, Interval(2, 3).contains(x)), + (0, True)) + assert splines[2] == Piecewise((-2 + x, Interval(2, 3).contains(x)), + (4 - x, Interval(3, 4).contains(x)), + (0, True)) + + +def test_basic_degree_2(): + d = 2 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + b0 = Piecewise((x**2/2, Interval(0, 1).contains(x)), + (Rational(-3, 2) + 3*x - x**2, Interval(1, 2).contains(x)), + (Rational(9, 2) - 3*x + x**2/2, Interval(2, 3).contains(x)), + (0, True)) + b1 = Piecewise((S.Half - x + x**2/2, Interval(1, 2).contains(x)), + (Rational(-11, 2) + 5*x - x**2, Interval(2, 3).contains(x)), + (8 - 4*x + x**2/2, Interval(3, 4).contains(x)), + (0, True)) + assert splines[0] == b0 + assert splines[1] == b1 + + +def test_basic_degree_3(): + d = 3 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + b0 = Piecewise( + (x**3/6, Interval(0, 1).contains(x)), + (Rational(2, 3) - 2*x + 2*x**2 - x**3/2, Interval(1, 2).contains(x)), + (Rational(-22, 3) + 10*x - 4*x**2 + x**3/2, Interval(2, 3).contains(x)), + (Rational(32, 3) - 8*x + 2*x**2 - x**3/6, Interval(3, 4).contains(x)), + (0, True) + ) + assert splines[0] == b0 + + +def test_repeated_degree_1(): + d = 1 + knots = [0, 0, 1, 2, 2, 3, 4, 4] + splines = bspline_basis_set(d, knots, x) + assert splines[0] == Piecewise((1 - x, Interval(0, 1).contains(x)), + (0, True)) + assert splines[1] == Piecewise((x, Interval(0, 1).contains(x)), + (2 - x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[2] == Piecewise((-1 + x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[3] == Piecewise((3 - x, Interval(2, 3).contains(x)), + (0, True)) + assert splines[4] == Piecewise((-2 + x, Interval(2, 3).contains(x)), + (4 - x, Interval(3, 4).contains(x)), + (0, True)) + assert splines[5] == Piecewise((-3 + x, Interval(3, 4).contains(x)), + (0, True)) + + +def test_repeated_degree_2(): + d = 2 + knots = [0, 0, 1, 2, 2, 3, 4, 4] + splines = bspline_basis_set(d, knots, x) + + assert splines[0] == Piecewise(((-3*x**2/2 + 2*x), And(x <= 1, x >= 0)), + (x**2/2 - 2*x + 2, And(x <= 2, x >= 1)), + (0, True)) + assert splines[1] == Piecewise((x**2/2, And(x <= 1, x >= 0)), + (-3*x**2/2 + 4*x - 2, And(x <= 2, x >= 1)), + (0, True)) + assert splines[2] == Piecewise((x**2 - 2*x + 1, And(x <= 2, x >= 1)), + (x**2 - 6*x + 9, And(x <= 3, x >= 2)), + (0, True)) + assert splines[3] == Piecewise((-3*x**2/2 + 8*x - 10, And(x <= 3, x >= 2)), + (x**2/2 - 4*x + 8, And(x <= 4, x >= 3)), + (0, True)) + assert splines[4] == Piecewise((x**2/2 - 2*x + 2, And(x <= 3, x >= 2)), + (-3*x**2/2 + 10*x - 16, And(x <= 4, x >= 3)), + (0, True)) + +# Tests for interpolating_spline + + +def test_10_points_degree_1(): + d = 1 + X = [-5, 2, 3, 4, 7, 9, 10, 30, 31, 34] + Y = [-10, -2, 2, 4, 7, 6, 20, 45, 19, 25] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((x*Rational(8, 7) - Rational(30, 7), (x >= -5) & (x <= 2)), (4*x - 10, (x >= 2) & (x <= 3)), + (2*x - 4, (x >= 3) & (x <= 4)), (x, (x >= 4) & (x <= 7)), + (-x/2 + Rational(21, 2), (x >= 7) & (x <= 9)), (14*x - 120, (x >= 9) & (x <= 10)), + (x*Rational(5, 4) + Rational(15, 2), (x >= 10) & (x <= 30)), (-26*x + 825, (x >= 30) & (x <= 31)), + (2*x - 43, (x >= 31) & (x <= 34))) + + +def test_3_points_degree_2(): + d = 2 + X = [-3, 10, 19] + Y = [3, -4, 30] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((505*x**2/2574 - x*Rational(4921, 2574) - Rational(1931, 429), (x >= -3) & (x <= 19))) + + +def test_5_points_degree_2(): + d = 2 + X = [-3, 2, 4, 5, 10] + Y = [-1, 2, 5, 10, 14] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((4*x**2/329 + x*Rational(1007, 1645) + Rational(1196, 1645), (x >= -3) & (x <= 3)), + (2701*x**2/1645 - x*Rational(15079, 1645) + Rational(5065, 329), (x >= 3) & (x <= Rational(9, 2))), + (-1319*x**2/1645 + x*Rational(21101, 1645) - Rational(11216, 329), (x >= Rational(9, 2)) & (x <= 10))) + + +@slow +def test_6_points_degree_3(): + d = 3 + X = [-1, 0, 2, 3, 9, 12] + Y = [-4, 3, 3, 7, 9, 20] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((6058*x**3/5301 - 18427*x**2/5301 + x*Rational(12622, 5301) + 3, (x >= -1) & (x <= 2)), + (-8327*x**3/5301 + 67883*x**2/5301 - x*Rational(159998, 5301) + Rational(43661, 1767), (x >= 2) & (x <= 3)), + (5414*x**3/47709 - 1386*x**2/589 + x*Rational(4267, 279) - Rational(12232, 589), (x >= 3) & (x <= 12))) + + +def test_issue_19262(): + Delta = symbols('Delta', positive=True) + knots = [i*Delta for i in range(4)] + basis = bspline_basis_set(1, knots, x) + y = symbols('y', nonnegative=True) + basis2 = bspline_basis_set(1, knots, y) + assert basis[0].subs(x, y) == basis2[0] + assert interpolating_spline(1, x, + [Delta*i for i in [1, 2, 4, 7]], [3, 6, 5, 7] + ) == Piecewise((3*x/Delta, (Delta <= x) & (x <= 2*Delta)), + (7 - x/(2*Delta), (x >= 2*Delta) & (x <= 4*Delta)), + (Rational(7, 3) + 2*x/(3*Delta), (x >= 4*Delta) & (x <= 7*Delta))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_delta_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_delta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a39d9e352143cf878cf69fa42f454f58be65c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_delta_functions.py @@ -0,0 +1,165 @@ +from sympy.core.numbers import (I, nan, oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (adjoint, conjugate, sign, transpose) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.simplify.simplify import signsimp + + +from sympy.testing.pytest import raises + +from sympy.core.expr import unchanged + +from sympy.core.function import ArgumentIndexError + + +x, y = symbols('x y') +i = symbols('t', nonzero=True) +j = symbols('j', positive=True) +k = symbols('k', negative=True) + +def test_DiracDelta(): + assert DiracDelta(1) == 0 + assert DiracDelta(5.1) == 0 + assert DiracDelta(-pi) == 0 + assert DiracDelta(5, 7) == 0 + assert DiracDelta(x, 0) == DiracDelta(x) + assert DiracDelta(i) == 0 + assert DiracDelta(j) == 0 + assert DiracDelta(k) == 0 + assert DiracDelta(nan) is nan + assert DiracDelta(0).func is DiracDelta + assert DiracDelta(x).func is DiracDelta + # FIXME: this is generally undefined @ x=0 + # But then limit(Delta(c)*Heaviside(x),x,-oo) + # need's to be implemented. + # assert 0*DiracDelta(x) == 0 + + assert adjoint(DiracDelta(x)) == DiracDelta(x) + assert adjoint(DiracDelta(x - y)) == DiracDelta(x - y) + assert conjugate(DiracDelta(x)) == DiracDelta(x) + assert conjugate(DiracDelta(x - y)) == DiracDelta(x - y) + assert transpose(DiracDelta(x)) == DiracDelta(x) + assert transpose(DiracDelta(x - y)) == DiracDelta(x - y) + + assert DiracDelta(x).diff(x) == DiracDelta(x, 1) + assert DiracDelta(x, 1).diff(x) == DiracDelta(x, 2) + + assert DiracDelta(x).is_simple(x) is True + assert DiracDelta(3*x).is_simple(x) is True + assert DiracDelta(x**2).is_simple(x) is False + assert DiracDelta(sqrt(x)).is_simple(x) is False + assert DiracDelta(x).is_simple(y) is False + + assert DiracDelta(x*y).expand(diracdelta=True, wrt=x) == DiracDelta(x)/abs(y) + assert DiracDelta(x*y).expand(diracdelta=True, wrt=y) == DiracDelta(y)/abs(x) + assert DiracDelta(x**2*y).expand(diracdelta=True, wrt=x) == DiracDelta(x**2*y) + assert DiracDelta(y).expand(diracdelta=True, wrt=x) == DiracDelta(y) + assert DiracDelta((x - 1)*(x - 2)*(x - 3)).expand(diracdelta=True, wrt=x) == ( + DiracDelta(x - 3)/2 + DiracDelta(x - 2) + DiracDelta(x - 1)/2) + + assert DiracDelta(2*x) != DiracDelta(x) # scaling property + assert DiracDelta(x) == DiracDelta(-x) # even function + assert DiracDelta(-x, 2) == DiracDelta(x, 2) + assert DiracDelta(-x, 1) == -DiracDelta(x, 1) # odd deriv is odd + assert DiracDelta(-oo*x) == DiracDelta(oo*x) + assert DiracDelta(x - y) != DiracDelta(y - x) + assert signsimp(DiracDelta(x - y) - DiracDelta(y - x)) == 0 + + assert DiracDelta(x*y).expand(diracdelta=True, wrt=x) == DiracDelta(x)/abs(y) + assert DiracDelta(x*y).expand(diracdelta=True, wrt=y) == DiracDelta(y)/abs(x) + assert DiracDelta(x**2*y).expand(diracdelta=True, wrt=x) == DiracDelta(x**2*y) + assert DiracDelta(y).expand(diracdelta=True, wrt=x) == DiracDelta(y) + assert DiracDelta((x - 1)*(x - 2)*(x - 3)).expand(diracdelta=True) == ( + DiracDelta(x - 3)/2 + DiracDelta(x - 2) + DiracDelta(x - 1)/2) + + raises(ArgumentIndexError, lambda: DiracDelta(x).fdiff(2)) + raises(ValueError, lambda: DiracDelta(x, -1)) + raises(ValueError, lambda: DiracDelta(I)) + raises(ValueError, lambda: DiracDelta(2 + 3*I)) + + +def test_heaviside(): + assert Heaviside(-5) == 0 + assert Heaviside(1) == 1 + assert Heaviside(0) == S.Half + + assert Heaviside(0, x) == x + assert unchanged(Heaviside,x, nan) + assert Heaviside(0, nan) == nan + + h0 = Heaviside(x, 0) + h12 = Heaviside(x, S.Half) + h1 = Heaviside(x, 1) + + assert h0.args == h0.pargs == (x, 0) + assert h1.args == h1.pargs == (x, 1) + assert h12.args == (x, S.Half) + assert h12.pargs == (x,) # default 1/2 suppressed + + assert adjoint(Heaviside(x)) == Heaviside(x) + assert adjoint(Heaviside(x - y)) == Heaviside(x - y) + assert conjugate(Heaviside(x)) == Heaviside(x) + assert conjugate(Heaviside(x - y)) == Heaviside(x - y) + assert transpose(Heaviside(x)) == Heaviside(x) + assert transpose(Heaviside(x - y)) == Heaviside(x - y) + + assert Heaviside(x).diff(x) == DiracDelta(x) + assert Heaviside(x + I).is_Function is True + assert Heaviside(I*x).is_Function is True + + raises(ArgumentIndexError, lambda: Heaviside(x).fdiff(2)) + raises(ValueError, lambda: Heaviside(I)) + raises(ValueError, lambda: Heaviside(2 + 3*I)) + + +def test_rewrite(): + x, y = Symbol('x', real=True), Symbol('y') + assert Heaviside(x).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (Heaviside(0), Eq(x, 0)), (1, True))) + assert Heaviside(y).rewrite(Piecewise) == ( + Piecewise((0, y < 0), (Heaviside(0), Eq(y, 0)), (1, True))) + assert Heaviside(x, y).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (y, Eq(x, 0)), (1, True))) + assert Heaviside(x, 0).rewrite(Piecewise) == ( + Piecewise((0, x <= 0), (1, True))) + assert Heaviside(x, 1).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (1, True))) + assert Heaviside(x, nan).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (nan, Eq(x, 0)), (1, True))) + + assert Heaviside(x).rewrite(sign) == \ + Heaviside(x, H0=Heaviside(0)).rewrite(sign) == \ + Piecewise( + (sign(x)/2 + S(1)/2, Eq(Heaviside(0), S(1)/2)), + (Piecewise( + (sign(x)/2 + S(1)/2, Ne(x, 0)), (Heaviside(0), True)), True) + ) + + assert Heaviside(y).rewrite(sign) == Heaviside(y) + assert Heaviside(x, S.Half).rewrite(sign) == (sign(x)+1)/2 + assert Heaviside(x, y).rewrite(sign) == \ + Piecewise( + (sign(x)/2 + S(1)/2, Eq(y, S(1)/2)), + (Piecewise( + (sign(x)/2 + S(1)/2, Ne(x, 0)), (y, True)), True) + ) + + assert DiracDelta(y).rewrite(Piecewise) == Piecewise((DiracDelta(0), Eq(y, 0)), (0, True)) + assert DiracDelta(y, 1).rewrite(Piecewise) == DiracDelta(y, 1) + assert DiracDelta(x - 5).rewrite(Piecewise) == ( + Piecewise((DiracDelta(0), Eq(x - 5, 0)), (0, True))) + + assert (x*DiracDelta(x - 10)).rewrite(SingularityFunction) == x*SingularityFunction(x, 10, -1) + assert 5*x*y*DiracDelta(y, 1).rewrite(SingularityFunction) == 5*x*y*SingularityFunction(y, 0, -2) + assert DiracDelta(0).rewrite(SingularityFunction) == SingularityFunction(0, 0, -1) + assert DiracDelta(0, 1).rewrite(SingularityFunction) == SingularityFunction(0, 0, -2) + + assert Heaviside(x).rewrite(SingularityFunction) == SingularityFunction(x, 0, 0) + assert 5*x*y*Heaviside(y + 1).rewrite(SingularityFunction) == 5*x*y*SingularityFunction(y, -1, 0) + assert ((x - 3)**3*Heaviside(x - 3)).rewrite(SingularityFunction) == (x - 3)**3*SingularityFunction(x, 3, 0) + assert Heaviside(0).rewrite(SingularityFunction) == S.Half diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_hyper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_hyper.py new file mode 100644 index 0000000000000000000000000000000000000000..f1be5b5f0db158ff76173e180ed8d88bd59461b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_hyper.py @@ -0,0 +1,403 @@ +from sympy.core.containers import Tuple +from sympy.core.function import Derivative +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import (appellf1, hyper, meijerg) +from sympy.series.order import O +from sympy.abc import x, z, k +from sympy.series.limits import limit +from sympy.testing.pytest import raises, slow +from sympy.core.random import ( + random_complex_number as randcplx, + verify_numerically as tn, + test_derivative_numerically as td) + + +def test_TupleParametersBase(): + # test that our implementation of the chain rule works + p = hyper((), (), z**2) + assert p.diff(z) == p*2*z + + +def test_hyper(): + raises(TypeError, lambda: hyper(1, 2, z)) + + assert hyper((2, 1), (1,), z) == hyper(Tuple(1, 2), Tuple(1), z) + assert hyper((2, 1, 2), (1, 2, 1, 3), z) == hyper((2,), (1, 3), z) + u = hyper((2, 1, 2), (1, 2, 1, 3), z, evaluate=False) + assert u.ap == Tuple(1, 2, 2) + assert u.bq == Tuple(1, 1, 2, 3) + + h = hyper((1, 2), (3, 4, 5), z) + assert h.ap == Tuple(1, 2) + assert h.bq == Tuple(3, 4, 5) + assert h.argument == z + assert h.is_commutative is True + h = hyper((2, 1), (4, 3, 5), z) + assert h.ap == Tuple(1, 2) + assert h.bq == Tuple(3, 4, 5) + assert h.argument == z + assert h.is_commutative is True + + # just a few checks to make sure that all arguments go where they should + assert tn(hyper(Tuple(), Tuple(), z), exp(z), z) + assert tn(z*hyper((1, 1), Tuple(2), -z), log(1 + z), z) + + # differentiation + h = hyper( + (randcplx(), randcplx(), randcplx()), (randcplx(), randcplx()), z) + assert td(h, z) + + a1, a2, b1, b2, b3 = symbols('a1:3, b1:4') + assert hyper((a1, a2), (b1, b2, b3), z).diff(z) == \ + a1*a2/(b1*b2*b3) * hyper((a1 + 1, a2 + 1), (b1 + 1, b2 + 1, b3 + 1), z) + + # differentiation wrt parameters is not supported + assert hyper([z], [], z).diff(z) == Derivative(hyper([z], [], z), z) + + # hyper is unbranched wrt parameters + from sympy.functions.elementary.complexes import polar_lift + assert hyper([polar_lift(z)], [polar_lift(k)], polar_lift(x)) == \ + hyper([z], [k], polar_lift(x)) + + # hyper does not automatically evaluate anyway, but the test is to make + # sure that the evaluate keyword is accepted + assert hyper((1, 2), (1,), z, evaluate=False).func is hyper + + +def test_expand_func(): + # evaluation at 1 of Gauss' hypergeometric function: + from sympy.abc import a, b, c + from sympy.core.function import expand_func + a1, b1, c1 = randcplx(), randcplx(), randcplx() + 5 + assert expand_func(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(-a - b + c)/(gamma(-a + c)*gamma(-b + c)) + assert abs(expand_func(hyper([a1, b1], [c1], 1)).n() + - hyper([a1, b1], [c1], 1).n()) < 1e-10 + + # hyperexpand wrapper for hyper: + assert expand_func(hyper([], [], z)) == exp(z) + assert expand_func(hyper([1, 2, 3], [], z)) == hyper([1, 2, 3], [], z) + assert expand_func(meijerg([[1, 1], []], [[1], [0]], z)) == log(z + 1) + assert expand_func(meijerg([[1, 1], []], [[], []], z)) == \ + meijerg([[1, 1], []], [[], []], z) + + +def replace_dummy(expr, sym): + from sympy.core.symbol import Dummy + dum = expr.atoms(Dummy) + if not dum: + return expr + assert len(dum) == 1 + return expr.xreplace({dum.pop(): sym}) + + +def test_hyper_rewrite_sum(): + from sympy.concrete.summations import Sum + from sympy.core.symbol import Dummy + from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) + _k = Dummy("k") + assert replace_dummy(hyper((1, 2), (1, 3), x).rewrite(Sum), _k) == \ + Sum(x**_k / factorial(_k) * RisingFactorial(2, _k) / + RisingFactorial(3, _k), (_k, 0, oo)) + + assert hyper((1, 2, 3), (-1, 3), z).rewrite(Sum) == \ + hyper((1, 2, 3), (-1, 3), z) + + +def test_radius_of_convergence(): + assert hyper((1, 2), [3], z).radius_of_convergence == 1 + assert hyper((1, 2), [3, 4], z).radius_of_convergence is oo + assert hyper((1, 2, 3), [4], z).radius_of_convergence == 0 + assert hyper((0, 1, 2), [4], z).radius_of_convergence is oo + assert hyper((-1, 1, 2), [-4], z).radius_of_convergence == 0 + assert hyper((-1, -2, 2), [-1], z).radius_of_convergence is oo + assert hyper((-1, 2), [-1, -2], z).radius_of_convergence == 0 + assert hyper([-1, 1, 3], [-2, 2], z).radius_of_convergence == 1 + assert hyper([-1, 1], [-2, 2], z).radius_of_convergence is oo + assert hyper([-1, 1, 3], [-2], z).radius_of_convergence == 0 + assert hyper((-1, 2, 3, 4), [], z).radius_of_convergence is oo + + assert hyper([1, 1], [3], 1).convergence_statement == True + assert hyper([1, 1], [2], 1).convergence_statement == False + assert hyper([1, 1], [2], -1).convergence_statement == True + assert hyper([1, 1], [1], -1).convergence_statement == False + + +def test_meijer(): + raises(TypeError, lambda: meijerg(1, z)) + raises(TypeError, lambda: meijerg(((1,), (2,)), (3,), (4,), z)) + + assert meijerg(((1, 2), (3,)), ((4,), (5,)), z) == \ + meijerg(Tuple(1, 2), Tuple(3), Tuple(4), Tuple(5), z) + + g = meijerg((1, 2), (3, 4, 5), (6, 7, 8, 9), (10, 11, 12, 13, 14), z) + assert g.an == Tuple(1, 2) + assert g.ap == Tuple(1, 2, 3, 4, 5) + assert g.aother == Tuple(3, 4, 5) + assert g.bm == Tuple(6, 7, 8, 9) + assert g.bq == Tuple(6, 7, 8, 9, 10, 11, 12, 13, 14) + assert g.bother == Tuple(10, 11, 12, 13, 14) + assert g.argument == z + assert g.nu == 75 + assert g.delta == -1 + assert g.is_commutative is True + assert g.is_number is False + #issue 13071 + assert meijerg([[],[]], [[S.Half],[0]], 1).is_number is True + + assert meijerg([1, 2], [3], [4], [5], z).delta == S.Half + + # just a few checks to make sure that all arguments go where they should + assert tn(meijerg(Tuple(), Tuple(), Tuple(0), Tuple(), -z), exp(z), z) + assert tn(sqrt(pi)*meijerg(Tuple(), Tuple(), + Tuple(0), Tuple(S.Half), z**2/4), cos(z), z) + assert tn(meijerg(Tuple(1, 1), Tuple(), Tuple(1), Tuple(0), z), + log(1 + z), z) + + # test exceptions + raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((oo,), (2, 0)), x)) + raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((1,), (2, 0)), x)) + + # differentiation + g = meijerg((randcplx(),), (randcplx() + 2*I,), Tuple(), + (randcplx(), randcplx()), z) + assert td(g, z) + + g = meijerg(Tuple(), (randcplx(),), Tuple(), + (randcplx(), randcplx()), z) + assert td(g, z) + + g = meijerg(Tuple(), Tuple(), Tuple(randcplx()), + Tuple(randcplx(), randcplx()), z) + assert td(g, z) + + a1, a2, b1, b2, c1, c2, d1, d2 = symbols('a1:3, b1:3, c1:3, d1:3') + assert meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z).diff(z) == \ + (meijerg((a1 - 1, a2), (b1, b2), (c1, c2), (d1, d2), z) + + (a1 - 1)*meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z))/z + + assert meijerg([z, z], [], [], [], z).diff(z) == \ + Derivative(meijerg([z, z], [], [], [], z), z) + + # meijerg is unbranched wrt parameters + from sympy.functions.elementary.complexes import polar_lift as pl + assert meijerg([pl(a1)], [pl(a2)], [pl(b1)], [pl(b2)], pl(z)) == \ + meijerg([a1], [a2], [b1], [b2], pl(z)) + + # integrand + from sympy.abc import a, b, c, d, s + assert meijerg([a], [b], [c], [d], z).integrand(s) == \ + z**s*gamma(c - s)*gamma(-a + s + 1)/(gamma(b - s)*gamma(-d + s + 1)) + + +def test_meijerg_derivative(): + assert meijerg([], [1, 1], [0, 0, x], [], z).diff(x) == \ + log(z)*meijerg([], [1, 1], [0, 0, x], [], z) \ + + 2*meijerg([], [1, 1, 1], [0, 0, x, 0], [], z) + + y = randcplx() + a = 5 # mpmath chokes with non-real numbers, and Mod1 with floats + assert td(meijerg([x], [], [], [], y), x) + assert td(meijerg([x**2], [], [], [], y), x) + assert td(meijerg([], [x], [], [], y), x) + assert td(meijerg([], [], [x], [], y), x) + assert td(meijerg([], [], [], [x], y), x) + assert td(meijerg([x], [a], [a + 1], [], y), x) + assert td(meijerg([x], [a + 1], [a], [], y), x) + assert td(meijerg([x, a], [], [], [a + 1], y), x) + assert td(meijerg([x, a + 1], [], [], [a], y), x) + b = Rational(3, 2) + assert td(meijerg([a + 2], [b], [b - 3, x], [a], y), x) + + +def test_meijerg_period(): + assert meijerg([], [1], [0], [], x).get_period() == 2*pi + assert meijerg([1], [], [], [0], x).get_period() == 2*pi + assert meijerg([], [], [0], [], x).get_period() == 2*pi # exp(x) + assert meijerg( + [], [], [0], [S.Half], x).get_period() == 2*pi # cos(sqrt(x)) + assert meijerg( + [], [], [S.Half], [0], x).get_period() == 4*pi # sin(sqrt(x)) + assert meijerg([1, 1], [], [1], [0], x).get_period() is oo # log(1 + x) + + +def test_hyper_unpolarify(): + from sympy.functions.elementary.exponential import exp_polar + a = exp_polar(2*pi*I)*x + b = x + assert hyper([], [], a).argument == b + assert hyper([0], [], a).argument == a + assert hyper([0], [0], a).argument == b + assert hyper([0, 1], [0], a).argument == a + assert hyper([0, 1], [0], exp_polar(2*pi*I)).argument == 1 + + +@slow +def test_hyperrep(): + from sympy.functions.special.hyper import (HyperRep, HyperRep_atanh, + HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1, + HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2, + HyperRep_cosasin, HyperRep_sinasin) + # First test the base class works. + from sympy.functions.elementary.exponential import exp_polar + from sympy.functions.elementary.piecewise import Piecewise + a, b, c, d, z = symbols('a b c d z') + + class myrep(HyperRep): + @classmethod + def _expr_small(cls, x): + return a + + @classmethod + def _expr_small_minus(cls, x): + return b + + @classmethod + def _expr_big(cls, x, n): + return c*n + + @classmethod + def _expr_big_minus(cls, x, n): + return d*n + assert myrep(z).rewrite('nonrep') == Piecewise((0, abs(z) > 1), (a, True)) + assert myrep(exp_polar(I*pi)*z).rewrite('nonrep') == \ + Piecewise((0, abs(z) > 1), (b, True)) + assert myrep(exp_polar(2*I*pi)*z).rewrite('nonrep') == \ + Piecewise((c, abs(z) > 1), (a, True)) + assert myrep(exp_polar(3*I*pi)*z).rewrite('nonrep') == \ + Piecewise((d, abs(z) > 1), (b, True)) + assert myrep(exp_polar(4*I*pi)*z).rewrite('nonrep') == \ + Piecewise((2*c, abs(z) > 1), (a, True)) + assert myrep(exp_polar(5*I*pi)*z).rewrite('nonrep') == \ + Piecewise((2*d, abs(z) > 1), (b, True)) + assert myrep(z).rewrite('nonrepsmall') == a + assert myrep(exp_polar(I*pi)*z).rewrite('nonrepsmall') == b + + def t(func, hyp, z): + """ Test that func is a valid representation of hyp. """ + # First test that func agrees with hyp for small z + if not tn(func.rewrite('nonrepsmall'), hyp, z, + a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half): + return False + # Next check that the two small representations agree. + if not tn( + func.rewrite('nonrepsmall').subs( + z, exp_polar(I*pi)*z).replace(exp_polar, exp), + func.subs(z, exp_polar(I*pi)*z).rewrite('nonrepsmall'), + z, a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half): + return False + # Next check continuity along exp_polar(I*pi)*t + expr = func.subs(z, exp_polar(I*pi)*z).rewrite('nonrep') + if abs(expr.subs(z, 1 + 1e-15).n() - expr.subs(z, 1 - 1e-15).n()) > 1e-10: + return False + # Finally check continuity of the big reps. + + def dosubs(func, a, b): + rv = func.subs(z, exp_polar(a)*z).rewrite('nonrep') + return rv.subs(z, exp_polar(b)*z).replace(exp_polar, exp) + for n in [0, 1, 2, 3, 4, -1, -2, -3, -4]: + expr1 = dosubs(func, 2*I*pi*n, I*pi/2) + expr2 = dosubs(func, 2*I*pi*n + I*pi, -I*pi/2) + if not tn(expr1, expr2, z): + return False + expr1 = dosubs(func, 2*I*pi*(n + 1), -I*pi/2) + expr2 = dosubs(func, 2*I*pi*n + I*pi, I*pi/2) + if not tn(expr1, expr2, z): + return False + return True + + # Now test the various representatives. + a = Rational(1, 3) + assert t(HyperRep_atanh(z), hyper([S.Half, 1], [Rational(3, 2)], z), z) + assert t(HyperRep_power1(a, z), hyper([-a], [], z), z) + assert t(HyperRep_power2(a, z), hyper([a, a - S.Half], [2*a], z), z) + assert t(HyperRep_log1(z), -z*hyper([1, 1], [2], z), z) + assert t(HyperRep_asin1(z), hyper([S.Half, S.Half], [Rational(3, 2)], z), z) + assert t(HyperRep_asin2(z), hyper([1, 1], [Rational(3, 2)], z), z) + assert t(HyperRep_sqrts1(a, z), hyper([-a, S.Half - a], [S.Half], z), z) + assert t(HyperRep_sqrts2(a, z), + -2*z/(2*a + 1)*hyper([-a - S.Half, -a], [S.Half], z).diff(z), z) + assert t(HyperRep_log2(z), -z/4*hyper([Rational(3, 2), 1, 1], [2, 2], z), z) + assert t(HyperRep_cosasin(a, z), hyper([-a, a], [S.Half], z), z) + assert t(HyperRep_sinasin(a, z), 2*a*z*hyper([1 - a, 1 + a], [Rational(3, 2)], z), z) + + +@slow +def test_meijerg_eval(): + from sympy.functions.elementary.exponential import exp_polar + from sympy.functions.special.bessel import besseli + from sympy.abc import l + a = randcplx() + arg = x*exp_polar(k*pi*I) + expr1 = pi*meijerg([[], [(a + 1)/2]], [[a/2], [-a/2, (a + 1)/2]], arg**2/4) + expr2 = besseli(a, arg) + + # Test that the two expressions agree for all arguments. + for x_ in [0.5, 1.5]: + for k_ in [0.0, 0.1, 0.3, 0.5, 0.8, 1, 5.751, 15.3]: + assert abs((expr1 - expr2).n(subs={x: x_, k: k_})) < 1e-10 + assert abs((expr1 - expr2).n(subs={x: x_, k: -k_})) < 1e-10 + + # Test continuity independently + eps = 1e-13 + expr2 = expr1.subs(k, l) + for x_ in [0.5, 1.5]: + for k_ in [0.5, Rational(1, 3), 0.25, 0.75, Rational(2, 3), 1.0, 1.5]: + assert abs((expr1 - expr2).n( + subs={x: x_, k: k_ + eps, l: k_ - eps})) < 1e-10 + assert abs((expr1 - expr2).n( + subs={x: x_, k: -k_ + eps, l: -k_ - eps})) < 1e-10 + + expr = (meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(-I*pi)/4) + + meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(I*pi)/4)) \ + /(2*sqrt(pi)) + assert (expr - pi/exp(1)).n(chop=True) == 0 + + +def test_limits(): + k, x = symbols('k, x') + assert hyper((1,), (Rational(4, 3), Rational(5, 3)), k**2).series(k) == \ + 1 + 9*k**2/20 + 81*k**4/1120 + O(k**6) # issue 6350 + + # https://github.com/sympy/sympy/issues/11465 + assert limit(1/hyper((1, ), (1, ), x), x, 0) == 1 + + +def test_appellf1(): + a, b1, b2, c, x, y = symbols('a b1 b2 c x y') + assert appellf1(a, b2, b1, c, y, x) == appellf1(a, b1, b2, c, x, y) + assert appellf1(a, b1, b1, c, y, x) == appellf1(a, b1, b1, c, x, y) + assert appellf1(a, b1, b2, c, S.Zero, S.Zero) is S.One + + f = appellf1(a, b1, b2, c, S.Zero, S.Zero, evaluate=False) + assert f.func is appellf1 + assert f.doit() is S.One + + +def test_derivative_appellf1(): + from sympy.core.function import diff + a, b1, b2, c, x, y, z = symbols('a b1 b2 c x y z') + assert diff(appellf1(a, b1, b2, c, x, y), x) == a*b1*appellf1(a + 1, b2, b1 + 1, c + 1, y, x)/c + assert diff(appellf1(a, b1, b2, c, x, y), y) == a*b2*appellf1(a + 1, b1, b2 + 1, c + 1, x, y)/c + assert diff(appellf1(a, b1, b2, c, x, y), z) == 0 + assert diff(appellf1(a, b1, b2, c, x, y), a) == Derivative(appellf1(a, b1, b2, c, x, y), a) + + +def test_eval_nseries(): + a1, b1, a2, b2 = symbols('a1 b1 a2 b2') + assert hyper((1,2), (1,2,3), x**2)._eval_nseries(x, 7, None) == \ + 1 + x**2/3 + x**4/24 + x**6/360 + O(x**7) + assert exp(x)._eval_nseries(x,7,None) == \ + hyper((a1, b1), (a1, b1), x)._eval_nseries(x, 7, None) + assert hyper((a1, a2), (b1, b2), x)._eval_nseries(z, 7, None) ==\ + hyper((a1, a2), (b1, b2), x) + O(z**7) + assert hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1)).nseries(x) == \ + 1 - x + x**2/4 - 3*x**3/4 - 15*x**4/64 - 93*x**5/64 + O(x**6) + assert (pi/2*hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1))).nseries(x) == \ + pi/2 - pi*x/2 + pi*x**2/8 - 3*pi*x**3/8 - 15*pi*x**4/128 - 93*pi*x**5/128 + O(x**6) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spec_polynomials.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spec_polynomials.py new file mode 100644 index 0000000000000000000000000000000000000000..584ad3cf97df8b9d92da9fc7805ab4296f40671c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spec_polynomials.py @@ -0,0 +1,475 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import (Derivative, diff) +from sympy.core.numbers import (Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.factorials import (RisingFactorial, binomial, factorial) +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, gegenbauer, hermite, hermite_prob, jacobi, jacobi_normalized, laguerre, legendre) +from sympy.polys.orthopolys import laguerre_poly +from sympy.polys.polyroots import roots + +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.testing.pytest import raises + + +x = Symbol('x') + + +def test_jacobi(): + n = Symbol("n") + a = Symbol("a") + b = Symbol("b") + + assert jacobi(0, a, b, x) == 1 + assert jacobi(1, a, b, x) == a/2 - b/2 + x*(a/2 + b/2 + 1) + + assert jacobi(n, a, a, x) == RisingFactorial( + a + 1, n)*gegenbauer(n, a + S.Half, x)/RisingFactorial(2*a + 1, n) + assert jacobi(n, a, -a, x) == ((-1)**a*(-x + 1)**(-a/2)*(x + 1)**(a/2)*assoc_legendre(n, a, x)* + factorial(-a + n)*gamma(a + n + 1)/(factorial(a + n)*gamma(n + 1))) + assert jacobi(n, -b, b, x) == ((-x + 1)**(b/2)*(x + 1)**(-b/2)*assoc_legendre(n, b, x)* + gamma(-b + n + 1)/gamma(n + 1)) + assert jacobi(n, 0, 0, x) == legendre(n, x) + assert jacobi(n, S.Half, S.Half, x) == RisingFactorial( + Rational(3, 2), n)*chebyshevu(n, x)/factorial(n + 1) + assert jacobi(n, Rational(-1, 2), Rational(-1, 2), x) == RisingFactorial( + S.Half, n)*chebyshevt(n, x)/factorial(n) + + X = jacobi(n, a, b, x) + assert isinstance(X, jacobi) + + assert jacobi(n, a, b, -x) == (-1)**n*jacobi(n, b, a, x) + assert jacobi(n, a, b, 0) == 2**(-n)*gamma(a + n + 1)*hyper( + (-b - n, -n), (a + 1,), -1)/(factorial(n)*gamma(a + 1)) + assert jacobi(n, a, b, 1) == RisingFactorial(a + 1, n)/factorial(n) + + m = Symbol("m", positive=True) + assert jacobi(m, a, b, oo) == oo*RisingFactorial(a + b + m + 1, m) + assert unchanged(jacobi, n, a, b, oo) + + assert conjugate(jacobi(m, a, b, x)) == \ + jacobi(m, conjugate(a), conjugate(b), conjugate(x)) + + _k = Dummy('k') + assert diff(jacobi(n, a, b, x), n) == Derivative(jacobi(n, a, b, x), n) + assert diff(jacobi(n, a, b, x), a).dummy_eq(Sum((jacobi(n, a, b, x) + + (2*_k + a + b + 1)*RisingFactorial(_k + b + 1, -_k + n)*jacobi(_k, a, + b, x)/((-_k + n)*RisingFactorial(_k + a + b + 1, -_k + n)))/(_k + a + + b + n + 1), (_k, 0, n - 1))) + assert diff(jacobi(n, a, b, x), b).dummy_eq(Sum(((-1)**(-_k + n)*(2*_k + + a + b + 1)*RisingFactorial(_k + a + 1, -_k + n)*jacobi(_k, a, b, x)/ + ((-_k + n)*RisingFactorial(_k + a + b + 1, -_k + n)) + jacobi(n, a, + b, x))/(_k + a + b + n + 1), (_k, 0, n - 1))) + assert diff(jacobi(n, a, b, x), x) == \ + (a/2 + b/2 + n/2 + S.Half)*jacobi(n - 1, a + 1, b + 1, x) + + assert jacobi_normalized(n, a, b, x) == \ + (jacobi(n, a, b, x)/sqrt(2**(a + b + 1)*gamma(a + n + 1)*gamma(b + n + 1) + /((a + b + 2*n + 1)*factorial(n)*gamma(a + b + n + 1)))) + + raises(ValueError, lambda: jacobi(-2.1, a, b, x)) + raises(ValueError, lambda: jacobi(Dummy(positive=True, integer=True), 1, 2, oo)) + + assert jacobi(n, a, b, x).rewrite(Sum).dummy_eq(Sum((S.Half - x/2) + **_k*RisingFactorial(-n, _k)*RisingFactorial(_k + a + 1, -_k + n)* + RisingFactorial(a + b + n + 1, _k)/factorial(_k), (_k, 0, n))/factorial(n)) + assert jacobi(n, a, b, x).rewrite("polynomial").dummy_eq(Sum((S.Half - x/2) + **_k*RisingFactorial(-n, _k)*RisingFactorial(_k + a + 1, -_k + n)* + RisingFactorial(a + b + n + 1, _k)/factorial(_k), (_k, 0, n))/factorial(n)) + raises(ArgumentIndexError, lambda: jacobi(n, a, b, x).fdiff(5)) + + +def test_gegenbauer(): + n = Symbol("n") + a = Symbol("a") + + assert gegenbauer(0, a, x) == 1 + assert gegenbauer(1, a, x) == 2*a*x + assert gegenbauer(2, a, x) == -a + x**2*(2*a**2 + 2*a) + assert gegenbauer(3, a, x) == \ + x**3*(4*a**3/3 + 4*a**2 + a*Rational(8, 3)) + x*(-2*a**2 - 2*a) + + assert gegenbauer(-1, a, x) == 0 + assert gegenbauer(n, S.Half, x) == legendre(n, x) + assert gegenbauer(n, 1, x) == chebyshevu(n, x) + assert gegenbauer(n, -1, x) == 0 + + X = gegenbauer(n, a, x) + assert isinstance(X, gegenbauer) + + assert gegenbauer(n, a, -x) == (-1)**n*gegenbauer(n, a, x) + assert gegenbauer(n, a, 0) == 2**n*sqrt(pi) * \ + gamma(a + n/2)/(gamma(a)*gamma(-n/2 + S.Half)*gamma(n + 1)) + assert gegenbauer(n, a, 1) == gamma(2*a + n)/(gamma(2*a)*gamma(n + 1)) + + assert gegenbauer(n, Rational(3, 4), -1) is zoo + assert gegenbauer(n, Rational(1, 4), -1) == (sqrt(2)*cos(pi*(n + S.One/4))* + gamma(n + S.Half)/(sqrt(pi)*gamma(n + 1))) + + m = Symbol("m", positive=True) + assert gegenbauer(m, a, oo) == oo*RisingFactorial(a, m) + assert unchanged(gegenbauer, n, a, oo) + + assert conjugate(gegenbauer(n, a, x)) == gegenbauer(n, conjugate(a), conjugate(x)) + + _k = Dummy('k') + + assert diff(gegenbauer(n, a, x), n) == Derivative(gegenbauer(n, a, x), n) + assert diff(gegenbauer(n, a, x), a).dummy_eq(Sum((2*(-1)**(-_k + n) + 2)* + (_k + a)*gegenbauer(_k, a, x)/((-_k + n)*(_k + 2*a + n)) + ((2*_k + + 2)/((_k + 2*a)*(2*_k + 2*a + 1)) + 2/(_k + 2*a + n))*gegenbauer(n, a + , x), (_k, 0, n - 1))) + assert diff(gegenbauer(n, a, x), x) == 2*a*gegenbauer(n - 1, a + 1, x) + + assert gegenbauer(n, a, x).rewrite(Sum).dummy_eq( + Sum((-1)**_k*(2*x)**(-2*_k + n)*RisingFactorial(a, -_k + n) + /(factorial(_k)*factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + assert gegenbauer(n, a, x).rewrite("polynomial").dummy_eq( + Sum((-1)**_k*(2*x)**(-2*_k + n)*RisingFactorial(a, -_k + n) + /(factorial(_k)*factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + + raises(ArgumentIndexError, lambda: gegenbauer(n, a, x).fdiff(4)) + + +def test_legendre(): + assert legendre(0, x) == 1 + assert legendre(1, x) == x + assert legendre(2, x) == ((3*x**2 - 1)/2).expand() + assert legendre(3, x) == ((5*x**3 - 3*x)/2).expand() + assert legendre(4, x) == ((35*x**4 - 30*x**2 + 3)/8).expand() + assert legendre(5, x) == ((63*x**5 - 70*x**3 + 15*x)/8).expand() + assert legendre(6, x) == ((231*x**6 - 315*x**4 + 105*x**2 - 5)/16).expand() + + assert legendre(10, -1) == 1 + assert legendre(11, -1) == -1 + assert legendre(10, 1) == 1 + assert legendre(11, 1) == 1 + assert legendre(10, 0) != 0 + assert legendre(11, 0) == 0 + + assert legendre(-1, x) == 1 + k = Symbol('k') + assert legendre(5 - k, x).subs(k, 2) == ((5*x**3 - 3*x)/2).expand() + + assert roots(legendre(4, x), x) == { + sqrt(Rational(3, 7) - Rational(2, 35)*sqrt(30)): 1, + -sqrt(Rational(3, 7) - Rational(2, 35)*sqrt(30)): 1, + sqrt(Rational(3, 7) + Rational(2, 35)*sqrt(30)): 1, + -sqrt(Rational(3, 7) + Rational(2, 35)*sqrt(30)): 1, + } + + n = Symbol("n") + + X = legendre(n, x) + assert isinstance(X, legendre) + assert unchanged(legendre, n, x) + + assert legendre(n, 0) == sqrt(pi)/(gamma(S.Half - n/2)*gamma(n/2 + 1)) + assert legendre(n, 1) == 1 + assert legendre(n, oo) is oo + assert legendre(-n, x) == legendre(n - 1, x) + assert legendre(n, -x) == (-1)**n*legendre(n, x) + assert unchanged(legendre, -n + k, x) + + assert conjugate(legendre(n, x)) == legendre(n, conjugate(x)) + + assert diff(legendre(n, x), x) == \ + n*(x*legendre(n, x) - legendre(n - 1, x))/(x**2 - 1) + assert diff(legendre(n, x), n) == Derivative(legendre(n, x), n) + + _k = Dummy('k') + assert legendre(n, x).rewrite(Sum).dummy_eq(Sum((-1)**_k*(S.Half - + x/2)**_k*(x/2 + S.Half)**(-_k + n)*binomial(n, _k)**2, (_k, 0, n))) + assert legendre(n, x).rewrite("polynomial").dummy_eq(Sum((-1)**_k*(S.Half - + x/2)**_k*(x/2 + S.Half)**(-_k + n)*binomial(n, _k)**2, (_k, 0, n))) + raises(ArgumentIndexError, lambda: legendre(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: legendre(n, x).fdiff(3)) + + +def test_assoc_legendre(): + Plm = assoc_legendre + Q = sqrt(1 - x**2) + + assert Plm(0, 0, x) == 1 + assert Plm(1, 0, x) == x + assert Plm(1, 1, x) == -Q + assert Plm(2, 0, x) == (3*x**2 - 1)/2 + assert Plm(2, 1, x) == -3*x*Q + assert Plm(2, 2, x) == 3*Q**2 + assert Plm(3, 0, x) == (5*x**3 - 3*x)/2 + assert Plm(3, 1, x).expand() == (( 3*(1 - 5*x**2)/2 ).expand() * Q).expand() + assert Plm(3, 2, x) == 15*x * Q**2 + assert Plm(3, 3, x) == -15 * Q**3 + + # negative m + assert Plm(1, -1, x) == -Plm(1, 1, x)/2 + assert Plm(2, -2, x) == Plm(2, 2, x)/24 + assert Plm(2, -1, x) == -Plm(2, 1, x)/6 + assert Plm(3, -3, x) == -Plm(3, 3, x)/720 + assert Plm(3, -2, x) == Plm(3, 2, x)/120 + assert Plm(3, -1, x) == -Plm(3, 1, x)/12 + + n = Symbol("n") + m = Symbol("m") + X = Plm(n, m, x) + assert isinstance(X, assoc_legendre) + + assert Plm(n, 0, x) == legendre(n, x) + assert Plm(n, m, 0) == 2**m*sqrt(pi)/(gamma(-m/2 - n/2 + + S.Half)*gamma(-m/2 + n/2 + 1)) + + assert diff(Plm(m, n, x), x) == (m*x*assoc_legendre(m, n, x) - + (m + n)*assoc_legendre(m - 1, n, x))/(x**2 - 1) + + _k = Dummy('k') + assert Plm(m, n, x).rewrite(Sum).dummy_eq( + (1 - x**2)**(n/2)*Sum((-1)**_k*2**(-m)*x**(-2*_k + m - n)*factorial + (-2*_k + 2*m)/(factorial(_k)*factorial(-_k + m)*factorial(-2*_k + m + - n)), (_k, 0, floor(m/2 - n/2)))) + assert Plm(m, n, x).rewrite("polynomial").dummy_eq( + (1 - x**2)**(n/2)*Sum((-1)**_k*2**(-m)*x**(-2*_k + m - n)*factorial + (-2*_k + 2*m)/(factorial(_k)*factorial(-_k + m)*factorial(-2*_k + m + - n)), (_k, 0, floor(m/2 - n/2)))) + assert conjugate(assoc_legendre(n, m, x)) == \ + assoc_legendre(n, conjugate(m), conjugate(x)) + raises(ValueError, lambda: Plm(0, 1, x)) + raises(ValueError, lambda: Plm(-1, 1, x)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(1)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(2)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(4)) + + +def test_chebyshev(): + assert chebyshevt(0, x) == 1 + assert chebyshevt(1, x) == x + assert chebyshevt(2, x) == 2*x**2 - 1 + assert chebyshevt(3, x) == 4*x**3 - 3*x + + for n in range(1, 4): + for k in range(n): + z = chebyshevt_root(n, k) + assert chebyshevt(n, z) == 0 + raises(ValueError, lambda: chebyshevt_root(n, n)) + + for n in range(1, 4): + for k in range(n): + z = chebyshevu_root(n, k) + assert chebyshevu(n, z) == 0 + raises(ValueError, lambda: chebyshevu_root(n, n)) + + n = Symbol("n") + X = chebyshevt(n, x) + assert isinstance(X, chebyshevt) + assert unchanged(chebyshevt, n, x) + assert chebyshevt(n, -x) == (-1)**n*chebyshevt(n, x) + assert chebyshevt(-n, x) == chebyshevt(n, x) + + assert chebyshevt(n, 0) == cos(pi*n/2) + assert chebyshevt(n, 1) == 1 + assert chebyshevt(n, oo) is oo + + assert conjugate(chebyshevt(n, x)) == chebyshevt(n, conjugate(x)) + + assert diff(chebyshevt(n, x), x) == n*chebyshevu(n - 1, x) + + X = chebyshevu(n, x) + assert isinstance(X, chebyshevu) + + y = Symbol('y') + assert chebyshevu(n, -x) == (-1)**n*chebyshevu(n, x) + assert chebyshevu(-n, x) == -chebyshevu(n - 2, x) + assert unchanged(chebyshevu, -n + y, x) + + assert chebyshevu(n, 0) == cos(pi*n/2) + assert chebyshevu(n, 1) == n + 1 + assert chebyshevu(n, oo) is oo + + assert conjugate(chebyshevu(n, x)) == chebyshevu(n, conjugate(x)) + + assert diff(chebyshevu(n, x), x) == \ + (-x*chebyshevu(n, x) + (n + 1)*chebyshevt(n + 1, x))/(x**2 - 1) + + _k = Dummy('k') + assert chebyshevt(n, x).rewrite(Sum).dummy_eq(Sum(x**(-2*_k + n) + *(x**2 - 1)**_k*binomial(n, 2*_k), (_k, 0, floor(n/2)))) + assert chebyshevt(n, x).rewrite("polynomial").dummy_eq(Sum(x**(-2*_k + n) + *(x**2 - 1)**_k*binomial(n, 2*_k), (_k, 0, floor(n/2)))) + assert chebyshevu(n, x).rewrite(Sum).dummy_eq(Sum((-1)**_k*(2*x) + **(-2*_k + n)*factorial(-_k + n)/(factorial(_k)* + factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + assert chebyshevu(n, x).rewrite("polynomial").dummy_eq(Sum((-1)**_k*(2*x) + **(-2*_k + n)*factorial(-_k + n)/(factorial(_k)* + factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + raises(ArgumentIndexError, lambda: chebyshevt(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: chebyshevt(n, x).fdiff(3)) + raises(ArgumentIndexError, lambda: chebyshevu(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: chebyshevu(n, x).fdiff(3)) + + +def test_hermite(): + assert hermite(0, x) == 1 + assert hermite(1, x) == 2*x + assert hermite(2, x) == 4*x**2 - 2 + assert hermite(3, x) == 8*x**3 - 12*x + assert hermite(4, x) == 16*x**4 - 48*x**2 + 12 + assert hermite(6, x) == 64*x**6 - 480*x**4 + 720*x**2 - 120 + + n = Symbol("n") + assert unchanged(hermite, n, x) + assert hermite(n, -x) == (-1)**n*hermite(n, x) + assert unchanged(hermite, -n, x) + + assert hermite(n, 0) == 2**n*sqrt(pi)/gamma(S.Half - n/2) + assert hermite(n, oo) is oo + + assert conjugate(hermite(n, x)) == hermite(n, conjugate(x)) + + _k = Dummy('k') + assert hermite(n, x).rewrite(Sum).dummy_eq(factorial(n)*Sum((-1) + **_k*(2*x)**(-2*_k + n)/(factorial(_k)*factorial(-2*_k + n)), (_k, + 0, floor(n/2)))) + assert hermite(n, x).rewrite("polynomial").dummy_eq(factorial(n)*Sum((-1) + **_k*(2*x)**(-2*_k + n)/(factorial(_k)*factorial(-2*_k + n)), (_k, + 0, floor(n/2)))) + + assert diff(hermite(n, x), x) == 2*n*hermite(n - 1, x) + assert diff(hermite(n, x), n) == Derivative(hermite(n, x), n) + raises(ArgumentIndexError, lambda: hermite(n, x).fdiff(3)) + + assert hermite(n, x).rewrite(hermite_prob) == \ + sqrt(2)**n * hermite_prob(n, x*sqrt(2)) + + +def test_hermite_prob(): + assert hermite_prob(0, x) == 1 + assert hermite_prob(1, x) == x + assert hermite_prob(2, x) == x**2 - 1 + assert hermite_prob(3, x) == x**3 - 3*x + assert hermite_prob(4, x) == x**4 - 6*x**2 + 3 + assert hermite_prob(6, x) == x**6 - 15*x**4 + 45*x**2 - 15 + + n = Symbol("n") + assert unchanged(hermite_prob, n, x) + assert hermite_prob(n, -x) == (-1)**n*hermite_prob(n, x) + assert unchanged(hermite_prob, -n, x) + + assert hermite_prob(n, 0) == sqrt(pi)/gamma(S.Half - n/2) + assert hermite_prob(n, oo) is oo + + assert conjugate(hermite_prob(n, x)) == hermite_prob(n, conjugate(x)) + + _k = Dummy('k') + assert hermite_prob(n, x).rewrite(Sum).dummy_eq(factorial(n) * + Sum((-S.Half)**_k * x**(n-2*_k) / (factorial(_k) * factorial(n-2*_k)), + (_k, 0, floor(n/2)))) + assert hermite_prob(n, x).rewrite("polynomial").dummy_eq(factorial(n) * + Sum((-S.Half)**_k * x**(n-2*_k) / (factorial(_k) * factorial(n-2*_k)), + (_k, 0, floor(n/2)))) + + assert diff(hermite_prob(n, x), x) == n*hermite_prob(n-1, x) + assert diff(hermite_prob(n, x), n) == Derivative(hermite_prob(n, x), n) + raises(ArgumentIndexError, lambda: hermite_prob(n, x).fdiff(3)) + + assert hermite_prob(n, x).rewrite(hermite) == \ + sqrt(2)**(-n) * hermite(n, x/sqrt(2)) + + +def test_laguerre(): + n = Symbol("n") + m = Symbol("m", negative=True) + + # Laguerre polynomials: + assert laguerre(0, x) == 1 + assert laguerre(1, x) == -x + 1 + assert laguerre(2, x) == x**2/2 - 2*x + 1 + assert laguerre(3, x) == -x**3/6 + 3*x**2/2 - 3*x + 1 + assert laguerre(-2, x) == (x + 1)*exp(x) + + X = laguerre(n, x) + assert isinstance(X, laguerre) + + assert laguerre(n, 0) == 1 + assert laguerre(n, oo) == (-1)**n*oo + assert laguerre(n, -oo) is oo + + assert conjugate(laguerre(n, x)) == laguerre(n, conjugate(x)) + + _k = Dummy('k') + + assert laguerre(n, x).rewrite(Sum).dummy_eq( + Sum(x**_k*RisingFactorial(-n, _k)/factorial(_k)**2, (_k, 0, n))) + assert laguerre(n, x).rewrite("polynomial").dummy_eq( + Sum(x**_k*RisingFactorial(-n, _k)/factorial(_k)**2, (_k, 0, n))) + assert laguerre(m, x).rewrite(Sum).dummy_eq( + exp(x)*Sum((-x)**_k*RisingFactorial(m + 1, _k)/factorial(_k)**2, + (_k, 0, -m - 1))) + assert laguerre(m, x).rewrite("polynomial").dummy_eq( + exp(x)*Sum((-x)**_k*RisingFactorial(m + 1, _k)/factorial(_k)**2, + (_k, 0, -m - 1))) + + assert diff(laguerre(n, x), x) == -assoc_laguerre(n - 1, 1, x) + + k = Symbol('k') + assert laguerre(-n, x) == exp(x)*laguerre(n - 1, -x) + assert laguerre(-3, x) == exp(x)*laguerre(2, -x) + assert unchanged(laguerre, -n + k, x) + + raises(ValueError, lambda: laguerre(-2.1, x)) + raises(ValueError, lambda: laguerre(Rational(5, 2), x)) + raises(ArgumentIndexError, lambda: laguerre(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: laguerre(n, x).fdiff(3)) + + +def test_assoc_laguerre(): + n = Symbol("n") + m = Symbol("m") + alpha = Symbol("alpha") + + # generalized Laguerre polynomials: + assert assoc_laguerre(0, alpha, x) == 1 + assert assoc_laguerre(1, alpha, x) == -x + alpha + 1 + assert assoc_laguerre(2, alpha, x).expand() == \ + (x**2/2 - (alpha + 2)*x + (alpha + 2)*(alpha + 1)/2).expand() + assert assoc_laguerre(3, alpha, x).expand() == \ + (-x**3/6 + (alpha + 3)*x**2/2 - (alpha + 2)*(alpha + 3)*x/2 + + (alpha + 1)*(alpha + 2)*(alpha + 3)/6).expand() + + # Test the lowest 10 polynomials with laguerre_poly, to make sure it works: + for i in range(10): + assert assoc_laguerre(i, 0, x).expand() == laguerre_poly(i, x) + + X = assoc_laguerre(n, m, x) + assert isinstance(X, assoc_laguerre) + + assert assoc_laguerre(n, 0, x) == laguerre(n, x) + assert assoc_laguerre(n, alpha, 0) == binomial(alpha + n, alpha) + p = Symbol("p", positive=True) + assert assoc_laguerre(p, alpha, oo) == (-1)**p*oo + assert assoc_laguerre(p, alpha, -oo) is oo + + assert diff(assoc_laguerre(n, alpha, x), x) == \ + -assoc_laguerre(n - 1, alpha + 1, x) + _k = Dummy('k') + assert diff(assoc_laguerre(n, alpha, x), alpha).dummy_eq( + Sum(assoc_laguerre(_k, alpha, x)/(-alpha + n), (_k, 0, n - 1))) + + assert conjugate(assoc_laguerre(n, alpha, x)) == \ + assoc_laguerre(n, conjugate(alpha), conjugate(x)) + + assert assoc_laguerre(n, alpha, x).rewrite(Sum).dummy_eq( + gamma(alpha + n + 1)*Sum(x**_k*RisingFactorial(-n, _k)/ + (factorial(_k)*gamma(_k + alpha + 1)), (_k, 0, n))/factorial(n)) + assert assoc_laguerre(n, alpha, x).rewrite("polynomial").dummy_eq( + gamma(alpha + n + 1)*Sum(x**_k*RisingFactorial(-n, _k)/ + (factorial(_k)*gamma(_k + alpha + 1)), (_k, 0, n))/factorial(n)) + raises(ValueError, lambda: assoc_laguerre(-2.1, alpha, x)) + raises(ArgumentIndexError, lambda: assoc_laguerre(n, alpha, x).fdiff(1)) + raises(ArgumentIndexError, lambda: assoc_laguerre(n, alpha, x).fdiff(4)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0d4ffebabb62c13d3fc2996e8ba23866467720 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py @@ -0,0 +1,66 @@ +from sympy.core.function import diff +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.functions.special.spherical_harmonics import Ynm, Znm, Ynm_c + + +def test_Ynm(): + # https://en.wikipedia.org/wiki/Spherical_harmonics + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + from sympy.abc import n,m + + assert Ynm(0, 0, th, ph).expand(func=True) == 1/(2*sqrt(pi)) + assert Ynm(1, -1, th, ph) == -exp(-2*I*ph)*Ynm(1, 1, th, ph) + assert Ynm(1, -1, th, ph).expand(func=True) == sqrt(6)*sin(th)*exp(-I*ph)/(4*sqrt(pi)) + assert Ynm(1, 0, th, ph).expand(func=True) == sqrt(3)*cos(th)/(2*sqrt(pi)) + assert Ynm(1, 1, th, ph).expand(func=True) == -sqrt(6)*sin(th)*exp(I*ph)/(4*sqrt(pi)) + assert Ynm(2, 0, th, ph).expand(func=True) == 3*sqrt(5)*cos(th)**2/(4*sqrt(pi)) - sqrt(5)/(4*sqrt(pi)) + assert Ynm(2, 1, th, ph).expand(func=True) == -sqrt(30)*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + assert Ynm(2, -2, th, ph).expand(func=True) == (-sqrt(30)*exp(-2*I*ph)*cos(th)**2/(8*sqrt(pi)) + + sqrt(30)*exp(-2*I*ph)/(8*sqrt(pi))) + assert Ynm(2, 2, th, ph).expand(func=True) == (-sqrt(30)*exp(2*I*ph)*cos(th)**2/(8*sqrt(pi)) + + sqrt(30)*exp(2*I*ph)/(8*sqrt(pi))) + + assert diff(Ynm(n, m, th, ph), th) == (m*cot(th)*Ynm(n, m, th, ph) + + sqrt((-m + n)*(m + n + 1))*exp(-I*ph)*Ynm(n, m + 1, th, ph)) + assert diff(Ynm(n, m, th, ph), ph) == I*m*Ynm(n, m, th, ph) + + assert conjugate(Ynm(n, m, th, ph)) == (-1)**(2*m)*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + assert Ynm(n, m, -th, ph) == Ynm(n, m, th, ph) + assert Ynm(n, m, th, -ph) == exp(-2*I*m*ph)*Ynm(n, m, th, ph) + assert Ynm(n, -m, th, ph) == (-1)**m*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + +def test_Ynm_c(): + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + from sympy.abc import n,m + + assert Ynm_c(n, m, th, ph) == (-1)**(2*m)*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + +def test_Znm(): + # https://en.wikipedia.org/wiki/Solid_harmonics#List_of_lowest_functions + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + + assert Znm(0, 0, th, ph) == Ynm(0, 0, th, ph) + assert Znm(1, -1, th, ph) == (-sqrt(2)*I*(Ynm(1, 1, th, ph) + - exp(-2*I*ph)*Ynm(1, 1, th, ph))/2) + assert Znm(1, 0, th, ph) == Ynm(1, 0, th, ph) + assert Znm(1, 1, th, ph) == (sqrt(2)*(Ynm(1, 1, th, ph) + + exp(-2*I*ph)*Ynm(1, 1, th, ph))/2) + assert Znm(0, 0, th, ph).expand(func=True) == 1/(2*sqrt(pi)) + assert Znm(1, -1, th, ph).expand(func=True) == (sqrt(3)*I*sin(th)*exp(I*ph)/(4*sqrt(pi)) + - sqrt(3)*I*sin(th)*exp(-I*ph)/(4*sqrt(pi))) + assert Znm(1, 0, th, ph).expand(func=True) == sqrt(3)*cos(th)/(2*sqrt(pi)) + assert Znm(1, 1, th, ph).expand(func=True) == (-sqrt(3)*sin(th)*exp(I*ph)/(4*sqrt(pi)) + - sqrt(3)*sin(th)*exp(-I*ph)/(4*sqrt(pi))) + assert Znm(2, -1, th, ph).expand(func=True) == (sqrt(15)*I*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + - sqrt(15)*I*sin(th)*exp(-I*ph)*cos(th)/(4*sqrt(pi))) + assert Znm(2, 0, th, ph).expand(func=True) == 3*sqrt(5)*cos(th)**2/(4*sqrt(pi)) - sqrt(5)/(4*sqrt(pi)) + assert Znm(2, 1, th, ph).expand(func=True) == (-sqrt(15)*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + - sqrt(15)*sin(th)*exp(-I*ph)*cos(th)/(4*sqrt(pi))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75230c27811f99fd3121da5ff7f62e052ae26bea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/curve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/curve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f47799b10192ac535db9871f6b8a3b4625f4394b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/curve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/ellipse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/ellipse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516b8bafca9e42e8cad1031870d729bf894a06f7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/ellipse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/entity.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/entity.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a8a28e5bad9ba353b94ea582a81d848b41dffb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/entity.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/exceptions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/exceptions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6cd71e8fff32e4c0ff6d4b034d81b75fe21cb5c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/exceptions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/line.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/line.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5a90d54e70f1aefb61a8d65d511511c252a0a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/line.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/parabola.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/parabola.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b2412b7f9ee273838477661879a4255af31bef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/parabola.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/plane.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/plane.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..858f1440602dde06b6348c0eb6ce82307007f63e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/plane.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/point.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/point.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f67496ed349886f896cc71fc8f0564a39a4112e3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/point.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d97dd5308fff0035107d0de7d503a6589b66cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/__pycache__/util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b05c01856b6a9eaefd9f5801ad507093f1ef628c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_curve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_curve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..258414f6fedc50ff48e9ff1a0b32c0da1fbbd49a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_curve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_ellipse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_ellipse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c37cd57c10d695aa933fa24b65ce7b568d7dd6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_ellipse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_entity.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_entity.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b6090763ef9970cc90b3c955bd259e8ca1587b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_entity.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_geometrysets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_geometrysets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faffea07e2103a152d089bfa1894802429120c14 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_geometrysets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_line.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_line.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ac6bf3a1b92042bc643ab22f9f58d9e069ad467 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_line.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_parabola.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_parabola.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94423a324fa4e036f04cbc720e32c4b7cb439b67 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_parabola.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_plane.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_plane.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..869389151dd9baa736e480d870374c1695f2b6b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_plane.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_point.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_point.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f251d7f78416f7cfe20353110493d777d855ab5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_point.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_polygon.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_polygon.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d16d7345982e3903504c6138767ee477fde7b89 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_polygon.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8dcac205026a9419d2042054133d48f4643ad8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/__pycache__/test_util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_curve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..50aa80273a1d8eb9e414a8d591571f3127352dad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_curve.py @@ -0,0 +1,120 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.hyperbolic import asinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Curve, Line, Point, Ellipse, Ray, Segment, Circle, Polygon, RegularPolygon +from sympy.testing.pytest import raises, slow + + +def test_curve(): + x = Symbol('x', real=True) + s = Symbol('s') + z = Symbol('z') + + # this curve is independent of the indicated parameter + c = Curve([2*s, s**2], (z, 0, 2)) + + assert c.parameter == z + assert c.functions == (2*s, s**2) + assert c.arbitrary_point() == Point(2*s, s**2) + assert c.arbitrary_point(z) == Point(2*s, s**2) + + # this is how it is normally used + c = Curve([2*s, s**2], (s, 0, 2)) + + assert c.parameter == s + assert c.functions == (2*s, s**2) + t = Symbol('t') + # the t returned as assumptions + assert c.arbitrary_point() != Point(2*t, t**2) + t = Symbol('t', real=True) + # now t has the same assumptions so the test passes + assert c.arbitrary_point() == Point(2*t, t**2) + assert c.arbitrary_point(z) == Point(2*z, z**2) + assert c.arbitrary_point(c.parameter) == Point(2*s, s**2) + assert c.arbitrary_point(None) == Point(2*s, s**2) + assert c.plot_interval() == [t, 0, 2] + assert c.plot_interval(z) == [z, 0, 2] + + assert Curve([x, x], (x, 0, 1)).rotate(pi/2) == Curve([-x, x], (x, 0, 1)) + assert Curve([x, x], (x, 0, 1)).rotate(pi/2, (1, 2)).scale(2, 3).translate( + 1, 3).arbitrary_point(s) == \ + Line((0, 0), (1, 1)).rotate(pi/2, (1, 2)).scale(2, 3).translate( + 1, 3).arbitrary_point(s) == \ + Point(-2*s + 7, 3*s + 6) + + raises(ValueError, lambda: Curve((s), (s, 1, 2))) + raises(ValueError, lambda: Curve((x, x * 2), (1, x))) + + raises(ValueError, lambda: Curve((s, s + t), (s, 1, 2)).arbitrary_point()) + raises(ValueError, lambda: Curve((s, s + t), (t, 1, 2)).arbitrary_point(s)) + + +@slow +def test_free_symbols(): + a, b, c, d, e, f, s = symbols('a:f,s') + assert Point(a, b).free_symbols == {a, b} + assert Line((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Ray((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Ray((a, b), angle=c).free_symbols == {a, b, c} + assert Segment((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Line((a, b), slope=c).free_symbols == {a, b, c} + assert Curve((a*s, b*s), (s, c, d)).free_symbols == {a, b, c, d} + assert Ellipse((a, b), c, d).free_symbols == {a, b, c, d} + assert Ellipse((a, b), c, eccentricity=d).free_symbols == \ + {a, b, c, d} + assert Ellipse((a, b), vradius=c, eccentricity=d).free_symbols == \ + {a, b, c, d} + assert Circle((a, b), c).free_symbols == {a, b, c} + assert Circle((a, b), (c, d), (e, f)).free_symbols == \ + {e, d, c, b, f, a} + assert Polygon((a, b), (c, d), (e, f)).free_symbols == \ + {e, b, d, f, a, c} + assert RegularPolygon((a, b), c, d, e).free_symbols == {e, a, b, c, d} + + +def test_transform(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + c = Curve((x, x**2), (x, 0, 1)) + cout = Curve((2*x - 4, 3*x**2 - 10), (x, 0, 1)) + pts = [Point(0, 0), Point(S.Half, Rational(1, 4)), Point(1, 1)] + pts_out = [Point(-4, -10), Point(-3, Rational(-37, 4)), Point(-2, -7)] + + assert c.scale(2, 3, (4, 5)) == cout + assert [c.subs(x, xi/2) for xi in Tuple(0, 1, 2)] == pts + assert [cout.subs(x, xi/2) for xi in Tuple(0, 1, 2)] == pts_out + assert Curve((x + y, 3*x), (x, 0, 1)).subs(y, S.Half) == \ + Curve((x + S.Half, 3*x), (x, 0, 1)) + assert Curve((x, 3*x), (x, 0, 1)).translate(4, 5) == \ + Curve((x + 4, 3*x + 5), (x, 0, 1)) + + +def test_length(): + t = Symbol('t', real=True) + + c1 = Curve((t, 0), (t, 0, 1)) + assert c1.length == 1 + + c2 = Curve((t, t), (t, 0, 1)) + assert c2.length == sqrt(2) + + c3 = Curve((t ** 2, t), (t, 2, 5)) + assert c3.length == -sqrt(17) - asinh(4) / 4 + asinh(10) / 4 + 5 * sqrt(101) / 2 + + +def test_parameter_value(): + t = Symbol('t') + C = Curve([2*t, t**2], (t, 0, 2)) + assert C.parameter_value((2, 1), t) == {t: 1} + raises(ValueError, lambda: C.parameter_value((2, 0), t)) + + +def test_issue_17997(): + t, s = symbols('t s') + c = Curve((t, t**2), (t, 0, 10)) + p = Curve([2*s, s**2], (s, 0, 2)) + assert c(2) == Point(2, 4) + assert p(1) == Point(2, 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_ellipse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_ellipse.py new file mode 100644 index 0000000000000000000000000000000000000000..a79eba8c35771bda9f0980aca68d937f8e625c0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_ellipse.py @@ -0,0 +1,613 @@ +from sympy.core import expand +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sec +from sympy.geometry.line import Segment2D +from sympy.geometry.point import Point2D +from sympy.geometry import (Circle, Ellipse, GeometryError, Line, Point, + Polygon, Ray, RegularPolygon, Segment, + Triangle, intersection) +from sympy.testing.pytest import raises, slow +from sympy.integrals.integrals import integrate +from sympy.functions.special.elliptic_integrals import elliptic_e +from sympy.functions.elementary.miscellaneous import Max + + +def test_ellipse_equation_using_slope(): + from sympy.abc import x, y + + e1 = Ellipse(Point(1, 0), 3, 2) + assert str(e1.equation(_slope=1)) == str((-x + y + 1)**2/8 + (x + y - 1)**2/18 - 1) + + e2 = Ellipse(Point(0, 0), 4, 1) + assert str(e2.equation(_slope=1)) == str((-x + y)**2/2 + (x + y)**2/32 - 1) + + e3 = Ellipse(Point(1, 5), 6, 2) + assert str(e3.equation(_slope=2)) == str((-2*x + y - 3)**2/20 + (x + 2*y - 11)**2/180 - 1) + + +def test_object_from_equation(): + from sympy.abc import x, y, a, b, c, d, e + assert Circle(x**2 + y**2 + 3*x + 4*y - 8) == Circle(Point2D(S(-3) / 2, -2), sqrt(57) / 2) + assert Circle(x**2 + y**2 + 6*x + 8*y + 25) == Circle(Point2D(-3, -4), 0) + assert Circle(a**2 + b**2 + 6*a + 8*b + 25, x='a', y='b') == Circle(Point2D(-3, -4), 0) + assert Circle(x**2 + y**2 - 25) == Circle(Point2D(0, 0), 5) + assert Circle(x**2 + y**2) == Circle(Point2D(0, 0), 0) + assert Circle(a**2 + b**2, x='a', y='b') == Circle(Point2D(0, 0), 0) + assert Circle(x**2 + y**2 + 6*x + 8) == Circle(Point2D(-3, 0), 1) + assert Circle(x**2 + y**2 + 6*y + 8) == Circle(Point2D(0, -3), 1) + assert Circle((x - 1)**2 + y**2 - 9) == Circle(Point2D(1, 0), 3) + assert Circle(6*(x**2) + 6*(y**2) + 6*x + 8*y - 25) == Circle(Point2D(Rational(-1, 2), Rational(-2, 3)), 5*sqrt(7)/6) + assert Circle(Eq(a**2 + b**2, 25), x='a', y=b) == Circle(Point2D(0, 0), 5) + raises(GeometryError, lambda: Circle(x**2 + y**2 + 3*x + 4*y + 26)) + raises(GeometryError, lambda: Circle(x**2 + y**2 + 25)) + raises(GeometryError, lambda: Circle(a**2 + b**2 + 25, x='a', y='b')) + raises(GeometryError, lambda: Circle(x**2 + 6*y + 8)) + raises(GeometryError, lambda: Circle(6*(x ** 2) + 4*(y**2) + 6*x + 8*y + 25)) + raises(ValueError, lambda: Circle(a**2 + b**2 + 3*a + 4*b - 8)) + # .equation() adds 'real=True' assumption; '==' would fail if assumptions differed + x, y = symbols('x y', real=True) + eq = a*x**2 + a*y**2 + c*x + d*y + e + assert expand(Circle(eq).equation()*a) == eq + + +@slow +def test_ellipse_geom(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + t = Symbol('t', real=True) + y1 = Symbol('y1', real=True) + half = S.Half + p1 = Point(0, 0) + p2 = Point(1, 1) + p4 = Point(0, 1) + + e1 = Ellipse(p1, 1, 1) + e2 = Ellipse(p2, half, 1) + e3 = Ellipse(p1, y1, y1) + c1 = Circle(p1, 1) + c2 = Circle(p2, 1) + c3 = Circle(Point(sqrt(2), sqrt(2)), 1) + l1 = Line(p1, p2) + + # Test creation with three points + cen, rad = Point(3*half, 2), 5*half + assert Circle(Point(0, 0), Point(3, 0), Point(0, 4)) == Circle(cen, rad) + assert Circle(Point(0, 0), Point(1, 1), Point(2, 2)) == Segment2D(Point2D(0, 0), Point2D(2, 2)) + + raises(ValueError, lambda: Ellipse(None, None, None, 1)) + raises(ValueError, lambda: Ellipse()) + raises(GeometryError, lambda: Circle(Point(0, 0))) + raises(GeometryError, lambda: Circle(Symbol('x')*Symbol('y'))) + + # Basic Stuff + assert Ellipse(None, 1, 1).center == Point(0, 0) + assert e1 == c1 + assert e1 != e2 + assert e1 != l1 + assert p4 in e1 + assert e1 in e1 + assert e2 in e2 + assert 1 not in e2 + assert p2 not in e2 + assert e1.area == pi + assert e2.area == pi/2 + assert e3.area == pi*y1*abs(y1) + assert c1.area == e1.area + assert c1.circumference == e1.circumference + assert e3.circumference == 2*pi*y1 + assert e1.plot_interval() == e2.plot_interval() == [t, -pi, pi] + assert e1.plot_interval(x) == e2.plot_interval(x) == [x, -pi, pi] + + assert c1.minor == 1 + assert c1.major == 1 + assert c1.hradius == 1 + assert c1.vradius == 1 + + assert Ellipse((1, 1), 0, 0) == Point(1, 1) + assert Ellipse((1, 1), 1, 0) == Segment(Point(0, 1), Point(2, 1)) + assert Ellipse((1, 1), 0, 1) == Segment(Point(1, 0), Point(1, 2)) + + # Private Functions + assert hash(c1) == hash(Circle(Point(1, 0), Point(0, 1), Point(0, -1))) + assert c1 in e1 + assert (Line(p1, p2) in e1) is False + assert e1.__cmp__(e1) == 0 + assert e1.__cmp__(Point(0, 0)) > 0 + + # Encloses + assert e1.encloses(Segment(Point(-0.5, -0.5), Point(0.5, 0.5))) is True + assert e1.encloses(Line(p1, p2)) is False + assert e1.encloses(Ray(p1, p2)) is False + assert e1.encloses(e1) is False + assert e1.encloses( + Polygon(Point(-0.5, -0.5), Point(-0.5, 0.5), Point(0.5, 0.5))) is True + assert e1.encloses(RegularPolygon(p1, 0.5, 3)) is True + assert e1.encloses(RegularPolygon(p1, 5, 3)) is False + assert e1.encloses(RegularPolygon(p2, 5, 3)) is False + + assert e2.arbitrary_point() in e2 + raises(ValueError, lambda: Ellipse(Point(x, y), 1, 1).arbitrary_point(parameter='x')) + + # Foci + f1, f2 = Point(sqrt(12), 0), Point(-sqrt(12), 0) + ef = Ellipse(Point(0, 0), 4, 2) + assert ef.foci in [(f1, f2), (f2, f1)] + + # Tangents + v = sqrt(2) / 2 + p1_1 = Point(v, v) + p1_2 = p2 + Point(half, 0) + p1_3 = p2 + Point(0, 1) + assert e1.tangent_lines(p4) == c1.tangent_lines(p4) + assert e2.tangent_lines(p1_2) == [Line(Point(Rational(3, 2), 1), Point(Rational(3, 2), S.Half))] + assert e2.tangent_lines(p1_3) == [Line(Point(1, 2), Point(Rational(5, 4), 2))] + assert c1.tangent_lines(p1_1) != [Line(p1_1, Point(0, sqrt(2)))] + assert c1.tangent_lines(p1) == [] + assert e2.is_tangent(Line(p1_2, p2 + Point(half, 1))) + assert e2.is_tangent(Line(p1_3, p2 + Point(half, 1))) + assert c1.is_tangent(Line(p1_1, Point(0, sqrt(2)))) + assert e1.is_tangent(Line(Point(0, 0), Point(1, 1))) is False + assert c1.is_tangent(e1) is True + assert c1.is_tangent(Ellipse(Point(2, 0), 1, 1)) is True + assert c1.is_tangent( + Polygon(Point(1, 1), Point(1, -1), Point(2, 0))) is False + assert c1.is_tangent( + Polygon(Point(1, 1), Point(1, 0), Point(2, 0))) is False + assert Circle(Point(5, 5), 3).is_tangent(Circle(Point(0, 5), 1)) is False + + assert Ellipse(Point(5, 5), 2, 1).tangent_lines(Point(0, 0)) == \ + [Line(Point(0, 0), Point(Rational(77, 25), Rational(132, 25))), + Line(Point(0, 0), Point(Rational(33, 5), Rational(22, 5)))] + assert Ellipse(Point(5, 5), 2, 1).tangent_lines(Point(3, 4)) == \ + [Line(Point(3, 4), Point(4, 4)), Line(Point(3, 4), Point(3, 5))] + assert Circle(Point(5, 5), 2).tangent_lines(Point(3, 3)) == \ + [Line(Point(3, 3), Point(4, 3)), Line(Point(3, 3), Point(3, 4))] + assert Circle(Point(5, 5), 2).tangent_lines(Point(5 - 2*sqrt(2), 5)) == \ + [Line(Point(5 - 2*sqrt(2), 5), Point(5 - sqrt(2), 5 - sqrt(2))), + Line(Point(5 - 2*sqrt(2), 5), Point(5 - sqrt(2), 5 + sqrt(2))), ] + assert Circle(Point(5, 5), 5).tangent_lines(Point(4, 0)) == \ + [Line(Point(4, 0), Point(Rational(40, 13), Rational(5, 13))), + Line(Point(4, 0), Point(5, 0))] + assert Circle(Point(5, 5), 5).tangent_lines(Point(0, 6)) == \ + [Line(Point(0, 6), Point(0, 7)), + Line(Point(0, 6), Point(Rational(5, 13), Rational(90, 13)))] + + # for numerical calculations, we shouldn't demand exact equality, + # so only test up to the desired precision + def lines_close(l1, l2, prec): + """ tests whether l1 and 12 are within 10**(-prec) + of each other """ + return abs(l1.p1 - l2.p1) < 10**(-prec) and abs(l1.p2 - l2.p2) < 10**(-prec) + def line_list_close(ll1, ll2, prec): + return all(lines_close(l1, l2, prec) for l1, l2 in zip(ll1, ll2)) + + e = Ellipse(Point(0, 0), 2, 1) + assert e.normal_lines(Point(0, 0)) == \ + [Line(Point(0, 0), Point(0, 1)), Line(Point(0, 0), Point(1, 0))] + assert e.normal_lines(Point(1, 0)) == \ + [Line(Point(0, 0), Point(1, 0))] + assert e.normal_lines((0, 1)) == \ + [Line(Point(0, 0), Point(0, 1))] + assert line_list_close(e.normal_lines(Point(1, 1), 2), [ + Line(Point(Rational(-51, 26), Rational(-1, 5)), Point(Rational(-25, 26), Rational(17, 83))), + Line(Point(Rational(28, 29), Rational(-7, 8)), Point(Rational(57, 29), Rational(-9, 2)))], 2) + # test the failure of Poly.intervals and checks a point on the boundary + p = Point(sqrt(3), S.Half) + assert p in e + assert line_list_close(e.normal_lines(p, 2), [ + Line(Point(Rational(-341, 171), Rational(-1, 13)), Point(Rational(-170, 171), Rational(5, 64))), + Line(Point(Rational(26, 15), Rational(-1, 2)), Point(Rational(41, 15), Rational(-43, 26)))], 2) + # be sure to use the slope that isn't undefined on boundary + e = Ellipse((0, 0), 2, 2*sqrt(3)/3) + assert line_list_close(e.normal_lines((1, 1), 2), [ + Line(Point(Rational(-64, 33), Rational(-20, 71)), Point(Rational(-31, 33), Rational(2, 13))), + Line(Point(1, -1), Point(2, -4))], 2) + # general ellipse fails except under certain conditions + e = Ellipse((0, 0), x, 1) + assert e.normal_lines((x + 1, 0)) == [Line(Point(0, 0), Point(1, 0))] + raises(NotImplementedError, lambda: e.normal_lines((x + 1, 1))) + # Properties + major = 3 + minor = 1 + e4 = Ellipse(p2, minor, major) + assert e4.focus_distance == sqrt(major**2 - minor**2) + ecc = e4.focus_distance / major + assert e4.eccentricity == ecc + assert e4.periapsis == major*(1 - ecc) + assert e4.apoapsis == major*(1 + ecc) + assert e4.semilatus_rectum == major*(1 - ecc ** 2) + # independent of orientation + e4 = Ellipse(p2, major, minor) + assert e4.focus_distance == sqrt(major**2 - minor**2) + ecc = e4.focus_distance / major + assert e4.eccentricity == ecc + assert e4.periapsis == major*(1 - ecc) + assert e4.apoapsis == major*(1 + ecc) + + # Intersection + l1 = Line(Point(1, -5), Point(1, 5)) + l2 = Line(Point(-5, -1), Point(5, -1)) + l3 = Line(Point(-1, -1), Point(1, 1)) + l4 = Line(Point(-10, 0), Point(0, 10)) + pts_c1_l3 = [Point(sqrt(2)/2, sqrt(2)/2), Point(-sqrt(2)/2, -sqrt(2)/2)] + + assert intersection(e2, l4) == [] + assert intersection(c1, Point(1, 0)) == [Point(1, 0)] + assert intersection(c1, l1) == [Point(1, 0)] + assert intersection(c1, l2) == [Point(0, -1)] + assert intersection(c1, l3) in [pts_c1_l3, [pts_c1_l3[1], pts_c1_l3[0]]] + assert intersection(c1, c2) == [Point(0, 1), Point(1, 0)] + assert intersection(c1, c3) == [Point(sqrt(2)/2, sqrt(2)/2)] + assert e1.intersection(l1) == [Point(1, 0)] + assert e2.intersection(l4) == [] + assert e1.intersection(Circle(Point(0, 2), 1)) == [Point(0, 1)] + assert e1.intersection(Circle(Point(5, 0), 1)) == [] + assert e1.intersection(Ellipse(Point(2, 0), 1, 1)) == [Point(1, 0)] + assert e1.intersection(Ellipse(Point(5, 0), 1, 1)) == [] + assert e1.intersection(Point(2, 0)) == [] + assert e1.intersection(e1) == e1 + assert intersection(Ellipse(Point(0, 0), 2, 1), Ellipse(Point(3, 0), 1, 2)) == [Point(2, 0)] + assert intersection(Circle(Point(0, 0), 2), Circle(Point(3, 0), 1)) == [Point(2, 0)] + assert intersection(Circle(Point(0, 0), 2), Circle(Point(7, 0), 1)) == [] + assert intersection(Ellipse(Point(0, 0), 5, 17), Ellipse(Point(4, 0), 1, 0.2) + ) == [Point(5.0, 0, evaluate=False)] + assert intersection(Ellipse(Point(0, 0), 5, 17), Ellipse(Point(4, 0), 0.999, 0.2)) == [] + assert Circle((0, 0), S.Half).intersection( + Triangle((-1, 0), (1, 0), (0, 1))) == [ + Point(Rational(-1, 2), 0), Point(S.Half, 0)] + raises(TypeError, lambda: intersection(e2, Line((0, 0, 0), (0, 0, 1)))) + raises(TypeError, lambda: intersection(e2, Rational(12))) + raises(TypeError, lambda: Ellipse.intersection(e2, 1)) + # some special case intersections + csmall = Circle(p1, 3) + cbig = Circle(p1, 5) + cout = Circle(Point(5, 5), 1) + # one circle inside of another + assert csmall.intersection(cbig) == [] + # separate circles + assert csmall.intersection(cout) == [] + # coincident circles + assert csmall.intersection(csmall) == csmall + + v = sqrt(2) + t1 = Triangle(Point(0, v), Point(0, -v), Point(v, 0)) + points = intersection(t1, c1) + assert len(points) == 4 + assert Point(0, 1) in points + assert Point(0, -1) in points + assert Point(v/2, v/2) in points + assert Point(v/2, -v/2) in points + + circ = Circle(Point(0, 0), 5) + elip = Ellipse(Point(0, 0), 5, 20) + assert intersection(circ, elip) in \ + [[Point(5, 0), Point(-5, 0)], [Point(-5, 0), Point(5, 0)]] + assert elip.tangent_lines(Point(0, 0)) == [] + elip = Ellipse(Point(0, 0), 3, 2) + assert elip.tangent_lines(Point(3, 0)) == \ + [Line(Point(3, 0), Point(3, -12))] + + e1 = Ellipse(Point(0, 0), 5, 10) + e2 = Ellipse(Point(2, 1), 4, 8) + a = Rational(53, 17) + c = 2*sqrt(3991)/17 + ans = [Point(a - c/8, a/2 + c), Point(a + c/8, a/2 - c)] + assert e1.intersection(e2) == ans + e2 = Ellipse(Point(x, y), 4, 8) + c = sqrt(3991) + ans = [Point(-c/68 + a, c*Rational(2, 17) + a/2), Point(c/68 + a, c*Rational(-2, 17) + a/2)] + assert [p.subs({x: 2, y:1}) for p in e1.intersection(e2)] == ans + + # Combinations of above + assert e3.is_tangent(e3.tangent_lines(p1 + Point(y1, 0))[0]) + + e = Ellipse((1, 2), 3, 2) + assert e.tangent_lines(Point(10, 0)) == \ + [Line(Point(10, 0), Point(1, 0)), + Line(Point(10, 0), Point(Rational(14, 5), Rational(18, 5)))] + + # encloses_point + e = Ellipse((0, 0), 1, 2) + assert e.encloses_point(e.center) + assert e.encloses_point(e.center + Point(0, e.vradius - Rational(1, 10))) + assert e.encloses_point(e.center + Point(e.hradius - Rational(1, 10), 0)) + assert e.encloses_point(e.center + Point(e.hradius, 0)) is False + assert e.encloses_point( + e.center + Point(e.hradius + Rational(1, 10), 0)) is False + e = Ellipse((0, 0), 2, 1) + assert e.encloses_point(e.center) + assert e.encloses_point(e.center + Point(0, e.vradius - Rational(1, 10))) + assert e.encloses_point(e.center + Point(e.hradius - Rational(1, 10), 0)) + assert e.encloses_point(e.center + Point(e.hradius, 0)) is False + assert e.encloses_point( + e.center + Point(e.hradius + Rational(1, 10), 0)) is False + assert c1.encloses_point(Point(1, 0)) is False + assert c1.encloses_point(Point(0.3, 0.4)) is True + + assert e.scale(2, 3) == Ellipse((0, 0), 4, 3) + assert e.scale(3, 6) == Ellipse((0, 0), 6, 6) + assert e.rotate(pi) == e + assert e.rotate(pi, (1, 2)) == Ellipse(Point(2, 4), 2, 1) + raises(NotImplementedError, lambda: e.rotate(pi/3)) + + # Circle rotation tests (Issue #11743) + # Link - https://github.com/sympy/sympy/issues/11743 + cir = Circle(Point(1, 0), 1) + assert cir.rotate(pi/2) == Circle(Point(0, 1), 1) + assert cir.rotate(pi/3) == Circle(Point(S.Half, sqrt(3)/2), 1) + assert cir.rotate(pi/3, Point(1, 0)) == Circle(Point(1, 0), 1) + assert cir.rotate(pi/3, Point(0, 1)) == Circle(Point(S.Half + sqrt(3)/2, S.Half + sqrt(3)/2), 1) + + +def test_construction(): + e1 = Ellipse(hradius=2, vradius=1, eccentricity=None) + assert e1.eccentricity == sqrt(3)/2 + + e2 = Ellipse(hradius=2, vradius=None, eccentricity=sqrt(3)/2) + assert e2.vradius == 1 + + e3 = Ellipse(hradius=None, vradius=1, eccentricity=sqrt(3)/2) + assert e3.hradius == 2 + + # filter(None, iterator) filters out anything falsey, including 0 + # eccentricity would be filtered out in this case and the constructor would throw an error + e4 = Ellipse(Point(0, 0), hradius=1, eccentricity=0) + assert e4.vradius == 1 + + #tests for eccentricity > 1 + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = S(3)/2)) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity=sec(5))) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity=S.Pi-S(2))) + + #tests for eccentricity = 1 + #if vradius is not defined + assert Ellipse(None, 1, None, 1).length == 2 + #if hradius is not defined + raises(GeometryError, lambda: Ellipse(None, None, 1, eccentricity = 1)) + + #tests for eccentricity < 0 + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = -3)) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = -0.5)) + +def test_ellipse_random_point(): + y1 = Symbol('y1', real=True) + e3 = Ellipse(Point(0, 0), y1, y1) + rx, ry = Symbol('rx'), Symbol('ry') + for ind in range(0, 5): + r = e3.random_point() + # substitution should give zero*y1**2 + assert e3.equation(rx, ry).subs(zip((rx, ry), r.args)).equals(0) + # test for the case with seed + r = e3.random_point(seed=1) + assert e3.equation(rx, ry).subs(zip((rx, ry), r.args)).equals(0) + + +def test_repr(): + assert repr(Circle((0, 1), 2)) == 'Circle(Point2D(0, 1), 2)' + + +def test_transform(): + c = Circle((1, 1), 2) + assert c.scale(-1) == Circle((-1, 1), 2) + assert c.scale(y=-1) == Circle((1, -1), 2) + assert c.scale(2) == Ellipse((2, 1), 4, 2) + + assert Ellipse((0, 0), 2, 3).scale(2, 3, (4, 5)) == \ + Ellipse(Point(-4, -10), 4, 9) + assert Circle((0, 0), 2).scale(2, 3, (4, 5)) == \ + Ellipse(Point(-4, -10), 4, 6) + assert Ellipse((0, 0), 2, 3).scale(3, 3, (4, 5)) == \ + Ellipse(Point(-8, -10), 6, 9) + assert Circle((0, 0), 2).scale(3, 3, (4, 5)) == \ + Circle(Point(-8, -10), 6) + assert Circle(Point(-8, -10), 6).scale(Rational(1, 3), Rational(1, 3), (4, 5)) == \ + Circle((0, 0), 2) + assert Circle((0, 0), 2).translate(4, 5) == \ + Circle((4, 5), 2) + assert Circle((0, 0), 2).scale(3, 3) == \ + Circle((0, 0), 6) + + +def test_bounds(): + e1 = Ellipse(Point(0, 0), 3, 5) + e2 = Ellipse(Point(2, -2), 7, 7) + c1 = Circle(Point(2, -2), 7) + c2 = Circle(Point(-2, 0), Point(0, 2), Point(2, 0)) + assert e1.bounds == (-3, -5, 3, 5) + assert e2.bounds == (-5, -9, 9, 5) + assert c1.bounds == (-5, -9, 9, 5) + assert c2.bounds == (-2, -2, 2, 2) + + +def test_reflect(): + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + t1 = Triangle((0, 0), (1, 0), (2, 3)) + assert t1.area == -t1.reflect(l).area + e = Ellipse((1, 0), 1, 2) + assert e.area == -e.reflect(Line((1, 0), slope=0)).area + assert e.area == -e.reflect(Line((1, 0), slope=oo)).area + raises(NotImplementedError, lambda: e.reflect(Line((1, 0), slope=m))) + assert Circle((0, 1), 1).reflect(Line((0, 0), (1, 1))) == Circle(Point2D(1, 0), -1) + + +def test_is_tangent(): + e1 = Ellipse(Point(0, 0), 3, 5) + c1 = Circle(Point(2, -2), 7) + assert e1.is_tangent(Point(0, 0)) is False + assert e1.is_tangent(Point(3, 0)) is False + assert e1.is_tangent(e1) is True + assert e1.is_tangent(Ellipse((0, 0), 1, 2)) is False + assert e1.is_tangent(Ellipse((0, 0), 3, 2)) is True + assert c1.is_tangent(Ellipse((2, -2), 7, 1)) is True + assert c1.is_tangent(Circle((11, -2), 2)) is True + assert c1.is_tangent(Circle((7, -2), 2)) is True + assert c1.is_tangent(Ray((-5, -2), (-15, -20))) is False + assert c1.is_tangent(Ray((-3, -2), (-15, -20))) is False + assert c1.is_tangent(Ray((-3, -22), (15, 20))) is False + assert c1.is_tangent(Ray((9, 20), (9, -20))) is True + assert c1.is_tangent(Ray((2, 5), (9, 5))) is True + assert c1.is_tangent(Segment((2, 5), (9, 5))) is True + assert e1.is_tangent(Segment((2, 2), (-7, 7))) is False + assert e1.is_tangent(Segment((0, 0), (1, 2))) is False + assert c1.is_tangent(Segment((0, 0), (-5, -2))) is False + assert e1.is_tangent(Segment((3, 0), (12, 12))) is False + assert e1.is_tangent(Segment((12, 12), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 0), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 5), (3, 5))) is True + assert e1.is_tangent(Line((10, 0), (10, 10))) is False + assert e1.is_tangent(Line((0, 0), (1, 1))) is False + assert e1.is_tangent(Line((-3, 0), (-2.99, -0.001))) is False + assert e1.is_tangent(Line((-3, 0), (-3, 1))) is True + assert e1.is_tangent(Polygon((0, 0), (5, 5), (5, -5))) is False + assert e1.is_tangent(Polygon((-100, -50), (-40, -334), (-70, -52))) is False + assert e1.is_tangent(Polygon((-3, 0), (3, 0), (0, 1))) is False + assert e1.is_tangent(Polygon((-3, 0), (3, 0), (0, 5))) is False + assert e1.is_tangent(Polygon((-3, 0), (0, -5), (3, 0), (0, 5))) is False + assert e1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is True + assert c1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is False + assert e1.is_tangent(Polygon((0, 0), (3, 0), (7, 7), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (6, 5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (0, -5), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 0), (5, 7), (6, -5))) is False + assert c1.is_tangent(Segment((0, 0), (-5, -2))) is False + assert e1.is_tangent(Segment((-3, 0), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 5), (3, 5))) is True + assert e1.is_tangent(Polygon((0, 0), (5, 5), (5, -5))) is False + assert e1.is_tangent(Polygon((-100, -50), (-40, -334), (-70, -52))) is False + assert e1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is True + assert c1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (0, -5), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 0), (5, 7), (6, -5))) is False + raises(TypeError, lambda: e1.is_tangent(Point(0, 0, 0))) + raises(TypeError, lambda: e1.is_tangent(Rational(5))) + + +def test_parameter_value(): + t = Symbol('t') + e = Ellipse(Point(0, 0), 3, 5) + assert e.parameter_value((3, 0), t) == {t: 0} + raises(ValueError, lambda: e.parameter_value((4, 0), t)) + + +@slow +def test_second_moment_of_area(): + x, y = symbols('x, y') + e = Ellipse(Point(0, 0), 5, 4) + I_yy = 2*4*integrate(sqrt(25 - x**2)*x**2, (x, -5, 5))/5 + I_xx = 2*5*integrate(sqrt(16 - y**2)*y**2, (y, -4, 4))/4 + Y = 3*sqrt(1 - x**2/5**2) + I_xy = integrate(integrate(y, (y, -Y, Y))*x, (x, -5, 5)) + assert I_yy == e.second_moment_of_area()[1] + assert I_xx == e.second_moment_of_area()[0] + assert I_xy == e.second_moment_of_area()[2] + #checking for other point + t1 = e.second_moment_of_area(Point(6,5)) + t2 = (580*pi, 845*pi, 600*pi) + assert t1==t2 + + +def test_section_modulus_and_polar_second_moment_of_area(): + d = Symbol('d', positive=True) + c = Circle((3, 7), 8) + assert c.polar_second_moment_of_area() == 2048*pi + assert c.section_modulus() == (128*pi, 128*pi) + c = Circle((2, 9), d/2) + assert c.polar_second_moment_of_area() == pi*d**3*Abs(d)/64 + pi*d*Abs(d)**3/64 + assert c.section_modulus() == (pi*d**3/S(32), pi*d**3/S(32)) + + a, b = symbols('a, b', positive=True) + e = Ellipse((4, 6), a, b) + assert e.section_modulus() == (pi*a*b**2/S(4), pi*a**2*b/S(4)) + assert e.polar_second_moment_of_area() == pi*a**3*b/S(4) + pi*a*b**3/S(4) + e = e.rotate(pi/2) # no change in polar and section modulus + assert e.section_modulus() == (pi*a**2*b/S(4), pi*a*b**2/S(4)) + assert e.polar_second_moment_of_area() == pi*a**3*b/S(4) + pi*a*b**3/S(4) + + e = Ellipse((a, b), 2, 6) + assert e.section_modulus() == (18*pi, 6*pi) + assert e.polar_second_moment_of_area() == 120*pi + + e = Ellipse(Point(0, 0), 2, 2) + assert e.section_modulus() == (2*pi, 2*pi) + assert e.section_modulus(Point(2, 2)) == (2*pi, 2*pi) + assert e.section_modulus((2, 2)) == (2*pi, 2*pi) + + +def test_circumference(): + M = Symbol('M') + m = Symbol('m') + assert Ellipse(Point(0, 0), M, m).circumference == 4 * M * elliptic_e((M ** 2 - m ** 2) / M**2) + + assert Ellipse(Point(0, 0), 5, 4).circumference == 20 * elliptic_e(S(9) / 25) + + # circle + assert Ellipse(None, 1, None, 0).circumference == 2*pi + + # test numerically + assert abs(Ellipse(None, hradius=5, vradius=3).circumference.evalf(16) - 25.52699886339813) < 1e-10 + + +def test_issue_15259(): + assert Circle((1, 2), 0) == Point(1, 2) + + +def test_issue_15797_equals(): + Ri = 0.024127189424130748 + Ci = (0.0864931002830291, 0.0819863295239654) + A = Point(0, 0.0578591400998346) + c = Circle(Ci, Ri) # evaluated + assert c.is_tangent(c.tangent_lines(A)[0]) == True + assert c.center.x.is_Rational + assert c.center.y.is_Rational + assert c.radius.is_Rational + u = Circle(Ci, Ri, evaluate=False) # unevaluated + assert u.center.x.is_Float + assert u.center.y.is_Float + assert u.radius.is_Float + + +def test_auxiliary_circle(): + x, y, a, b = symbols('x y a b') + e = Ellipse((x, y), a, b) + # the general result + assert e.auxiliary_circle() == Circle((x, y), Max(a, b)) + # a special case where Ellipse is a Circle + assert Circle((3, 4), 8).auxiliary_circle() == Circle((3, 4), 8) + + +def test_director_circle(): + x, y, a, b = symbols('x y a b') + e = Ellipse((x, y), a, b) + # the general result + assert e.director_circle() == Circle((x, y), sqrt(a**2 + b**2)) + # a special case where Ellipse is a Circle + assert Circle((3, 4), 8).director_circle() == Circle((3, 4), 8*sqrt(2)) + + +def test_evolute(): + #ellipse centered at h,k + x, y, h, k = symbols('x y h k',real = True) + a, b = symbols('a b') + e = Ellipse(Point(h, k), a, b) + t1 = (e.hradius*(x - e.center.x))**Rational(2, 3) + t2 = (e.vradius*(y - e.center.y))**Rational(2, 3) + E = t1 + t2 - (e.hradius**2 - e.vradius**2)**Rational(2, 3) + assert e.evolute() == E + #Numerical Example + e = Ellipse(Point(1, 1), 6, 3) + t1 = (6*(x - 1))**Rational(2, 3) + t2 = (3*(y - 1))**Rational(2, 3) + E = t1 + t2 - (27)**Rational(2, 3) + assert e.evolute() == E + + +def test_svg(): + e1 = Ellipse(Point(1, 0), 3, 2) + assert e1._svg(2, "#FFAAFF") == '' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_entity.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..0d440fd5dbd193c7c490b45a706fab2703e247ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_entity.py @@ -0,0 +1,120 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.geometry import (Circle, Ellipse, Point, Line, Parabola, + Polygon, Ray, RegularPolygon, Segment, Triangle, Plane, Curve) +from sympy.geometry.entity import scale, GeometryEntity +from sympy.testing.pytest import raises + + +def test_entity(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert GeometryEntity(x, y) in GeometryEntity(x, y) + raises(NotImplementedError, lambda: Point(0, 0) in GeometryEntity(x, y)) + + assert GeometryEntity(x, y) == GeometryEntity(x, y) + assert GeometryEntity(x, y).equals(GeometryEntity(x, y)) + + c = Circle((0, 0), 5) + assert GeometryEntity.encloses(c, Point(0, 0)) + assert GeometryEntity.encloses(c, Segment((0, 0), (1, 1))) + assert GeometryEntity.encloses(c, Line((0, 0), (1, 1))) is False + assert GeometryEntity.encloses(c, Circle((0, 0), 4)) + assert GeometryEntity.encloses(c, Polygon(Point(0, 0), Point(1, 0), Point(0, 1))) + assert GeometryEntity.encloses(c, RegularPolygon(Point(8, 8), 1, 3)) is False + + +def test_svg(): + a = Symbol('a') + b = Symbol('b') + d = Symbol('d') + + entity = Circle(Point(a, b), d) + assert entity._repr_svg_() is None + + entity = Circle(Point(0, 0), S.Infinity) + assert entity._repr_svg_() is None + + +def test_subs(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + p = Point(x, 2) + q = Point(1, 1) + r = Point(3, 4) + for o in [p, + Segment(p, q), + Ray(p, q), + Line(p, q), + Triangle(p, q, r), + RegularPolygon(p, 3, 6), + Polygon(p, q, r, Point(5, 4)), + Circle(p, 3), + Ellipse(p, 3, 4)]: + assert 'y' in str(o.subs(x, y)) + assert p.subs({x: 1}) == Point(1, 2) + assert Point(1, 2).subs(Point(1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs((1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs(Point(1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs({(1, 2)}) == Point(2, 2) + raises(ValueError, lambda: Point(1, 2).subs(1)) + raises(TypeError, lambda: Point(1, 1).subs((Point(1, 1), Point(1, + 2)), 1, 2)) + + +def test_transform(): + assert scale(1, 2, (3, 4)).tolist() == \ + [[1, 0, 0], [0, 2, 0], [0, -4, 1]] + + +def test_reflect_entity_overrides(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + p = Point(x, y) + r = p.reflect(l) + c = Circle((x, y), 3) + cr = c.reflect(l) + assert cr == Circle(r, -3) + assert c.area == -cr.area + + pent = RegularPolygon((1, 2), 1, 5) + slope = S.ComplexInfinity + while slope is S.ComplexInfinity: + slope = Rational(*(x._random()/2).as_real_imag()) + l = Line(pent.vertices[1], slope=slope) + rpent = pent.reflect(l) + assert rpent.center == pent.center.reflect(l) + rvert = [i.reflect(l) for i in pent.vertices] + for v in rpent.vertices: + for i in range(len(rvert)): + ri = rvert[i] + if ri.equals(v): + rvert.remove(ri) + break + assert not rvert + assert pent.area.equals(-rpent.area) + + +def test_geometry_EvalfMixin(): + x = pi + t = Symbol('t') + for g in [ + Point(x, x), + Plane(Point(0, x, 0), (0, 0, x)), + Curve((x*t, x), (t, 0, x)), + Ellipse((x, x), x, -x), + Circle((x, x), x), + Line((0, x), (x, 0)), + Segment((0, x), (x, 0)), + Ray((0, x), (x, 0)), + Parabola((0, x), Line((-x, 0), (x, 0))), + Polygon((0, 0), (0, x), (x, 0), (x, x)), + RegularPolygon((0, x), x, 4, x), + Triangle((0, 0), (x, 0), (x, x)), + ]: + assert str(g).replace('pi', '3.1') == str(g.n(2)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_geometrysets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_geometrysets.py new file mode 100644 index 0000000000000000000000000000000000000000..c52898b3c9ba4e9db80c244db3aebf88db2cc8b4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_geometrysets.py @@ -0,0 +1,38 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.geometry import Circle, Line, Point, Polygon, Segment +from sympy.sets import FiniteSet, Union, Intersection, EmptySet + + +def test_booleans(): + """ test basic unions and intersections """ + half = S.Half + + p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + p5, p6, p7 = map(Point, [(3, 2), (1, -1), (0, 2)]) + l1 = Line(Point(0,0), Point(1,1)) + l2 = Line(Point(half, half), Point(5,5)) + l3 = Line(p2, p3) + l4 = Line(p3, p4) + poly1 = Polygon(p1, p2, p3, p4) + poly2 = Polygon(p5, p6, p7) + poly3 = Polygon(p1, p2, p5) + assert Union(l1, l2).equals(l1) + assert Intersection(l1, l2).equals(l1) + assert Intersection(l1, l4) == FiniteSet(Point(1,1)) + assert Intersection(Union(l1, l4), l3) == FiniteSet(Point(Rational(-1, 3), Rational(-1, 3)), Point(5, 1)) + assert Intersection(l1, FiniteSet(Point(7,-7))) == EmptySet + assert Intersection(Circle(Point(0,0), 3), Line(p1,p2)) == FiniteSet(Point(-3,0), Point(3,0)) + assert Intersection(l1, FiniteSet(p1)) == FiniteSet(p1) + assert Union(l1, FiniteSet(p1)) == l1 + + fs = FiniteSet(Point(Rational(1, 3), 1), Point(Rational(2, 3), 0), Point(Rational(9, 5), Rational(1, 5)), Point(Rational(7, 3), 1)) + # test the intersection of polygons + assert Intersection(poly1, poly2) == fs + # make sure if we union polygons with subsets, the subsets go away + assert Union(poly1, poly2, fs) == Union(poly1, poly2) + # make sure that if we union with a FiniteSet that isn't a subset, + # that the points in the intersection stop being listed + assert Union(poly1, FiniteSet(Point(0,0), Point(3,5))) == Union(poly1, FiniteSet(Point(3,5))) + # intersect two polygons that share an edge + assert Intersection(poly1, poly3) == Union(FiniteSet(Point(Rational(3, 2), 1), Point(2, 1)), Segment(Point(0, 0), Point(1, 0))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_line.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_line.py new file mode 100644 index 0000000000000000000000000000000000000000..5158ec05ab414020fbbe2681a2658454dd15b6eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_line.py @@ -0,0 +1,861 @@ +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.sets import EmptySet +from sympy.simplify.simplify import simplify +from sympy.functions.elementary.trigonometric import tan +from sympy.geometry import (Circle, GeometryError, Line, Point, Ray, + Segment, Triangle, intersection, Point3D, Line3D, Ray3D, Segment3D, + Point2D, Line2D, Plane) +from sympy.geometry.line import Undecidable +from sympy.geometry.polygon import _asa as asa +from sympy.utilities.iterables import cartes +from sympy.testing.pytest import raises, warns + + +x = Symbol('x', real=True) +y = Symbol('y', real=True) +z = Symbol('z', real=True) +k = Symbol('k', real=True) +x1 = Symbol('x1', real=True) +y1 = Symbol('y1', real=True) +t = Symbol('t', real=True) +a, b = symbols('a,b', real=True) +m = symbols('m', real=True) + + +def test_object_from_equation(): + from sympy.abc import x, y, a, b + assert Line(3*x + y + 18) == Line2D(Point2D(0, -18), Point2D(1, -21)) + assert Line(3*x + 5 * y + 1) == Line2D( + Point2D(0, Rational(-1, 5)), Point2D(1, Rational(-4, 5))) + assert Line(3*a + b + 18, x="a", y="b") == Line2D( + Point2D(0, -18), Point2D(1, -21)) + assert Line(3*x + y) == Line2D(Point2D(0, 0), Point2D(1, -3)) + assert Line(x + y) == Line2D(Point2D(0, 0), Point2D(1, -1)) + assert Line(Eq(3*a + b, -18), x="a", y=b) == Line2D( + Point2D(0, -18), Point2D(1, -21)) + # issue 22361 + assert Line(x - 1) == Line2D(Point2D(1, 0), Point2D(1, 1)) + assert Line(2*x - 2, y=x) == Line2D(Point2D(0, 1), Point2D(1, 1)) + assert Line(y) == Line2D(Point2D(0, 0), Point2D(1, 0)) + assert Line(2*y, x=y) == Line2D(Point2D(0, 0), Point2D(0, 1)) + assert Line(y, x=y) == Line2D(Point2D(0, 0), Point2D(0, 1)) + raises(ValueError, lambda: Line(x / y)) + raises(ValueError, lambda: Line(a / b, x='a', y='b')) + raises(ValueError, lambda: Line(y / x)) + raises(ValueError, lambda: Line(b / a, x='a', y='b')) + raises(ValueError, lambda: Line((x + 1)**2 + y)) + + +def feq(a, b): + """Test if two floating point values are 'equal'.""" + t_float = Float("1.0E-10") + return -t_float < a - b < t_float + + +def test_angle_between(): + a = Point(1, 2, 3, 4) + b = a.orthogonal_direction + o = a.origin + assert feq(Line.angle_between(Line(Point(0, 0), Point(1, 1)), + Line(Point(0, 0), Point(5, 0))).evalf(), pi.evalf() / 4) + assert Line(a, o).angle_between(Line(b, o)) == pi / 2 + z = Point3D(0, 0, 0) + assert Line3D.angle_between(Line3D(z, Point3D(1, 1, 1)), + Line3D(z, Point3D(5, 0, 0))) == acos(sqrt(3) / 3) + # direction of points is used to determine angle + assert Line3D.angle_between(Line3D(z, Point3D(1, 1, 1)), + Line3D(Point3D(5, 0, 0), z)) == acos(-sqrt(3) / 3) + + +def test_closing_angle(): + a = Ray((0, 0), angle=0) + b = Ray((1, 2), angle=pi/2) + assert a.closing_angle(b) == -pi/2 + assert b.closing_angle(a) == pi/2 + assert a.closing_angle(a) == 0 + + +def test_smallest_angle(): + a = Line(Point(1, 1), Point(1, 2)) + b = Line(Point(1, 1),Point(2, 3)) + assert a.smallest_angle_between(b) == acos(2*sqrt(5)/5) + + +def test_svg(): + a = Line(Point(1, 1),Point(1, 2)) + assert a._svg() == '' + a = Segment(Point(1, 0),Point(1, 1)) + assert a._svg() == '' + a = Ray(Point(2, 3), Point(3, 5)) + assert a._svg() == '' + + +def test_arbitrary_point(): + l1 = Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + l2 = Line(Point(x1, x1), Point(y1, y1)) + assert l2.arbitrary_point() in l2 + assert Ray((1, 1), angle=pi / 4).arbitrary_point() == \ + Point(t + 1, t + 1) + assert Segment((1, 1), (2, 3)).arbitrary_point() == Point(1 + t, 1 + 2 * t) + assert l1.perpendicular_segment(l1.arbitrary_point()) == l1.arbitrary_point() + assert Ray3D((1, 1, 1), direction_ratio=[1, 2, 3]).arbitrary_point() == \ + Point3D(t + 1, 2 * t + 1, 3 * t + 1) + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 1, 1)).midpoint == \ + Point3D(S.Half, S.Half, S.Half) + assert Segment3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1)).length == sqrt(3) * sqrt((x1 - y1) ** 2) + assert Segment3D((1, 1, 1), (2, 3, 4)).arbitrary_point() == \ + Point3D(t + 1, 2 * t + 1, 3 * t + 1) + raises(ValueError, (lambda: Line((x, 1), (2, 3)).arbitrary_point(x))) + + +def test_are_concurrent_2d(): + l1 = Line(Point(0, 0), Point(1, 1)) + l2 = Line(Point(x1, x1), Point(x1, 1 + x1)) + assert Line.are_concurrent(l1) is False + assert Line.are_concurrent(l1, l2) + assert Line.are_concurrent(l1, l1, l1, l2) + assert Line.are_concurrent(l1, l2, Line(Point(5, x1), Point(Rational(-3, 5), x1))) + assert Line.are_concurrent(l1, Line(Point(0, 0), Point(-x1, x1)), l2) is False + + +def test_are_concurrent_3d(): + p1 = Point3D(0, 0, 0) + l1 = Line(p1, Point3D(1, 1, 1)) + parallel_1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + parallel_2 = Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0)) + assert Line3D.are_concurrent(l1) is False + assert Line3D.are_concurrent(l1, Line(Point3D(x1, x1, x1), Point3D(y1, y1, y1))) is False + assert Line3D.are_concurrent(l1, Line3D(p1, Point3D(x1, x1, x1)), + Line(Point3D(x1, x1, x1), Point3D(x1, 1 + x1, 1))) is True + assert Line3D.are_concurrent(parallel_1, parallel_2) is False + + +def test_arguments(): + """Functions accepting `Point` objects in `geometry` + should also accept tuples, lists, and generators and + automatically convert them to points.""" + from sympy.utilities.iterables import subsets + + singles2d = ((1, 2), [1, 3], Point(1, 5)) + doubles2d = subsets(singles2d, 2) + l2d = Line(Point2D(1, 2), Point2D(2, 3)) + singles3d = ((1, 2, 3), [1, 2, 4], Point(1, 2, 6)) + doubles3d = subsets(singles3d, 2) + l3d = Line(Point3D(1, 2, 3), Point3D(1, 1, 2)) + singles4d = ((1, 2, 3, 4), [1, 2, 3, 5], Point(1, 2, 3, 7)) + doubles4d = subsets(singles4d, 2) + l4d = Line(Point(1, 2, 3, 4), Point(2, 2, 2, 2)) + # test 2D + test_single = ['contains', 'distance', 'equals', 'parallel_line', 'perpendicular_line', 'perpendicular_segment', + 'projection', 'intersection'] + for p in doubles2d: + Line2D(*p) + for func in test_single: + for p in singles2d: + getattr(l2d, func)(p) + # test 3D + for p in doubles3d: + Line3D(*p) + for func in test_single: + for p in singles3d: + getattr(l3d, func)(p) + # test 4D + for p in doubles4d: + Line(*p) + for func in test_single: + for p in singles4d: + getattr(l4d, func)(p) + + +def test_basic_properties_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + p10 = Point(2000, 2000) + p_r3 = Ray(p1, p2).random_point() + p_r4 = Ray(p2, p1).random_point() + + l1 = Line(p1, p2) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + l4 = Line(p1, Point(1, 0)) + + r1 = Ray(p1, Point(0, 1)) + r2 = Ray(Point(0, 1), p1) + + s1 = Segment(p1, p10) + p_s1 = s1.random_point() + + assert Line((1, 1), slope=1) == Line((1, 1), (2, 2)) + assert Line((1, 1), slope=oo) == Line((1, 1), (1, 2)) + assert Line((1, 1), slope=oo).bounds == (1, 1, 1, 2) + assert Line((1, 1), slope=-oo) == Line((1, 1), (1, 2)) + assert Line(p1, p2).scale(2, 1) == Line(p1, Point(2, 1)) + assert Line(p1, p2) == Line(p1, p2) + assert Line(p1, p2) != Line(p2, p1) + assert l1 != Line(Point(x1, x1), Point(y1, y1)) + assert l1 != l3 + assert Line(p1, p10) != Line(p10, p1) + assert Line(p1, p10) != p1 + assert p1 in l1 # is p1 on the line l1? + assert p1 not in l3 + assert s1 in Line(p1, p10) + assert Ray(Point(0, 0), Point(0, 1)) in Ray(Point(0, 0), Point(0, 2)) + assert Ray(Point(0, 0), Point(0, 2)) in Ray(Point(0, 0), Point(0, 1)) + assert Ray(Point(0, 0), Point(0, 2)).xdirection == S.Zero + assert Ray(Point(0, 0), Point(1, 2)).xdirection == S.Infinity + assert Ray(Point(0, 0), Point(-1, 2)).xdirection == S.NegativeInfinity + assert Ray(Point(0, 0), Point(2, 0)).ydirection == S.Zero + assert Ray(Point(0, 0), Point(2, 2)).ydirection == S.Infinity + assert Ray(Point(0, 0), Point(2, -2)).ydirection == S.NegativeInfinity + assert (r1 in s1) is False + assert Segment(p1, p2) in s1 + assert Ray(Point(x1, x1), Point(x1, 1 + x1)) != Ray(p1, Point(-1, 5)) + assert Segment(p1, p2).midpoint == Point(S.Half, S.Half) + assert Segment(p1, Point(-x1, x1)).length == sqrt(2 * (x1 ** 2)) + + assert l1.slope == 1 + assert l3.slope is oo + assert l4.slope == 0 + assert Line(p1, Point(0, 1)).slope is oo + assert Line(r1.source, r1.random_point()).slope == r1.slope + assert Line(r2.source, r2.random_point()).slope == r2.slope + assert Segment(Point(0, -1), Segment(p1, Point(0, 1)).random_point()).slope == Segment(p1, Point(0, 1)).slope + + assert l4.coefficients == (0, 1, 0) + assert Line((-x, x), (-x + 1, x - 1)).coefficients == (1, 1, 0) + assert Line(p1, Point(0, 1)).coefficients == (1, 0, 0) + # issue 7963 + r = Ray((0, 0), angle=x) + assert r.subs(x, 3 * pi / 4) == Ray((0, 0), (-1, 1)) + assert r.subs(x, 5 * pi / 4) == Ray((0, 0), (-1, -1)) + assert r.subs(x, -pi / 4) == Ray((0, 0), (1, -1)) + assert r.subs(x, pi / 2) == Ray((0, 0), (0, 1)) + assert r.subs(x, -pi / 2) == Ray((0, 0), (0, -1)) + + for ind in range(0, 5): + assert l3.random_point() in l3 + + assert p_r3.x >= p1.x and p_r3.y >= p1.y + assert p_r4.x <= p2.x and p_r4.y <= p2.y + assert p1.x <= p_s1.x <= p10.x and p1.y <= p_s1.y <= p10.y + assert hash(s1) != hash(Segment(p10, p1)) + + assert s1.plot_interval() == [t, 0, 1] + assert Line(p1, p10).plot_interval() == [t, -5, 5] + assert Ray((0, 0), angle=pi / 4).plot_interval() == [t, 0, 10] + + +def test_basic_properties_3d(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(x1, x1, x1) + p5 = Point3D(x1, 1 + x1, 1) + + l1 = Line3D(p1, p2) + l3 = Line3D(p3, p5) + + r1 = Ray3D(p1, Point3D(-1, 5, 0)) + r3 = Ray3D(p1, p2) + + s1 = Segment3D(p1, p2) + + assert Line3D((1, 1, 1), direction_ratio=[2, 3, 4]) == Line3D(Point3D(1, 1, 1), Point3D(3, 4, 5)) + assert Line3D((1, 1, 1), direction_ratio=[1, 5, 7]) == Line3D(Point3D(1, 1, 1), Point3D(2, 6, 8)) + assert Line3D((1, 1, 1), direction_ratio=[1, 2, 3]) == Line3D(Point3D(1, 1, 1), Point3D(2, 3, 4)) + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).direction_cosine == [1, 0, 0] + assert Line3D(Line3D(p1, Point3D(0, 1, 0))) == Line3D(p1, Point3D(0, 1, 0)) + assert Ray3D(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0))) == Ray3D(p1, Point3D(1, 0, 0)) + assert Line3D(p1, p2) != Line3D(p2, p1) + assert l1 != l3 + assert l1 != Line3D(p3, Point3D(y1, y1, y1)) + assert r3 != r1 + assert Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) in Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)) + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)) in Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).xdirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).ydirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).zdirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(-2, 2, 2)).xdirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, -2, 2)).ydirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, -2)).zdirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(0, 2, 2)).xdirection == S.Zero + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 0, 2)).ydirection == S.Zero + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 0)).zdirection == S.Zero + assert p1 in l1 + assert p1 not in l3 + + assert l1.direction_ratio == [1, 1, 1] + + assert s1.midpoint == Point3D(S.Half, S.Half, S.Half) + # Test zdirection + assert Ray3D(p1, Point3D(0, 0, -1)).zdirection is S.NegativeInfinity + + +def test_contains(): + p1 = Point(0, 0) + + r = Ray(p1, Point(4, 4)) + r1 = Ray3D(p1, Point3D(0, 0, -1)) + r2 = Ray3D(p1, Point3D(0, 1, 0)) + r3 = Ray3D(p1, Point3D(0, 0, 1)) + + l = Line(Point(0, 1), Point(3, 4)) + # Segment contains + assert Point(0, (a + b) / 2) in Segment((0, a), (0, b)) + assert Point((a + b) / 2, 0) in Segment((a, 0), (b, 0)) + assert Point3D(0, 1, 0) in Segment3D((0, 1, 0), (0, 1, 0)) + assert Point3D(1, 0, 0) in Segment3D((1, 0, 0), (1, 0, 0)) + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).contains([]) is True + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).contains( + Segment3D(Point3D(2, 2, 2), Point3D(3, 2, 2))) is False + # Line contains + assert l.contains(Point(0, 1)) is True + assert l.contains((0, 1)) is True + assert l.contains((0, 0)) is False + # Ray contains + assert r.contains(p1) is True + assert r.contains((1, 1)) is True + assert r.contains((1, 3)) is False + assert r.contains(Segment((1, 1), (2, 2))) is True + assert r.contains(Segment((1, 2), (2, 5))) is False + assert r.contains(Ray((2, 2), (3, 3))) is True + assert r.contains(Ray((2, 2), (3, 5))) is False + assert r1.contains(Segment3D(p1, Point3D(0, 0, -10))) is True + assert r1.contains(Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))) is False + assert r2.contains(Point3D(0, 0, 0)) is True + assert r3.contains(Point3D(0, 0, 0)) is True + assert Ray3D(Point3D(1, 1, 1), Point3D(1, 0, 0)).contains([]) is False + assert Line3D((0, 0, 0), (x, y, z)).contains((2 * x, 2 * y, 2 * z)) + with warns(UserWarning, test_stacklevel=False): + assert Line3D(p1, Point3D(0, 1, 0)).contains(Point(1.0, 1.0)) is False + + with warns(UserWarning, test_stacklevel=False): + assert r3.contains(Point(1.0, 1.0)) is False + + +def test_contains_nonreal_symbols(): + u, v, w, z = symbols('u, v, w, z') + l = Segment(Point(u, w), Point(v, z)) + p = Point(u*Rational(2, 3) + v/3, w*Rational(2, 3) + z/3) + assert l.contains(p) + + +def test_distance_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + half = S.Half + + s1 = Segment(Point(0, 0), Point(1, 1)) + s2 = Segment(Point(half, half), Point(1, 0)) + + r = Ray(p1, p2) + + assert s1.distance(Point(0, 0)) == 0 + assert s1.distance((0, 0)) == 0 + assert s2.distance(Point(0, 0)) == 2 ** half / 2 + assert s2.distance(Point(Rational(3) / 2, Rational(3) / 2)) == 2 ** half + assert Line(p1, p2).distance(Point(-1, 1)) == sqrt(2) + assert Line(p1, p2).distance(Point(1, -1)) == sqrt(2) + assert Line(p1, p2).distance(Point(2, 2)) == 0 + assert Line(p1, p2).distance((-1, 1)) == sqrt(2) + assert Line((0, 0), (0, 1)).distance(p1) == 0 + assert Line((0, 0), (0, 1)).distance(p2) == 1 + assert Line((0, 0), (1, 0)).distance(p1) == 0 + assert Line((0, 0), (1, 0)).distance(p2) == 1 + assert r.distance(Point(-1, -1)) == sqrt(2) + assert r.distance(Point(1, 1)) == 0 + assert r.distance(Point(-1, 1)) == sqrt(2) + assert Ray((1, 1), (2, 2)).distance(Point(1.5, 3)) == 3 * sqrt(2) / 4 + assert r.distance((1, 1)) == 0 + + +def test_dimension_normalization(): + with warns(UserWarning, test_stacklevel=False): + assert Ray((1, 1), (2, 1, 2)) == Ray((1, 1, 0), (2, 1, 2)) + + +def test_distance_3d(): + p1, p2 = Point3D(0, 0, 0), Point3D(1, 1, 1) + p3 = Point3D(Rational(3) / 2, Rational(3) / 2, Rational(3) / 2) + + s1 = Segment3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + s2 = Segment3D(Point3D(S.Half, S.Half, S.Half), Point3D(1, 0, 1)) + + r = Ray3D(p1, p2) + + assert s1.distance(p1) == 0 + assert s2.distance(p1) == sqrt(3) / 2 + assert s2.distance(p3) == 2 * sqrt(6) / 3 + assert s1.distance((0, 0, 0)) == 0 + assert s2.distance((0, 0, 0)) == sqrt(3) / 2 + assert s1.distance(p1) == 0 + assert s2.distance(p1) == sqrt(3) / 2 + assert s2.distance(p3) == 2 * sqrt(6) / 3 + assert s1.distance((0, 0, 0)) == 0 + assert s2.distance((0, 0, 0)) == sqrt(3) / 2 + # Line to point + assert Line3D(p1, p2).distance(Point3D(-1, 1, 1)) == 2 * sqrt(6) / 3 + assert Line3D(p1, p2).distance(Point3D(1, -1, 1)) == 2 * sqrt(6) / 3 + assert Line3D(p1, p2).distance(Point3D(2, 2, 2)) == 0 + assert Line3D(p1, p2).distance((2, 2, 2)) == 0 + assert Line3D(p1, p2).distance((1, -1, 1)) == 2 * sqrt(6) / 3 + assert Line3D((0, 0, 0), (0, 1, 0)).distance(p1) == 0 + assert Line3D((0, 0, 0), (0, 1, 0)).distance(p2) == sqrt(2) + assert Line3D((0, 0, 0), (1, 0, 0)).distance(p1) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(p2) == sqrt(2) + # Line to line + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 0, 0), (0, 1, 2))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 0, 0), (1, 0, 0))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((10, 0, 0), (10, 1, 2))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 1, 0), (0, 1, 1))) == 1 + # Line to plane + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((2, 0, 0), (0, 0, 1))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((0, 1, 0), (0, 1, 0))) == 1 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((1, 1, 3), (1, 0, 0))) == 0 + # Ray to point + assert r.distance(Point3D(-1, -1, -1)) == sqrt(3) + assert r.distance(Point3D(1, 1, 1)) == 0 + assert r.distance((-1, -1, -1)) == sqrt(3) + assert r.distance((1, 1, 1)) == 0 + assert Ray3D((0, 0, 0), (1, 1, 2)).distance((-1, -1, 2)) == 4 * sqrt(3) / 3 + assert Ray3D((1, 1, 1), (2, 2, 2)).distance(Point3D(1.5, -3, -1)) == Rational(9) / 2 + assert Ray3D((1, 1, 1), (2, 2, 2)).distance(Point3D(1.5, 3, 1)) == sqrt(78) / 6 + + +def test_equals(): + p1 = Point(0, 0) + p2 = Point(1, 1) + + l1 = Line(p1, p2) + l2 = Line((0, 5), slope=m) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert l1.perpendicular_line(p1.args).equals(Line(Point(0, 0), Point(1, -1))) + assert l1.perpendicular_line(p1).equals(Line(Point(0, 0), Point(1, -1))) + assert Line(Point(x1, x1), Point(y1, y1)).parallel_line(Point(-x1, x1)). \ + equals(Line(Point(-x1, x1), Point(-y1, 2 * x1 - y1))) + assert l3.parallel_line(p1.args).equals(Line(Point(0, 0), Point(0, -1))) + assert l3.parallel_line(p1).equals(Line(Point(0, 0), Point(0, -1))) + assert (l2.distance(Point(2, 3)) - 2 * abs(m + 1) / sqrt(m ** 2 + 1)).equals(0) + assert Line3D(p1, Point3D(0, 1, 0)).equals(Point(1.0, 1.0)) is False + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).equals(Line3D(Point3D(-5, 0, 0), Point3D(-1, 0, 0))) is True + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).equals(Line3D(p1, Point3D(0, 1, 0))) is False + assert Ray3D(p1, Point3D(0, 0, -1)).equals(Point(1.0, 1.0)) is False + assert Ray3D(p1, Point3D(0, 0, -1)).equals(Ray3D(p1, Point3D(0, 0, -1))) is True + assert Line3D((0, 0), (t, t)).perpendicular_line(Point(0, 1, 0)).equals( + Line3D(Point3D(0, 1, 0), Point3D(S.Half, S.Half, 0))) + assert Line3D((0, 0), (t, t)).perpendicular_segment(Point(0, 1, 0)).equals(Segment3D((0, 1), (S.Half, S.Half))) + assert Line3D(p1, Point3D(0, 1, 0)).equals(Point(1.0, 1.0)) is False + + +def test_equation(): + p1 = Point(0, 0) + p2 = Point(1, 1) + l1 = Line(p1, p2) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert simplify(l1.equation()) in (x - y, y - x) + assert simplify(l3.equation()) in (x - x1, x1 - x) + assert simplify(l1.equation()) in (x - y, y - x) + assert simplify(l3.equation()) in (x - x1, x1 - x) + + assert Line(p1, Point(1, 0)).equation(x=x, y=y) == y + assert Line(p1, Point(0, 1)).equation() == x + assert Line(Point(2, 0), Point(2, 1)).equation() == x - 2 + assert Line(p2, Point(2, 1)).equation() == y - 1 + + assert Line3D(Point(x1, x1, x1), Point(y1, y1, y1) + ).equation() == (-x + y, -x + z) + assert Line3D(Point(1, 2, 3), Point(2, 3, 4) + ).equation() == (-x + y - 1, -x + z - 2) + assert Line3D(Point(1, 2, 3), Point(1, 3, 4) + ).equation() == (x - 1, -y + z - 1) + assert Line3D(Point(1, 2, 3), Point(2, 2, 4) + ).equation() == (y - 2, -x + z - 2) + assert Line3D(Point(1, 2, 3), Point(2, 3, 3) + ).equation() == (-x + y - 1, z - 3) + assert Line3D(Point(1, 2, 3), Point(1, 2, 4) + ).equation() == (x - 1, y - 2) + assert Line3D(Point(1, 2, 3), Point(1, 3, 3) + ).equation() == (x - 1, z - 3) + assert Line3D(Point(1, 2, 3), Point(2, 2, 3) + ).equation() == (y - 2, z - 3) + + +def test_intersection_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + p3 = Point(x1, x1) + p4 = Point(y1, y1) + + l1 = Line(p1, p2) + l3 = Line(Point(0, 0), Point(3, 4)) + + r1 = Ray(Point(1, 1), Point(2, 2)) + r2 = Ray(Point(0, 0), Point(3, 4)) + r4 = Ray(p1, p2) + r6 = Ray(Point(0, 1), Point(1, 2)) + r7 = Ray(Point(0.5, 0.5), Point(1, 1)) + + s1 = Segment(p1, p2) + s2 = Segment(Point(0.25, 0.25), Point(0.5, 0.5)) + s3 = Segment(Point(0, 0), Point(3, 4)) + + assert intersection(l1, p1) == [p1] + assert intersection(l1, Point(x1, 1 + x1)) == [] + assert intersection(l1, Line(p3, p4)) in [[l1], [Line(p3, p4)]] + assert intersection(l1, l1.parallel_line(Point(x1, 1 + x1))) == [] + assert intersection(l3, l3) == [l3] + assert intersection(l3, r2) == [r2] + assert intersection(l3, s3) == [s3] + assert intersection(s3, l3) == [s3] + assert intersection(Segment(Point(-10, 10), Point(10, 10)), Segment(Point(-5, -5), Point(-5, 5))) == [] + assert intersection(r2, l3) == [r2] + assert intersection(r1, Ray(Point(2, 2), Point(0, 0))) == [Segment(Point(1, 1), Point(2, 2))] + assert intersection(r1, Ray(Point(1, 1), Point(-1, -1))) == [Point(1, 1)] + assert intersection(r1, Segment(Point(0, 0), Point(2, 2))) == [Segment(Point(1, 1), Point(2, 2))] + + assert r4.intersection(s2) == [s2] + assert r4.intersection(Segment(Point(2, 3), Point(3, 4))) == [] + assert r4.intersection(Segment(Point(-1, -1), Point(0.5, 0.5))) == [Segment(p1, Point(0.5, 0.5))] + assert r4.intersection(Ray(p2, p1)) == [s1] + assert Ray(p2, p1).intersection(r6) == [] + assert r4.intersection(r7) == r7.intersection(r4) == [r7] + assert Ray3D((0, 0), (3, 0)).intersection(Ray3D((1, 0), (3, 0))) == [Ray3D((1, 0), (3, 0))] + assert Ray3D((1, 0), (3, 0)).intersection(Ray3D((0, 0), (3, 0))) == [Ray3D((1, 0), (3, 0))] + assert Ray(Point(0, 0), Point(0, 4)).intersection(Ray(Point(0, 1), Point(0, -1))) == \ + [Segment(Point(0, 0), Point(0, 1))] + + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((1, 0), (2, 0))) == [Segment3D((1, 0), (2, 0))] + assert Segment3D((1, 0), (2, 0)).intersection( + Segment3D((0, 0), (3, 0))) == [Segment3D((1, 0), (2, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((3, 0), (4, 0))) == [Point3D((3, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((2, 0), (5, 0))) == [Segment3D((2, 0), (3, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((-2, 0), (1, 0))) == [Segment3D((0, 0), (1, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((-2, 0), (0, 0))) == [Point3D(0, 0)] + assert s1.intersection(Segment(Point(1, 1), Point(2, 2))) == [Point(1, 1)] + assert s1.intersection(Segment(Point(0.5, 0.5), Point(1.5, 1.5))) == [Segment(Point(0.5, 0.5), p2)] + assert s1.intersection(Segment(Point(4, 4), Point(5, 5))) == [] + assert s1.intersection(Segment(Point(-1, -1), p1)) == [p1] + assert s1.intersection(Segment(Point(-1, -1), Point(0.5, 0.5))) == [Segment(p1, Point(0.5, 0.5))] + assert s1.intersection(Line(Point(1, 0), Point(2, 1))) == [] + assert s1.intersection(s2) == [s2] + assert s2.intersection(s1) == [s2] + + assert asa(120, 8, 52) == \ + Triangle( + Point(0, 0), + Point(8, 0), + Point(-4 * cos(19 * pi / 90) / sin(2 * pi / 45), + 4 * sqrt(3) * cos(19 * pi / 90) / sin(2 * pi / 45))) + assert Line((0, 0), (1, 1)).intersection(Ray((1, 0), (1, 2))) == [Point(1, 1)] + assert Line((0, 0), (1, 1)).intersection(Segment((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (1, 1)).intersection(Ray((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (1, 1)).intersection(Segment((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (10, 10)).contains(Segment((1, 1), (2, 2))) is True + assert Segment((1, 1), (2, 2)) in Line((0, 0), (10, 10)) + assert s1.intersection(Ray((1, 1), (4, 4))) == [Point(1, 1)] + + # This test is disabled because it hangs after rref changes which simplify + # intermediate results and return a different representation from when the + # test was written. + # # 16628 - this should be fast + # p0 = Point2D(Rational(249, 5), Rational(497999, 10000)) + # p1 = Point2D((-58977084786*sqrt(405639795226) + 2030690077184193 + + # 20112207807*sqrt(630547164901) + 99600*sqrt(255775022850776494562626)) + # /(2000*sqrt(255775022850776494562626) + 1991998000*sqrt(405639795226) + # + 1991998000*sqrt(630547164901) + 1622561172902000), + # (-498000*sqrt(255775022850776494562626) - 995999*sqrt(630547164901) + + # 90004251917891999 + + # 496005510002*sqrt(405639795226))/(10000*sqrt(255775022850776494562626) + # + 9959990000*sqrt(405639795226) + 9959990000*sqrt(630547164901) + + # 8112805864510000)) + # p2 = Point2D(Rational(497, 10), Rational(-497, 10)) + # p3 = Point2D(Rational(-497, 10), Rational(-497, 10)) + # l = Line(p0, p1) + # s = Segment(p2, p3) + # n = (-52673223862*sqrt(405639795226) - 15764156209307469 - + # 9803028531*sqrt(630547164901) + + # 33200*sqrt(255775022850776494562626)) + # d = sqrt(405639795226) + 315274080450 + 498000*sqrt( + # 630547164901) + sqrt(255775022850776494562626) + # assert intersection(l, s) == [ + # Point2D(n/d*Rational(3, 2000), Rational(-497, 10))] + + +def test_line_intersection(): + # see also test_issue_11238 in test_matrices.py + x0 = tan(pi*Rational(13, 45)) + x1 = sqrt(3) + x2 = x0**2 + x, y = [8*x0/(x0 + x1), (24*x0 - 8*x1*x2)/(x2 - 3)] + assert Line(Point(0, 0), Point(1, -sqrt(3))).contains(Point(x, y)) is True + + +def test_intersection_3d(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + + l1 = Line3D(p1, p2) + l2 = Line3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + r1 = Ray3D(Point3D(1, 1, 1), Point3D(2, 2, 2)) + r2 = Ray3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + s1 = Segment3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + assert intersection(l1, p1) == [p1] + assert intersection(l1, Point3D(x1, 1 + x1, 1)) == [] + assert intersection(l1, l1.parallel_line(p1)) == [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1))] + assert intersection(l2, r2) == [r2] + assert intersection(l2, s1) == [s1] + assert intersection(r2, l2) == [r2] + assert intersection(r1, Ray3D(Point3D(1, 1, 1), Point3D(-1, -1, -1))) == [Point3D(1, 1, 1)] + assert intersection(r1, Segment3D(Point3D(0, 0, 0), Point3D(2, 2, 2))) == [ + Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))] + assert intersection(Ray3D(Point3D(1, 0, 0), Point3D(-1, 0, 0)), Ray3D(Point3D(0, 1, 0), Point3D(0, -1, 0))) \ + == [Point3D(0, 0, 0)] + assert intersection(r1, Ray3D(Point3D(2, 2, 2), Point3D(0, 0, 0))) == \ + [Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))] + assert intersection(s1, r2) == [s1] + + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).intersection(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) == \ + [Point3D(2, 2, 1)] + assert Line3D((0, 1, 2), (0, 2, 3)).intersection(Line3D((0, 1, 2), (0, 1, 1))) == [Point3D(0, 1, 2)] + assert Line3D((0, 0), (t, t)).intersection(Line3D((0, 1), (t, t))) == \ + [Point3D(t, t)] + + assert Ray3D(Point3D(0, 0, 0), Point3D(0, 4, 0)).intersection(Ray3D(Point3D(0, 1, 1), Point3D(0, -1, 1))) == [] + + +def test_is_parallel(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(x1, x1, x1) + + l2 = Line(Point(x1, x1), Point(y1, y1)) + l2_1 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert Line.is_parallel(Line(Point(0, 0), Point(1, 1)), l2) + assert Line.is_parallel(l2, Line(Point(x1, x1), Point(x1, 1 + x1))) is False + assert Line.is_parallel(l2, l2.parallel_line(Point(-x1, x1))) + assert Line.is_parallel(l2_1, l2_1.parallel_line(Point(0, 0))) + assert Line3D(p1, p2).is_parallel(Line3D(p1, p2)) # same as in 2D + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).is_parallel(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) is False + assert Line3D(p1, p2).parallel_line(p3) == Line3D(Point3D(x1, x1, x1), + Point3D(x1 + 1, x1 + 1, x1 + 1)) + assert Line3D(p1, p2).parallel_line(p3.args) == \ + Line3D(Point3D(x1, x1, x1), Point3D(x1 + 1, x1 + 1, x1 + 1)) + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).is_parallel(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) is False + + +def test_is_perpendicular(): + p1 = Point(0, 0) + p2 = Point(1, 1) + + l1 = Line(p1, p2) + l2 = Line(Point(x1, x1), Point(y1, y1)) + l1_1 = Line(p1, Point(-x1, x1)) + # 2D + assert Line.is_perpendicular(l1, l1_1) + assert Line.is_perpendicular(l1, l2) is False + p = l1.random_point() + assert l1.perpendicular_segment(p) == p + # 3D + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)), + Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0))) is True + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)), + Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0))) is False + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)), + Line3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1))) is False + + +def test_is_similar(): + p1 = Point(2000, 2000) + p2 = p1.scale(2, 2) + + r1 = Ray3D(Point3D(1, 1, 1), Point3D(1, 0, 0)) + r2 = Ray(Point(0, 0), Point(0, 1)) + + s1 = Segment(Point(0, 0), p1) + + assert s1.is_similar(Segment(p1, p2)) + assert s1.is_similar(r2) is False + assert r1.is_similar(Line3D(Point3D(1, 1, 1), Point3D(1, 0, 0))) is True + assert r1.is_similar(Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0))) is False + + +def test_length(): + s2 = Segment3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1)) + assert Line(Point(0, 0), Point(1, 1)).length is oo + assert s2.length == sqrt(3) * sqrt((x1 - y1) ** 2) + assert Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)).length is oo + + +def test_projection(): + p1 = Point(0, 0) + p2 = Point3D(0, 0, 0) + p3 = Point(-x1, x1) + + l1 = Line(p1, Point(1, 1)) + l2 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + l3 = Line3D(p2, Point3D(1, 1, 1)) + + r1 = Ray(Point(1, 1), Point(2, 2)) + + s1 = Segment(Point2D(0, 0), Point2D(0, 1)) + s2 = Segment(Point2D(1, 0), Point2D(2, 1/2)) + + assert Line(Point(x1, x1), Point(y1, y1)).projection(Point(y1, y1)) == Point(y1, y1) + assert Line(Point(x1, x1), Point(x1, 1 + x1)).projection(Point(1, 1)) == Point(x1, 1) + assert Segment(Point(-2, 2), Point(0, 4)).projection(r1) == Segment(Point(-1, 3), Point(0, 4)) + assert Segment(Point(0, 4), Point(-2, 2)).projection(r1) == Segment(Point(0, 4), Point(-1, 3)) + assert s2.projection(s1) == EmptySet + assert l1.projection(p3) == p1 + assert l1.projection(Ray(p1, Point(-1, 5))) == Ray(Point(0, 0), Point(2, 2)) + assert l1.projection(Ray(p1, Point(-1, 1))) == p1 + assert r1.projection(Ray(Point(1, 1), Point(-1, -1))) == Point(1, 1) + assert r1.projection(Ray(Point(0, 4), Point(-1, -5))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Segment(Point(-1, 5), Point(-5, -10))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Ray(Point(1, 1), Point(-1, -1))) == Point(1, 1) + assert r1.projection(Ray(Point(0, 4), Point(-1, -5))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Segment(Point(-1, 5), Point(-5, -10))) == Segment(Point(1, 1), Point(2, 2)) + + assert l3.projection(Ray3D(p2, Point3D(-1, 5, 0))) == Ray3D(Point3D(0, 0, 0), Point3D(Rational(4, 3), Rational(4, 3), Rational(4, 3))) + assert l3.projection(Ray3D(p2, Point3D(-1, 1, 1))) == Ray3D(Point3D(0, 0, 0), Point3D(Rational(1, 3), Rational(1, 3), Rational(1, 3))) + assert l2.projection(Point3D(5, 5, 0)) == Point3D(5, 0) + assert l2.projection(Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0))).equals(l2) + + +def test_perpendicular_line(): + # 3d - requires a particular orthogonal to be selected + p1, p2, p3 = Point(0, 0, 0), Point(2, 3, 4), Point(-2, 2, 0) + l = Line(p1, p2) + p = l.perpendicular_line(p3) + assert p.p1 == p3 + assert p.p2 in l + # 2d - does not require special selection + p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2) + l = Line(p1, p2) + p = l.perpendicular_line(p3) + assert p.p1 == p3 + # p is directed from l to p3 + assert p.direction.unit == (p3 - l.projection(p3)).unit + + +def test_perpendicular_bisector(): + s1 = Segment(Point(0, 0), Point(1, 1)) + aline = Line(Point(S.Half, S.Half), Point(Rational(3, 2), Rational(-1, 2))) + on_line = Segment(Point(S.Half, S.Half), Point(Rational(3, 2), Rational(-1, 2))).midpoint + + assert s1.perpendicular_bisector().equals(aline) + assert s1.perpendicular_bisector(on_line).equals(Segment(s1.midpoint, on_line)) + assert s1.perpendicular_bisector(on_line + (1, 0)).equals(aline) + + +def test_raises(): + d, e = symbols('a,b', real=True) + s = Segment((d, 0), (e, 0)) + + raises(TypeError, lambda: Line((1, 1), 1)) + raises(ValueError, lambda: Line(Point(0, 0), Point(0, 0))) + raises(Undecidable, lambda: Point(2 * d, 0) in s) + raises(ValueError, lambda: Ray3D(Point(1.0, 1.0))) + raises(ValueError, lambda: Line3D(Point3D(0, 0, 0), Point3D(0, 0, 0))) + raises(TypeError, lambda: Line3D((1, 1), 1)) + raises(ValueError, lambda: Line3D(Point3D(0, 0, 0))) + raises(TypeError, lambda: Ray((1, 1), 1)) + raises(GeometryError, lambda: Line(Point(0, 0), Point(1, 0)) + .projection(Circle(Point(0, 0), 1))) + + +def test_ray_generation(): + assert Ray((1, 1), angle=pi / 4) == Ray((1, 1), (2, 2)) + assert Ray((1, 1), angle=pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=-pi / 2) == Ray((1, 1), (1, 0)) + assert Ray((1, 1), angle=-3 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=5 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=5.0 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=pi) == Ray((1, 1), (0, 1)) + assert Ray((1, 1), angle=3.0 * pi) == Ray((1, 1), (0, 1)) + assert Ray((1, 1), angle=4.0 * pi) == Ray((1, 1), (2, 1)) + assert Ray((1, 1), angle=0) == Ray((1, 1), (2, 1)) + assert Ray((1, 1), angle=4.05 * pi) == Ray(Point(1, 1), + Point(2, -sqrt(5) * sqrt(2 * sqrt(5) + 10) / 4 - sqrt( + 2 * sqrt(5) + 10) / 4 + 2 + sqrt(5))) + assert Ray((1, 1), angle=4.02 * pi) == Ray(Point(1, 1), + Point(2, 1 + tan(4.02 * pi))) + assert Ray((1, 1), angle=5) == Ray((1, 1), (2, 1 + tan(5))) + + assert Ray3D((1, 1, 1), direction_ratio=[4, 4, 4]) == Ray3D(Point3D(1, 1, 1), Point3D(5, 5, 5)) + assert Ray3D((1, 1, 1), direction_ratio=[1, 2, 3]) == Ray3D(Point3D(1, 1, 1), Point3D(2, 3, 4)) + assert Ray3D((1, 1, 1), direction_ratio=[1, 1, 1]) == Ray3D(Point3D(1, 1, 1), Point3D(2, 2, 2)) + + +def test_issue_7814(): + circle = Circle(Point(x, 0), y) + line = Line(Point(k, z), slope=0) + _s = sqrt((y - z)*(y + z)) + assert line.intersection(circle) == [Point2D(x + _s, z), Point2D(x - _s, z)] + + +def test_issue_2941(): + def _check(): + for f, g in cartes(*[(Line, Ray, Segment)] * 2): + l1 = f(a, b) + l2 = g(c, d) + assert l1.intersection(l2) == l2.intersection(l1) + # intersect at end point + c, d = (-2, -2), (-2, 0) + a, b = (0, 0), (1, 1) + _check() + # midline intersection + c, d = (-2, -3), (-2, 0) + _check() + + +def test_parameter_value(): + t = Symbol('t') + p1, p2 = Point(0, 1), Point(5, 6) + l = Line(p1, p2) + assert l.parameter_value((5, 6), t) == {t: 1} + raises(ValueError, lambda: l.parameter_value((0, 0), t)) + + +def test_bisectors(): + r1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + r2 = Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0)) + bisections = r1.bisectors(r2) + assert bisections == [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 0)), + Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))] + ans = [Line3D(Point3D(0, 0, 0), Point3D(1, 0, 1)), + Line3D(Point3D(0, 0, 0), Point3D(-1, 0, 1))] + l1 = (0, 0, 0), (0, 0, 1) + l2 = (0, 0), (1, 0) + for a, b in cartes((Line, Segment, Ray), repeat=2): + assert a(*l1).bisectors(b(*l2)) == ans + + +def test_issue_8615(): + a = Line3D(Point3D(6, 5, 0), Point3D(6, -6, 0)) + b = Line3D(Point3D(6, -1, 19/10), Point3D(6, -1, 0)) + assert a.intersection(b) == [Point3D(6, -1, 0)] + + +def test_issue_12598(): + r1 = Ray(Point(0, 1), Point(0.98, 0.79).n(2)) + r2 = Ray(Point(0, 0), Point(0.71, 0.71).n(2)) + assert str(r1.intersection(r2)[0]) == 'Point2D(0.82, 0.82)' + l1 = Line((0, 0), (1, 1)) + l2 = Segment((-1, 1), (0, -1)).n(2) + assert str(l1.intersection(l2)[0]) == 'Point2D(-0.33, -0.33)' + l2 = Segment((-1, 1), (-1/2, 1/2)).n(2) + assert not l1.intersection(l2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_parabola.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_parabola.py new file mode 100644 index 0000000000000000000000000000000000000000..2a683f26619952d93475aca9ebd3d47cfb3657a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_parabola.py @@ -0,0 +1,143 @@ +from sympy.core.numbers import (Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry.ellipse import (Circle, Ellipse) +from sympy.geometry.line import (Line, Ray2D, Segment2D) +from sympy.geometry.parabola import Parabola +from sympy.geometry.point import (Point, Point2D) +from sympy.testing.pytest import raises + +from sympy.abc import x, y + +def test_parabola_geom(): + a, b = symbols('a b') + p1 = Point(0, 0) + p2 = Point(3, 7) + p3 = Point(0, 4) + p4 = Point(6, 0) + p5 = Point(a, a) + d1 = Line(Point(4, 0), Point(4, 9)) + d2 = Line(Point(7, 6), Point(3, 6)) + d3 = Line(Point(4, 0), slope=oo) + d4 = Line(Point(7, 6), slope=0) + d5 = Line(Point(b, a), slope=oo) + d6 = Line(Point(a, b), slope=0) + + half = S.Half + + pa1 = Parabola(None, d2) + pa2 = Parabola(directrix=d1) + pa3 = Parabola(p1, d1) + pa4 = Parabola(p2, d2) + pa5 = Parabola(p2, d4) + pa6 = Parabola(p3, d2) + pa7 = Parabola(p2, d1) + pa8 = Parabola(p4, d1) + pa9 = Parabola(p4, d3) + pa10 = Parabola(p5, d5) + pa11 = Parabola(p5, d6) + d = Line(Point(3, 7), Point(2, 9)) + pa12 = Parabola(Point(7, 8), d) + pa12r = Parabola(Point(7, 8).reflect(d), d) + + raises(ValueError, lambda: + Parabola(Point(7, 8, 9), Line(Point(6, 7), Point(7, 7)))) + raises(ValueError, lambda: + Parabola(Point(0, 2), Line(Point(7, 2), Point(6, 2)))) + raises(ValueError, lambda: Parabola(Point(7, 8), Point(3, 8))) + + # Basic Stuff + assert pa1.focus == Point(0, 0) + assert pa1.ambient_dimension == S(2) + assert pa2 == pa3 + assert pa4 != pa7 + assert pa6 != pa7 + assert pa6.focus == Point2D(0, 4) + assert pa6.focal_length == 1 + assert pa6.p_parameter == -1 + assert pa6.vertex == Point2D(0, 5) + assert pa6.eccentricity == 1 + assert pa7.focus == Point2D(3, 7) + assert pa7.focal_length == half + assert pa7.p_parameter == -half + assert pa7.vertex == Point2D(7*half, 7) + assert pa4.focal_length == half + assert pa4.p_parameter == half + assert pa4.vertex == Point2D(3, 13*half) + assert pa8.focal_length == 1 + assert pa8.p_parameter == 1 + assert pa8.vertex == Point2D(5, 0) + assert pa4.focal_length == pa5.focal_length + assert pa4.p_parameter == pa5.p_parameter + assert pa4.vertex == pa5.vertex + assert pa4.equation() == pa5.equation() + assert pa8.focal_length == pa9.focal_length + assert pa8.p_parameter == pa9.p_parameter + assert pa8.vertex == pa9.vertex + assert pa8.equation() == pa9.equation() + assert pa10.focal_length == pa11.focal_length == sqrt((a - b) ** 2) / 2 # if a, b real == abs(a - b)/2 + assert pa11.vertex == Point(*pa10.vertex[::-1]) == Point(a, + a - sqrt((a - b)**2)*sign(a - b)/2) # change axis x->y, y->x on pa10 + aos = pa12.axis_of_symmetry + assert aos == Line(Point(7, 8), Point(5, 7)) + assert pa12.directrix == Line(Point(3, 7), Point(2, 9)) + assert pa12.directrix.angle_between(aos) == S.Pi/2 + assert pa12.eccentricity == 1 + assert pa12.equation(x, y) == (x - 7)**2 + (y - 8)**2 - (-2*x - y + 13)**2/5 + assert pa12.focal_length == 9*sqrt(5)/10 + assert pa12.focus == Point(7, 8) + assert pa12.p_parameter == 9*sqrt(5)/10 + assert pa12.vertex == Point2D(S(26)/5, S(71)/10) + assert pa12r.focal_length == 9*sqrt(5)/10 + assert pa12r.focus == Point(-S(1)/5, S(22)/5) + assert pa12r.p_parameter == -9*sqrt(5)/10 + assert pa12r.vertex == Point(S(8)/5, S(53)/10) + + +def test_parabola_intersection(): + l1 = Line(Point(1, -2), Point(-1,-2)) + l2 = Line(Point(1, 2), Point(-1,2)) + l3 = Line(Point(1, 0), Point(-1,0)) + + p1 = Point(0,0) + p2 = Point(0, -2) + p3 = Point(120, -12) + parabola1 = Parabola(p1, l1) + + # parabola with parabola + assert parabola1.intersection(parabola1) == [parabola1] + assert parabola1.intersection(Parabola(p1, l2)) == [Point2D(-2, 0), Point2D(2, 0)] + assert parabola1.intersection(Parabola(p2, l3)) == [Point2D(0, -1)] + assert parabola1.intersection(Parabola(Point(16, 0), l1)) == [Point2D(8, 15)] + assert parabola1.intersection(Parabola(Point(0, 16), l1)) == [Point2D(-6, 8), Point2D(6, 8)] + assert parabola1.intersection(Parabola(p3, l3)) == [] + # parabola with point + assert parabola1.intersection(p1) == [] + assert parabola1.intersection(Point2D(0, -1)) == [Point2D(0, -1)] + assert parabola1.intersection(Point2D(4, 3)) == [Point2D(4, 3)] + # parabola with line + assert parabola1.intersection(Line(Point2D(-7, 3), Point(12, 3))) == [Point2D(-4, 3), Point2D(4, 3)] + assert parabola1.intersection(Line(Point(-4, -1), Point(4, -1))) == [Point(0, -1)] + assert parabola1.intersection(Line(Point(2, 0), Point(0, -2))) == [Point2D(2, 0)] + raises(TypeError, lambda: parabola1.intersection(Line(Point(0, 0, 0), Point(1, 1, 1)))) + # parabola with segment + assert parabola1.intersection(Segment2D((-4, -5), (4, 3))) == [Point2D(0, -1), Point2D(4, 3)] + assert parabola1.intersection(Segment2D((0, -5), (0, 6))) == [Point2D(0, -1)] + assert parabola1.intersection(Segment2D((-12, -65), (14, -68))) == [] + # parabola with ray + assert parabola1.intersection(Ray2D((-4, -5), (4, 3))) == [Point2D(0, -1), Point2D(4, 3)] + assert parabola1.intersection(Ray2D((0, 7), (1, 14))) == [Point2D(14 + 2*sqrt(57), 105 + 14*sqrt(57))] + assert parabola1.intersection(Ray2D((0, 7), (0, 14))) == [] + # parabola with ellipse/circle + assert parabola1.intersection(Circle(p1, 2)) == [Point2D(-2, 0), Point2D(2, 0)] + assert parabola1.intersection(Circle(p2, 1)) == [Point2D(0, -1)] + assert parabola1.intersection(Ellipse(p2, 2, 1)) == [Point2D(0, -1)] + assert parabola1.intersection(Ellipse(Point(0, 19), 5, 7)) == [] + assert parabola1.intersection(Ellipse((0, 3), 12, 4)) == [ + Point2D(0, -1), + Point2D(-4*sqrt(17)/3, Rational(59, 9)), + Point2D(4*sqrt(17)/3, Rational(59, 9))] + # parabola with unsupported type + raises(TypeError, lambda: parabola1.intersection(2)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_plane.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_plane.py new file mode 100644 index 0000000000000000000000000000000000000000..1010fce5c3bc68348eacee13f29c1d7588f17e39 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_plane.py @@ -0,0 +1,268 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.geometry import Line, Point, Ray, Segment, Point3D, Line3D, Ray3D, Segment3D, Plane, Circle +from sympy.geometry.util import are_coplanar +from sympy.testing.pytest import raises + + +def test_plane(): + x, y, z, u, v = symbols('x y z u v', real=True) + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(1, 2, 3) + pl3 = Plane(p1, p2, p3) + pl4 = Plane(p1, normal_vector=(1, 1, 1)) + pl4b = Plane(p1, p2) + pl5 = Plane(p3, normal_vector=(1, 2, 3)) + pl6 = Plane(Point3D(2, 3, 7), normal_vector=(2, 2, 2)) + pl7 = Plane(Point3D(1, -5, -6), normal_vector=(1, -2, 1)) + pl8 = Plane(p1, normal_vector=(0, 0, 1)) + pl9 = Plane(p1, normal_vector=(0, 12, 0)) + pl10 = Plane(p1, normal_vector=(-2, 0, 0)) + pl11 = Plane(p2, normal_vector=(0, 0, 1)) + l1 = Line3D(Point3D(5, 0, 0), Point3D(1, -1, 1)) + l2 = Line3D(Point3D(0, -2, 0), Point3D(3, 1, 1)) + l3 = Line3D(Point3D(0, -1, 0), Point3D(5, -1, 9)) + + raises(ValueError, lambda: Plane(p1, p1, p1)) + + assert Plane(p1, p2, p3) != Plane(p1, p3, p2) + assert Plane(p1, p2, p3).is_coplanar(Plane(p1, p3, p2)) + assert Plane(p1, p2, p3).is_coplanar(p1) + assert Plane(p1, p2, p3).is_coplanar(Circle(p1, 1)) is False + assert Plane(p1, normal_vector=(0, 0, 1)).is_coplanar(Circle(p1, 1)) + + assert pl3 == Plane(Point3D(0, 0, 0), normal_vector=(1, -2, 1)) + assert pl3 != pl4 + assert pl4 == pl4b + assert pl5 == Plane(Point3D(1, 2, 3), normal_vector=(1, 2, 3)) + + assert pl5.equation(x, y, z) == x + 2*y + 3*z - 14 + assert pl3.equation(x, y, z) == x - 2*y + z + + assert pl3.p1 == p1 + assert pl4.p1 == p1 + assert pl5.p1 == p3 + + assert pl4.normal_vector == (1, 1, 1) + assert pl5.normal_vector == (1, 2, 3) + + assert p1 in pl3 + assert p1 in pl4 + assert p3 in pl5 + + assert pl3.projection(Point(0, 0)) == p1 + p = pl3.projection(Point3D(1, 1, 0)) + assert p == Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6)) + assert p in pl3 + + l = pl3.projection_line(Line(Point(0, 0), Point(1, 1))) + assert l == Line3D(Point3D(0, 0, 0), Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6))) + assert l in pl3 + # get a segment that does not intersect the plane which is also + # parallel to pl3's normal veector + t = Dummy() + r = pl3.random_point() + a = pl3.perpendicular_line(r).arbitrary_point(t) + s = Segment3D(a.subs(t, 1), a.subs(t, 2)) + assert s.p1 not in pl3 and s.p2 not in pl3 + assert pl3.projection_line(s).equals(r) + assert pl3.projection_line(Segment(Point(1, 0), Point(1, 1))) == \ + Segment3D(Point3D(Rational(5, 6), Rational(1, 3), Rational(-1, 6)), Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6))) + assert pl6.projection_line(Ray(Point(1, 0), Point(1, 1))) == \ + Ray3D(Point3D(Rational(14, 3), Rational(11, 3), Rational(11, 3)), Point3D(Rational(13, 3), Rational(13, 3), Rational(10, 3))) + assert pl3.perpendicular_line(r.args) == pl3.perpendicular_line(r) + + assert pl3.is_parallel(pl6) is False + assert pl4.is_parallel(pl6) + assert pl3.is_parallel(Line(p1, p2)) + assert pl6.is_parallel(l1) is False + + assert pl3.is_perpendicular(pl6) + assert pl4.is_perpendicular(pl7) + assert pl6.is_perpendicular(pl7) + assert pl6.is_perpendicular(pl4) is False + assert pl6.is_perpendicular(l1) is False + assert pl6.is_perpendicular(Line((0, 0, 0), (1, 1, 1))) + assert pl6.is_perpendicular((1, 1)) is False + + assert pl6.distance(pl6.arbitrary_point(u, v)) == 0 + assert pl7.distance(pl7.arbitrary_point(u, v)) == 0 + assert pl6.distance(pl6.arbitrary_point(t)) == 0 + assert pl7.distance(pl7.arbitrary_point(t)) == 0 + assert pl6.p1.distance(pl6.arbitrary_point(t)).simplify() == 1 + assert pl7.p1.distance(pl7.arbitrary_point(t)).simplify() == 1 + assert pl3.arbitrary_point(t) == Point3D(-sqrt(30)*sin(t)/30 + \ + 2*sqrt(5)*cos(t)/5, sqrt(30)*sin(t)/15 + sqrt(5)*cos(t)/5, sqrt(30)*sin(t)/6) + assert pl3.arbitrary_point(u, v) == Point3D(2*u - v, u + 2*v, 5*v) + + assert pl7.distance(Point3D(1, 3, 5)) == 5*sqrt(6)/6 + assert pl6.distance(Point3D(0, 0, 0)) == 4*sqrt(3) + assert pl6.distance(pl6.p1) == 0 + assert pl7.distance(pl6) == 0 + assert pl7.distance(l1) == 0 + assert pl6.distance(Segment3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == \ + pl6.distance(Point3D(1, 3, 4)) == 4*sqrt(3)/3 + assert pl6.distance(Segment3D(Point3D(1, 3, 4), Point3D(0, 3, 7))) == \ + pl6.distance(Point3D(0, 3, 7)) == 2*sqrt(3)/3 + assert pl6.distance(Segment3D(Point3D(0, 3, 7), Point3D(-1, 3, 10))) == 0 + assert pl6.distance(Segment3D(Point3D(-1, 3, 10), Point3D(-2, 3, 13))) == 0 + assert pl6.distance(Segment3D(Point3D(-2, 3, 13), Point3D(-3, 3, 16))) == \ + pl6.distance(Point3D(-2, 3, 13)) == 2*sqrt(3)/3 + assert pl6.distance(Plane(Point3D(5, 5, 5), normal_vector=(8, 8, 8))) == sqrt(3) + assert pl6.distance(Ray3D(Point3D(1, 3, 4), direction_ratio=[1, 0, -3])) == 4*sqrt(3)/3 + assert pl6.distance(Ray3D(Point3D(2, 3, 1), direction_ratio=[-1, 0, 3])) == 0 + + + assert pl6.angle_between(pl3) == pi/2 + assert pl6.angle_between(pl6) == 0 + assert pl6.angle_between(pl4) == 0 + assert pl7.angle_between(Line3D(Point3D(2, 3, 5), Point3D(2, 4, 6))) == \ + -asin(sqrt(3)/6) + assert pl6.angle_between(Ray3D(Point3D(2, 4, 1), Point3D(6, 5, 3))) == \ + asin(sqrt(7)/3) + assert pl7.angle_between(Segment3D(Point3D(5, 6, 1), Point3D(1, 2, 4))) == \ + asin(7*sqrt(246)/246) + + assert are_coplanar(l1, l2, l3) is False + assert are_coplanar(l1) is False + assert are_coplanar(Point3D(2, 7, 2), Point3D(0, 0, 2), + Point3D(1, 1, 2), Point3D(1, 2, 2)) + assert are_coplanar(Plane(p1, p2, p3), Plane(p1, p3, p2)) + assert Plane.are_concurrent(pl3, pl4, pl5) is False + assert Plane.are_concurrent(pl6) is False + raises(ValueError, lambda: Plane.are_concurrent(Point3D(0, 0, 0))) + raises(ValueError, lambda: Plane((1, 2, 3), normal_vector=(0, 0, 0))) + + assert pl3.parallel_plane(Point3D(1, 2, 5)) == Plane(Point3D(1, 2, 5), \ + normal_vector=(1, -2, 1)) + + # perpendicular_plane + p = Plane((0, 0, 0), (1, 0, 0)) + # default + assert p.perpendicular_plane() == Plane(Point3D(0, 0, 0), (0, 1, 0)) + # 1 pt + assert p.perpendicular_plane(Point3D(1, 0, 1)) == \ + Plane(Point3D(1, 0, 1), (0, 1, 0)) + # pts as tuples + assert p.perpendicular_plane((1, 0, 1), (1, 1, 1)) == \ + Plane(Point3D(1, 0, 1), (0, 0, -1)) + # more than two planes + raises(ValueError, lambda: p.perpendicular_plane((1, 0, 1), (1, 1, 1), (1, 1, 0))) + + a, b = Point3D(0, 0, 0), Point3D(0, 1, 0) + Z = (0, 0, 1) + p = Plane(a, normal_vector=Z) + # case 4 + assert p.perpendicular_plane(a, b) == Plane(a, (1, 0, 0)) + n = Point3D(*Z) + # case 1 + assert p.perpendicular_plane(a, n) == Plane(a, (-1, 0, 0)) + # case 2 + assert Plane(a, normal_vector=b.args).perpendicular_plane(a, a + b) == \ + Plane(Point3D(0, 0, 0), (1, 0, 0)) + # case 1&3 + assert Plane(b, normal_vector=Z).perpendicular_plane(b, b + n) == \ + Plane(Point3D(0, 1, 0), (-1, 0, 0)) + # case 2&3 + assert Plane(b, normal_vector=b.args).perpendicular_plane(n, n + b) == \ + Plane(Point3D(0, 0, 1), (1, 0, 0)) + + p = Plane(a, normal_vector=(0, 0, 1)) + assert p.perpendicular_plane() == Plane(a, normal_vector=(1, 0, 0)) + + assert pl6.intersection(pl6) == [pl6] + assert pl4.intersection(pl4.p1) == [pl4.p1] + assert pl3.intersection(pl6) == [ + Line3D(Point3D(8, 4, 0), Point3D(2, 4, 6))] + assert pl3.intersection(Line3D(Point3D(1,2,4), Point3D(4,4,2))) == [ + Point3D(2, Rational(8, 3), Rational(10, 3))] + assert pl3.intersection(Plane(Point3D(6, 0, 0), normal_vector=(2, -5, 3)) + ) == [Line3D(Point3D(-24, -12, 0), Point3D(-25, -13, -1))] + assert pl6.intersection(Ray3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == [ + Point3D(-1, 3, 10)] + assert pl6.intersection(Segment3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == [] + assert pl7.intersection(Line(Point(2, 3), Point(4, 2))) == [ + Point3D(Rational(13, 2), Rational(3, 4), 0)] + r = Ray(Point(2, 3), Point(4, 2)) + assert Plane((1,2,0), normal_vector=(0,0,1)).intersection(r) == [ + Ray3D(Point(2, 3), Point(4, 2))] + assert pl9.intersection(pl8) == [Line3D(Point3D(0, 0, 0), Point3D(12, 0, 0))] + assert pl10.intersection(pl11) == [Line3D(Point3D(0, 0, 1), Point3D(0, 2, 1))] + assert pl4.intersection(pl8) == [Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))] + assert pl11.intersection(pl8) == [] + assert pl9.intersection(pl11) == [Line3D(Point3D(0, 0, 1), Point3D(12, 0, 1))] + assert pl9.intersection(pl4) == [Line3D(Point3D(0, 0, 0), Point3D(12, 0, -12))] + assert pl3.random_point() in pl3 + assert pl3.random_point(seed=1) in pl3 + + # test geometrical entity using equals + assert pl4.intersection(pl4.p1)[0].equals(pl4.p1) + assert pl3.intersection(pl6)[0].equals(Line3D(Point3D(8, 4, 0), Point3D(2, 4, 6))) + pl8 = Plane((1, 2, 0), normal_vector=(0, 0, 1)) + assert pl8.intersection(Line3D(p1, (1, 12, 0)))[0].equals(Line((0, 0, 0), (0.1, 1.2, 0))) + assert pl8.intersection(Ray3D(p1, (1, 12, 0)))[0].equals(Ray((0, 0, 0), (1, 12, 0))) + assert pl8.intersection(Segment3D(p1, (21, 1, 0)))[0].equals(Segment3D(p1, (21, 1, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(0, 0, 112)))[0].equals(pl8) + assert pl8.intersection(Plane(p1, normal_vector=(0, 12, 0)))[0].equals( + Line3D(p1, direction_ratio=(112 * pi, 0, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(11, 0, 1)))[0].equals( + Line3D(p1, direction_ratio=(0, -11, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(1, 0, 11)))[0].equals( + Line3D(p1, direction_ratio=(0, 11, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(-1, -1, -11)))[0].equals( + Line3D(p1, direction_ratio=(1, -1, 0))) + assert pl3.random_point() in pl3 + assert len(pl8.intersection(Ray3D(Point3D(0, 2, 3), Point3D(1, 0, 3)))) == 0 + # check if two plane are equals + assert pl6.intersection(pl6)[0].equals(pl6) + assert pl8.equals(Plane(p1, normal_vector=(0, 12, 0))) is False + assert pl8.equals(pl8) + assert pl8.equals(Plane(p1, normal_vector=(0, 0, -12))) + assert pl8.equals(Plane(p1, normal_vector=(0, 0, -12*sqrt(3)))) + assert pl8.equals(p1) is False + + # issue 8570 + l2 = Line3D(Point3D(Rational(50000004459633, 5000000000000), + Rational(-891926590718643, 1000000000000000), + Rational(231800966893633, 100000000000000)), + Point3D(Rational(50000004459633, 50000000000000), + Rational(-222981647679771, 250000000000000), + Rational(231800966893633, 100000000000000))) + + p2 = Plane(Point3D(Rational(402775636372767, 100000000000000), + Rational(-97224357654973, 100000000000000), + Rational(216793600814789, 100000000000000)), + (-S('9.00000087501922'), -S('4.81170658872543e-13'), + S('0.0'))) + + assert str([i.n(2) for i in p2.intersection(l2)]) == \ + '[Point3D(4.0, -0.89, 2.3)]' + + +def test_dimension_normalization(): + A = Plane(Point3D(1, 1, 2), normal_vector=(1, 1, 1)) + b = Point(1, 1) + assert A.projection(b) == Point(Rational(5, 3), Rational(5, 3), Rational(2, 3)) + + a, b = Point(0, 0), Point3D(0, 1) + Z = (0, 0, 1) + p = Plane(a, normal_vector=Z) + assert p.perpendicular_plane(a, b) == Plane(Point3D(0, 0, 0), (1, 0, 0)) + assert Plane((1, 2, 1), (2, 1, 0), (3, 1, 2) + ).intersection((2, 1)) == [Point(2, 1, 0)] + + +def test_parameter_value(): + t, u, v = symbols("t, u v") + p1, p2, p3 = Point(0, 0, 0), Point(0, 0, 1), Point(0, 1, 0) + p = Plane(p1, p2, p3) + assert p.parameter_value((0, -3, 2), t) == {t: asin(2*sqrt(13)/13)} + assert p.parameter_value((0, -3, 2), u, v) == {u: 3, v: 2} + assert p.parameter_value(p1, t) == p1 + raises(ValueError, lambda: p.parameter_value((1, 0, 0), t)) + raises(ValueError, lambda: p.parameter_value(Line(Point(0, 0), Point(1, 1)), t)) + raises(ValueError, lambda: p.parameter_value((0, -3, 2), t, 1)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_point.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_point.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2b2768eb3fba2009f702351de1aac3ed6e71d4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_point.py @@ -0,0 +1,481 @@ +from sympy.core.basic import Basic +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Line, Point, Point2D, Point3D, Line3D, Plane +from sympy.geometry.entity import rotate, scale, translate, GeometryEntity +from sympy.matrices import Matrix +from sympy.utilities.iterables import subsets, permutations, cartes +from sympy.utilities.misc import Undecidable +from sympy.testing.pytest import raises, warns + + +def test_point(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + x1 = Symbol('x1', real=True) + x2 = Symbol('x2', real=True) + y1 = Symbol('y1', real=True) + y2 = Symbol('y2', real=True) + half = S.Half + p1 = Point(x1, x2) + p2 = Point(y1, y2) + p3 = Point(0, 0) + p4 = Point(1, 1) + p5 = Point(0, 1) + line = Line(Point(1, 0), slope=1) + + assert p1 in p1 + assert p1 not in p2 + assert p2.y == y2 + assert (p3 + p4) == p4 + assert (p2 - p1) == Point(y1 - x1, y2 - x2) + assert -p2 == Point(-y1, -y2) + raises(TypeError, lambda: Point(1)) + raises(ValueError, lambda: Point([1])) + raises(ValueError, lambda: Point(3, I)) + raises(ValueError, lambda: Point(2*I, I)) + raises(ValueError, lambda: Point(3 + I, I)) + + assert Point(34.05, sqrt(3)) == Point(Rational(681, 20), sqrt(3)) + assert Point.midpoint(p3, p4) == Point(half, half) + assert Point.midpoint(p1, p4) == Point(half + half*x1, half + half*x2) + assert Point.midpoint(p2, p2) == p2 + assert p2.midpoint(p2) == p2 + assert p1.origin == Point(0, 0) + + assert Point.distance(p3, p4) == sqrt(2) + assert Point.distance(p1, p1) == 0 + assert Point.distance(p3, p2) == sqrt(p2.x**2 + p2.y**2) + raises(TypeError, lambda: Point.distance(p1, 0)) + raises(TypeError, lambda: Point.distance(p1, GeometryEntity())) + + # distance should be symmetric + assert p1.distance(line) == line.distance(p1) + assert p4.distance(line) == line.distance(p4) + + assert Point.taxicab_distance(p4, p3) == 2 + + assert Point.canberra_distance(p4, p5) == 1 + raises(ValueError, lambda: Point.canberra_distance(p3, p3)) + + p1_1 = Point(x1, x1) + p1_2 = Point(y2, y2) + p1_3 = Point(x1 + 1, x1) + assert Point.is_collinear(p3) + + with warns(UserWarning, test_stacklevel=False): + assert Point.is_collinear(p3, Point(p3, dim=4)) + assert p3.is_collinear() + assert Point.is_collinear(p3, p4) + assert Point.is_collinear(p3, p4, p1_1, p1_2) + assert Point.is_collinear(p3, p4, p1_1, p1_3) is False + assert Point.is_collinear(p3, p3, p4, p5) is False + + raises(TypeError, lambda: Point.is_collinear(line)) + raises(TypeError, lambda: p1_1.is_collinear(line)) + + assert p3.intersection(Point(0, 0)) == [p3] + assert p3.intersection(p4) == [] + assert p3.intersection(line) == [] + with warns(UserWarning, test_stacklevel=False): + assert Point.intersection(Point(0, 0, 0), Point(0, 0)) == [Point(0, 0, 0)] + + x_pos = Symbol('x', positive=True) + p2_1 = Point(x_pos, 0) + p2_2 = Point(0, x_pos) + p2_3 = Point(-x_pos, 0) + p2_4 = Point(0, -x_pos) + p2_5 = Point(x_pos, 5) + assert Point.is_concyclic(p2_1) + assert Point.is_concyclic(p2_1, p2_2) + assert Point.is_concyclic(p2_1, p2_2, p2_3, p2_4) + for pts in permutations((p2_1, p2_2, p2_3, p2_5)): + assert Point.is_concyclic(*pts) is False + assert Point.is_concyclic(p4, p4 * 2, p4 * 3) is False + assert Point(0, 0).is_concyclic((1, 1), (2, 2), (2, 1)) is False + assert Point.is_concyclic(Point(0, 0, 0, 0), Point(1, 0, 0, 0), Point(1, 1, 0, 0), Point(1, 1, 1, 0)) is False + + assert p1.is_scalar_multiple(p1) + assert p1.is_scalar_multiple(2*p1) + assert not p1.is_scalar_multiple(p2) + assert Point.is_scalar_multiple(Point(1, 1), (-1, -1)) + assert Point.is_scalar_multiple(Point(0, 0), (0, -1)) + # test when is_scalar_multiple can't be determined + raises(Undecidable, lambda: Point.is_scalar_multiple(Point(sympify("x1%y1"), sympify("x2%y2")), Point(0, 1))) + + assert Point(0, 1).orthogonal_direction == Point(1, 0) + assert Point(1, 0).orthogonal_direction == Point(0, 1) + + assert p1.is_zero is None + assert p3.is_zero + assert p4.is_zero is False + assert p1.is_nonzero is None + assert p3.is_nonzero is False + assert p4.is_nonzero + + assert p4.scale(2, 3) == Point(2, 3) + assert p3.scale(2, 3) == p3 + + assert p4.rotate(pi, Point(0.5, 0.5)) == p3 + assert p1.__radd__(p2) == p1.midpoint(p2).scale(2, 2) + assert (-p3).__rsub__(p4) == p3.midpoint(p4).scale(2, 2) + + assert p4 * 5 == Point(5, 5) + assert p4 / 5 == Point(0.2, 0.2) + assert 5 * p4 == Point(5, 5) + + raises(ValueError, lambda: Point(0, 0) + 10) + + # Point differences should be simplified + assert Point(x*(x - 1), y) - Point(x**2 - x, y + 1) == Point(0, -1) + + a, b = S.Half, Rational(1, 3) + assert Point(a, b).evalf(2) == \ + Point(a.n(2), b.n(2), evaluate=False) + raises(ValueError, lambda: Point(1, 2) + 1) + + # test project + assert Point.project((0, 1), (1, 0)) == Point(0, 0) + assert Point.project((1, 1), (1, 0)) == Point(1, 0) + raises(ValueError, lambda: Point.project(p1, Point(0, 0))) + + # test transformations + p = Point(1, 0) + assert p.rotate(pi/2) == Point(0, 1) + assert p.rotate(pi/2, p) == p + p = Point(1, 1) + assert p.scale(2, 3) == Point(2, 3) + assert p.translate(1, 2) == Point(2, 3) + assert p.translate(1) == Point(2, 1) + assert p.translate(y=1) == Point(1, 2) + assert p.translate(*p.args) == Point(2, 2) + + # Check invalid input for transform + raises(ValueError, lambda: p3.transform(p3)) + raises(ValueError, lambda: p.transform(Matrix([[1, 0], [0, 1]]))) + + # test __contains__ + assert 0 in Point(0, 0, 0, 0) + assert 1 not in Point(0, 0, 0, 0) + + # test affine_rank + assert Point.affine_rank() == -1 + + +def test_point3D(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + x1 = Symbol('x1', real=True) + x2 = Symbol('x2', real=True) + x3 = Symbol('x3', real=True) + y1 = Symbol('y1', real=True) + y2 = Symbol('y2', real=True) + y3 = Symbol('y3', real=True) + half = S.Half + p1 = Point3D(x1, x2, x3) + p2 = Point3D(y1, y2, y3) + p3 = Point3D(0, 0, 0) + p4 = Point3D(1, 1, 1) + p5 = Point3D(0, 1, 2) + + assert p1 in p1 + assert p1 not in p2 + assert p2.y == y2 + assert (p3 + p4) == p4 + assert (p2 - p1) == Point3D(y1 - x1, y2 - x2, y3 - x3) + assert -p2 == Point3D(-y1, -y2, -y3) + + assert Point(34.05, sqrt(3)) == Point(Rational(681, 20), sqrt(3)) + assert Point3D.midpoint(p3, p4) == Point3D(half, half, half) + assert Point3D.midpoint(p1, p4) == Point3D(half + half*x1, half + half*x2, + half + half*x3) + assert Point3D.midpoint(p2, p2) == p2 + assert p2.midpoint(p2) == p2 + + assert Point3D.distance(p3, p4) == sqrt(3) + assert Point3D.distance(p1, p1) == 0 + assert Point3D.distance(p3, p2) == sqrt(p2.x**2 + p2.y**2 + p2.z**2) + + p1_1 = Point3D(x1, x1, x1) + p1_2 = Point3D(y2, y2, y2) + p1_3 = Point3D(x1 + 1, x1, x1) + Point3D.are_collinear(p3) + assert Point3D.are_collinear(p3, p4) + assert Point3D.are_collinear(p3, p4, p1_1, p1_2) + assert Point3D.are_collinear(p3, p4, p1_1, p1_3) is False + assert Point3D.are_collinear(p3, p3, p4, p5) is False + + assert p3.intersection(Point3D(0, 0, 0)) == [p3] + assert p3.intersection(p4) == [] + + + assert p4 * 5 == Point3D(5, 5, 5) + assert p4 / 5 == Point3D(0.2, 0.2, 0.2) + assert 5 * p4 == Point3D(5, 5, 5) + + raises(ValueError, lambda: Point3D(0, 0, 0) + 10) + + # Test coordinate properties + assert p1.coordinates == (x1, x2, x3) + assert p2.coordinates == (y1, y2, y3) + assert p3.coordinates == (0, 0, 0) + assert p4.coordinates == (1, 1, 1) + assert p5.coordinates == (0, 1, 2) + assert p5.x == 0 + assert p5.y == 1 + assert p5.z == 2 + + # Point differences should be simplified + assert Point3D(x*(x - 1), y, 2) - Point3D(x**2 - x, y + 1, 1) == \ + Point3D(0, -1, 1) + + a, b, c = S.Half, Rational(1, 3), Rational(1, 4) + assert Point3D(a, b, c).evalf(2) == \ + Point(a.n(2), b.n(2), c.n(2), evaluate=False) + raises(ValueError, lambda: Point3D(1, 2, 3) + 1) + + # test transformations + p = Point3D(1, 1, 1) + assert p.scale(2, 3) == Point3D(2, 3, 1) + assert p.translate(1, 2) == Point3D(2, 3, 1) + assert p.translate(1) == Point3D(2, 1, 1) + assert p.translate(z=1) == Point3D(1, 1, 2) + assert p.translate(*p.args) == Point3D(2, 2, 2) + + # Test __new__ + assert Point3D(0.1, 0.2, evaluate=False, on_morph='ignore').args[0].is_Float + + # Test length property returns correctly + assert p.length == 0 + assert p1_1.length == 0 + assert p1_2.length == 0 + + # Test are_colinear type error + raises(TypeError, lambda: Point3D.are_collinear(p, x)) + + # Test are_coplanar + assert Point.are_coplanar() + assert Point.are_coplanar((1, 2, 0), (1, 2, 0), (1, 3, 0)) + assert Point.are_coplanar((1, 2, 0), (1, 2, 3)) + with warns(UserWarning, test_stacklevel=False): + raises(ValueError, lambda: Point2D.are_coplanar((1, 2), (1, 2, 3))) + assert Point3D.are_coplanar((1, 2, 0), (1, 2, 3)) + assert Point.are_coplanar((0, 0, 0), (1, 1, 0), (1, 1, 1), (1, 2, 1)) is False + planar2 = Point3D(1, -1, 1) + planar3 = Point3D(-1, 1, 1) + assert Point3D.are_coplanar(p, planar2, planar3) == True + assert Point3D.are_coplanar(p, planar2, planar3, p3) == False + assert Point.are_coplanar(p, planar2) + planar2 = Point3D(1, 1, 2) + planar3 = Point3D(1, 1, 3) + assert Point3D.are_coplanar(p, planar2, planar3) # line, not plane + plane = Plane((1, 2, 1), (2, 1, 0), (3, 1, 2)) + assert Point.are_coplanar(*[plane.projection(((-1)**i, i)) for i in range(4)]) + + # all 2D points are coplanar + assert Point.are_coplanar(Point(x, y), Point(x, x + y), Point(y, x + 2)) is True + + # Test Intersection + assert planar2.intersection(Line3D(p, planar3)) == [Point3D(1, 1, 2)] + + # Test Scale + assert planar2.scale(1, 1, 1) == planar2 + assert planar2.scale(2, 2, 2, planar3) == Point3D(1, 1, 1) + assert planar2.scale(1, 1, 1, p3) == planar2 + + # Test Transform + identity = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + assert p.transform(identity) == p + trans = Matrix([[1, 0, 0, 1], [0, 1, 0, 1], [0, 0, 1, 1], [0, 0, 0, 1]]) + assert p.transform(trans) == Point3D(2, 2, 2) + raises(ValueError, lambda: p.transform(p)) + raises(ValueError, lambda: p.transform(Matrix([[1, 0], [0, 1]]))) + + # Test Equals + assert p.equals(x1) == False + + # Test __sub__ + p_4d = Point(0, 0, 0, 1) + with warns(UserWarning, test_stacklevel=False): + assert p - p_4d == Point(1, 1, 1, -1) + p_4d3d = Point(0, 0, 1, 0) + with warns(UserWarning, test_stacklevel=False): + assert p - p_4d3d == Point(1, 1, 0, 0) + + +def test_Point2D(): + + # Test Distance + p1 = Point2D(1, 5) + p2 = Point2D(4, 2.5) + p3 = (6, 3) + assert p1.distance(p2) == sqrt(61)/2 + assert p2.distance(p3) == sqrt(17)/2 + + # Test coordinates + assert p1.x == 1 + assert p1.y == 5 + assert p2.x == 4 + assert p2.y == S(5)/2 + assert p1.coordinates == (1, 5) + assert p2.coordinates == (4, S(5)/2) + + # test bounds + assert p1.bounds == (1, 5, 1, 5) + +def test_issue_9214(): + p1 = Point3D(4, -2, 6) + p2 = Point3D(1, 2, 3) + p3 = Point3D(7, 2, 3) + + assert Point3D.are_collinear(p1, p2, p3) is False + + +def test_issue_11617(): + p1 = Point3D(1,0,2) + p2 = Point2D(2,0) + + with warns(UserWarning, test_stacklevel=False): + assert p1.distance(p2) == sqrt(5) + + +def test_transform(): + p = Point(1, 1) + assert p.transform(rotate(pi/2)) == Point(-1, 1) + assert p.transform(scale(3, 2)) == Point(3, 2) + assert p.transform(translate(1, 2)) == Point(2, 3) + assert Point(1, 1).scale(2, 3, (4, 5)) == \ + Point(-2, -7) + assert Point(1, 1).translate(4, 5) == \ + Point(5, 6) + + +def test_concyclic_doctest_bug(): + p1, p2 = Point(-1, 0), Point(1, 0) + p3, p4 = Point(0, 1), Point(-1, 2) + assert Point.is_concyclic(p1, p2, p3) + assert not Point.is_concyclic(p1, p2, p3, p4) + + +def test_arguments(): + """Functions accepting `Point` objects in `geometry` + should also accept tuples and lists and + automatically convert them to points.""" + + singles2d = ((1,2), [1,2], Point(1,2)) + singles2d2 = ((1,3), [1,3], Point(1,3)) + doubles2d = cartes(singles2d, singles2d2) + p2d = Point2D(1,2) + singles3d = ((1,2,3), [1,2,3], Point(1,2,3)) + doubles3d = subsets(singles3d, 2) + p3d = Point3D(1,2,3) + singles4d = ((1,2,3,4), [1,2,3,4], Point(1,2,3,4)) + doubles4d = subsets(singles4d, 2) + p4d = Point(1,2,3,4) + + # test 2D + test_single = ['distance', 'is_scalar_multiple', 'taxicab_distance', 'midpoint', 'intersection', 'dot', 'equals', '__add__', '__sub__'] + test_double = ['is_concyclic', 'is_collinear'] + for p in singles2d: + Point2D(p) + for func in test_single: + for p in singles2d: + getattr(p2d, func)(p) + for func in test_double: + for p in doubles2d: + getattr(p2d, func)(*p) + + # test 3D + test_double = ['is_collinear'] + for p in singles3d: + Point3D(p) + for func in test_single: + for p in singles3d: + getattr(p3d, func)(p) + for func in test_double: + for p in doubles3d: + getattr(p3d, func)(*p) + + # test 4D + test_double = ['is_collinear'] + for p in singles4d: + Point(p) + for func in test_single: + for p in singles4d: + getattr(p4d, func)(p) + for func in test_double: + for p in doubles4d: + getattr(p4d, func)(*p) + + # test evaluate=False for ops + x = Symbol('x') + a = Point(0, 1) + assert a + (0.1, x) == Point(0.1, 1 + x, evaluate=False) + a = Point(0, 1) + assert a/10.0 == Point(0, 0.1, evaluate=False) + a = Point(0, 1) + assert a*10.0 == Point(0, 10.0, evaluate=False) + + # test evaluate=False when changing dimensions + u = Point(.1, .2, evaluate=False) + u4 = Point(u, dim=4, on_morph='ignore') + assert u4.args == (.1, .2, 0, 0) + assert all(i.is_Float for i in u4.args[:2]) + # and even when *not* changing dimensions + assert all(i.is_Float for i in Point(u).args) + + # never raise error if creating an origin + assert Point(dim=3, on_morph='error') + + # raise error with unmatched dimension + raises(ValueError, lambda: Point(1, 1, dim=3, on_morph='error')) + # test unknown on_morph + raises(ValueError, lambda: Point(1, 1, dim=3, on_morph='unknown')) + # test invalid expressions + raises(TypeError, lambda: Point(Basic(), Basic())) + +def test_unit(): + assert Point(1, 1).unit == Point(sqrt(2)/2, sqrt(2)/2) + + +def test_dot(): + raises(TypeError, lambda: Point(1, 2).dot(Line((0, 0), (1, 1)))) + + +def test__normalize_dimension(): + assert Point._normalize_dimension(Point(1, 2), Point(3, 4)) == [ + Point(1, 2), Point(3, 4)] + assert Point._normalize_dimension( + Point(1, 2), Point(3, 4, 0), on_morph='ignore') == [ + Point(1, 2, 0), Point(3, 4, 0)] + + +def test_issue_22684(): + # Used to give an error + with evaluate(False): + Point(1, 2) + + +def test_direction_cosine(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + + assert p1.direction_cosine(Point3D(1, 0, 0)) == [1, 0, 0] + assert p1.direction_cosine(Point3D(0, 1, 0)) == [0, 1, 0] + assert p1.direction_cosine(Point3D(0, 0, pi)) == [0, 0, 1] + + assert p1.direction_cosine(Point3D(5, 0, 0)) == [1, 0, 0] + assert p1.direction_cosine(Point3D(0, sqrt(3), 0)) == [0, 1, 0] + assert p1.direction_cosine(Point3D(0, 0, 5)) == [0, 0, 1] + + assert p1.direction_cosine(Point3D(2.4, 2.4, 0)) == [sqrt(2)/2, sqrt(2)/2, 0] + assert p1.direction_cosine(Point3D(1, 1, 1)) == [sqrt(3) / 3, sqrt(3) / 3, sqrt(3) / 3] + assert p1.direction_cosine(Point3D(-12, 0 -15)) == [-4*sqrt(41)/41, -5*sqrt(41)/41, 0] + + assert p2.direction_cosine(Point3D(0, 0, 0)) == [-sqrt(3) / 3, -sqrt(3) / 3, -sqrt(3) / 3] + assert p2.direction_cosine(Point3D(1, 1, 12)) == [0, 0, 1] + assert p2.direction_cosine(Point3D(12, 1, 12)) == [sqrt(2) / 2, 0, sqrt(2) / 2] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_polygon.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_polygon.py new file mode 100644 index 0000000000000000000000000000000000000000..520023349f363bdb12146465305c2a5650c80934 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_polygon.py @@ -0,0 +1,676 @@ +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.functions.elementary.trigonometric import tan +from sympy.geometry import (Circle, Ellipse, GeometryError, Point, Point2D, + Polygon, Ray, RegularPolygon, Segment, Triangle, + are_similar, convex_hull, intersection, Line, Ray2D) +from sympy.testing.pytest import raises, slow, warns +from sympy.core.random import verify_numerically +from sympy.geometry.polygon import rad, deg +from sympy.integrals.integrals import integrate +from sympy.utilities.iterables import rotate_left + + +def feq(a, b): + """Test if two floating point values are 'equal'.""" + t_float = Float("1.0E-10") + return -t_float < a - b < t_float + +@slow +def test_polygon(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + q = Symbol('q', real=True) + u = Symbol('u', real=True) + v = Symbol('v', real=True) + w = Symbol('w', real=True) + x1 = Symbol('x1', real=True) + half = S.Half + a, b, c = Point(0, 0), Point(2, 0), Point(3, 3) + t = Triangle(a, b, c) + assert Polygon(Point(0, 0)) == Point(0, 0) + assert Polygon(a, Point(1, 0), b, c) == t + assert Polygon(Point(1, 0), b, c, a) == t + assert Polygon(b, c, a, Point(1, 0)) == t + # 2 "remove folded" tests + assert Polygon(a, Point(3, 0), b, c) == t + assert Polygon(a, b, Point(3, -1), b, c) == t + # remove multiple collinear points + assert Polygon(Point(-4, 15), Point(-11, 15), Point(-15, 15), + Point(-15, 33/5), Point(-15, -87/10), Point(-15, -15), + Point(-42/5, -15), Point(-2, -15), Point(7, -15), Point(15, -15), + Point(15, -3), Point(15, 10), Point(15, 15)) == \ + Polygon(Point(-15, -15), Point(15, -15), Point(15, 15), Point(-15, 15)) + + p1 = Polygon( + Point(0, 0), Point(3, -1), + Point(6, 0), Point(4, 5), + Point(2, 3), Point(0, 3)) + p2 = Polygon( + Point(6, 0), Point(3, -1), + Point(0, 0), Point(0, 3), + Point(2, 3), Point(4, 5)) + p3 = Polygon( + Point(0, 0), Point(3, 0), + Point(5, 2), Point(4, 4)) + p4 = Polygon( + Point(0, 0), Point(4, 4), + Point(5, 2), Point(3, 0)) + p5 = Polygon( + Point(0, 0), Point(4, 4), + Point(0, 4)) + p6 = Polygon( + Point(-11, 1), Point(-9, 6.6), + Point(-4, -3), Point(-8.4, -8.7)) + p7 = Polygon( + Point(x, y), Point(q, u), + Point(v, w)) + p8 = Polygon( + Point(x, y), Point(v, w), + Point(q, u)) + p9 = Polygon( + Point(0, 0), Point(4, 4), + Point(3, 0), Point(5, 2)) + p10 = Polygon( + Point(0, 2), Point(2, 2), + Point(0, 0), Point(2, 0)) + p11 = Polygon(Point(0, 0), 1, n=3) + p12 = Polygon(Point(0, 0), 1, 0, n=3) + p13 = Polygon( + Point(0, 0),Point(8, 8), + Point(23, 20),Point(0, 20)) + p14 = Polygon(*rotate_left(p13.args, 1)) + + + r = Ray(Point(-9, 6.6), Point(-9, 5.5)) + # + # General polygon + # + assert p1 == p2 + assert len(p1.args) == 6 + assert len(p1.sides) == 6 + assert p1.perimeter == 5 + 2*sqrt(10) + sqrt(29) + sqrt(8) + assert p1.area == 22 + assert not p1.is_convex() + assert Polygon((-1, 1), (2, -1), (2, 1), (-1, -1), (3, 0) + ).is_convex() is False + # ensure convex for both CW and CCW point specification + assert p3.is_convex() + assert p4.is_convex() + dict5 = p5.angles + assert dict5[Point(0, 0)] == pi / 4 + assert dict5[Point(0, 4)] == pi / 2 + assert p5.encloses_point(Point(x, y)) is None + assert p5.encloses_point(Point(1, 3)) + assert p5.encloses_point(Point(0, 0)) is False + assert p5.encloses_point(Point(4, 0)) is False + assert p1.encloses(Circle(Point(2.5, 2.5), 5)) is False + assert p1.encloses(Ellipse(Point(2.5, 2), 5, 6)) is False + assert p5.plot_interval('x') == [x, 0, 1] + assert p5.distance( + Polygon(Point(10, 10), Point(14, 14), Point(10, 14))) == 6 * sqrt(2) + assert p5.distance( + Polygon(Point(1, 8), Point(5, 8), Point(8, 12), Point(1, 12))) == 4 + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + Polygon(Point(0, 0), Point(1, 0), Point(1, 1)).distance( + Polygon(Point(0, 0), Point(0, 1), Point(1, 1))) + assert hash(p5) == hash(Polygon(Point(0, 0), Point(4, 4), Point(0, 4))) + assert hash(p1) == hash(p2) + assert hash(p7) == hash(p8) + assert hash(p3) != hash(p9) + assert p5 == Polygon(Point(4, 4), Point(0, 4), Point(0, 0)) + assert Polygon(Point(4, 4), Point(0, 4), Point(0, 0)) in p5 + assert p5 != Point(0, 4) + assert Point(0, 1) in p5 + assert p5.arbitrary_point('t').subs(Symbol('t', real=True), 0) == \ + Point(0, 0) + raises(ValueError, lambda: Polygon( + Point(x, 0), Point(0, y), Point(x, y)).arbitrary_point('x')) + assert p6.intersection(r) == [Point(-9, Rational(-84, 13)), Point(-9, Rational(33, 5))] + assert p10.area == 0 + assert p11 == RegularPolygon(Point(0, 0), 1, 3, 0) + assert p11 == p12 + assert p11.vertices[0] == Point(1, 0) + assert p11.args[0] == Point(0, 0) + p11.spin(pi/2) + assert p11.vertices[0] == Point(0, 1) + # + # Regular polygon + # + p1 = RegularPolygon(Point(0, 0), 10, 5) + p2 = RegularPolygon(Point(0, 0), 5, 5) + raises(GeometryError, lambda: RegularPolygon(Point(0, 0), Point(0, + 1), Point(1, 1))) + raises(GeometryError, lambda: RegularPolygon(Point(0, 0), 1, 2)) + raises(ValueError, lambda: RegularPolygon(Point(0, 0), 1, 2.5)) + + assert p1 != p2 + assert p1.interior_angle == pi*Rational(3, 5) + assert p1.exterior_angle == pi*Rational(2, 5) + assert p2.apothem == 5*cos(pi/5) + assert p2.circumcenter == p1.circumcenter == Point(0, 0) + assert p1.circumradius == p1.radius == 10 + assert p2.circumcircle == Circle(Point(0, 0), 5) + assert p2.incircle == Circle(Point(0, 0), p2.apothem) + assert p2.inradius == p2.apothem == (5 * (1 + sqrt(5)) / 4) + p2.spin(pi / 10) + dict1 = p2.angles + assert dict1[Point(0, 5)] == 3 * pi / 5 + assert p1.is_convex() + assert p1.rotation == 0 + assert p1.encloses_point(Point(0, 0)) + assert p1.encloses_point(Point(11, 0)) is False + assert p2.encloses_point(Point(0, 4.9)) + p1.spin(pi/3) + assert p1.rotation == pi/3 + assert p1.vertices[0] == Point(5, 5*sqrt(3)) + for var in p1.args: + if isinstance(var, Point): + assert var == Point(0, 0) + else: + assert var in (5, 10, pi / 3) + assert p1 != Point(0, 0) + assert p1 != p5 + + # while spin works in place (notice that rotation is 2pi/3 below) + # rotate returns a new object + p1_old = p1 + assert p1.rotate(pi/3) == RegularPolygon(Point(0, 0), 10, 5, pi*Rational(2, 3)) + assert p1 == p1_old + + assert p1.area == (-250*sqrt(5) + 1250)/(4*tan(pi/5)) + assert p1.length == 20*sqrt(-sqrt(5)/8 + Rational(5, 8)) + assert p1.scale(2, 2) == \ + RegularPolygon(p1.center, p1.radius*2, p1._n, p1.rotation) + assert RegularPolygon((0, 0), 1, 4).scale(2, 3) == \ + Polygon(Point(2, 0), Point(0, 3), Point(-2, 0), Point(0, -3)) + + assert repr(p1) == str(p1) + + # + # Angles + # + angles = p4.angles + assert feq(angles[Point(0, 0)].evalf(), Float("0.7853981633974483")) + assert feq(angles[Point(4, 4)].evalf(), Float("1.2490457723982544")) + assert feq(angles[Point(5, 2)].evalf(), Float("1.8925468811915388")) + assert feq(angles[Point(3, 0)].evalf(), Float("2.3561944901923449")) + + angles = p3.angles + assert feq(angles[Point(0, 0)].evalf(), Float("0.7853981633974483")) + assert feq(angles[Point(4, 4)].evalf(), Float("1.2490457723982544")) + assert feq(angles[Point(5, 2)].evalf(), Float("1.8925468811915388")) + assert feq(angles[Point(3, 0)].evalf(), Float("2.3561944901923449")) + + # https://github.com/sympy/sympy/issues/24885 + interior_angles_sum = sum(p13.angles.values()) + assert feq(interior_angles_sum, (len(p13.angles) - 2)*pi ) + interior_angles_sum = sum(p14.angles.values()) + assert feq(interior_angles_sum, (len(p14.angles) - 2)*pi ) + + # + # Triangle + # + p1 = Point(0, 0) + p2 = Point(5, 0) + p3 = Point(0, 5) + t1 = Triangle(p1, p2, p3) + t2 = Triangle(p1, p2, Point(Rational(5, 2), sqrt(Rational(75, 4)))) + t3 = Triangle(p1, Point(x1, 0), Point(0, x1)) + s1 = t1.sides + assert Triangle(p1, p2, p1) == Polygon(p1, p2, p1) == Segment(p1, p2) + raises(GeometryError, lambda: Triangle(Point(0, 0))) + + # Basic stuff + assert Triangle(p1, p1, p1) == p1 + assert Triangle(p2, p2*2, p2*3) == Segment(p2, p2*3) + assert t1.area == Rational(25, 2) + assert t1.is_right() + assert t2.is_right() is False + assert t3.is_right() + assert p1 in t1 + assert t1.sides[0] in t1 + assert Segment((0, 0), (1, 0)) in t1 + assert Point(5, 5) not in t2 + assert t1.is_convex() + assert feq(t1.angles[p1].evalf(), pi.evalf()/2) + + assert t1.is_equilateral() is False + assert t2.is_equilateral() + assert t3.is_equilateral() is False + assert are_similar(t1, t2) is False + assert are_similar(t1, t3) + assert are_similar(t2, t3) is False + assert t1.is_similar(Point(0, 0)) is False + assert t1.is_similar(t2) is False + + # Bisectors + bisectors = t1.bisectors() + assert bisectors[p1] == Segment( + p1, Point(Rational(5, 2), Rational(5, 2))) + assert t2.bisectors()[p2] == Segment( + Point(5, 0), Point(Rational(5, 4), 5*sqrt(3)/4)) + p4 = Point(0, x1) + assert t3.bisectors()[p4] == Segment(p4, Point(x1*(sqrt(2) - 1), 0)) + ic = (250 - 125*sqrt(2))/50 + assert t1.incenter == Point(ic, ic) + + # Inradius + assert t1.inradius == t1.incircle.radius == 5 - 5*sqrt(2)/2 + assert t2.inradius == t2.incircle.radius == 5*sqrt(3)/6 + assert t3.inradius == t3.incircle.radius == x1**2/((2 + sqrt(2))*Abs(x1)) + + # Exradius + assert t1.exradii[t1.sides[2]] == 5*sqrt(2)/2 + + # Excenters + assert t1.excenters[t1.sides[2]] == Point2D(25*sqrt(2), -5*sqrt(2)/2) + + # Circumcircle + assert t1.circumcircle.center == Point(2.5, 2.5) + + # Medians + Centroid + m = t1.medians + assert t1.centroid == Point(Rational(5, 3), Rational(5, 3)) + assert m[p1] == Segment(p1, Point(Rational(5, 2), Rational(5, 2))) + assert t3.medians[p1] == Segment(p1, Point(x1/2, x1/2)) + assert intersection(m[p1], m[p2], m[p3]) == [t1.centroid] + assert t1.medial == Triangle(Point(2.5, 0), Point(0, 2.5), Point(2.5, 2.5)) + + # Nine-point circle + assert t1.nine_point_circle == Circle(Point(2.5, 0), + Point(0, 2.5), Point(2.5, 2.5)) + assert t1.nine_point_circle == Circle(Point(0, 0), + Point(0, 2.5), Point(2.5, 2.5)) + + # Perpendicular + altitudes = t1.altitudes + assert altitudes[p1] == Segment(p1, Point(Rational(5, 2), Rational(5, 2))) + assert altitudes[p2].equals(s1[0]) + assert altitudes[p3] == s1[2] + assert t1.orthocenter == p1 + t = S('''Triangle( + Point(100080156402737/5000000000000, 79782624633431/500000000000), + Point(39223884078253/2000000000000, 156345163124289/1000000000000), + Point(31241359188437/1250000000000, 338338270939941/1000000000000000))''') + assert t.orthocenter == S('''Point(-780660869050599840216997''' + '''79471538701955848721853/80368430960602242240789074233100000000000000,''' + '''20151573611150265741278060334545897615974257/16073686192120448448157''' + '''8148466200000000000)''') + + # Ensure + assert len(intersection(*bisectors.values())) == 1 + assert len(intersection(*altitudes.values())) == 1 + assert len(intersection(*m.values())) == 1 + + # Distance + p1 = Polygon( + Point(0, 0), Point(1, 0), + Point(1, 1), Point(0, 1)) + p2 = Polygon( + Point(0, Rational(5)/4), Point(1, Rational(5)/4), + Point(1, Rational(9)/4), Point(0, Rational(9)/4)) + p3 = Polygon( + Point(1, 2), Point(2, 2), + Point(2, 1)) + p4 = Polygon( + Point(1, 1), Point(Rational(6)/5, 1), + Point(1, Rational(6)/5)) + pt1 = Point(half, half) + pt2 = Point(1, 1) + + '''Polygon to Point''' + assert p1.distance(pt1) == half + assert p1.distance(pt2) == 0 + assert p2.distance(pt1) == Rational(3)/4 + assert p3.distance(pt2) == sqrt(2)/2 + + '''Polygon to Polygon''' + # p1.distance(p2) emits a warning + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + assert p1.distance(p2) == half/2 + + assert p1.distance(p3) == sqrt(2)/2 + + # p3.distance(p4) emits a warning + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + assert p3.distance(p4) == (sqrt(2)/2 - sqrt(Rational(2)/25)/2) + + +def test_convex_hull(): + p = [Point(-5, -1), Point(-2, 1), Point(-2, -1), Point(-1, -3), \ + Point(0, 0), Point(1, 1), Point(2, 2), Point(2, -1), Point(3, 1), \ + Point(4, -1), Point(6, 2)] + ch = Polygon(p[0], p[3], p[9], p[10], p[6], p[1]) + #test handling of duplicate points + p.append(p[3]) + + #more than 3 collinear points + another_p = [Point(-45, -85), Point(-45, 85), Point(-45, 26), \ + Point(-45, -24)] + ch2 = Segment(another_p[0], another_p[1]) + + assert convex_hull(*another_p) == ch2 + assert convex_hull(*p) == ch + assert convex_hull(p[0]) == p[0] + assert convex_hull(p[0], p[1]) == Segment(p[0], p[1]) + + # no unique points + assert convex_hull(*[p[-1]]*3) == p[-1] + + # collection of items + assert convex_hull(*[Point(0, 0), \ + Segment(Point(1, 0), Point(1, 1)), \ + RegularPolygon(Point(2, 0), 2, 4)]) == \ + Polygon(Point(0, 0), Point(2, -2), Point(4, 0), Point(2, 2)) + + +def test_encloses(): + # square with a dimpled left side + s = Polygon(Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1), \ + Point(S.Half, S.Half)) + # the following is True if the polygon isn't treated as closing on itself + assert s.encloses(Point(0, S.Half)) is False + assert s.encloses(Point(S.Half, S.Half)) is False # it's a vertex + assert s.encloses(Point(Rational(3, 4), S.Half)) is True + + +def test_triangle_kwargs(): + assert Triangle(sss=(3, 4, 5)) == \ + Triangle(Point(0, 0), Point(3, 0), Point(3, 4)) + assert Triangle(asa=(30, 2, 30)) == \ + Triangle(Point(0, 0), Point(2, 0), Point(1, sqrt(3)/3)) + assert Triangle(sas=(1, 45, 2)) == \ + Triangle(Point(0, 0), Point(2, 0), Point(sqrt(2)/2, sqrt(2)/2)) + assert Triangle(sss=(1, 2, 5)) is None + assert deg(rad(180)) == 180 + + +def test_transform(): + pts = [Point(0, 0), Point(S.Half, Rational(1, 4)), Point(1, 1)] + pts_out = [Point(-4, -10), Point(-3, Rational(-37, 4)), Point(-2, -7)] + assert Triangle(*pts).scale(2, 3, (4, 5)) == Triangle(*pts_out) + assert RegularPolygon((0, 0), 1, 4).scale(2, 3, (4, 5)) == \ + Polygon(Point(-2, -10), Point(-4, -7), Point(-6, -10), Point(-4, -13)) + # Checks for symmetric scaling + assert RegularPolygon((0, 0), 1, 4).scale(2, 2) == \ + RegularPolygon(Point2D(0, 0), 2, 4, 0) + +def test_reflect(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + p = Point(x, y) + r = p.reflect(l) + dp = l.perpendicular_segment(p).length + dr = l.perpendicular_segment(r).length + + assert verify_numerically(dp, dr) + + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((3, 0), slope=oo)) \ + == Triangle(Point(5, 0), Point(4, 0), Point(4, 2)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((0, 3), slope=oo)) \ + == Triangle(Point(-1, 0), Point(-2, 0), Point(-2, 2)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((0, 3), slope=0)) \ + == Triangle(Point(1, 6), Point(2, 6), Point(2, 4)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((3, 0), slope=0)) \ + == Triangle(Point(1, 0), Point(2, 0), Point(2, -2)) + +def test_bisectors(): + p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + p = Polygon(Point(0, 0), Point(2, 0), Point(1, 1), Point(0, 3)) + q = Polygon(Point(1, 0), Point(2, 0), Point(3, 3), Point(-1, 5)) + poly = Polygon(Point(3, 4), Point(0, 0), Point(8, 7), Point(-1, 1), Point(19, -19)) + t = Triangle(p1, p2, p3) + assert t.bisectors()[p2] == Segment(Point(1, 0), Point(0, sqrt(2) - 1)) + assert p.bisectors()[Point2D(0, 3)] == Ray2D(Point2D(0, 3), \ + Point2D(sin(acos(2*sqrt(5)/5)/2), 3 - cos(acos(2*sqrt(5)/5)/2))) + assert q.bisectors()[Point2D(-1, 5)] == \ + Ray2D(Point2D(-1, 5), Point2D(-1 + sqrt(29)*(5*sin(acos(9*sqrt(145)/145)/2) + \ + 2*cos(acos(9*sqrt(145)/145)/2))/29, sqrt(29)*(-5*cos(acos(9*sqrt(145)/145)/2) + \ + 2*sin(acos(9*sqrt(145)/145)/2))/29 + 5)) + assert poly.bisectors()[Point2D(-1, 1)] == Ray2D(Point2D(-1, 1), \ + Point2D(-1 + sin(acos(sqrt(26)/26)/2 + pi/4), 1 - sin(-acos(sqrt(26)/26)/2 + pi/4))) + +def test_incenter(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).incenter \ + == Point(1 - sqrt(2)/2, 1 - sqrt(2)/2) + +def test_inradius(): + assert Triangle(Point(0, 0), Point(4, 0), Point(0, 3)).inradius == 1 + +def test_incircle(): + assert Triangle(Point(0, 0), Point(2, 0), Point(0, 2)).incircle \ + == Circle(Point(2 - sqrt(2), 2 - sqrt(2)), 2 - sqrt(2)) + +def test_exradii(): + t = Triangle(Point(0, 0), Point(6, 0), Point(0, 2)) + assert t.exradii[t.sides[2]] == (-2 + sqrt(10)) + +def test_medians(): + t = Triangle(Point(0, 0), Point(1, 0), Point(0, 1)) + assert t.medians[Point(0, 0)] == Segment(Point(0, 0), Point(S.Half, S.Half)) + +def test_medial(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).medial \ + == Triangle(Point(S.Half, 0), Point(S.Half, S.Half), Point(0, S.Half)) + +def test_nine_point_circle(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).nine_point_circle \ + == Circle(Point2D(Rational(1, 4), Rational(1, 4)), sqrt(2)/4) + +def test_eulerline(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).eulerline \ + == Line(Point2D(0, 0), Point2D(S.Half, S.Half)) + assert Triangle(Point(0, 0), Point(10, 0), Point(5, 5*sqrt(3))).eulerline \ + == Point2D(5, 5*sqrt(3)/3) + assert Triangle(Point(4, -6), Point(4, -1), Point(-3, 3)).eulerline \ + == Line(Point2D(Rational(64, 7), 3), Point2D(Rational(-29, 14), Rational(-7, 2))) + +def test_intersection(): + poly1 = Triangle(Point(0, 0), Point(1, 0), Point(0, 1)) + poly2 = Polygon(Point(0, 1), Point(-5, 0), + Point(0, -4), Point(0, Rational(1, 5)), + Point(S.Half, -0.1), Point(1, 0), Point(0, 1)) + + assert poly1.intersection(poly2) == [Point2D(Rational(1, 3), 0), + Segment(Point(0, Rational(1, 5)), Point(0, 0)), + Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(poly1) == [Point(Rational(1, 3), 0), + Segment(Point(0, 0), Point(0, Rational(1, 5))), + Segment(Point(1, 0), Point(0, 1))] + assert poly1.intersection(Point(0, 0)) == [Point(0, 0)] + assert poly1.intersection(Point(-12, -43)) == [] + assert poly2.intersection(Line((-12, 0), (12, 0))) == [Point(-5, 0), + Point(0, 0), Point(Rational(1, 3), 0), Point(1, 0)] + assert poly2.intersection(Line((-12, 12), (12, 12))) == [] + assert poly2.intersection(Ray((-3, 4), (1, 0))) == [Segment(Point(1, 0), + Point(0, 1))] + assert poly2.intersection(Circle((0, -1), 1)) == [Point(0, -2), + Point(0, 0)] + assert poly1.intersection(poly1) == [Segment(Point(0, 0), Point(1, 0)), + Segment(Point(0, 1), Point(0, 0)), Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(poly2) == [Segment(Point(-5, 0), Point(0, -4)), + Segment(Point(0, -4), Point(0, Rational(1, 5))), + Segment(Point(0, Rational(1, 5)), Point(S.Half, Rational(-1, 10))), + Segment(Point(0, 1), Point(-5, 0)), + Segment(Point(S.Half, Rational(-1, 10)), Point(1, 0)), + Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(Triangle(Point(0, 1), Point(1, 0), Point(-1, 1))) \ + == [Point(Rational(-5, 7), Rational(6, 7)), Segment(Point2D(0, 1), Point(1, 0))] + assert poly1.intersection(RegularPolygon((-12, -15), 3, 3)) == [] + + +def test_parameter_value(): + t = Symbol('t') + sq = Polygon((0, 0), (0, 1), (1, 1), (1, 0)) + assert sq.parameter_value((0.5, 1), t) == {t: Rational(3, 8)} + q = Polygon((0, 0), (2, 1), (2, 4), (4, 0)) + assert q.parameter_value((4, 0), t) == {t: -6 + 3*sqrt(5)} # ~= 0.708 + + raises(ValueError, lambda: sq.parameter_value((5, 6), t)) + raises(ValueError, lambda: sq.parameter_value(Circle(Point(0, 0), 1), t)) + + +def test_issue_12966(): + poly = Polygon(Point(0, 0), Point(0, 10), Point(5, 10), Point(5, 5), + Point(10, 5), Point(10, 0)) + t = Symbol('t') + pt = poly.arbitrary_point(t) + DELTA = 5/poly.perimeter + assert [pt.subs(t, DELTA*i) for i in range(int(1/DELTA))] == [ + Point(0, 0), Point(0, 5), Point(0, 10), Point(5, 10), + Point(5, 5), Point(10, 5), Point(10, 0), Point(5, 0)] + + +def test_second_moment_of_area(): + x, y = symbols('x, y') + # triangle + p1, p2, p3 = [(0, 0), (4, 0), (0, 2)] + p = (0, 0) + # equation of hypotenuse + eq_y = (1-x/4)*2 + I_yy = integrate((x**2) * (integrate(1, (y, 0, eq_y))), (x, 0, 4)) + I_xx = integrate(1 * (integrate(y**2, (y, 0, eq_y))), (x, 0, 4)) + I_xy = integrate(x * (integrate(y, (y, 0, eq_y))), (x, 0, 4)) + + triangle = Polygon(p1, p2, p3) + + assert (I_xx - triangle.second_moment_of_area(p)[0]) == 0 + assert (I_yy - triangle.second_moment_of_area(p)[1]) == 0 + assert (I_xy - triangle.second_moment_of_area(p)[2]) == 0 + + # rectangle + p1, p2, p3, p4=[(0, 0), (4, 0), (4, 2), (0, 2)] + I_yy = integrate((x**2) * integrate(1, (y, 0, 2)), (x, 0, 4)) + I_xx = integrate(1 * integrate(y**2, (y, 0, 2)), (x, 0, 4)) + I_xy = integrate(x * integrate(y, (y, 0, 2)), (x, 0, 4)) + + rectangle = Polygon(p1, p2, p3, p4) + + assert (I_xx - rectangle.second_moment_of_area(p)[0]) == 0 + assert (I_yy - rectangle.second_moment_of_area(p)[1]) == 0 + assert (I_xy - rectangle.second_moment_of_area(p)[2]) == 0 + + + r = RegularPolygon(Point(0, 0), 5, 3) + assert r.second_moment_of_area() == (1875*sqrt(3)/S(32), 1875*sqrt(3)/S(32), 0) + + +def test_first_moment(): + a, b = symbols('a, b', positive=True) + # rectangle + p1 = Polygon((0, 0), (a, 0), (a, b), (0, b)) + assert p1.first_moment_of_area() == (a*b**2/8, a**2*b/8) + assert p1.first_moment_of_area((a/3, b/4)) == (-3*a*b**2/32, -a**2*b/9) + + p1 = Polygon((0, 0), (40, 0), (40, 30), (0, 30)) + assert p1.first_moment_of_area() == (4500, 6000) + + # triangle + p2 = Polygon((0, 0), (a, 0), (a/2, b)) + assert p2.first_moment_of_area() == (4*a*b**2/81, a**2*b/24) + assert p2.first_moment_of_area((a/8, b/6)) == (-25*a*b**2/648, -5*a**2*b/768) + + p2 = Polygon((0, 0), (12, 0), (12, 30)) + assert p2.first_moment_of_area() == (S(1600)/3, -S(640)/3) + + +def test_section_modulus_and_polar_second_moment_of_area(): + a, b = symbols('a, b', positive=True) + x, y = symbols('x, y') + rectangle = Polygon((0, b), (0, 0), (a, 0), (a, b)) + assert rectangle.section_modulus(Point(x, y)) == (a*b**3/12/(-b/2 + y), a**3*b/12/(-a/2 + x)) + assert rectangle.polar_second_moment_of_area() == a**3*b/12 + a*b**3/12 + + convex = RegularPolygon((0, 0), 1, 6) + assert convex.section_modulus() == (Rational(5, 8), sqrt(3)*Rational(5, 16)) + assert convex.polar_second_moment_of_area() == 5*sqrt(3)/S(8) + + concave = Polygon((0, 0), (1, 8), (3, 4), (4, 6), (7, 1)) + assert concave.section_modulus() == (Rational(-6371, 429), Rational(-9778, 519)) + assert concave.polar_second_moment_of_area() == Rational(-38669, 252) + + +def test_cut_section(): + # concave polygon + p = Polygon((-1, -1), (1, Rational(5, 2)), (2, 1), (3, Rational(5, 2)), (4, 2), (5, 3), (-1, 3)) + l = Line((0, 0), (Rational(9, 2), 3)) + p1 = p.cut_section(l)[0] + p2 = p.cut_section(l)[1] + assert p1 == Polygon( + Point2D(Rational(-9, 13), Rational(-6, 13)), Point2D(1, Rational(5, 2)), Point2D(Rational(24, 13), Rational(16, 13)), + Point2D(Rational(12, 5), Rational(8, 5)), Point2D(3, Rational(5, 2)), Point2D(Rational(24, 7), Rational(16, 7)), + Point2D(Rational(9, 2), 3), Point2D(-1, 3), Point2D(-1, Rational(-2, 3))) + assert p2 == Polygon(Point2D(-1, -1), Point2D(Rational(-9, 13), Rational(-6, 13)), Point2D(Rational(24, 13), Rational(16, 13)), + Point2D(2, 1), Point2D(Rational(12, 5), Rational(8, 5)), Point2D(Rational(24, 7), Rational(16, 7)), Point2D(4, 2), Point2D(5, 3), + Point2D(Rational(9, 2), 3), Point2D(-1, Rational(-2, 3))) + + # convex polygon + p = RegularPolygon(Point2D(0, 0), 6, 6) + s = p.cut_section(Line((0, 0), slope=1)) + assert s[0] == Polygon(Point2D(-3*sqrt(3) + 9, -3*sqrt(3) + 9), Point2D(3, 3*sqrt(3)), + Point2D(-3, 3*sqrt(3)), Point2D(-6, 0), Point2D(-9 + 3*sqrt(3), -9 + 3*sqrt(3))) + assert s[1] == Polygon(Point2D(6, 0), Point2D(-3*sqrt(3) + 9, -3*sqrt(3) + 9), + Point2D(-9 + 3*sqrt(3), -9 + 3*sqrt(3)), Point2D(-3, -3*sqrt(3)), Point2D(3, -3*sqrt(3))) + + # case where line does not intersects but coincides with the edge of polygon + a, b = 20, 10 + t1, t2, t3, t4 = [(0, b), (0, 0), (a, 0), (a, b)] + p = Polygon(t1, t2, t3, t4) + p1, p2 = p.cut_section(Line((0, b), slope=0)) + assert p1 == None + assert p2 == Polygon(Point2D(0, 10), Point2D(0, 0), Point2D(20, 0), Point2D(20, 10)) + + p3, p4 = p.cut_section(Line((0, 0), slope=0)) + assert p3 == Polygon(Point2D(0, 10), Point2D(0, 0), Point2D(20, 0), Point2D(20, 10)) + assert p4 == None + + # case where the line does not intersect with a polygon at all + raises(ValueError, lambda: p.cut_section(Line((0, a), slope=0))) + +def test_type_of_triangle(): + # Isoceles triangle + p1 = Polygon(Point(0, 0), Point(5, 0), Point(2, 4)) + assert p1.is_isosceles() == True + assert p1.is_scalene() == False + assert p1.is_equilateral() == False + + # Scalene triangle + p2 = Polygon (Point(0, 0), Point(0, 2), Point(4, 0)) + assert p2.is_isosceles() == False + assert p2.is_scalene() == True + assert p2.is_equilateral() == False + + # Equilateral triangle + p3 = Polygon(Point(0, 0), Point(6, 0), Point(3, sqrt(27))) + assert p3.is_isosceles() == True + assert p3.is_scalene() == False + assert p3.is_equilateral() == True + +def test_do_poly_distance(): + # Non-intersecting polygons + square1 = Polygon (Point(0, 0), Point(0, 1), Point(1, 1), Point(1, 0)) + triangle1 = Polygon(Point(1, 2), Point(2, 2), Point(2, 1)) + assert square1._do_poly_distance(triangle1) == sqrt(2)/2 + + # Polygons which sides intersect + square2 = Polygon(Point(1, 0), Point(2, 0), Point(2, 1), Point(1, 1)) + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output", test_stacklevel=False): + assert square1._do_poly_distance(square2) == 0 + + # Polygons which bodies intersect + triangle2 = Polygon(Point(0, -1), Point(2, -1), Point(S.Half, S.Half)) + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output", test_stacklevel=False): + assert triangle2._do_poly_distance(square1) == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_util.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..da52a795a9383c6438ca06303e8ae6506dccdc65 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/geometry/tests/test_util.py @@ -0,0 +1,170 @@ +import pytest +from sympy.core.numbers import Float +from sympy.core.function import (Derivative, Function) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions import exp, cos, sin, tan, cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Point, Point2D, Line, Polygon, Segment, convex_hull,\ + intersection, centroid, Point3D, Line3D, Ray, Ellipse +from sympy.geometry.util import idiff, closest_points, farthest_points, _ordered_points, are_coplanar +from sympy.solvers.solvers import solve +from sympy.testing.pytest import raises + + +def test_idiff(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + t = Symbol('t', real=True) + f = Function('f') + g = Function('g') + # the use of idiff in ellipse also provides coverage + circ = x**2 + y**2 - 4 + ans = -3*x*(x**2/y**2 + 1)/y**3 + assert ans == idiff(circ, y, x, 3), idiff(circ, y, x, 3) + assert ans == idiff(circ, [y], x, 3) + assert idiff(circ, y, x, 3) == ans + explicit = 12*x/sqrt(-x**2 + 4)**5 + assert ans.subs(y, solve(circ, y)[0]).equals(explicit) + assert True in [sol.diff(x, 3).equals(explicit) for sol in solve(circ, y)] + assert idiff(x + t + y, [y, t], x) == -Derivative(t, x) - 1 + assert idiff(f(x) * exp(f(x)) - x * exp(x), f(x), x) == (x + 1)*exp(x)*exp(-f(x))/(f(x) + 1) + assert idiff(f(x) - y * exp(x), [f(x), y], x) == (y + Derivative(y, x))*exp(x) + assert idiff(f(x) - y * exp(x), [y, f(x)], x) == -y + Derivative(f(x), x)*exp(-x) + assert idiff(f(x) - g(x), [f(x), g(x)], x) == Derivative(g(x), x) + # this should be fast + fxy = y - (-10*(-sin(x) + 1/x)**2 + tan(x)**2 + 2*cosh(x/10)) + assert idiff(fxy, y, x) == -20*sin(x)*cos(x) + 2*tan(x)**3 + \ + 2*tan(x) + sinh(x/10)/5 + 20*cos(x)/x - 20*sin(x)/x**2 + 20/x**3 + + +def test_intersection(): + assert intersection(Point(0, 0)) == [] + raises(TypeError, lambda: intersection(Point(0, 0), 3)) + assert intersection( + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), + Line((0, 0), (0, 1)), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + assert intersection( + Line((0, 0), (0, 1)), + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + assert intersection( + Line((0, 0), (0, 1)), + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), + Line((0, 0), slope=1), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + R = 4.0 + c = intersection( + Ray(Point2D(0.001, -1), + Point2D(0.0008, -1.7)), + Ellipse(center=Point2D(0, 0), hradius=R, vradius=2.0), pairwise=True)[0].coordinates + assert c == pytest.approx( + Point2D(0.000714285723396502, -1.99999996811224, evaluate=False).coordinates) + # check this is responds to a lower precision parameter + R = Float(4, 5) + c2 = intersection( + Ray(Point2D(0.001, -1), + Point2D(0.0008, -1.7)), + Ellipse(center=Point2D(0, 0), hradius=R, vradius=2.0), pairwise=True)[0].coordinates + assert c2 == pytest.approx( + Point2D(0.000714285723396502, -1.99999996811224, evaluate=False).coordinates) + assert c[0]._prec == 53 + assert c2[0]._prec == 20 + + +def test_convex_hull(): + raises(TypeError, lambda: convex_hull(Point(0, 0), 3)) + points = [(1, -1), (1, -2), (3, -1), (-5, -2), (15, -4)] + assert convex_hull(*points, **{"polygon": False}) == ( + [Point2D(-5, -2), Point2D(1, -1), Point2D(3, -1), Point2D(15, -4)], + [Point2D(-5, -2), Point2D(15, -4)]) + + +def test_centroid(): + p = Polygon((0, 0), (10, 0), (10, 10)) + q = p.translate(0, 20) + assert centroid(p, q) == Point(20, 40)/3 + p = Segment((0, 0), (2, 0)) + q = Segment((0, 0), (2, 2)) + assert centroid(p, q) == Point(1, -sqrt(2) + 2) + assert centroid(Point(0, 0), Point(2, 0)) == Point(2, 0)/2 + assert centroid(Point(0, 0), Point(0, 0), Point(2, 0)) == Point(2, 0)/3 + + +def test_farthest_points_closest_points(): + from sympy.core.random import randint + from sympy.utilities.iterables import subsets + + for how in (min, max): + if how == min: + func = closest_points + else: + func = farthest_points + + raises(ValueError, lambda: func(Point2D(0, 0), Point2D(0, 0))) + + # 3rd pt dx is close and pt is closer to 1st pt + p1 = [Point2D(0, 0), Point2D(3, 0), Point2D(1, 1)] + # 3rd pt dx is close and pt is closer to 2nd pt + p2 = [Point2D(0, 0), Point2D(3, 0), Point2D(2, 1)] + # 3rd pt dx is close and but pt is not closer + p3 = [Point2D(0, 0), Point2D(3, 0), Point2D(1, 10)] + # 3rd pt dx is not closer and it's closer to 2nd pt + p4 = [Point2D(0, 0), Point2D(3, 0), Point2D(4, 0)] + # 3rd pt dx is not closer and it's closer to 1st pt + p5 = [Point2D(0, 0), Point2D(3, 0), Point2D(-1, 0)] + # duplicate point doesn't affect outcome + dup = [Point2D(0, 0), Point2D(3, 0), Point2D(3, 0), Point2D(-1, 0)] + # symbolic + x = Symbol('x', positive=True) + s = [Point2D(a) for a in ((x, 1), (x + 3, 2), (x + 2, 2))] + + for points in (p1, p2, p3, p4, p5, dup, s): + d = how(i.distance(j) for i, j in subsets(set(points), 2)) + ans = a, b = list(func(*points))[0] + assert a.distance(b) == d + assert ans == _ordered_points(ans) + + # if the following ever fails, the above tests were not sufficient + # and the logical error in the routine should be fixed + points = set() + while len(points) != 7: + points.add(Point2D(randint(1, 100), randint(1, 100))) + points = list(points) + d = how(i.distance(j) for i, j in subsets(points, 2)) + ans = a, b = list(func(*points))[0] + assert a.distance(b) == d + assert ans == _ordered_points(ans) + + # equidistant points + a, b, c = ( + Point2D(0, 0), Point2D(1, 0), Point2D(S.Half, sqrt(3)/2)) + ans = {_ordered_points((i, j)) + for i, j in subsets((a, b, c), 2)} + assert closest_points(b, c, a) == ans + assert farthest_points(b, c, a) == ans + + # unique to farthest + points = [(1, 1), (1, 2), (3, 1), (-5, 2), (15, 4)] + assert farthest_points(*points) == { + (Point2D(-5, 2), Point2D(15, 4))} + points = [(1, -1), (1, -2), (3, -1), (-5, -2), (15, -4)] + assert farthest_points(*points) == { + (Point2D(-5, -2), Point2D(15, -4))} + assert farthest_points((1, 1), (0, 0)) == { + (Point2D(0, 0), Point2D(1, 1))} + raises(ValueError, lambda: farthest_points((1, 1))) + + +def test_are_coplanar(): + a = Line3D(Point3D(5, 0, 0), Point3D(1, -1, 1)) + b = Line3D(Point3D(0, -2, 0), Point3D(3, 1, 1)) + c = Line3D(Point3D(0, -1, 0), Point3D(5, -1, 9)) + d = Line(Point2D(0, 3), Point2D(1, 5)) + + assert are_coplanar(a, b, c) == False + assert are_coplanar(a, d) == False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..728ebc9baeaa03262f51a6a19ca5c3979cb4d582 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/printing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/printing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8da499a5cc9f693c0fc2a660dec3901be1cc4b8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/printing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/session.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/session.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b48a983e1875be5f8a9e81055e0dc751d1c582e3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/__pycache__/session.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72f2bcab865c879b9ae4208547a5c6146e1d7e0b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_interactive.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_interactive.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94aba33a68e282d45a08379a4ca19cd696c81fce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_interactive.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_ipython.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_ipython.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2831ef276d6975ecb04723d3fc6218f2d7f3f63 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/interactive/tests/__pycache__/test_ipython.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d023d86f2c6f0c64d7ac460c50eedc355e78b21f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__init__.py @@ -0,0 +1,3 @@ +from sympy.liealgebras.cartan_type import CartanType + +__all__ = ['CartanType'] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ab938e07a7cbe8d56c469a386cfe7fad0afe6b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45536d69d9d8a28f4f501b23f77fff9745c6d54 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5282decf167a6f2338c877a38e00c5da256b6a45 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/cartan_type.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/dynkin_diagram.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/dynkin_diagram.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73aa81189bee7e7b66de4559269df2b49c4a0d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/dynkin_diagram.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/root_system.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/root_system.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b96fbc8558ae934aebcd9dae9cd3a53502e3cb8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/root_system.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_a.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_a.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e7178ea4d8dadca61da9f993ba51c53e02a8c62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_a.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_b.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_b.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b576cf27346798e0779091482924d7ce93c26c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_b.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_c.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_c.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..113b7e4e7e91970650f57f46acf05876968a7bd3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_c.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_d.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4b95f7b530d46914b575b28bf35f836c9ae0905 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_d.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_e.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_e.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..337ffa0868b4174d187819be635b1339ae05bff0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_e.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_f.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_f.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6318b39e2836ce9c146ba1fbe2eb5be7c017aba0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_f.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_g.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_g.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f90acc4a947fbc3874046a6b142c4d53c1621fa2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/type_g.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/weyl_group.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/weyl_group.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a400af3202490f3b43b2640ce2733dfe6fd2ec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/__pycache__/weyl_group.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_matrix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2d29b37bc9a1a26790ee88b5902951afe4fc4560 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_matrix.py @@ -0,0 +1,25 @@ +from .cartan_type import CartanType + +def CartanMatrix(ct): + """Access the Cartan matrix of a specific Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_matrix import CartanMatrix + >>> CartanMatrix("A2") + Matrix([ + [ 2, -1], + [-1, 2]]) + + >>> CartanMatrix(['C', 3]) + Matrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, -2, 2]]) + + This method works by returning the Cartan matrix + which corresponds to Cartan type t. + """ + + return CartanType(ct).cartan_matrix() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_type.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_type.py new file mode 100644 index 0000000000000000000000000000000000000000..16bb152469238ea912a30c2d0f8210d6f729bdb1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/cartan_type.py @@ -0,0 +1,73 @@ +from sympy.core import Atom, Basic + + +class CartanType_generator(): + """ + Constructor for actually creating things + """ + def __call__(self, *args): + c = args[0] + if isinstance(c, list): + letter, n = c[0], int(c[1]) + elif isinstance(c, str): + letter, n = c[0], int(c[1:]) + else: + raise TypeError("Argument must be a string (e.g. 'A3') or a list (e.g. ['A', 3])") + + if n < 0: + raise ValueError("Lie algebra rank cannot be negative") + if letter == "A": + from . import type_a + return type_a.TypeA(n) + if letter == "B": + from . import type_b + return type_b.TypeB(n) + + if letter == "C": + from . import type_c + return type_c.TypeC(n) + + if letter == "D": + from . import type_d + return type_d.TypeD(n) + + if letter == "E": + if n >= 6 and n <= 8: + from . import type_e + return type_e.TypeE(n) + + if letter == "F": + if n == 4: + from . import type_f + return type_f.TypeF(n) + + if letter == "G": + if n == 2: + from . import type_g + return type_g.TypeG(n) + +CartanType = CartanType_generator() + + +class Standard_Cartan(Atom): + """ + Concrete base class for Cartan types such as A4, etc + """ + + def __new__(cls, series, n): + obj = Basic.__new__(cls) + obj.n = n + obj.series = series + return obj + + def rank(self): + """ + Returns the rank of the Lie algebra + """ + return self.n + + def series(self): + """ + Returns the type of the Lie algebra + """ + return self.series diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/dynkin_diagram.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/dynkin_diagram.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9e2dac4d54490b803eeaf9637cb9b66b01f058 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/dynkin_diagram.py @@ -0,0 +1,24 @@ +from .cartan_type import CartanType + + +def DynkinDiagram(t): + """Display the Dynkin diagram of a given Lie algebra + + Works by generating the CartanType for the input, t, and then returning the + Dynkin diagram method from the individual classes. + + Examples + ======== + + >>> from sympy.liealgebras.dynkin_diagram import DynkinDiagram + >>> print(DynkinDiagram("A3")) + 0---0---0 + 1 2 3 + + >>> print(DynkinDiagram("B4")) + 0---0---0=>=0 + 1 2 3 4 + + """ + + return CartanType(t).dynkin_diagram() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/root_system.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/root_system.py new file mode 100644 index 0000000000000000000000000000000000000000..36eb24605e78bbdc669736910d89be5606df1389 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/root_system.py @@ -0,0 +1,196 @@ +from .cartan_type import CartanType +from sympy.core.basic import Atom + +class RootSystem(Atom): + """Represent the root system of a simple Lie algebra + + Every simple Lie algebra has a unique root system. To find the root + system, we first consider the Cartan subalgebra of g, which is the maximal + abelian subalgebra, and consider the adjoint action of g on this + subalgebra. There is a root system associated with this action. Now, a + root system over a vector space V is a set of finite vectors Phi (called + roots), which satisfy: + + 1. The roots span V + 2. The only scalar multiples of x in Phi are x and -x + 3. For every x in Phi, the set Phi is closed under reflection + through the hyperplane perpendicular to x. + 4. If x and y are roots in Phi, then the projection of y onto + the line through x is a half-integral multiple of x. + + Now, there is a subset of Phi, which we will call Delta, such that: + 1. Delta is a basis of V + 2. Each root x in Phi can be written x = sum k_y y for y in Delta + + The elements of Delta are called the simple roots. + Therefore, we see that the simple roots span the root space of a given + simple Lie algebra. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Root_system + .. [2] Lie Algebras and Representation Theory - Humphreys + + """ + + def __new__(cls, cartantype): + """Create a new RootSystem object + + This method assigns an attribute called cartan_type to each instance of + a RootSystem object. When an instance of RootSystem is called, it + needs an argument, which should be an instance of a simple Lie algebra. + We then take the CartanType of this argument and set it as the + cartan_type attribute of the RootSystem instance. + + """ + obj = Atom.__new__(cls) + obj.cartan_type = CartanType(cartantype) + return obj + + def simple_roots(self): + """Generate the simple roots of the Lie algebra + + The rank of the Lie algebra determines the number of simple roots that + it has. This method obtains the rank of the Lie algebra, and then uses + the simple_root method from the Lie algebra classes to generate all the + simple roots. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> roots = c.simple_roots() + >>> roots + {1: [1, -1, 0, 0], 2: [0, 1, -1, 0], 3: [0, 0, 1, -1]} + + """ + n = self.cartan_type.rank() + roots = {i: self.cartan_type.simple_root(i) for i in range(1, n+1)} + return roots + + + def all_roots(self): + """Generate all the roots of a given root system + + The result is a dictionary where the keys are integer numbers. It + generates the roots by getting the dictionary of all positive roots + from the bases classes, and then taking each root, and multiplying it + by -1 and adding it to the dictionary. In this way all the negative + roots are generated. + + """ + alpha = self.cartan_type.positive_roots() + keys = list(alpha.keys()) + k = max(keys) + for val in keys: + k += 1 + root = alpha[val] + newroot = [-x for x in root] + alpha[k] = newroot + return alpha + + def root_space(self): + """Return the span of the simple roots + + The root space is the vector space spanned by the simple roots, i.e. it + is a vector space with a distinguished basis, the simple roots. This + method returns a string that represents the root space as the span of + the simple roots, alpha[1],...., alpha[n]. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.root_space() + 'alpha[1] + alpha[2] + alpha[3]' + + """ + n = self.cartan_type.rank() + rs = " + ".join("alpha["+str(i) +"]" for i in range(1, n+1)) + return rs + + def add_simple_roots(self, root1, root2): + """Add two simple roots together + + The function takes as input two integers, root1 and root2. It then + uses these integers as keys in the dictionary of simple roots, and gets + the corresponding simple roots, and then adds them together. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> newroot = c.add_simple_roots(1, 2) + >>> newroot + [1, 0, -1, 0] + + """ + + alpha = self.simple_roots() + if root1 > len(alpha) or root2 > len(alpha): + raise ValueError("You've used a root that doesn't exist!") + a1 = alpha[root1] + a2 = alpha[root2] + newroot = [_a1 + _a2 for _a1, _a2 in zip(a1, a2)] + return newroot + + def add_as_roots(self, root1, root2): + """Add two roots together if and only if their sum is also a root + + It takes as input two vectors which should be roots. It then computes + their sum and checks if it is in the list of all possible roots. If it + is, it returns the sum. Otherwise it returns a string saying that the + sum is not a root. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.add_as_roots([1, 0, -1, 0], [0, 0, 1, -1]) + [1, 0, 0, -1] + >>> c.add_as_roots([1, -1, 0, 0], [0, 0, -1, 1]) + 'The sum of these two roots is not a root' + + """ + alpha = self.all_roots() + newroot = [r1 + r2 for r1, r2 in zip(root1, root2)] + if newroot in alpha.values(): + return newroot + else: + return "The sum of these two roots is not a root" + + + def cartan_matrix(self): + """Cartan matrix of Lie algebra associated with this root system + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, -1, 2]]) + """ + return self.cartan_type.cartan_matrix() + + def dynkin_diagram(self): + """Dynkin diagram of the Lie algebra associated with this root system + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> print(c.dynkin_diagram()) + 0---0---0 + 1 2 3 + """ + return self.cartan_type.dynkin_diagram() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c83bcf49b54573fa2b3825c696ec4985e840fb4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..329d83b53fa26633a00c2592e325916bc4061c44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_type.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc25261f3f142a5c56efd09bf40fef0f2c2c14f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_cartan_type.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_dynkin_diagram.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_dynkin_diagram.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebf5e76c08339e526ce963979abd6866bce10909 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_dynkin_diagram.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_root_system.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_root_system.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447f1cb06619348126e6457f3b1685d7f3a48852 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_root_system.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_A.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_A.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b0967126405b3259011e76628911f5d46d6f30a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_A.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_B.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_B.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..344311074e0113c657f02dfdd8d602babe21804c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_B.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_C.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_C.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b7cb3aa0688380d4acf2915deb25ae55b96bcd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_C.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_D.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_D.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c6c755f7698f561fb6a8a1ed7e6b028e62612f0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_D.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_E.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_E.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee08158e3bde283b18825878822ccbae85bb059a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_E.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_F.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_F.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b82ff4087bfaded2f5dd24f6029c067a281f62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_F.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_G.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_G.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182f1fcf06c0aa9b16cc104562d511429c9f88a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_type_G.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_weyl_group.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_weyl_group.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb6e4e55c0742a7b4a352de6f0054408eafb9e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/__pycache__/test_weyl_group.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_matrix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..98b1793dee63e0e87c610768554a8388dfd641a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_matrix.py @@ -0,0 +1,10 @@ +from sympy.liealgebras.cartan_matrix import CartanMatrix +from sympy.matrices import Matrix + +def test_CartanMatrix(): + c = CartanMatrix("A3") + m = Matrix(3, 3, [2, -1, 0, -1, 2, -1, 0, -1, 2]) + assert c == m + a = CartanMatrix(["G",2]) + mt = Matrix(2, 2, [2, -1, -3, 2]) + assert a == mt diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_type.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_type.py new file mode 100644 index 0000000000000000000000000000000000000000..257eeca41d0f5f2eb240cc270f76d452848ed405 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_cartan_type.py @@ -0,0 +1,12 @@ +from sympy.liealgebras.cartan_type import CartanType, Standard_Cartan + +def test_Standard_Cartan(): + c = CartanType("A4") + assert c.rank() == 4 + assert c.series == "A" + m = Standard_Cartan("A", 2) + assert m.rank() == 2 + assert m.series == "A" + b = CartanType("B12") + assert b.rank() == 12 + assert b.series == "B" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_dynkin_diagram.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_dynkin_diagram.py new file mode 100644 index 0000000000000000000000000000000000000000..ad2ee4c162945c437ecf83d75c7fef9455c9464a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_dynkin_diagram.py @@ -0,0 +1,9 @@ +from sympy.liealgebras.dynkin_diagram import DynkinDiagram + +def test_DynkinDiagram(): + c = DynkinDiagram("A3") + diag = "0---0---0\n1 2 3" + assert c == diag + ct = DynkinDiagram(["B", 3]) + diag2 = "0---0=>=0\n1 2 3" + assert ct == diag2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_root_system.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_root_system.py new file mode 100644 index 0000000000000000000000000000000000000000..42110da5a1c59a7e6b2e537ee13746bfce361579 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_root_system.py @@ -0,0 +1,18 @@ +from sympy.liealgebras.root_system import RootSystem +from sympy.liealgebras.type_a import TypeA +from sympy.matrices import Matrix + +def test_root_system(): + c = RootSystem("A3") + assert c.cartan_type == TypeA(3) + assert c.simple_roots() == {1: [1, -1, 0, 0], 2: [0, 1, -1, 0], 3: [0, 0, 1, -1]} + assert c.root_space() == "alpha[1] + alpha[2] + alpha[3]" + assert c.cartan_matrix() == Matrix([[ 2, -1, 0], [-1, 2, -1], [ 0, -1, 2]]) + assert c.dynkin_diagram() == "0---0---0\n1 2 3" + assert c.add_simple_roots(1, 2) == [1, 0, -1, 0] + assert c.all_roots() == {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], + 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], 5: [0, 1, 0, -1], + 6: [0, 0, 1, -1], 7: [-1, 1, 0, 0], 8: [-1, 0, 1, 0], + 9: [-1, 0, 0, 1], 10: [0, -1, 1, 0], + 11: [0, -1, 0, 1], 12: [0, 0, -1, 1]} + assert c.add_as_roots([1, 0, -1, 0], [0, 0, 1, -1]) == [1, 0, 0, -1] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_A.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_A.py new file mode 100644 index 0000000000000000000000000000000000000000..85d6f451ee167cf6db17ab20e59efab86ac0b691 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_A.py @@ -0,0 +1,17 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix + +def test_type_A(): + c = CartanType("A3") + m = Matrix(3, 3, [2, -1, 0, -1, 2, -1, 0, -1, 2]) + assert m == c.cartan_matrix() + assert c.basis() == 8 + assert c.roots() == 12 + assert c.dimension() == 4 + assert c.simple_root(1) == [1, -1, 0, 0] + assert c.highest_root() == [1, 0, 0, -1] + assert c.lie_algebra() == "su(4)" + diag = "0---0---0\n1 2 3" + assert c.dynkin_diagram() == diag + assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], + 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_B.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_B.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2a9011f96bc647e48d39e16cf10703a99d86b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_B.py @@ -0,0 +1,17 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix + +def test_type_B(): + c = CartanType("B3") + m = Matrix(3, 3, [2, -1, 0, -1, 2, -2, 0, -1, 2]) + assert m == c.cartan_matrix() + assert c.dimension() == 3 + assert c.roots() == 18 + assert c.simple_root(3) == [0, 0, 1] + assert c.basis() == 3 + assert c.lie_algebra() == "so(6)" + diag = "0---0=>=0\n1 2 3" + assert c.dynkin_diagram() == diag + assert c.positive_roots() == {1: [1, -1, 0], 2: [1, 1, 0], 3: [1, 0, -1], + 4: [1, 0, 1], 5: [0, 1, -1], 6: [0, 1, 1], 7: [1, 0, 0], + 8: [0, 1, 0], 9: [0, 0, 1]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_C.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_C.py new file mode 100644 index 0000000000000000000000000000000000000000..8154c201e6c50adb7c74458b240ed98b9a0dd123 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_C.py @@ -0,0 +1,22 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix + +def test_type_C(): + c = CartanType("C4") + m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -1, 0, 0, -1, 2, -1, 0, 0, -2, 2]) + assert c.cartan_matrix() == m + assert c.dimension() == 4 + assert c.simple_root(4) == [0, 0, 0, 2] + assert c.roots() == 32 + assert c.basis() == 36 + assert c.lie_algebra() == "sp(8)" + t = CartanType(['C', 3]) + assert t.dimension() == 3 + diag = "0---0---0=<=0\n1 2 3 4" + assert c.dynkin_diagram() == diag + assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0], + 3: [1, 0, -1, 0], 4: [1, 0, 1, 0], 5: [1, 0, 0, -1], + 6: [1, 0, 0, 1], 7: [0, 1, -1, 0], 8: [0, 1, 1, 0], + 9: [0, 1, 0, -1], 10: [0, 1, 0, 1], 11: [0, 0, 1, -1], + 12: [0, 0, 1, 1], 13: [2, 0, 0, 0], 14: [0, 2, 0, 0], 15: [0, 0, 2, 0], + 16: [0, 0, 0, 2]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_D.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_D.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf6a34cb5be475cc30042e95bf8eae2376a2223 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_D.py @@ -0,0 +1,19 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix + + + +def test_type_D(): + c = CartanType("D4") + m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -1, -1, 0, -1, 2, 0, 0, -1, 0, 2]) + assert c.cartan_matrix() == m + assert c.basis() == 6 + assert c.lie_algebra() == "so(8)" + assert c.roots() == 24 + assert c.simple_root(3) == [0, 0, 1, -1] + diag = " 3\n 0\n |\n |\n0---0---0\n1 2 4" + assert diag == c.dynkin_diagram() + assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0], + 3: [1, 0, -1, 0], 4: [1, 0, 1, 0], 5: [1, 0, 0, -1], 6: [1, 0, 0, 1], + 7: [0, 1, -1, 0], 8: [0, 1, 1, 0], 9: [0, 1, 0, -1], 10: [0, 1, 0, 1], + 11: [0, 0, 1, -1], 12: [0, 0, 1, 1]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_E.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_E.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb08342f41ede3390f34e9b297864eda16bedc7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_E.py @@ -0,0 +1,22 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix +from sympy.core.backend import Rational + +def test_type_E(): + c = CartanType("E6") + m = Matrix(6, 6, [2, 0, -1, 0, 0, 0, 0, 2, 0, -1, 0, 0, + -1, 0, 2, -1, 0, 0, 0, -1, -1, 2, -1, 0, 0, 0, 0, + -1, 2, -1, 0, 0, 0, 0, -1, 2]) + assert c.cartan_matrix() == m + assert c.dimension() == 8 + assert c.simple_root(6) == [0, 0, 0, -1, 1, 0, 0, 0] + assert c.roots() == 72 + assert c.basis() == 78 + diag = " "*8 + "2\n" + " "*8 + "0\n" + " "*8 + "|\n" + " "*8 + "|\n" + diag += "---".join("0" for i in range(1, 6))+"\n" + diag += "1 " + " ".join(str(i) for i in range(3, 7)) + assert c.dynkin_diagram() == diag + posroots = c.positive_roots() + assert posroots[8] == [1, 0, 0, 0, 1, 0, 0, 0] + assert posroots[21] == [Rational(1,2),Rational(1,2),Rational(1,2),Rational(1,2), + Rational(1,2),Rational(-1,2),Rational(-1,2),Rational(1,2)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_F.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_F.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb58223d0b5886e6044108c9c5cc3bbf371dd14 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_F.py @@ -0,0 +1,24 @@ +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix +from sympy.core.backend import S + +def test_type_F(): + c = CartanType("F4") + m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -2, 0, 0, -1, 2, -1, 0, 0, -1, 2]) + assert c.cartan_matrix() == m + assert c.dimension() == 4 + assert c.simple_root(1) == [1, -1, 0, 0] + assert c.simple_root(2) == [0, 1, -1, 0] + assert c.simple_root(3) == [0, 0, 0, 1] + assert c.simple_root(4) == [-S.Half, -S.Half, -S.Half, -S.Half] + assert c.roots() == 48 + assert c.basis() == 52 + diag = "0---0=>=0---0\n" + " ".join(str(i) for i in range(1, 5)) + assert c.dynkin_diagram() == diag + assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0], 3: [1, 0, -1, 0], + 4: [1, 0, 1, 0], 5: [1, 0, 0, -1], 6: [1, 0, 0, 1], 7: [0, 1, -1, 0], + 8: [0, 1, 1, 0], 9: [0, 1, 0, -1], 10: [0, 1, 0, 1], 11: [0, 0, 1, -1], + 12: [0, 0, 1, 1], 13: [1, 0, 0, 0], 14: [0, 1, 0, 0], 15: [0, 0, 1, 0], + 16: [0, 0, 0, 1], 17: [S.Half, S.Half, S.Half, S.Half], 18: [S.Half, -S.Half, S.Half, S.Half], + 19: [S.Half, S.Half, -S.Half, S.Half], 20: [S.Half, S.Half, S.Half, -S.Half], 21: [S.Half, S.Half, -S.Half, -S.Half], + 22: [S.Half, -S.Half, S.Half, -S.Half], 23: [S.Half, -S.Half, -S.Half, S.Half], 24: [S.Half, -S.Half, -S.Half, -S.Half]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_G.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_G.py new file mode 100644 index 0000000000000000000000000000000000000000..c427eeb85bad8fc77d17a1563a7b796d4e0f217f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_type_G.py @@ -0,0 +1,16 @@ +# coding=utf-8 +from sympy.liealgebras.cartan_type import CartanType +from sympy.matrices import Matrix + +def test_type_G(): + c = CartanType("G2") + m = Matrix(2, 2, [2, -1, -3, 2]) + assert c.cartan_matrix() == m + assert c.simple_root(2) == [1, -2, 1] + assert c.basis() == 14 + assert c.roots() == 12 + assert c.dimension() == 3 + diag = "0≡<≡0\n1 2" + assert diag == c.dynkin_diagram() + assert c.positive_roots() == {1: [0, 1, -1], 2: [1, -2, 1], 3: [1, -1, 0], + 4: [1, 0, 1], 5: [1, 1, -2], 6: [2, -1, -1]} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_weyl_group.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_weyl_group.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e57246fdcb5a431d8bbd65f1f60e0254a9cdf0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/tests/test_weyl_group.py @@ -0,0 +1,35 @@ +from sympy.liealgebras.weyl_group import WeylGroup +from sympy.matrices import Matrix + +def test_weyl_group(): + c = WeylGroup("A3") + assert c.matrix_form('r1*r2') == Matrix([[0, 0, 1, 0], [1, 0, 0, 0], + [0, 1, 0, 0], [0, 0, 0, 1]]) + assert c.generators() == ['r1', 'r2', 'r3'] + assert c.group_order() == 24.0 + assert c.group_name() == "S4: the symmetric group acting on 4 elements." + assert c.coxeter_diagram() == "0---0---0\n1 2 3" + assert c.element_order('r1*r2*r3') == 4 + assert c.element_order('r1*r3*r2*r3') == 3 + d = WeylGroup("B5") + assert d.group_order() == 3840 + assert d.element_order('r1*r2*r4*r5') == 12 + assert d.matrix_form('r2*r3') == Matrix([[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) + assert d.element_order('r1*r2*r1*r3*r5') == 6 + e = WeylGroup("D5") + assert e.element_order('r2*r3*r5') == 4 + assert e.matrix_form('r2*r3*r5') == Matrix([[1, 0, 0, 0, 0], [0, 0, 0, 0, -1], + [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, -1, 0]]) + f = WeylGroup("G2") + assert f.element_order('r1*r2*r1*r2') == 3 + assert f.element_order('r2*r1*r1*r2') == 1 + + assert f.matrix_form('r1*r2*r1*r2') == Matrix([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + g = WeylGroup("F4") + assert g.matrix_form('r2*r3') == Matrix([[1, 0, 0, 0], [0, 1, 0, 0], + [0, 0, 0, -1], [0, 0, 1, 0]]) + + assert g.element_order('r2*r3') == 4 + h = WeylGroup("E6") + assert h.group_order() == 51840 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_a.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_a.py new file mode 100644 index 0000000000000000000000000000000000000000..96dc615366ae20d668d651620ac088f15751c50e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_a.py @@ -0,0 +1,164 @@ +from sympy.liealgebras.cartan_type import Standard_Cartan +from sympy.core.backend import eye + + +class TypeA(Standard_Cartan): + """ + This class contains the information about + the A series of simple Lie algebras. + ==== + """ + + def __new__(cls, n): + if n < 1: + raise ValueError("n cannot be less than 1") + return Standard_Cartan.__new__(cls, "A", n) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A4") + >>> c.dimension() + 5 + """ + return self.n+1 + + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + + n = self.n + root = [0]*(n+1) + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In A_n the ith simple root is the root which has a 1 + in the ith position, a -1 in the (i+1)th position, + and zeroes elsewhere. + + This method returns the ith simple root for the A series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A4") + >>> c.simple_root(1) + [1, -1, 0, 0, 0] + + """ + + return self.basic_root(i-1, i) + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of A_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n): + for j in range(i+1, n+1): + k += 1 + posroots[k] = self.basic_root(i, j) + return posroots + + def highest_root(self): + """ + Returns the highest weight root for A_n + """ + + return self.basic_root(0, self.n) + + def roots(self): + """ + Returns the total number of roots for A_n + """ + n = self.n + return n*(n+1) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for A_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + + """ + + n = self.n + m = 2 * eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0,1] = -1 + m[n-1, n-2] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of A_n + """ + n = self.n + return n**2 - 1 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with A_n + """ + n = self.n + return "su(" + str(n + 1) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n+1)) + "\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_b.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_b.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ee85502261f4702769067c64021521a2bc1725 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_b.py @@ -0,0 +1,170 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeB(Standard_Cartan): + + def __new__(cls, n): + if n < 2: + raise ValueError("n cannot be less than 2") + return Standard_Cartan.__new__(cls, "B", n) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("B3") + >>> c.dimension() + 3 + """ + + return self.n + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + root = [0]*self.n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In B_n the first n-1 simple roots are the same as the + roots in A_(n-1) (a 1 in the ith position, a -1 in + the (i+1)th position, and zeroes elsewhere). The n-th + simple root is the root with a 1 in the nth position + and zeroes elsewhere. + + This method returns the ith simple root for the B series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("B3") + >>> c.simple_root(2) + [0, 1, -1] + + """ + n = self.n + if i < n: + return self.basic_root(i-1, i) + else: + root = [0]*self.n + root[n-1] = 1 + return root + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of B_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 1 + posroots[k] = root + + return posroots + + def roots(self): + """ + Returns the total number of roots for B_n" + """ + + n = self.n + return 2*(n**2) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for B_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('B4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -2], + [ 0, 0, -1, 2]]) + + """ + + n = self.n + m = 2* eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0, 1] = -1 + m[n-2, n-1] = -2 + m[n-1, n-2] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of B_n + """ + + n = self.n + return (n**2 - n)/2 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with B_n + """ + + n = self.n + return "so(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n)) + "=>=0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_c.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_c.py new file mode 100644 index 0000000000000000000000000000000000000000..615bb900b5ba9613fd02e43f476d34eef0d5d35c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_c.py @@ -0,0 +1,169 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeC(Standard_Cartan): + + def __new__(cls, n): + if n < 3: + raise ValueError("n cannot be less than 3") + return Standard_Cartan.__new__(cls, "C", n) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("C3") + >>> c.dimension() + 3 + """ + n = self.n + return n + + def basic_root(self, i, j): + """Generate roots with 1 in ith position and a -1 in jth position + """ + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """The ith simple root for the C series + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In C_n, the first n-1 simple roots are the same as + the roots in A_(n-1) (a 1 in the ith position, a -1 + in the (i+1)th position, and zeroes elsewhere). The + nth simple root is the root in which there is a 2 in + the nth position and zeroes elsewhere. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("C3") + >>> c.simple_root(2) + [0, 1, -1] + + """ + + n = self.n + if i < n: + return self.basic_root(i-1,i) + else: + root = [0]*self.n + root[n-1] = 2 + return root + + + def positive_roots(self): + """Generates all the positive roots of A_n + + This is half of all of the roots of C_n; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 2 + posroots[k] = root + + return posroots + + def roots(self): + """ + Returns the total number of roots for C_n" + """ + + n = self.n + return 2*(n**2) + + def cartan_matrix(self): + """The Cartan matrix for C_n + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('C4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -2, 2]]) + + """ + + n = self.n + m = 2 * eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0,1] = -1 + m[n-1, n-2] = -2 + return m + + + def basis(self): + """ + Returns the number of independent generators of C_n + """ + + n = self.n + return n*(2*n + 1) + + def lie_algebra(self): + """ + Returns the Lie algebra associated with C_n" + """ + + n = self.n + return "sp(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n)) + "=<=0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_d.py new file mode 100644 index 0000000000000000000000000000000000000000..9450d76e906c79e23db0ce223ed0de03d71c1199 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_d.py @@ -0,0 +1,173 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeD(Standard_Cartan): + + def __new__(cls, n): + if n < 3: + raise ValueError("n cannot be less than 3") + return Standard_Cartan.__new__(cls, "D", n) + + + def dimension(self): + """Dmension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("D4") + >>> c.dimension() + 4 + """ + + return self.n + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In D_n, the first n-1 simple roots are the same as + the roots in A_(n-1) (a 1 in the ith position, a -1 + in the (i+1)th position, and zeroes elsewhere). + The nth simple root is the root in which there 1s in + the nth and (n-1)th positions, and zeroes elsewhere. + + This method returns the ith simple root for the D series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("D4") + >>> c.simple_root(2) + [0, 1, -1, 0] + + """ + + n = self.n + if i < n: + return self.basic_root(i-1, i) + else: + root = [0]*n + root[n-2] = 1 + root[n-1] = 1 + return root + + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of D_n + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + return posroots + + def roots(self): + """ + Returns the total number of roots for D_n" + """ + + n = self.n + return 2*n*(n-1) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for D_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('D4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, -1], + [ 0, -1, 2, 0], + [ 0, -1, 0, 2]]) + + """ + + n = self.n + m = 2*eye(n) + for i in range(1, n - 2): + m[i,i+1] = -1 + m[i,i-1] = -1 + m[n-2, n-3] = -1 + m[n-3, n-1] = -1 + m[n-1, n-3] = -1 + m[0, 1] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of D_n + """ + n = self.n + return n*(n-1)/2 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with D_n" + """ + + n = self.n + return "so(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = " "*4*(n-3) + str(n-1) + "\n" + diag += " "*4*(n-3) + "0\n" + diag += " "*4*(n-3) +"|\n" + diag += " "*4*(n-3) + "|\n" + diag += "---".join("0" for i in range(1,n)) + "\n" + diag += " ".join(str(i) for i in range(1, n-1)) + " "+str(n) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_e.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_e.py new file mode 100644 index 0000000000000000000000000000000000000000..3db9a820d31bff31acc58ba1592a1b10f8be53db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_e.py @@ -0,0 +1,275 @@ +import itertools + +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye, Rational +from sympy.core.singleton import S + +class TypeE(Standard_Cartan): + + def __new__(cls, n): + if n < 6 or n > 8: + raise ValueError("Invalid value of n") + return Standard_Cartan.__new__(cls, "E", n) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("E6") + >>> c.dimension() + 8 + """ + + return 8 + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a -1 in the ith position and a 1 + in the jth position. + + """ + + root = [0]*8 + root[i] = -1 + root[j] = 1 + return root + + def simple_root(self, i): + """ + Every Lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + This method returns the ith simple root for E_n. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("E6") + >>> c.simple_root(2) + [1, 1, 0, 0, 0, 0, 0, 0] + """ + n = self.n + if i == 1: + root = [-0.5]*8 + root[0] = 0.5 + root[7] = 0.5 + return root + elif i == 2: + root = [0]*8 + root[1] = 1 + root[0] = 1 + return root + else: + if i in (7, 8) and n == 6: + raise ValueError("E6 only has six simple roots!") + if i == 8 and n == 7: + raise ValueError("E7 only has seven simple roots!") + + return self.basic_root(i - 3, i - 2) + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of E_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + n = self.n + neghalf = Rational(-1, 2) + poshalf = S.Half + if n == 6: + posroots = {} + k = 0 + for i in range(n-1): + for j in range(i+1, n-1): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e in itertools.product( + range(2), range(2), range(2), range(2), range(2)): + if (a + b + c + d + e)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + posroots[k] = root[:] + return posroots + if n == 7: + posroots = {} + k = 0 + for i in range(n-1): + for j in range(i+1, n-1): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + k += 1 + posroots[k] = [0, 0, 0, 0, 0, 1, 1, 0] + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e, f in itertools.product( + range(2), range(2), range(2), range(2), range(2), range(2)): + if (a + b + c + d + e + f)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + if f == 1: + root[5] = poshalf + posroots[k] = root[:] + return posroots + if n == 8: + posroots = {} + k = 0 + for i in range(n): + for j in range(i+1, n): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e, f, g in itertools.product( + range(2), range(2), range(2), range(2), range(2), + range(2), range(2)): + if (a + b + c + d + e + f + g)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + if f == 1: + root[5] = poshalf + if g == 1: + root[6] = poshalf + posroots[k] = root[:] + return posroots + + + + def roots(self): + """ + Returns the total number of roots of E_n + """ + + n = self.n + if n == 6: + return 72 + if n == 7: + return 126 + if n == 8: + return 240 + + + def cartan_matrix(self): + """ + Returns the Cartan matrix for G_2 + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + + + """ + + n = self.n + m = 2*eye(n) + for i in range(3, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0, 2] = m[2, 0] = -1 + m[1, 3] = m[3, 1] = -1 + m[2, 3] = -1 + m[n-1, n-2] = -1 + return m + + + def basis(self): + """ + Returns the number of independent generators of E_n + """ + + n = self.n + if n == 6: + return 78 + if n == 7: + return 133 + if n == 8: + return 248 + + def dynkin_diagram(self): + n = self.n + diag = " "*8 + str(2) + "\n" + diag += " "*8 + "0\n" + diag += " "*8 + "|\n" + diag += " "*8 + "|\n" + diag += "---".join("0" for i in range(1, n)) + "\n" + diag += "1 " + " ".join(str(i) for i in range(3, n+1)) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_f.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_f.py new file mode 100644 index 0000000000000000000000000000000000000000..f04da557870f2cd21818cf69c454ef598e2ab65a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_f.py @@ -0,0 +1,162 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import Matrix, Rational + + +class TypeF(Standard_Cartan): + + def __new__(cls, n): + if n != 4: + raise ValueError("n should be 4") + return Standard_Cartan.__new__(cls, "F", 4) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("F4") + >>> c.dimension() + 4 + """ + + return 4 + + + def basic_root(self, i, j): + """Generate roots with 1 in ith position and -1 in jth position + + """ + + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """The ith simple root of F_4 + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("F4") + >>> c.simple_root(3) + [0, 0, 0, 1] + + """ + + if i < 3: + return self.basic_root(i-1, i) + if i == 3: + root = [0]*4 + root[3] = 1 + return root + if i == 4: + root = [Rational(-1, 2)]*4 + return root + + def positive_roots(self): + """Generate all the positive roots of A_n + + This is half of all of the roots of F_4; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 1 + posroots[k] = root + + k += 1 + root = [Rational(1, 2)]*n + posroots[k] = root + for i in range(1, 4): + k += 1 + root = [Rational(1, 2)]*n + root[i] = Rational(-1, 2) + posroots[k] = root + + posroots[k+1] = [Rational(1, 2), Rational(1, 2), Rational(-1, 2), Rational(-1, 2)] + posroots[k+2] = [Rational(1, 2), Rational(-1, 2), Rational(1, 2), Rational(-1, 2)] + posroots[k+3] = [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(1, 2)] + posroots[k+4] = [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(-1, 2)] + + return posroots + + + def roots(self): + """ + Returns the total number of roots for F_4 + """ + return 48 + + def cartan_matrix(self): + """The Cartan matrix for F_4 + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + """ + + m = Matrix( 4, 4, [2, -1, 0, 0, -1, 2, -2, 0, 0, + -1, 2, -1, 0, 0, -1, 2]) + return m + + def basis(self): + """ + Returns the number of independent generators of F_4 + """ + return 52 + + def dynkin_diagram(self): + diag = "0---0=>=0---0\n" + diag += " ".join(str(i) for i in range(1, 5)) + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_g.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_g.py new file mode 100644 index 0000000000000000000000000000000000000000..014409cf5ed966b53c596b14e0073e89ceee05b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/type_g.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +from .cartan_type import Standard_Cartan +from sympy.core.backend import Matrix + +class TypeG(Standard_Cartan): + + def __new__(cls, n): + if n != 2: + raise ValueError("n should be 2") + return Standard_Cartan.__new__(cls, "G", 2) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.dimension() + 3 + """ + return 3 + + def simple_root(self, i): + """The ith simple root of G_2 + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.simple_root(1) + [0, 1, -1] + + """ + if i == 1: + return [0, 1, -1] + else: + return [1, -2, 1] + + def positive_roots(self): + """Generate all the positive roots of A_n + + This is half of all of the roots of A_n; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + roots = {1: [0, 1, -1], 2: [1, -2, 1], 3: [1, -1, 0], 4: [1, 0, 1], + 5: [1, 1, -2], 6: [2, -1, -1]} + return roots + + def roots(self): + """ + Returns the total number of roots of G_2" + """ + return 12 + + def cartan_matrix(self): + """The Cartan matrix for G_2 + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.cartan_matrix() + Matrix([ + [ 2, -1], + [-3, 2]]) + + """ + + m = Matrix( 2, 2, [2, -1, -3, 2]) + return m + + def basis(self): + """ + Returns the number of independent generators of G_2 + """ + return 14 + + def dynkin_diagram(self): + diag = "0≡<≡0\n1 2" + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/weyl_group.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/weyl_group.py new file mode 100644 index 0000000000000000000000000000000000000000..15ff70b6f1fc4649268a38ee13e1f717a1c9f5fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/liealgebras/weyl_group.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from .cartan_type import CartanType +from mpmath import fac +from sympy.core.backend import Matrix, eye, Rational, igcd +from sympy.core.basic import Atom + +class WeylGroup(Atom): + + """ + For each semisimple Lie group, we have a Weyl group. It is a subgroup of + the isometry group of the root system. Specifically, it's the subgroup + that is generated by reflections through the hyperplanes orthogonal to + the roots. Therefore, Weyl groups are reflection groups, and so a Weyl + group is a finite Coxeter group. + + """ + + def __new__(cls, cartantype): + obj = Atom.__new__(cls) + obj.cartan_type = CartanType(cartantype) + return obj + + def generators(self): + """ + This method creates the generating reflections of the Weyl group for + a given Lie algebra. For a Lie algebra of rank n, there are n + different generating reflections. This function returns them as + a list. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("F4") + >>> c.generators() + ['r1', 'r2', 'r3', 'r4'] + """ + n = self.cartan_type.rank() + generators = [] + for i in range(1, n+1): + reflection = "r"+str(i) + generators.append(reflection) + return generators + + def group_order(self): + """ + This method returns the order of the Weyl group. + For types A, B, C, D, and E the order depends on + the rank of the Lie algebra. For types F and G, + the order is fixed. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("D4") + >>> c.group_order() + 192.0 + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + return fac(n+1) + + if self.cartan_type.series in ("B", "C"): + return fac(n)*(2**n) + + if self.cartan_type.series == "D": + return fac(n)*(2**(n-1)) + + if self.cartan_type.series == "E": + if n == 6: + return 51840 + if n == 7: + return 2903040 + if n == 8: + return 696729600 + if self.cartan_type.series == "F": + return 1152 + + if self.cartan_type.series == "G": + return 12 + + def group_name(self): + """ + This method returns some general information about the Weyl group for + a given Lie algebra. It returns the name of the group and the elements + it acts on, if relevant. + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + return "S"+str(n+1) + ": the symmetric group acting on " + str(n+1) + " elements." + + if self.cartan_type.series in ("B", "C"): + return "The hyperoctahedral group acting on " + str(2*n) + " elements." + + if self.cartan_type.series == "D": + return "The symmetry group of the " + str(n) + "-dimensional demihypercube." + + if self.cartan_type.series == "E": + if n == 6: + return "The symmetry group of the 6-polytope." + + if n == 7: + return "The symmetry group of the 7-polytope." + + if n == 8: + return "The symmetry group of the 8-polytope." + + if self.cartan_type.series == "F": + return "The symmetry group of the 24-cell, or icositetrachoron." + + if self.cartan_type.series == "G": + return "D6, the dihedral group of order 12, and symmetry group of the hexagon." + + def element_order(self, weylelt): + """ + This method returns the order of a given Weyl group element, which should + be specified by the user in the form of products of the generating + reflections, i.e. of the form r1*r2 etc. + + For types A-F, this method current works by taking the matrix form of + the specified element, and then finding what power of the matrix is the + identity. It then returns this power. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> b = WeylGroup("B4") + >>> b.element_order('r1*r4*r2') + 4 + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n+1): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "D": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "E": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(8): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "G": + elts = list(weylelt) + reflections = elts[1::3] + m = self.delete_doubles(reflections) + while self.delete_doubles(m) != m: + m = self.delete_doubles(m) + reflections = m + if len(reflections) % 2 == 1: + return 2 + + elif len(reflections) == 0: + return 1 + + else: + if len(reflections) == 1: + return 2 + else: + m = len(reflections) // 2 + lcm = (6 * m)/ igcd(m, 6) + order = lcm / m + return order + + + if self.cartan_type.series == 'F': + a = self.matrix_form(weylelt) + order = 1 + while a != eye(4): + a *= self.matrix_form(weylelt) + order += 1 + return order + + + if self.cartan_type.series in ("B", "C"): + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n): + a *= self.matrix_form(weylelt) + order += 1 + return order + + def delete_doubles(self, reflections): + """ + This is a helper method for determining the order of an element in the + Weyl group of G2. It takes a Weyl element and if repeated simple reflections + in it, it deletes them. + """ + counter = 0 + copy = list(reflections) + for elt in copy: + if counter < len(copy)-1: + if copy[counter + 1] == elt: + del copy[counter] + del copy[counter] + counter += 1 + + + return copy + + + def matrix_form(self, weylelt): + """ + This method takes input from the user in the form of products of the + generating reflections, and returns the matrix corresponding to the + element of the Weyl group. Since each element of the Weyl group is + a reflection of some type, there is a corresponding matrix representation. + This method uses the standard representation for all the generating + reflections. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> f = WeylGroup("F4") + >>> f.matrix_form('r2*r3') + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, -1], + [0, 0, 1, 0]]) + + """ + elts = list(weylelt) + reflections = elts[1::3] + n = self.cartan_type.rank() + if self.cartan_type.series == 'A': + matrixform = eye(n+1) + for elt in reflections: + a = int(elt) + mat = eye(n+1) + mat[a-1, a-1] = 0 + mat[a-1, a] = 1 + mat[a, a-1] = 1 + mat[a, a] = 0 + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'D': + matrixform = eye(n) + for elt in reflections: + a = int(elt) + mat = eye(n) + if a < n: + mat[a-1, a-1] = 0 + mat[a-1, a] = 1 + mat[a, a-1] = 1 + mat[a, a] = 0 + matrixform *= mat + else: + mat[n-2, n-1] = -1 + mat[n-2, n-2] = 0 + mat[n-1, n-2] = -1 + mat[n-1, n-1] = 0 + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'G': + matrixform = eye(3) + for elt in reflections: + a = int(elt) + if a == 1: + gen1 = Matrix([[1, 0, 0], [0, 0, 1], [0, 1, 0]]) + matrixform *= gen1 + else: + gen2 = Matrix([[Rational(2, 3), Rational(2, 3), Rational(-1, 3)], + [Rational(2, 3), Rational(-1, 3), Rational(2, 3)], + [Rational(-1, 3), Rational(2, 3), Rational(2, 3)]]) + matrixform *= gen2 + return matrixform + + if self.cartan_type.series == 'F': + matrixform = eye(4) + for elt in reflections: + a = int(elt) + if a == 1: + mat = Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + matrixform *= mat + elif a == 2: + mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + matrixform *= mat + elif a == 3: + mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]]) + matrixform *= mat + else: + + mat = Matrix([[Rational(1, 2), Rational(1, 2), Rational(1, 2), Rational(1, 2)], + [Rational(1, 2), Rational(1, 2), Rational(-1, 2), Rational(-1, 2)], + [Rational(1, 2), Rational(-1, 2), Rational(1, 2), Rational(-1, 2)], + [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(1, 2)]]) + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'E': + matrixform = eye(8) + for elt in reflections: + a = int(elt) + if a == 1: + mat = Matrix([[Rational(3, 4), Rational(1, 4), Rational(1, 4), Rational(1, 4), + Rational(1, 4), Rational(1, 4), Rational(1, 4), Rational(-1, 4)], + [Rational(1, 4), Rational(3, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(1, 4), Rational(-1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(3, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(3, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(3, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-3, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4)]]) + matrixform *= mat + elif a == 2: + mat = eye(8) + mat[0, 0] = 0 + mat[0, 1] = -1 + mat[1, 0] = -1 + mat[1, 1] = 0 + matrixform *= mat + else: + mat = eye(8) + mat[a-3, a-3] = 0 + mat[a-3, a-2] = 1 + mat[a-2, a-3] = 1 + mat[a-2, a-2] = 0 + matrixform *= mat + return matrixform + + + if self.cartan_type.series in ("B", "C"): + matrixform = eye(n) + for elt in reflections: + a = int(elt) + mat = eye(n) + if a == 1: + mat[0, 0] = -1 + matrixform *= mat + else: + mat[a - 2, a - 2] = 0 + mat[a-2, a-1] = 1 + mat[a - 1, a - 2] = 1 + mat[a -1, a - 1] = 0 + matrixform *= mat + return matrixform + + + + def coxeter_diagram(self): + """ + This method returns the Coxeter diagram corresponding to a Weyl group. + The Coxeter diagram can be obtained from a Lie algebra's Dynkin diagram + by deleting all arrows; the Coxeter diagram is the undirected graph. + The vertices of the Coxeter diagram represent the generating reflections + of the Weyl group, $s_i$. An edge is drawn between $s_i$ and $s_j$ if the order + $m(i, j)$ of $s_is_j$ is greater than two. If there is one edge, the order + $m(i, j)$ is 3. If there are two edges, the order $m(i, j)$ is 4, and if there + are three edges, the order $m(i, j)$ is 6. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("B3") + >>> print(c.coxeter_diagram()) + 0---0===0 + 1 2 3 + """ + n = self.cartan_type.rank() + if self.cartan_type.series in ("A", "D", "E"): + return self.cartan_type.dynkin_diagram() + + if self.cartan_type.series in ("B", "C"): + diag = "---".join("0" for i in range(1, n)) + "===0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag + + if self.cartan_type.series == "F": + diag = "0---0===0---0\n" + diag += " ".join(str(i) for i in range(1, 5)) + return diag + + if self.cartan_type.series == "G": + diag = "0≡≡≡0\n1 2" + return diag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb26903a384e9df3a0f02a92c488c5442cee1486 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__init__.py @@ -0,0 +1,12 @@ +from .boolalg import (to_cnf, to_dnf, to_nnf, And, Or, Not, Xor, Nand, Nor, Implies, + Equivalent, ITE, POSform, SOPform, simplify_logic, bool_map, true, false, + gateinputcount) +from .inference import satisfiable + +__all__ = [ + 'to_cnf', 'to_dnf', 'to_nnf', 'And', 'Or', 'Not', 'Xor', 'Nand', 'Nor', + 'Implies', 'Equivalent', 'ITE', 'POSform', 'SOPform', 'simplify_logic', + 'bool_map', 'true', 'false', 'gateinputcount', + + 'satisfiable', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29405a332fbcd08028051092398178eb3414f15e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/inference.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..535b8da455539b3c2e31e86cfca8aab4e5112e16 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/__pycache__/inference.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53522a7591d80105f84c9a2a2fd6a25c15f324ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..610cc13a28d9cf4e848e51e3b0bea20c14bd69e2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0520157d493400a71990a8d6dc6037c13af25697 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bb7b0bb636fc826702fda84bb36ea9386034ac3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7b7b7dae9672cc9f80d73fa730e27b18df3650 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81db1d70f828bc15ef530155ceb9c5eb81cd940a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0deff2b5bd94fe02457408b0939a1ad0538c184a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll.py new file mode 100644 index 0000000000000000000000000000000000000000..40e6802f7626c982a9a6cd7146baea3ac6b8b6e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll.py @@ -0,0 +1,308 @@ +"""Implementation of DPLL algorithm + +Further improvements: eliminate calls to pl_true, implement branching rules, +efficient unit propagation. + +References: + - https://en.wikipedia.org/wiki/DPLL_algorithm + - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm +""" + +from sympy.core.sorting import default_sort_key +from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \ + to_int_repr, _find_predicates +from sympy.assumptions.cnf import CNF +from sympy.logic.inference import pl_true, literal_symbol + + +def dpll_satisfiable(expr): + """ + Check satisfiability of a propositional sentence. + It returns a model rather than True when it succeeds + + >>> from sympy.abc import A, B + >>> from sympy.logic.algorithms.dpll import dpll_satisfiable + >>> dpll_satisfiable(A & ~B) + {A: True, B: False} + >>> dpll_satisfiable(A & ~A) + False + + """ + if not isinstance(expr, CNF): + clauses = conjuncts(to_cnf(expr)) + else: + clauses = expr.clauses + if False in clauses: + return False + symbols = sorted(_find_predicates(expr), key=default_sort_key) + symbols_int_repr = set(range(1, len(symbols) + 1)) + clauses_int_repr = to_int_repr(clauses, symbols) + result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {}) + if not result: + return result + output = {} + for key in result: + output.update({symbols[key - 1]: result[key]}) + return output + + +def dpll(clauses, symbols, model): + """ + Compute satisfiability in a partial model. + Clauses is an array of conjuncts. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import dpll + >>> dpll([A, B, D], [A, B], {D: False}) + False + + """ + # compute DP kernel + P, value = find_unit_clause(clauses, model) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = ~P + clauses = unit_propagate(clauses, P) + P, value = find_unit_clause(clauses, model) + P, value = find_pure_symbol(symbols, clauses) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = ~P + clauses = unit_propagate(clauses, P) + P, value = find_pure_symbol(symbols, clauses) + # end DP kernel + unknown_clauses = [] + for c in clauses: + val = pl_true(c, model) + if val is False: + return False + if val is not True: + unknown_clauses.append(c) + if not unknown_clauses: + return model + if not clauses: + return model + P = symbols.pop() + model_copy = model.copy() + model.update({P: True}) + model_copy.update({P: False}) + symbols_copy = symbols[:] + return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or + dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy)) + + +def dpll_int_repr(clauses, symbols, model): + """ + Compute satisfiability in a partial model. + Arguments are expected to be in integer representation + + >>> from sympy.logic.algorithms.dpll import dpll_int_repr + >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False}) + False + + """ + # compute DP kernel + P, value = find_unit_clause_int_repr(clauses, model) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = -P + clauses = unit_propagate_int_repr(clauses, P) + P, value = find_unit_clause_int_repr(clauses, model) + P, value = find_pure_symbol_int_repr(symbols, clauses) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = -P + clauses = unit_propagate_int_repr(clauses, P) + P, value = find_pure_symbol_int_repr(symbols, clauses) + # end DP kernel + unknown_clauses = [] + for c in clauses: + val = pl_true_int_repr(c, model) + if val is False: + return False + if val is not True: + unknown_clauses.append(c) + if not unknown_clauses: + return model + P = symbols.pop() + model_copy = model.copy() + model.update({P: True}) + model_copy.update({P: False}) + symbols_copy = symbols.copy() + return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or + dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy)) + +### helper methods for DPLL + + +def pl_true_int_repr(clause, model={}): + """ + Lightweight version of pl_true. + Argument clause represents the set of args of an Or clause. This is used + inside dpll_int_repr, it is not meant to be used directly. + + >>> from sympy.logic.algorithms.dpll import pl_true_int_repr + >>> pl_true_int_repr({1, 2}, {1: False}) + >>> pl_true_int_repr({1, 2}, {1: False, 2: False}) + False + + """ + result = False + for lit in clause: + if lit < 0: + p = model.get(-lit) + if p is not None: + p = not p + else: + p = model.get(lit) + if p is True: + return True + elif p is None: + result = None + return result + + +def unit_propagate(clauses, symbol): + """ + Returns an equivalent set of clauses + If a set of clauses contains the unit clause l, the other clauses are + simplified by the application of the two following rules: + + 1. every clause containing l is removed + 2. in every clause that contains ~l this literal is deleted + + Arguments are expected to be in CNF. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import unit_propagate + >>> unit_propagate([A | B, D | ~B, B], B) + [D, B] + + """ + output = [] + for c in clauses: + if c.func != Or: + output.append(c) + continue + for arg in c.args: + if arg == ~symbol: + output.append(Or(*[x for x in c.args if x != ~symbol])) + break + if arg == symbol: + break + else: + output.append(c) + return output + + +def unit_propagate_int_repr(clauses, s): + """ + Same as unit_propagate, but arguments are expected to be in integer + representation + + >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr + >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2) + [{3}] + + """ + negated = {-s} + return [clause - negated for clause in clauses if s not in clause] + + +def find_pure_symbol(symbols, unknown_clauses): + """ + Find a symbol and its value if it appears only as a positive literal + (or only as a negative) in clauses. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import find_pure_symbol + >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A]) + (A, True) + + """ + for sym in symbols: + found_pos, found_neg = False, False + for c in unknown_clauses: + if not found_pos and sym in disjuncts(c): + found_pos = True + if not found_neg and Not(sym) in disjuncts(c): + found_neg = True + if found_pos != found_neg: + return sym, found_pos + return None, None + + +def find_pure_symbol_int_repr(symbols, unknown_clauses): + """ + Same as find_pure_symbol, but arguments are expected + to be in integer representation + + >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr + >>> find_pure_symbol_int_repr({1,2,3}, + ... [{1, -2}, {-2, -3}, {3, 1}]) + (1, True) + + """ + all_symbols = set().union(*unknown_clauses) + found_pos = all_symbols.intersection(symbols) + found_neg = all_symbols.intersection([-s for s in symbols]) + for p in found_pos: + if -p not in found_neg: + return p, True + for p in found_neg: + if -p not in found_pos: + return -p, False + return None, None + + +def find_unit_clause(clauses, model): + """ + A unit clause has only 1 variable that is not bound in the model. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import find_unit_clause + >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True}) + (B, False) + + """ + for clause in clauses: + num_not_in_model = 0 + for literal in disjuncts(clause): + sym = literal_symbol(literal) + if sym not in model: + num_not_in_model += 1 + P, value = sym, not isinstance(literal, Not) + if num_not_in_model == 1: + return P, value + return None, None + + +def find_unit_clause_int_repr(clauses, model): + """ + Same as find_unit_clause, but arguments are expected to be in + integer representation. + + >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr + >>> find_unit_clause_int_repr([{1, 2, 3}, + ... {2, -3}, {1, -2}], {1: True}) + (2, False) + + """ + bound = set(model) | {-sym for sym in model} + for clause in clauses: + unbound = clause - bound + if len(unbound) == 1: + p = unbound.pop() + if p < 0: + return -p, False + else: + return p, True + return None, None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll2.py new file mode 100644 index 0000000000000000000000000000000000000000..4f18c81189d6be565dc9b7caa3f0bf48e978bb56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/dpll2.py @@ -0,0 +1,688 @@ +"""Implementation of DPLL algorithm + +Features: + - Clause learning + - Watch literal scheme + - VSIDS heuristic + +References: + - https://en.wikipedia.org/wiki/DPLL_algorithm +""" + +from collections import defaultdict +from heapq import heappush, heappop + +from sympy.core.sorting import ordered +from sympy.assumptions.cnf import EncodedCNF + +from sympy.logic.algorithms.lra_theory import LRASolver + + +def dpll_satisfiable(expr, all_models=False, use_lra_theory=False): + """ + Check satisfiability of a propositional sentence. + It returns a model rather than True when it succeeds. + Returns a generator of all models if all_models is True. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable + >>> dpll_satisfiable(A & ~B) + {A: True, B: False} + >>> dpll_satisfiable(A & ~A) + False + + """ + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if use_lra_theory: + lra, immediate_conflicts = LRASolver.from_encoded_cnf(expr) + else: + lra = None + immediate_conflicts = [] + solver = SATSolver(expr.data + immediate_conflicts, expr.variables, set(), expr.symbols, lra_theory=lra) + models = solver._find_model() + + if all_models: + return _all_models(models) + + try: + return next(models) + except StopIteration: + return False + + # Uncomment to confirm the solution is valid (hitting set for the clauses) + #else: + #for cls in clauses_int_repr: + #assert solver.var_settings.intersection(cls) + + +def _all_models(models): + satisfiable = False + try: + while True: + yield next(models) + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + +class SATSolver: + """ + Class for representing a SAT solver capable of + finding a model to a boolean theory in conjunctive + normal form. + """ + + def __init__(self, clauses, variables, var_settings, symbols=None, + heuristic='vsids', clause_learning='none', INTERVAL=500, + lra_theory = None): + + self.var_settings = var_settings + self.heuristic = heuristic + self.is_unsatisfied = False + self._unit_prop_queue = [] + self.update_functions = [] + self.INTERVAL = INTERVAL + + if symbols is None: + self.symbols = list(ordered(variables)) + else: + self.symbols = symbols + + self._initialize_variables(variables) + self._initialize_clauses(clauses) + + if 'vsids' == heuristic: + self._vsids_init() + self.heur_calculate = self._vsids_calculate + self.heur_lit_assigned = self._vsids_lit_assigned + self.heur_lit_unset = self._vsids_lit_unset + self.heur_clause_added = self._vsids_clause_added + + # Note: Uncomment this if/when clause learning is enabled + #self.update_functions.append(self._vsids_decay) + + else: + raise NotImplementedError + + if 'simple' == clause_learning: + self.add_learned_clause = self._simple_add_learned_clause + self.compute_conflict = self._simple_compute_conflict + self.update_functions.append(self._simple_clean_clauses) + elif 'none' == clause_learning: + self.add_learned_clause = lambda x: None + self.compute_conflict = lambda: None + else: + raise NotImplementedError + + # Create the base level + self.levels = [Level(0)] + self._current_level.varsettings = var_settings + + # Keep stats + self.num_decisions = 0 + self.num_learned_clauses = 0 + self.original_num_clauses = len(self.clauses) + + self.lra = lra_theory + + def _initialize_variables(self, variables): + """Set up the variable data structures needed.""" + self.sentinels = defaultdict(set) + self.occurrence_count = defaultdict(int) + self.variable_set = [False] * (len(variables) + 1) + + def _initialize_clauses(self, clauses): + """Set up the clause data structures needed. + + For each clause, the following changes are made: + - Unit clauses are queued for propagation right away. + - Non-unit clauses have their first and last literals set as sentinels. + - The number of clauses a literal appears in is computed. + """ + self.clauses = [list(clause) for clause in clauses] + + for i, clause in enumerate(self.clauses): + + # Handle the unit clauses + if 1 == len(clause): + self._unit_prop_queue.append(clause[0]) + continue + + self.sentinels[clause[0]].add(i) + self.sentinels[clause[-1]].add(i) + + for lit in clause: + self.occurrence_count[lit] += 1 + + def _find_model(self): + """ + Main DPLL loop. Returns a generator of models. + + Variables are chosen successively, and assigned to be either + True or False. If a solution is not found with this setting, + the opposite is chosen and the search continues. The solver + halts when every variable has a setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> list(l._find_model()) + [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}] + + >>> from sympy.abc import A, B, C + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set(), [A, B, C]) + >>> list(l._find_model()) + [{A: True, B: False, C: False}, {A: True, B: True, C: True}] + + """ + + # We use this variable to keep track of if we should flip a + # variable setting in successive rounds + flip_var = False + + # Check if unit prop says the theory is unsat right off the bat + self._simplify() + if self.is_unsatisfied: + return + + # While the theory still has clauses remaining + while True: + # Perform cleanup / fixup at regular intervals + if self.num_decisions % self.INTERVAL == 0: + for func in self.update_functions: + func() + + if flip_var: + # We have just backtracked and we are trying to opposite literal + flip_var = False + lit = self._current_level.decision + + else: + # Pick a literal to set + lit = self.heur_calculate() + self.num_decisions += 1 + + # Stopping condition for a satisfying theory + if 0 == lit: + + # check if assignment satisfies lra theory + if self.lra: + for enc_var in self.var_settings: + res = self.lra.assert_lit(enc_var) + if res is not None: + break + res = self.lra.check() + self.lra.reset_bounds() + else: + res = None + if res is None or res[0]: + yield {self.symbols[abs(lit) - 1]: + lit > 0 for lit in self.var_settings} + else: + self._simple_add_learned_clause(res[1]) + + # backtrack until we unassign one of the literals causing the conflict + while not any(-lit in res[1] for lit in self._current_level.var_settings): + self._undo() + + while self._current_level.flipped: + self._undo() + if len(self.levels) == 1: + return + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + continue + + # Start the new decision level + self.levels.append(Level(lit)) + + # Assign the literal, updating the clauses it satisfies + self._assign_literal(lit) + + # _simplify the theory + self._simplify() + + # Check if we've made the theory unsat + if self.is_unsatisfied: + + self.is_unsatisfied = False + + # We unroll all of the decisions until we can flip a literal + while self._current_level.flipped: + self._undo() + + # If we've unrolled all the way, the theory is unsat + if 1 == len(self.levels): + return + + # Detect and add a learned clause + self.add_learned_clause(self.compute_conflict()) + + # Try the opposite setting of the most recent decision + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + + ######################## + # Helper Methods # + ######################## + @property + def _current_level(self): + """The current decision level data structure + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {2}], {1, 2}, set()) + >>> next(l._find_model()) + {1: True, 2: True} + >>> l._current_level.decision + 0 + >>> l._current_level.flipped + False + >>> l._current_level.var_settings + {1, 2} + + """ + return self.levels[-1] + + def _clause_sat(self, cls): + """Check if a clause is satisfied by the current variable setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {-1}], {1}, set()) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l._clause_sat(0) + False + >>> l._clause_sat(1) + True + + """ + for lit in self.clauses[cls]: + if lit in self.var_settings: + return True + return False + + def _is_sentinel(self, lit, cls): + """Check if a literal is a sentinel of a given clause. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._is_sentinel(2, 3) + True + >>> l._is_sentinel(-3, 1) + False + + """ + return cls in self.sentinels[lit] + + def _assign_literal(self, lit): + """Make a literal assignment. + + The literal assignment must be recorded as part of the current + decision level. Additionally, if the literal is marked as a + sentinel of any clause, then a new sentinel must be chosen. If + this is not possible, then unit propagation is triggered and + another literal is added to the queue to be set in the future. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l.var_settings + {-3, -2, 1} + + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l._assign_literal(-1) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l.var_settings + {-1} + + """ + self.var_settings.add(lit) + self._current_level.var_settings.add(lit) + self.variable_set[abs(lit)] = True + self.heur_lit_assigned(lit) + + sentinel_list = list(self.sentinels[-lit]) + + for cls in sentinel_list: + if not self._clause_sat(cls): + other_sentinel = None + for newlit in self.clauses[cls]: + if newlit != -lit: + if self._is_sentinel(newlit, cls): + other_sentinel = newlit + elif not self.variable_set[abs(newlit)]: + self.sentinels[-lit].remove(cls) + self.sentinels[newlit].add(cls) + other_sentinel = None + break + + # Check if no sentinel update exists + if other_sentinel: + self._unit_prop_queue.append(other_sentinel) + + def _undo(self): + """ + _undo the changes of the most recent decision level. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (-3, {-3, -2}, False) + >>> l._undo() + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (0, {1}, False) + + """ + # Undo the variable settings + for lit in self._current_level.var_settings: + self.var_settings.remove(lit) + self.heur_lit_unset(lit) + self.variable_set[abs(lit)] = False + + # Pop the level off the stack + self.levels.pop() + + ######################### + # Propagation # + ######################### + """ + Propagation methods should attempt to soundly simplify the boolean + theory, and return True if any simplification occurred and False + otherwise. + """ + def _simplify(self): + """Iterate over the various forms of propagation to simplify the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.variable_set + [False, False, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simplify() + + >>> l.variable_set + [False, True, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3}, + ...3: {2, 4}} + + """ + changed = True + while changed: + changed = False + changed |= self._unit_prop() + changed |= self._pure_literal() + + def _unit_prop(self): + """Perform unit propagation on the current theory.""" + result = len(self._unit_prop_queue) > 0 + while self._unit_prop_queue: + next_lit = self._unit_prop_queue.pop() + if -next_lit in self.var_settings: + self.is_unsatisfied = True + self._unit_prop_queue = [] + return False + else: + self._assign_literal(next_lit) + + return result + + def _pure_literal(self): + """Look for pure literals and assign them when found.""" + return False + + ######################### + # Heuristics # + ######################### + def _vsids_init(self): + """Initialize the data structures needed for the VSIDS heuristic.""" + self.lit_heap = [] + self.lit_scores = {} + + for var in range(1, len(self.variable_set)): + self.lit_scores[var] = float(-self.occurrence_count[var]) + self.lit_scores[-var] = float(-self.occurrence_count[-var]) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_decay(self): + """Decay the VSIDS scores for every literal. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_decay() + + >>> l.lit_scores + {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0} + + """ + # We divide every literal score by 2 for a decay factor + # Note: This doesn't change the heap property + for lit in self.lit_scores.keys(): + self.lit_scores[lit] /= 2.0 + + def _vsids_calculate(self): + """ + VSIDS Heuristic Calculation + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_calculate() + -3 + + >>> l.lit_heap + [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)] + + """ + if len(self.lit_heap) == 0: + return 0 + + # Clean out the front of the heap as long the variables are set + while self.variable_set[abs(self.lit_heap[0][1])]: + heappop(self.lit_heap) + if len(self.lit_heap) == 0: + return 0 + + return heappop(self.lit_heap)[1] + + def _vsids_lit_assigned(self, lit): + """Handle the assignment of a literal for the VSIDS heuristic.""" + pass + + def _vsids_lit_unset(self, lit): + """Handle the unsetting of a literal for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_lit_unset(2) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1), + ...(-2.0, 2), (0.0, 1)] + + """ + var = abs(lit) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_clause_added(self, cls): + """Handle the addition of a new clause for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_clause_added({2, -3}) + + >>> l.num_learned_clauses + 1 + >>> l.lit_scores + {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0} + + """ + self.num_learned_clauses += 1 + for lit in cls: + self.lit_scores[lit] += 1 + + ######################## + # Clause Learning # + ######################## + def _simple_add_learned_clause(self, cls): + """Add a new clause to the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simple_add_learned_clause([3]) + + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}} + + """ + cls_num = len(self.clauses) + self.clauses.append(cls) + + for lit in cls: + self.occurrence_count[lit] += 1 + + self.sentinels[cls[0]].add(cls_num) + self.sentinels[cls[-1]].add(cls_num) + + self.heur_clause_added(cls) + + def _simple_compute_conflict(self): + """ Build a clause representing the fact that at least one decision made + so far is wrong. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._simple_compute_conflict() + [3] + + """ + return [-(level.decision) for level in self.levels[1:]] + + def _simple_clean_clauses(self): + """Clean up learned clauses.""" + pass + + +class Level: + """ + Represents a single level in the DPLL algorithm, and contains + enough information for a sound backtracking procedure. + """ + + def __init__(self, decision, flipped=False): + self.decision = decision + self.var_settings = set() + self.flipped = flipped diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/lra_theory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..1690760d36003aed6866f593120c05a5b8f92c83 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/lra_theory.py @@ -0,0 +1,912 @@ +"""Implements "A Fast Linear-Arithmetic Solver for DPLL(T)" + +The LRASolver class defined in this file can be used +in conjunction with a SAT solver to check the +satisfiability of formulas involving inequalities. + +Here's an example of how that would work: + + Suppose you want to check the satisfiability of + the following formula: + + >>> from sympy.core.relational import Eq + >>> from sympy.abc import x, y + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2)) + + First a preprocessing step should be done on f. During preprocessing, + f should be checked for any predicates such as `Q.prime` that can't be + handled. Also unequality like `~Eq(y, 1)` should be split. + + I should mention that the paper says to split both equalities and + unequality, but this implementation only requires that unequality + be split. + + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2)) + + Then an LRASolver instance needs to be initialized with this formula. + + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> cnf = CNF.from_prop(f) + >>> enc = EncodedCNF() + >>> enc.add_from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc) + + Any immediate one-lital conflicts clauses will be detected here. + In this example, `~Eq(1, 2)` is one such conflict clause. We'll + want to add it to `f` so that the SAT solver is forced to + assign Eq(1, 2) to False. + + >>> f = f & ~Eq(1, 2) + + Now that the one-literal conflict clauses have been added + and an lra object has been initialized, we can pass `f` + to a SAT solver. The SAT solver will give us a satisfying + assignment such as: + + (1 = 2): False + (y = 1): True + (y < 1): True + (y > 1): True + (x = 0): True + (x < 0): True + (x > 0): True + + Next you would pass this assignment to the LRASolver + which will be able to determine that this particular + assignment is satisfiable or not. + + Note that since EncodedCNF is inherently non-deterministic, + the int each predicate is encoded as is not consistent. As a + result, the code below likely does not reflect the assignment + given above. + + >>> lra.assert_lit(-1) #doctest: +SKIP + >>> lra.assert_lit(2) #doctest: +SKIP + >>> lra.assert_lit(3) #doctest: +SKIP + >>> lra.assert_lit(4) #doctest: +SKIP + >>> lra.assert_lit(5) #doctest: +SKIP + >>> lra.assert_lit(6) #doctest: +SKIP + >>> lra.assert_lit(7) #doctest: +SKIP + >>> is_sat, conflict_or_assignment = lra.check() + + As the particular assignment suggested is not satisfiable, + the LRASolver will return unsat and a conflict clause when + given that assignment. The conflict clause will always be + minimal, but there can be multiple minimal conflict clauses. + One possible conflict clause could be `~(x < 0) | ~(x > 0)`. + + We would then add whatever conflict clause is given to + `f` to prevent the SAT solver from coming up with an + assignment with the same conflicting literals. In this case, + the conflict clause `~(x < 0) | ~(x > 0)` would prevent + any assignment where both (x < 0) and (x > 0) were both + true. + + The SAT solver would then find another assignment + and we would check that assignment with the LRASolver + and so on. Eventually either a satisfying assignment + that the SAT solver and LRASolver agreed on would be found + or enough conflict clauses would be added so that the + boolean formula was unsatisfiable. + + +This implementation is based on [1]_, which includes a +detailed explanation of the algorithm and pseudocode +for the most important functions. + +[1]_ also explains how backtracking and theory propagation +could be implemented to speed up the current implementation, +but these are not currently implemented. + +TODO: + - Handle non-rational real numbers + - Handle positive and negative infinity + - Implement backtracking and theory proposition + - Simplify matrix by removing unused variables using Gaussian elimination + +References +========== + +.. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 +""" +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.matrices.dense import eye +from sympy.assumptions import Predicate +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.ask import Q +from sympy.core import Dummy +from sympy.core.mul import Mul +from sympy.core.add import Add +from sympy.core.relational import Eq, Ne +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.numbers import Rational, oo +from sympy.matrices.dense import Matrix + +class UnhandledInput(Exception): + """ + Raised while creating an LRASolver if non-linearity + or non-rational numbers are present. + """ + +# predicates that LRASolver understands and makes use of +ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge} + +# if true ~Q.gt(x, y) implies Q.le(x, y) +HANDLE_NEGATION = True + +class LRASolver(): + """ + Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on + the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling. + + References + ========== + + .. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 + """ + + def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode): + """ + Use the "from_encoded_cnf" method to create a new LRASolver. + """ + self.run_checks = testing_mode + self.s_subs = s_subs # used only for test_lra_theory.test_random_problems + + if any(not isinstance(a, Rational) for a in A): + raise UnhandledInput("Non-rational numbers are not handled") + if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()): + raise UnhandledInput("Non-rational numbers are not handled") + m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables) + if m != 0: + assert A.shape == (m, n) + if self.run_checks: + assert A[:, n-m:] == -eye(m) + + self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects + self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()} + self.A = A + self.slack = slack_variables + self.nonslack = nonslack_variables + self.all_var = nonslack_variables + slack_variables + + self.slack_set = set(slack_variables) + + self.is_sat = True # While True, all constraints asserted so far are satisfiable + self.result = None # always one of: (True, assignment), (False, conflict clause), None + + @staticmethod + def from_encoded_cnf(encoded_cnf, testing_mode=False): + """ + Creates an LRASolver from an EncodedCNF object + and a list of conflict clauses for propositions + that can be simplified to True or False. + + Parameters + ========== + + encoded_cnf : EncodedCNF + + testing_mode : bool + Setting testing_mode to True enables some slow assert statements + and sorting to reduce nonterministic behavior. + + Returns + ======= + + (lra, conflicts) + + lra : LRASolver + + conflicts : list + Contains a one-literal conflict clause for each proposition + that can be simplified to True or False. + + Example + ======= + + >>> from sympy.core.relational import Eq + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy.abc import x, y, z + >>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) + >>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + >>> phi = phi & Q.gt(2, 1) + >>> cnf = CNF.from_prop(phi) + >>> enc = EncodedCNF() + >>> enc.from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + >>> lra #doctest: +SKIP + + >>> conflicts #doctest: +SKIP + [[4]] + """ + # This function has three main jobs: + # - raise errors if the input formula is not handled + # - preprocesses the formula into a matrix and single variable constraints + # - create one-literal conflict clauses from predicates that are always True + # or always False such as Q.gt(3, 2) + # + # See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)" + # for an explanation of how the formula is converted into a matrix + # and a set of single variable constraints. + + encoding = {} # maps int to boundary + A = [] + + basic = [] + s_count = 0 + s_subs = {} + nonbasic = [] + + if testing_mode: + # sort to reduce nondeterminism + encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x)) + else: + encoded_cnf_items = encoded_cnf.encoding.items() + + empty_var = Dummy() + var_to_lra_var = {} + conflicts = [] + + for prop, enc in encoded_cnf_items: + if isinstance(prop, Predicate): + prop = prop(empty_var) + if not isinstance(prop, AppliedPredicate): + if prop == True: + conflicts.append([enc]) + continue + if prop == False: + conflicts.append([-enc]) + continue + + raise ValueError(f"Unhandled Predicate: {prop}") + + assert prop.function in ALLOWED_PRED + if prop.lhs == S.NaN or prop.rhs == S.NaN: + raise ValueError(f"{prop} contains nan") + if prop.lhs.is_imaginary or prop.rhs.is_imaginary: + raise UnhandledInput(f"{prop} contains an imaginary component") + if prop.lhs == oo or prop.rhs == oo: + raise UnhandledInput(f"{prop} contains infinity") + + prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible + if prop == True: + conflicts.append([enc]) + continue + elif prop == False: + conflicts.append([-enc]) + continue + elif prop is None: + raise UnhandledInput(f"{prop} could not be simplified") + + expr = prop.lhs - prop.rhs + if prop.function in [Q.ge, Q.gt]: + expr = -expr + + # expr should be less than (or equal to) 0 + # otherwise prop is False + if prop.function in [Q.le, Q.ge]: + bool = (expr <= 0) + elif prop.function in [Q.lt, Q.gt]: + bool = (expr < 0) + else: + assert prop.function == Q.eq + bool = Eq(expr, 0) + + if bool == True: + conflicts.append([enc]) + continue + elif bool == False: + conflicts.append([-enc]) + continue + + + vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2) + vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1) + const = const / var_coeff + + terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y] + for term in terms: + term, _ = _sep_const_coeff(term) + assert len(term.free_symbols) > 0 + if term not in var_to_lra_var: + var_to_lra_var[term] = LRAVariable(term) + nonbasic.append(term) + + if len(terms) > 1: + if vars not in s_subs: + s_count += 1 + d = Dummy(f"s{s_count}") + var_to_lra_var[d] = LRAVariable(d) + basic.append(d) + s_subs[vars] = d + A.append(vars - d) + var = s_subs[vars] + else: + var = terms[0] + + assert var_coeff != 0 + + equality = prop.function == Q.eq + upper = var_coeff > 0 if not equality else None + strict = prop.function in [Q.gt, Q.lt] + b = Boundary(var_to_lra_var[var], -const, upper, equality, strict) + encoding[enc] = b + + fs = [v.free_symbols for v in nonbasic + basic] + assert all(len(syms) > 0 for syms in fs) + fs_count = sum(len(syms) for syms in fs) + if len(fs) > 0 and len(set.union(*fs)) < fs_count: + raise UnhandledInput("Nonlinearity is not handled") + + A, _ = linear_eq_to_matrix(A, nonbasic + basic) + nonbasic = [var_to_lra_var[nb] for nb in nonbasic] + basic = [var_to_lra_var[b] for b in basic] + for idx, var in enumerate(nonbasic + basic): + var.col_idx = idx + + return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts + + def reset_bounds(self): + """ + Resets the state of the LRASolver to before + anything was asserted. + """ + self.result = None + for var in self.all_var: + var.lower = LRARational(-float("inf"), 0) + var.lower_from_eq = False + var.lower_from_neg = False + var.upper = LRARational(float("inf"), 0) + var.upper_from_eq= False + var.lower_from_neg = False + var.assign = LRARational(0, 0) + + def assert_lit(self, enc_constraint): + """ + Assert a literal representing a constraint + and update the internal state accordingly. + + Note that due to peculiarities of this implementation + asserting ~(x > 0) will assert (x <= 0) but asserting + ~Eq(x, 0) will not do anything. + + Parameters + ========== + + enc_constraint : int + A mapping of encodings to constraints + can be found in `self.enc_to_boundary`. + + Returns + ======= + + None or (False, explanation) + + explanation : set of ints + A conflict clause that "explains" why + the literals asserted so far are unsatisfiable. + """ + if abs(enc_constraint) not in self.enc_to_boundary: + return None + + if not HANDLE_NEGATION and enc_constraint < 0: + return None + + boundary = self.enc_to_boundary[abs(enc_constraint)] + sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0 + + if boundary.equality and negated: + return None # negated equality is not handled and should only appear in conflict clauses + + upper = boundary.upper != negated + if boundary.strict != negated: + delta = -1 if upper else 1 + c = LRARational(c, delta) + else: + c = LRARational(c, 0) + + if boundary.equality: + res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated) + if res1 and res1[0] == False: + res = res1 + else: + res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated) + res = res2 + elif upper: + res = self._assert_upper(sym, c, from_neg=negated) + else: + res = self._assert_lower(sym, c, from_neg=negated) + + if self.is_sat and sym not in self.slack_set: + self.is_sat = res is None + else: + self.is_sat = False + + return res + + def _assert_upper(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the upper bound on variable xi if the new upper bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci >= xi.upper: + return None + if ci < xi.lower: + assert (xi.lower[1] >= 0) is True + assert (ci[1] <= 0) is True + + lit1, neg1 = Boundary.from_lower(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.upper = ci + xi.upper_from_eq = from_equality + xi.upper_from_neg = from_neg + if xi in self.nonslack and xi.assign > ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _assert_lower(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the lower bound on variable xi if the new lower bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci <= xi.lower: + return None + if ci > xi.upper: + assert (xi.upper[1] <= 0) is True + assert (ci[1] >= 0) is True + + lit1, neg1 = Boundary.from_upper(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.lower = ci + xi.lower_from_eq = from_equality + xi.lower_from_neg = from_neg + if xi in self.nonslack and xi.assign < ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _update(self, xi, v): + """ + Updates all slack variables that have equations that contain + variable xi so that they stay satisfied given xi is equal to v. + """ + i = xi.col_idx + for j, b in enumerate(self.slack): + aji = self.A[j, i] + b.assign = b.assign + (v - xi.assign)*aji + xi.assign = v + + def check(self): + """ + Searches for an assignment that satisfies all constraints + or determines that no such assignment exists and gives + a minimal conflict clause that "explains" why the + constraints are unsatisfiable. + + Returns + ======= + + (True, assignment) or (False, explanation) + + assignment : dict of LRAVariables to values + Assigned values are tuples that represent a rational number + plus some infinatesimal delta. + + explanation : set of ints + """ + if self.is_sat: + return True, {var: var.assign for var in self.all_var} + if self.result: + return self.result + + from sympy.matrices.dense import Matrix + M = self.A.copy() + basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable + nonbasic = set(self.nonslack) + while True: + if self.run_checks: + # nonbasic variables must always be within bounds + assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic) + + # assignments for x must always satisfy Ax = 0 + # probably have to turn this off when dealing with strict ineq + if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10**(-10) for val in M*X) + + # check upper and lower match this format: + # x <= rat + delta iff x < rat + # x >= rat - delta iff x > rat + # this wouldn't make sense: + # x <= rat - delta + # x >= rat + delta + assert all(x.upper[1] <= 0 for x in self.all_var) + assert all(x.lower[1] >= 0 for x in self.all_var) + + cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper] + + if len(cand) == 0: + return True, {var: var.assign for var in self.all_var} + + xi = min(cand, key=lambda v: v.col_idx) # Bland's rule + i = basic[xi] + + if xi.assign < xi.lower: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)] + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_plus] + conflict += [Boundary.from_lower(nb) for nb in N_minus] + conflict.append(Boundary.from_lower(xi)) + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=str) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower) + + if xi.assign > xi.upper: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)] + + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_minus] + conflict += [Boundary.from_lower(nb) for nb in N_plus] + conflict.append(Boundary.from_upper(xi)) + + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=lambda v: v.col_idx) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper) + + def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v): + """ + Pivots basic variable xi with nonbasic variable xj, + and sets value of xi to v and adjusts the values of all basic variables + to keep equations satisfied. + """ + i, j = basic[xi], xj.col_idx + assert M[i, j] != 0 + theta = (v - xi.assign)*(1/M[i, j]) + xi.assign = v + xj.assign = xj.assign + theta + for xk in basic: + if xk != xi: + k = basic[xk] + akj = M[k, j] + xk.assign = xk.assign + theta*akj + # pivot + basic[xj] = basic[xi] + del basic[xi] + nonbasic.add(xi) + nonbasic.remove(xj) + return self._pivot(M, i, j) + + @staticmethod + def _pivot(M, i, j): + """ + Performs a pivot operation about entry i, j of M by performing + a series of row operations on a copy of M and returning the result. + The original M is left unmodified. + + Conceptually, M represents a system of equations and pivoting + can be thought of as rearranging equation i to be in terms of + variable j and then substituting in the rest of the equations + to get rid of other occurances of variable j. + + Example + ======= + + >>> from sympy.matrices.dense import Matrix + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy import var + >>> Matrix(3, 3, var('a:i')) + Matrix([ + [a, b, c], + [d, e, f], + [g, h, i]]) + + This matrix is equivalent to: + 0 = a*x + b*y + c*z + 0 = d*x + e*y + f*z + 0 = g*x + h*y + i*z + + >>> LRASolver._pivot(_, 1, 0) + Matrix([ + [ 0, -a*e/d + b, -a*f/d + c], + [-1, -e/d, -f/d], + [ 0, h - e*g/d, i - f*g/d]]) + + We rearrange equation 1 in terms of variable 0 (x) + and substitute to remove x from the other equations. + + 0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z + 0 = -x + (-e/d)*y + (-f/d)*z + 0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z + """ + _, _, Mij = M[i, :], M[:, j], M[i, j] + if Mij == 0: + raise ZeroDivisionError("Tried to pivot about zero-valued entry.") + A = M.copy() + A[i, :] = -A[i, :]/Mij + for row in range(M.shape[0]): + if row != i: + A[row, :] = A[row, :] + A[row, j] * A[i, :] + + return A + + +def _sep_const_coeff(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff + >>> from sympy.abc import x, y + >>> _sep_const_coeff(2*x) + (x, 2) + >>> _sep_const_coeff(2*x + 3*y) + (2*x + 3*y, 1) + """ + if isinstance(expr, Add): + return expr, sympify(1) + + if isinstance(expr, Mul): + coeffs = expr.args + else: + coeffs = [expr] + + var, const = [], [] + for c in coeffs: + c = sympify(c) + if len(c.free_symbols)==0: + const.append(c) + else: + var.append(c) + return Mul(*var), Mul(*const) + + +def _list_terms(expr): + if not isinstance(expr, Add): + return [expr] + + return expr.args + + +def _sep_const_terms(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_terms + >>> from sympy.abc import x, y + >>> _sep_const_terms(2*x + 3*y + 2) + (2*x + 3*y, 2) + """ + if isinstance(expr, Add): + terms = expr.args + else: + terms = [expr] + + var, const = [], [] + for t in terms: + if len(t.free_symbols) == 0: + const.append(t) + else: + var.append(t) + return sum(var), sum(const) + + +def _eval_binrel(binrel): + """ + Simplify binary relation to True / False if possible. + """ + if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0): + return binrel + if binrel.function == Q.lt: + res = binrel.lhs < binrel.rhs + elif binrel.function == Q.gt: + res = binrel.lhs > binrel.rhs + elif binrel.function == Q.le: + res = binrel.lhs <= binrel.rhs + elif binrel.function == Q.ge: + res = binrel.lhs >= binrel.rhs + elif binrel.function == Q.eq: + res = Eq(binrel.lhs, binrel.rhs) + elif binrel.function == Q.ne: + res = Ne(binrel.lhs, binrel.rhs) + + if res == True or res == False: + return res + else: + return None + + +class Boundary: + """ + Represents an upper or lower bound or an equality between a symbol + and some constant. + """ + def __init__(self, var, const, upper, equality, strict=None): + if not equality in [True, False]: + assert equality in [True, False] + + + self.var = var + if isinstance(const, tuple): + s = const[1] != 0 + if strict: + assert s == strict + self.bound = const[0] + self.strict = s + else: + self.bound = const + self.strict = strict + self.upper = upper if not equality else None + self.equality = equality + self.strict = strict + assert self.strict is not None + + @staticmethod + def from_upper(var): + neg = -1 if var.upper_from_neg else 1 + b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + @staticmethod + def from_lower(var): + neg = -1 if var.lower_from_neg else 1 + b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + def get_negated(self): + return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict) + + def get_inequality(self): + if self.equality: + return Eq(self.var.var, self.bound) + elif self.upper and self.strict: + return self.var.var < self.bound + elif not self.upper and self.strict: + return self.var.var > self.bound + elif self.upper: + return self.var.var <= self.bound + else: + return self.var.var >= self.bound + + def __repr__(self): + return repr("Boundary(" + repr(self.get_inequality()) + ")") + + def __eq__(self, other): + other = (other.var, other.bound, other.strict, other.upper, other.equality) + return (self.var, self.bound, self.strict, self.upper, self.equality) == other + + def __hash__(self): + return hash((self.var, self.bound, self.strict, self.upper, self.equality)) + + +class LRARational(): + """ + Represents a rational plus or minus some amount + of arbitrary small deltas. + """ + def __init__(self, rational, delta): + self.value = (rational, delta) + + def __lt__(self, other): + return self.value < other.value + + def __le__(self, other): + return self.value <= other.value + + def __eq__(self, other): + return self.value == other.value + + def __add__(self, other): + return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1]) + + def __sub__(self, other): + return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1]) + + def __mul__(self, other): + assert not isinstance(other, LRARational) + return LRARational(self.value[0] * other, self.value[1] * other) + + def __getitem__(self, index): + return self.value[index] + + def __repr__(self): + return repr(self.value) + + +class LRAVariable(): + """ + Object to keep track of upper and lower bounds + on `self.var`. + """ + def __init__(self, var): + self.upper = LRARational(float("inf"), 0) + self.upper_from_eq = False + self.upper_from_neg = False + self.lower = LRARational(-float("inf"), 0) + self.lower_from_eq = False + self.lower_from_neg = False + self.assign = LRARational(0,0) + self.var = var + self.col_idx = None + + def __repr__(self): + return repr(self.var) + + def __eq__(self, other): + if not isinstance(other, LRAVariable): + return False + return other.var == self.var + + def __hash__(self): + return hash(self.var) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/minisat22_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/minisat22_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c1f8f14f04309f7cb8197cc05d01a3c108545 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/minisat22_wrapper.py @@ -0,0 +1,46 @@ +from sympy.assumptions.cnf import EncodedCNF + +def minisat22_satisfiable(expr, all_models=False, minimal=False): + + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + from pysat.solvers import Minisat22 + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + r = Minisat22(expr.data) + + if minimal: + r.set_phases([-(i+1) for i in range(r.nof_vars())]) + + if not r.solve(): + return False + + if not all_models: + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r.get_model()} + + else: + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + while results.solve(): + sol = results.get_model() + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + if minimal: + results.add_clause([-i for i in sol if i>0]) + else: + results.add_clause([-i for i in sol]) + satisfiable = True + if not satisfiable: + yield False + raise StopIteration + + + return _gen(r) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/pycosat_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/pycosat_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff498b7e3f6b73d95e9b949598ef32df4ecf226 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/pycosat_wrapper.py @@ -0,0 +1,41 @@ +from sympy.assumptions.cnf import EncodedCNF + + +def pycosat_satisfiable(expr, all_models=False): + import pycosat + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if not all_models: + r = pycosat.solve(expr.data) + result = (r != "UNSAT") + if not result: + return result + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r} + else: + r = pycosat.itersolve(expr.data) + result = (r != "UNSAT") + if not result: + return result + + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + try: + while True: + sol = next(results) + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + return _gen(r) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/z3_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/z3_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fe44f713a2edfd5286c0f81b737212146766b11b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/algorithms/z3_wrapper.py @@ -0,0 +1,115 @@ +from sympy.printing.smtlib import smtlib_code +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.cnf import EncodedCNF +from sympy.assumptions.ask import Q + +from sympy.core import Add, Mul +from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import Pow +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import And, Or, Xor, Implies +from sympy.logic.boolalg import Not, ITE +from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate +from sympy.external import import_module + +def z3_satisfiable(expr, all_models=False): + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + z3 = import_module("z3") + if z3 is None: + raise ImportError("z3 is not installed") + + s = encoded_cnf_to_z3_solver(expr, z3) + + res = str(s.check()) + if res == "unsat": + return False + elif res == "sat": + return z3_model_to_sympy_model(s.model(), expr) + else: + return None + + +def z3_model_to_sympy_model(z3_model, enc_cnf): + rev_enc = {value : key for key, value in enc_cnf.encoding.items()} + return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model} + + +def clause_to_assertion(clause): + clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause] + return "(assert (or " + " ".join(clause_strings) + "))" + + +def encoded_cnf_to_z3_solver(enc_cnf, z3): + def dummify_bool(pred): + return False + assert isinstance(pred, AppliedPredicate) + + if pred.function in [Q.positive, Q.negative, Q.zero]: + return pred + else: + return False + + s = z3.Solver() + + declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables] + assertions = [clause_to_assertion(clause) for clause in enc_cnf.data] + + symbols = set() + for pred, enc in enc_cnf.encoding.items(): + if not isinstance(pred, AppliedPredicate): + continue + if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive): + continue + + pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions) + + symbols |= pred.free_symbols + pred = pred_str + clause = f"(implies d{enc} {pred})" + assertion = "(assert " + clause + ")" + assertions.append(assertion) + + for sym in symbols: + declarations.append(f"(declare-const {sym} Real)") + + declarations = "\n".join(declarations) + assertions = "\n".join(assertions) + s.from_string(declarations) + s.from_string(assertions) + + return s + + +known_functions = { + Add: '+', + Mul: '*', + + Equality: '=', + LessThan: '<=', + GreaterThan: '>=', + StrictLessThan: '<', + StrictGreaterThan: '>', + + EqualityPredicate(): '=', + LessThanPredicate(): '<=', + GreaterThanPredicate(): '>=', + StrictLessThanPredicate(): '<', + StrictGreaterThanPredicate(): '>', + + Abs: 'abs', + Min: 'min', + Max: 'max', + Pow: '^', + + And: 'and', + Or: 'or', + Xor: 'xor', + Not: 'not', + ITE: 'ite', + Implies: '=>', + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/boolalg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/boolalg.py new file mode 100644 index 0000000000000000000000000000000000000000..8e11a9b6361ac5d7e355d5d4fb176d8df443e07e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/boolalg.py @@ -0,0 +1,3587 @@ +""" +Boolean algebra module for SymPy +""" + +from __future__ import annotations +from typing import TYPE_CHECKING, overload, Any +from collections.abc import Iterable, Mapping + +from collections import defaultdict +from itertools import chain, combinations, product, permutations +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.decorators import sympify_method_args, sympify_return +from sympy.core.function import Application, Derivative +from sympy.core.kind import BooleanKind, NumberKind +from sympy.core.numbers import Number +from sympy.core.operations import LatticeOp +from sympy.core.singleton import Singleton, S +from sympy.core.sorting import ordered +from sympy.core.sympify import _sympy_converter, _sympify, sympify +from sympy.utilities.iterables import sift, ibin +from sympy.utilities.misc import filldedent + + +def as_Boolean(e): + """Like ``bool``, return the Boolean value of an expression, e, + which can be any instance of :py:class:`~.Boolean` or ``bool``. + + Examples + ======== + + >>> from sympy import true, false, nan + >>> from sympy.logic.boolalg import as_Boolean + >>> from sympy.abc import x + >>> as_Boolean(0) is false + True + >>> as_Boolean(1) is true + True + >>> as_Boolean(x) + x + >>> as_Boolean(2) + Traceback (most recent call last): + ... + TypeError: expecting bool or Boolean, not `2`. + >>> as_Boolean(nan) + Traceback (most recent call last): + ... + TypeError: expecting bool or Boolean, not `nan`. + + """ + from sympy.core.symbol import Symbol + if e == True: + return true + if e == False: + return false + if isinstance(e, Symbol): + z = e.is_zero + if z is None: + return e + return false if z else true + if isinstance(e, Boolean): + return e + raise TypeError('expecting bool or Boolean, not `%s`.' % e) + + +@sympify_method_args +class Boolean(Basic): + """A Boolean object is an object for which logic operations make sense.""" + + __slots__ = () + + kind = BooleanKind + + if TYPE_CHECKING: + + def __new__(cls, *args: Basic | complex) -> Boolean: + ... + + @overload # type: ignore + def subs(self, arg1: Mapping[Basic | complex, Boolean | complex], arg2: None=None) -> Boolean: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Boolean | complex]], arg2: None=None, **kwargs: Any) -> Boolean: ... + @overload + def subs(self, arg1: Boolean | complex, arg2: Boolean | complex) -> Boolean: ... + @overload + def subs(self, arg1: Mapping[Basic | complex, Basic | complex], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Basic | complex]], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Basic | complex, arg2: Basic | complex, **kwargs: Any) -> Basic: ... + + def subs(self, arg1: Mapping[Basic | complex, Basic | complex] | Basic | complex, # type: ignore + arg2: Basic | complex | None = None, **kwargs: Any) -> Basic: + ... + + def simplify(self, **kwargs) -> Boolean: + ... + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __and__(self, other): + return And(self, other) + + __rand__ = __and__ + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __or__(self, other): + return Or(self, other) + + __ror__ = __or__ + + def __invert__(self): + """Overloading for ~""" + return Not(self) + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __rshift__(self, other): + return Implies(self, other) + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __lshift__(self, other): + return Implies(other, self) + + __rrshift__ = __lshift__ + __rlshift__ = __rshift__ + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __xor__(self, other): + return Xor(self, other) + + __rxor__ = __xor__ + + def equals(self, other): + """ + Returns ``True`` if the given formulas have the same truth table. + For two formulas to be equal they must have the same literals. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy import And, Or, Not + >>> (A >> B).equals(~B >> ~A) + True + >>> Not(And(A, B, C)).equals(And(Not(A), Not(B), Not(C))) + False + >>> Not(And(A, Not(A))).equals(Or(B, Not(B))) + False + + """ + from sympy.logic.inference import satisfiable + from sympy.core.relational import Relational + + if self.has(Relational) or other.has(Relational): + raise NotImplementedError('handling of relationals') + return self.atoms() == other.atoms() and \ + not satisfiable(Not(Equivalent(self, other))) + + def to_nnf(self, simplify=True): + # override where necessary + return self + + def as_set(self): + """ + Rewrites Boolean expression in terms of real sets. + + Examples + ======== + + >>> from sympy import Symbol, Eq, Or, And + >>> x = Symbol('x', real=True) + >>> Eq(x, 0).as_set() + {0} + >>> (x > 0).as_set() + Interval.open(0, oo) + >>> And(-2 < x, x < 2).as_set() + Interval.open(-2, 2) + >>> Or(x < -2, 2 < x).as_set() + Union(Interval.open(-oo, -2), Interval.open(2, oo)) + + """ + from sympy.calculus.util import periodicity + from sympy.core.relational import Relational + + free = self.free_symbols + if len(free) == 1: + x = free.pop() + if x.kind is NumberKind: + reps = {} + for r in self.atoms(Relational): + if periodicity(r, x) not in (0, None): + s = r._eval_as_set() + if s in (S.EmptySet, S.UniversalSet, S.Reals): + reps[r] = s.as_relational(x) + continue + raise NotImplementedError(filldedent(''' + as_set is not implemented for relationals + with periodic solutions + ''')) + new = self.subs(reps) + if new.func != self.func: + return new.as_set() # restart with new obj + else: + return new._eval_as_set() + + return self._eval_as_set() + else: + raise NotImplementedError("Sorry, as_set has not yet been" + " implemented for multivariate" + " expressions") + + @property + def binary_symbols(self): + from sympy.core.relational import Eq, Ne + return set().union(*[i.binary_symbols for i in self.args + if i.is_Boolean or i.is_Symbol + or isinstance(i, (Eq, Ne))]) + + def _eval_refine(self, assumptions): + from sympy.assumptions import ask + ret = ask(self, assumptions) + if ret is True: + return true + elif ret is False: + return false + return None + + +class BooleanAtom(Boolean): + """ + Base class of :py:class:`~.BooleanTrue` and :py:class:`~.BooleanFalse`. + """ + is_Boolean = True + is_Atom = True + _op_priority = 11 # higher than Expr + + def simplify(self, *a, **kw): + return self + + def expand(self, *a, **kw): + return self + + @property + def canonical(self): + return self + + def _noop(self, other=None): + raise TypeError('BooleanAtom not allowed in this context.') + + __add__ = _noop + __radd__ = _noop + __sub__ = _noop + __rsub__ = _noop + __mul__ = _noop + __rmul__ = _noop + __pow__ = _noop + __rpow__ = _noop + __truediv__ = _noop + __rtruediv__ = _noop + __mod__ = _noop + __rmod__ = _noop + _eval_power = _noop + + def __lt__(self, other): + raise TypeError(filldedent(''' + A Boolean argument can only be used in + Eq and Ne; all other relationals expect + real expressions. + ''')) + + __le__ = __lt__ + __gt__ = __lt__ + __ge__ = __lt__ + # \\\ + + def _eval_simplify(self, **kwargs): + return self + + +class BooleanTrue(BooleanAtom, metaclass=Singleton): + """ + SymPy version of ``True``, a singleton that can be accessed via ``S.true``. + + This is the SymPy version of ``True``, for use in the logic module. The + primary advantage of using ``true`` instead of ``True`` is that shorthand Boolean + operations like ``~`` and ``>>`` will work as expected on this class, whereas with + True they act bitwise on 1. Functions in the logic module will return this + class when they evaluate to true. + + Notes + ===== + + There is liable to be some confusion as to when ``True`` should + be used and when ``S.true`` should be used in various contexts + throughout SymPy. An important thing to remember is that + ``sympify(True)`` returns ``S.true``. This means that for the most + part, you can just use ``True`` and it will automatically be converted + to ``S.true`` when necessary, similar to how you can generally use 1 + instead of ``S.One``. + + The rule of thumb is: + + "If the boolean in question can be replaced by an arbitrary symbolic + ``Boolean``, like ``Or(x, y)`` or ``x > 1``, use ``S.true``. + Otherwise, use ``True``" + + In other words, use ``S.true`` only on those contexts where the + boolean is being used as a symbolic representation of truth. + For example, if the object ends up in the ``.args`` of any expression, + then it must necessarily be ``S.true`` instead of ``True``, as + elements of ``.args`` must be ``Basic``. On the other hand, + ``==`` is not a symbolic operation in SymPy, since it always returns + ``True`` or ``False``, and does so in terms of structural equality + rather than mathematical, so it should return ``True``. The assumptions + system should use ``True`` and ``False``. Aside from not satisfying + the above rule of thumb, the assumptions system uses a three-valued logic + (``True``, ``False``, ``None``), whereas ``S.true`` and ``S.false`` + represent a two-valued logic. When in doubt, use ``True``. + + "``S.true == True is True``." + + While "``S.true is True``" is ``False``, "``S.true == True``" + is ``True``, so if there is any doubt over whether a function or + expression will return ``S.true`` or ``True``, just use ``==`` + instead of ``is`` to do the comparison, and it will work in either + case. Finally, for boolean flags, it's better to just use ``if x`` + instead of ``if x is True``. To quote PEP 8: + + Do not compare boolean values to ``True`` or ``False`` + using ``==``. + + * Yes: ``if greeting:`` + * No: ``if greeting == True:`` + * Worse: ``if greeting is True:`` + + Examples + ======== + + >>> from sympy import sympify, true, false, Or + >>> sympify(True) + True + >>> _ is True, _ is true + (False, True) + + >>> Or(true, false) + True + >>> _ is true + True + + Python operators give a boolean result for true but a + bitwise result for True + + >>> ~true, ~True # doctest: +SKIP + (False, -2) + >>> true >> true, True >> True + (True, 0) + + See Also + ======== + + sympy.logic.boolalg.BooleanFalse + + """ + def __bool__(self): + return True + + def __hash__(self): + return hash(True) + + def __eq__(self, other): + if other is True: + return True + if other is False: + return False + return super().__eq__(other) + + @property + def negated(self): + return false + + def as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import true + >>> true.as_set() + UniversalSet + + """ + return S.UniversalSet + + +class BooleanFalse(BooleanAtom, metaclass=Singleton): + """ + SymPy version of ``False``, a singleton that can be accessed via ``S.false``. + + This is the SymPy version of ``False``, for use in the logic module. The + primary advantage of using ``false`` instead of ``False`` is that shorthand + Boolean operations like ``~`` and ``>>`` will work as expected on this class, + whereas with ``False`` they act bitwise on 0. Functions in the logic module + will return this class when they evaluate to false. + + Notes + ====== + + See the notes section in :py:class:`sympy.logic.boolalg.BooleanTrue` + + Examples + ======== + + >>> from sympy import sympify, true, false, Or + >>> sympify(False) + False + >>> _ is False, _ is false + (False, True) + + >>> Or(true, false) + True + >>> _ is true + True + + Python operators give a boolean result for false but a + bitwise result for False + + >>> ~false, ~False # doctest: +SKIP + (True, -1) + >>> false >> false, False >> False + (True, 0) + + See Also + ======== + + sympy.logic.boolalg.BooleanTrue + + """ + def __bool__(self): + return False + + def __hash__(self): + return hash(False) + + def __eq__(self, other): + if other is True: + return False + if other is False: + return True + return super().__eq__(other) + + @property + def negated(self): + return true + + def as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import false + >>> false.as_set() + EmptySet + """ + return S.EmptySet + + +true = BooleanTrue() +false = BooleanFalse() +# We want S.true and S.false to work, rather than S.BooleanTrue and +# S.BooleanFalse, but making the class and instance names the same causes some +# major issues (like the inability to import the class directly from this +# file). +S.true = true +S.false = false + +_sympy_converter[bool] = lambda x: true if x else false + + +class BooleanFunction(Application, Boolean): + """Boolean function is a function that lives in a boolean space + It is used as base class for :py:class:`~.And`, :py:class:`~.Or`, + :py:class:`~.Not`, etc. + """ + is_Boolean = True + + def _eval_simplify(self, **kwargs): + rv = simplify_univariate(self) + if not isinstance(rv, BooleanFunction): + return rv.simplify(**kwargs) + rv = rv.func(*[a.simplify(**kwargs) for a in rv.args]) + return simplify_logic(rv) + + def simplify(self, **kwargs): + from sympy.simplify.simplify import simplify + return simplify(self, **kwargs) + + def __lt__(self, other): + raise TypeError(filldedent(''' + A Boolean argument can only be used in + Eq and Ne; all other relationals expect + real expressions. + ''')) + __le__ = __lt__ + __ge__ = __lt__ + __gt__ = __lt__ + + @classmethod + def binary_check_and_simplify(self, *args): + return [as_Boolean(i) for i in args] + + def to_nnf(self, simplify=True): + return self._to_nnf(*self.args, simplify=simplify) + + def to_anf(self, deep=True): + return self._to_anf(*self.args, deep=deep) + + @classmethod + def _to_nnf(cls, *args, **kwargs): + simplify = kwargs.get('simplify', True) + argset = set() + for arg in args: + if not is_literal(arg): + arg = arg.to_nnf(simplify) + if simplify: + if isinstance(arg, cls): + arg = arg.args + else: + arg = (arg,) + for a in arg: + if Not(a) in argset: + return cls.zero + argset.add(a) + else: + argset.add(arg) + return cls(*argset) + + @classmethod + def _to_anf(cls, *args, **kwargs): + deep = kwargs.get('deep', True) + new_args = [] + for arg in args: + if deep: + if not is_literal(arg) or isinstance(arg, Not): + arg = arg.to_anf(deep=deep) + new_args.append(arg) + return cls(*new_args, remove_true=False) + + # the diff method below is copied from Expr class + def diff(self, *symbols, **assumptions): + assumptions.setdefault("evaluate", True) + return Derivative(self, *symbols, **assumptions) + + def _eval_derivative(self, x): + if x in self.binary_symbols: + from sympy.core.relational import Eq + from sympy.functions.elementary.piecewise import Piecewise + return Piecewise( + (0, Eq(self.subs(x, 0), self.subs(x, 1))), + (1, True)) + elif x in self.free_symbols: + # not implemented, see https://www.encyclopediaofmath.org/ + # index.php/Boolean_differential_calculus + pass + else: + return S.Zero + + +class And(LatticeOp, BooleanFunction): + """ + Logical AND function. + + It evaluates its arguments in order, returning false immediately + when an argument is false and true if they are all true. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import And + >>> x & y + x & y + + Notes + ===== + + The ``&`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + and. Hence, ``And(a, b)`` and ``a & b`` will produce different results if + ``a`` and ``b`` are integers. + + >>> And(x, y).subs(x, 1) + y + + """ + zero = false + identity = true + + nargs = None + + if TYPE_CHECKING: + + def __new__(cls, *args: Boolean | bool) -> Boolean: # type: ignore + ... + + @property + def args(self) -> tuple[Boolean, ...]: + ... + + @classmethod + def _new_args_filter(cls, args): + args = BooleanFunction.binary_check_and_simplify(*args) + args = LatticeOp._new_args_filter(args, And) + newargs = [] + rel = set() + for x in ordered(args): + if x.is_Relational: + c = x.canonical + if c in rel: + continue + elif c.negated.canonical in rel: + return [false] + else: + rel.add(c) + newargs.append(x) + return newargs + + def _eval_subs(self, old, new): + args = [] + bad = None + for i in self.args: + try: + i = i.subs(old, new) + except TypeError: + # store TypeError + if bad is None: + bad = i + continue + if i == False: + return false + elif i != True: + args.append(i) + if bad is not None: + # let it raise + bad.subs(old, new) + # If old is And, replace the parts of the arguments with new if all + # are there + if isinstance(old, And): + old_set = set(old.args) + if old_set.issubset(args): + args = set(args) - old_set + args.add(new) + + return self.func(*args) + + def _eval_simplify(self, **kwargs): + from sympy.core.relational import Equality, Relational + from sympy.solvers.solveset import linear_coeffs + # standard simplify + rv = super()._eval_simplify(**kwargs) + if not isinstance(rv, And): + return rv + + # simplify args that are equalities involving + # symbols so x == 0 & x == y -> x==0 & y == 0 + Rel, nonRel = sift(rv.args, lambda i: isinstance(i, Relational), + binary=True) + if not Rel: + return rv + eqs, other = sift(Rel, lambda i: isinstance(i, Equality), binary=True) + + measure = kwargs['measure'] + if eqs: + ratio = kwargs['ratio'] + reps = {} + sifted = {} + # group by length of free symbols + sifted = sift(ordered([ + (i.free_symbols, i) for i in eqs]), + lambda x: len(x[0])) + eqs = [] + nonlineqs = [] + while 1 in sifted: + for free, e in sifted.pop(1): + x = free.pop() + if (e.lhs != x or x in e.rhs.free_symbols) and x not in reps: + try: + m, b = linear_coeffs( + Add(e.lhs, -e.rhs, evaluate=False), x) + enew = e.func(x, -b/m) + if measure(enew) <= ratio*measure(e): + e = enew + else: + eqs.append(e) + continue + except ValueError: + pass + if x in reps: + eqs.append(e.subs(x, reps[x])) + elif e.lhs == x and x not in e.rhs.free_symbols: + reps[x] = e.rhs + eqs.append(e) + else: + # x is not yet identified, but may be later + nonlineqs.append(e) + resifted = defaultdict(list) + for k in sifted: + for f, e in sifted[k]: + e = e.xreplace(reps) + f = e.free_symbols + resifted[len(f)].append((f, e)) + sifted = resifted + for k in sifted: + eqs.extend([e for f, e in sifted[k]]) + nonlineqs = [ei.subs(reps) for ei in nonlineqs] + other = [ei.subs(reps) for ei in other] + rv = rv.func(*([i.canonical for i in (eqs + nonlineqs + other)] + nonRel)) + patterns = _simplify_patterns_and() + threeterm_patterns = _simplify_patterns_and3() + return _apply_patternbased_simplification(rv, patterns, + measure, false, + threeterm_patterns=threeterm_patterns) + + def _eval_as_set(self): + from sympy.sets.sets import Intersection + return Intersection(*[arg.as_set() for arg in self.args]) + + def _eval_rewrite_as_Nor(self, *args, **kwargs): + return Nor(*[Not(arg) for arg in self.args]) + + def to_anf(self, deep=True): + if deep: + result = And._to_anf(*self.args, deep=deep) + return distribute_xor_over_and(result) + return self + + +class Or(LatticeOp, BooleanFunction): + """ + Logical OR function + + It evaluates its arguments in order, returning true immediately + when an argument is true, and false if they are all false. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Or + >>> x | y + x | y + + Notes + ===== + + The ``|`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + or. Hence, ``Or(a, b)`` and ``a | b`` will return different things if + ``a`` and ``b`` are integers. + + >>> Or(x, y).subs(x, 0) + y + + """ + zero = true + identity = false + + if TYPE_CHECKING: + + def __new__(cls, *args: Boolean | bool) -> Boolean: # type: ignore + ... + + @property + def args(self) -> tuple[Boolean, ...]: + ... + + @classmethod + def _new_args_filter(cls, args): + newargs = [] + rel = [] + args = BooleanFunction.binary_check_and_simplify(*args) + for x in args: + if x.is_Relational: + c = x.canonical + if c in rel: + continue + nc = c.negated.canonical + if any(r == nc for r in rel): + return [true] + rel.append(c) + newargs.append(x) + return LatticeOp._new_args_filter(newargs, Or) + + def _eval_subs(self, old, new): + args = [] + bad = None + for i in self.args: + try: + i = i.subs(old, new) + except TypeError: + # store TypeError + if bad is None: + bad = i + continue + if i == True: + return true + elif i != False: + args.append(i) + if bad is not None: + # let it raise + bad.subs(old, new) + # If old is Or, replace the parts of the arguments with new if all + # are there + if isinstance(old, Or): + old_set = set(old.args) + if old_set.issubset(args): + args = set(args) - old_set + args.add(new) + + return self.func(*args) + + def _eval_as_set(self): + from sympy.sets.sets import Union + return Union(*[arg.as_set() for arg in self.args]) + + def _eval_rewrite_as_Nand(self, *args, **kwargs): + return Nand(*[Not(arg) for arg in self.args]) + + def _eval_simplify(self, **kwargs): + from sympy.core.relational import Le, Ge, Eq + lege = self.atoms(Le, Ge) + if lege: + reps = {i: self.func( + Eq(i.lhs, i.rhs), i.strict) for i in lege} + return self.xreplace(reps)._eval_simplify(**kwargs) + # standard simplify + rv = super()._eval_simplify(**kwargs) + if not isinstance(rv, Or): + return rv + patterns = _simplify_patterns_or() + return _apply_patternbased_simplification(rv, patterns, + kwargs['measure'], true) + + def to_anf(self, deep=True): + args = range(1, len(self.args) + 1) + args = (combinations(self.args, j) for j in args) + args = chain.from_iterable(args) # powerset + args = (And(*arg) for arg in args) + args = (to_anf(x, deep=deep) if deep else x for x in args) + return Xor(*list(args), remove_true=False) + + +class Not(BooleanFunction): + """ + Logical Not function (negation) + + + Returns ``true`` if the statement is ``false`` or ``False``. + Returns ``false`` if the statement is ``true`` or ``True``. + + Examples + ======== + + >>> from sympy import Not, And, Or + >>> from sympy.abc import x, A, B + >>> Not(True) + False + >>> Not(False) + True + >>> Not(And(True, False)) + True + >>> Not(Or(True, False)) + False + >>> Not(And(And(True, x), Or(x, False))) + ~x + >>> ~x + ~x + >>> Not(And(Or(A, B), Or(~A, ~B))) + ~((A | B) & (~A | ~B)) + + Notes + ===== + + - The ``~`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + not. In particular, ``~a`` and ``Not(a)`` will be different if ``a`` is + an integer. Furthermore, since bools in Python subclass from ``int``, + ``~True`` is the same as ``~1`` which is ``-2``, which has a boolean + value of True. To avoid this issue, use the SymPy boolean types + ``true`` and ``false``. + + - As of Python 3.12, the bitwise not operator ``~`` used on a + Python ``bool`` is deprecated and will emit a warning. + + >>> from sympy import true + >>> ~True # doctest: +SKIP + -2 + >>> ~true + False + + """ + + is_Not = True + + @classmethod + def eval(cls, arg): + if isinstance(arg, Number) or arg in (True, False): + return false if arg else true + if arg.is_Not: + return arg.args[0] + # Simplify Relational objects. + if arg.is_Relational: + return arg.negated + + def _eval_as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import Not, Symbol + >>> x = Symbol('x') + >>> Not(x > 0).as_set() + Interval(-oo, 0) + """ + return self.args[0].as_set().complement(S.Reals) + + def to_nnf(self, simplify=True): + if is_literal(self): + return self + + expr = self.args[0] + + func, args = expr.func, expr.args + + if func == And: + return Or._to_nnf(*[Not(arg) for arg in args], simplify=simplify) + + if func == Or: + return And._to_nnf(*[Not(arg) for arg in args], simplify=simplify) + + if func == Implies: + a, b = args + return And._to_nnf(a, Not(b), simplify=simplify) + + if func == Equivalent: + return And._to_nnf(Or(*args), Or(*[Not(arg) for arg in args]), + simplify=simplify) + + if func == Xor: + result = [] + for i in range(1, len(args)+1, 2): + for neg in combinations(args, i): + clause = [Not(s) if s in neg else s for s in args] + result.append(Or(*clause)) + return And._to_nnf(*result, simplify=simplify) + + if func == ITE: + a, b, c = args + return And._to_nnf(Or(a, Not(c)), Or(Not(a), Not(b)), simplify=simplify) + + raise ValueError("Illegal operator %s in expression" % func) + + def to_anf(self, deep=True): + return Xor._to_anf(true, self.args[0], deep=deep) + + +class Xor(BooleanFunction): + """ + Logical XOR (exclusive OR) function. + + + Returns True if an odd number of the arguments are True and the rest are + False. + + Returns False if an even number of the arguments are True and the rest are + False. + + Examples + ======== + + >>> from sympy.logic.boolalg import Xor + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Xor(True, False) + True + >>> Xor(True, True) + False + >>> Xor(True, False, True, True, False) + True + >>> Xor(True, False, True, False) + False + >>> x ^ y + x ^ y + + Notes + ===== + + The ``^`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise xor. In + particular, ``a ^ b`` and ``Xor(a, b)`` will be different if ``a`` and + ``b`` are integers. + + >>> Xor(x, y).subs(y, 0) + x + + """ + def __new__(cls, *args, remove_true=True, **kwargs): + argset = set() + obj = super().__new__(cls, *args, **kwargs) + for arg in obj._args: + if isinstance(arg, Number) or arg in (True, False): + if arg: + arg = true + else: + continue + if isinstance(arg, Xor): + for a in arg.args: + argset.remove(a) if a in argset else argset.add(a) + elif arg in argset: + argset.remove(arg) + else: + argset.add(arg) + rel = [(r, r.canonical, r.negated.canonical) + for r in argset if r.is_Relational] + odd = False # is number of complimentary pairs odd? start 0 -> False + remove = [] + for i, (r, c, nc) in enumerate(rel): + for j in range(i + 1, len(rel)): + rj, cj = rel[j][:2] + if cj == nc: + odd = not odd + break + elif cj == c: + break + else: + continue + remove.append((r, rj)) + if odd: + argset.remove(true) if true in argset else argset.add(true) + for a, b in remove: + argset.remove(a) + argset.remove(b) + if len(argset) == 0: + return false + elif len(argset) == 1: + return argset.pop() + elif True in argset and remove_true: + argset.remove(True) + return Not(Xor(*argset)) + else: + obj._args = tuple(ordered(argset)) + obj._argset = frozenset(argset) + return obj + + # XXX: This should be cached on the object rather than using cacheit + # Maybe it can be computed in __new__? + @property # type: ignore + @cacheit + def args(self): + return tuple(ordered(self._argset)) + + def to_nnf(self, simplify=True): + args = [] + for i in range(0, len(self.args)+1, 2): + for neg in combinations(self.args, i): + clause = [Not(s) if s in neg else s for s in self.args] + args.append(Or(*clause)) + return And._to_nnf(*args, simplify=simplify) + + def _eval_rewrite_as_Or(self, *args, **kwargs): + a = self.args + return Or(*[_convert_to_varsSOP(x, self.args) + for x in _get_odd_parity_terms(len(a))]) + + def _eval_rewrite_as_And(self, *args, **kwargs): + a = self.args + return And(*[_convert_to_varsPOS(x, self.args) + for x in _get_even_parity_terms(len(a))]) + + def _eval_simplify(self, **kwargs): + # as standard simplify uses simplify_logic which writes things as + # And and Or, we only simplify the partial expressions before using + # patterns + rv = self.func(*[a.simplify(**kwargs) for a in self.args]) + rv = rv.to_anf() + if not isinstance(rv, Xor): # This shouldn't really happen here + return rv + patterns = _simplify_patterns_xor() + return _apply_patternbased_simplification(rv, patterns, + kwargs['measure'], None) + + def _eval_subs(self, old, new): + # If old is Xor, replace the parts of the arguments with new if all + # are there + if isinstance(old, Xor): + old_set = set(old.args) + if old_set.issubset(self.args): + args = set(self.args) - old_set + args.add(new) + return self.func(*args) + + +class Nand(BooleanFunction): + """ + Logical NAND function. + + It evaluates its arguments in order, giving True immediately if any + of them are False, and False if they are all True. + + Returns True if any of the arguments are False + Returns False if all arguments are True + + Examples + ======== + + >>> from sympy.logic.boolalg import Nand + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Nand(False, True) + True + >>> Nand(True, True) + False + >>> Nand(x, y) + ~(x & y) + + """ + @classmethod + def eval(cls, *args): + return Not(And(*args)) + + +class Nor(BooleanFunction): + """ + Logical NOR function. + + It evaluates its arguments in order, giving False immediately if any + of them are True, and True if they are all False. + + Returns False if any argument is True + Returns True if all arguments are False + + Examples + ======== + + >>> from sympy.logic.boolalg import Nor + >>> from sympy import symbols + >>> x, y = symbols('x y') + + >>> Nor(True, False) + False + >>> Nor(True, True) + False + >>> Nor(False, True) + False + >>> Nor(False, False) + True + >>> Nor(x, y) + ~(x | y) + + """ + @classmethod + def eval(cls, *args): + return Not(Or(*args)) + + +class Xnor(BooleanFunction): + """ + Logical XNOR function. + + Returns False if an odd number of the arguments are True and the rest are + False. + + Returns True if an even number of the arguments are True and the rest are + False. + + Examples + ======== + + >>> from sympy.logic.boolalg import Xnor + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Xnor(True, False) + False + >>> Xnor(True, True) + True + >>> Xnor(True, False, True, True, False) + False + >>> Xnor(True, False, True, False) + True + + """ + @classmethod + def eval(cls, *args): + return Not(Xor(*args)) + + +class Implies(BooleanFunction): + r""" + Logical implication. + + A implies B is equivalent to if A then B. Mathematically, it is written + as `A \Rightarrow B` and is equivalent to `\neg A \vee B` or ``~A | B``. + + Accepts two Boolean arguments; A and B. + Returns False if A is True and B is False + Returns True otherwise. + + Examples + ======== + + >>> from sympy.logic.boolalg import Implies + >>> from sympy import symbols + >>> x, y = symbols('x y') + + >>> Implies(True, False) + False + >>> Implies(False, False) + True + >>> Implies(True, True) + True + >>> Implies(False, True) + True + >>> x >> y + Implies(x, y) + >>> y << x + Implies(x, y) + + Notes + ===== + + The ``>>`` and ``<<`` operators are provided as a convenience, but note + that their use here is different from their normal use in Python, which is + bit shifts. Hence, ``Implies(a, b)`` and ``a >> b`` will return different + things if ``a`` and ``b`` are integers. In particular, since Python + considers ``True`` and ``False`` to be integers, ``True >> True`` will be + the same as ``1 >> 1``, i.e., 0, which has a truth value of False. To + avoid this issue, use the SymPy objects ``true`` and ``false``. + + >>> from sympy import true, false + >>> True >> False + 1 + >>> true >> false + False + + """ + @classmethod + def eval(cls, *args): + try: + newargs = [] + for x in args: + if isinstance(x, Number) or x in (0, 1): + newargs.append(bool(x)) + else: + newargs.append(x) + A, B = newargs + except ValueError: + raise ValueError( + "%d operand(s) used for an Implies " + "(pairs are required): %s" % (len(args), str(args))) + if A in (True, False) or B in (True, False): + return Or(Not(A), B) + elif A == B: + return true + elif A.is_Relational and B.is_Relational: + if A.canonical == B.canonical: + return true + if A.negated.canonical == B.canonical: + return B + else: + return Basic.__new__(cls, *args) + + def to_nnf(self, simplify=True): + a, b = self.args + return Or._to_nnf(Not(a), b, simplify=simplify) + + def to_anf(self, deep=True): + a, b = self.args + return Xor._to_anf(true, a, And(a, b), deep=deep) + + +class Equivalent(BooleanFunction): + """ + Equivalence relation. + + ``Equivalent(A, B)`` is True iff A and B are both True or both False. + + Returns True if all of the arguments are logically equivalent. + Returns False otherwise. + + For two arguments, this is equivalent to :py:class:`~.Xnor`. + + Examples + ======== + + >>> from sympy.logic.boolalg import Equivalent, And + >>> from sympy.abc import x + >>> Equivalent(False, False, False) + True + >>> Equivalent(True, False, False) + False + >>> Equivalent(x, And(x, True)) + True + + """ + def __new__(cls, *args, **options): + from sympy.core.relational import Relational + args = [_sympify(arg) for arg in args] + + argset = set(args) + for x in args: + if isinstance(x, Number) or x in [True, False]: # Includes 0, 1 + argset.discard(x) + argset.add(bool(x)) + rel = [] + for r in argset: + if isinstance(r, Relational): + rel.append((r, r.canonical, r.negated.canonical)) + remove = [] + for i, (r, c, nc) in enumerate(rel): + for j in range(i + 1, len(rel)): + rj, cj = rel[j][:2] + if cj == nc: + return false + elif cj == c: + remove.append((r, rj)) + break + for a, b in remove: + argset.remove(a) + argset.remove(b) + argset.add(True) + if len(argset) <= 1: + return true + if True in argset: + argset.discard(True) + return And(*argset) + if False in argset: + argset.discard(False) + return And(*[Not(arg) for arg in argset]) + _args = frozenset(argset) + obj = super().__new__(cls, _args) + obj._argset = _args + return obj + + # XXX: This should be cached on the object rather than using cacheit + # Maybe it can be computed in __new__? + @property # type: ignore + @cacheit + def args(self): + return tuple(ordered(self._argset)) + + def to_nnf(self, simplify=True): + args = [] + for a, b in zip(self.args, self.args[1:]): + args.append(Or(Not(a), b)) + args.append(Or(Not(self.args[-1]), self.args[0])) + return And._to_nnf(*args, simplify=simplify) + + def to_anf(self, deep=True): + a = And(*self.args) + b = And(*[to_anf(Not(arg), deep=False) for arg in self.args]) + b = distribute_xor_over_and(b) + return Xor._to_anf(a, b, deep=deep) + + +class ITE(BooleanFunction): + """ + If-then-else clause. + + ``ITE(A, B, C)`` evaluates and returns the result of B if A is true + else it returns the result of C. All args must be Booleans. + + From a logic gate perspective, ITE corresponds to a 2-to-1 multiplexer, + where A is the select signal. + + Examples + ======== + + >>> from sympy.logic.boolalg import ITE, And, Xor, Or + >>> from sympy.abc import x, y, z + >>> ITE(True, False, True) + False + >>> ITE(Or(True, False), And(True, True), Xor(True, True)) + True + >>> ITE(x, y, z) + ITE(x, y, z) + >>> ITE(True, x, y) + x + >>> ITE(False, x, y) + y + >>> ITE(x, y, y) + y + + Trying to use non-Boolean args will generate a TypeError: + + >>> ITE(True, [], ()) + Traceback (most recent call last): + ... + TypeError: expecting bool, Boolean or ITE, not `[]` + + """ + def __new__(cls, *args, **kwargs): + from sympy.core.relational import Eq, Ne + if len(args) != 3: + raise ValueError('expecting exactly 3 args') + a, b, c = args + # check use of binary symbols + if isinstance(a, (Eq, Ne)): + # in this context, we can evaluate the Eq/Ne + # if one arg is a binary symbol and the other + # is true/false + b, c = map(as_Boolean, (b, c)) + bin_syms = set().union(*[i.binary_symbols for i in (b, c)]) + if len(set(a.args) - bin_syms) == 1: + # one arg is a binary_symbols + _a = a + if a.lhs is true: + a = a.rhs + elif a.rhs is true: + a = a.lhs + elif a.lhs is false: + a = Not(a.rhs) + elif a.rhs is false: + a = Not(a.lhs) + else: + # binary can only equal True or False + a = false + if isinstance(_a, Ne): + a = Not(a) + else: + a, b, c = BooleanFunction.binary_check_and_simplify( + a, b, c) + rv = None + if kwargs.get('evaluate', True): + rv = cls.eval(a, b, c) + if rv is None: + rv = BooleanFunction.__new__(cls, a, b, c, evaluate=False) + return rv + + @classmethod + def eval(cls, *args): + from sympy.core.relational import Eq, Ne + # do the args give a singular result? + a, b, c = args + if isinstance(a, (Ne, Eq)): + _a = a + if true in a.args: + a = a.lhs if a.rhs is true else a.rhs + elif false in a.args: + a = Not(a.lhs) if a.rhs is false else Not(a.rhs) + else: + _a = None + if _a is not None and isinstance(_a, Ne): + a = Not(a) + if a is true: + return b + if a is false: + return c + if b == c: + return b + else: + # or maybe the results allow the answer to be expressed + # in terms of the condition + if b is true and c is false: + return a + if b is false and c is true: + return Not(a) + if [a, b, c] != args: + return cls(a, b, c, evaluate=False) + + def to_nnf(self, simplify=True): + a, b, c = self.args + return And._to_nnf(Or(Not(a), b), Or(a, c), simplify=simplify) + + def _eval_as_set(self): + return self.to_nnf().as_set() + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + return Piecewise((args[1], args[0]), (args[2], True)) + + +class Exclusive(BooleanFunction): + """ + True if only one or no argument is true. + + ``Exclusive(A, B, C)`` is equivalent to ``~(A & B) & ~(A & C) & ~(B & C)``. + + For two arguments, this is equivalent to :py:class:`~.Xor`. + + Examples + ======== + + >>> from sympy.logic.boolalg import Exclusive + >>> Exclusive(False, False, False) + True + >>> Exclusive(False, True, False) + True + >>> Exclusive(False, True, True) + False + + """ + @classmethod + def eval(cls, *args): + and_args = [] + for a, b in combinations(args, 2): + and_args.append(Not(And(a, b))) + return And(*and_args) + + +# end class definitions. Some useful methods + + +def conjuncts(expr): + """Return a list of the conjuncts in ``expr``. + + Examples + ======== + + >>> from sympy.logic.boolalg import conjuncts + >>> from sympy.abc import A, B + >>> conjuncts(A & B) + frozenset({A, B}) + >>> conjuncts(A | B) + frozenset({A | B}) + + """ + return And.make_args(expr) + + +def disjuncts(expr): + """Return a list of the disjuncts in ``expr``. + + Examples + ======== + + >>> from sympy.logic.boolalg import disjuncts + >>> from sympy.abc import A, B + >>> disjuncts(A | B) + frozenset({A, B}) + >>> disjuncts(A & B) + frozenset({A & B}) + + """ + return Or.make_args(expr) + + +def distribute_and_over_or(expr): + """ + Given a sentence ``expr`` consisting of conjunctions and disjunctions + of literals, return an equivalent sentence in CNF. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_and_over_or, And, Or, Not + >>> from sympy.abc import A, B, C + >>> distribute_and_over_or(Or(A, And(Not(B), Not(C)))) + (A | ~B) & (A | ~C) + + """ + return _distribute((expr, And, Or)) + + +def distribute_or_over_and(expr): + """ + Given a sentence ``expr`` consisting of conjunctions and disjunctions + of literals, return an equivalent sentence in DNF. + + Note that the output is NOT simplified. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_or_over_and, And, Or, Not + >>> from sympy.abc import A, B, C + >>> distribute_or_over_and(And(Or(Not(A), B), C)) + (B & C) | (C & ~A) + + """ + return _distribute((expr, Or, And)) + + +def distribute_xor_over_and(expr): + """ + Given a sentence ``expr`` consisting of conjunction and + exclusive disjunctions of literals, return an + equivalent exclusive disjunction. + + Note that the output is NOT simplified. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_xor_over_and, And, Xor, Not + >>> from sympy.abc import A, B, C + >>> distribute_xor_over_and(And(Xor(Not(A), B), C)) + (B & C) ^ (C & ~A) + """ + return _distribute((expr, Xor, And)) + + +def _distribute(info): + """ + Distributes ``info[1]`` over ``info[2]`` with respect to ``info[0]``. + """ + if isinstance(info[0], info[2]): + for arg in info[0].args: + if isinstance(arg, info[1]): + conj = arg + break + else: + return info[0] + rest = info[2](*[a for a in info[0].args if a is not conj]) + return info[1](*list(map(_distribute, + [(info[2](c, rest), info[1], info[2]) + for c in conj.args])), remove_true=False) + elif isinstance(info[0], info[1]): + return info[1](*list(map(_distribute, + [(x, info[1], info[2]) + for x in info[0].args])), + remove_true=False) + else: + return info[0] + + +def to_anf(expr, deep=True): + r""" + Converts expr to Algebraic Normal Form (ANF). + + ANF is a canonical normal form, which means that two + equivalent formulas will convert to the same ANF. + + A logical expression is in ANF if it has the form + + .. math:: 1 \oplus a \oplus b \oplus ab \oplus abc + + i.e. it can be: + - purely true, + - purely false, + - conjunction of variables, + - exclusive disjunction. + + The exclusive disjunction can only contain true, variables + or conjunction of variables. No negations are permitted. + + If ``deep`` is ``False``, arguments of the boolean + expression are considered variables, i.e. only the + top-level expression is converted to ANF. + + Examples + ======== + >>> from sympy.logic.boolalg import And, Or, Not, Implies, Equivalent + >>> from sympy.logic.boolalg import to_anf + >>> from sympy.abc import A, B, C + >>> to_anf(Not(A)) + A ^ True + >>> to_anf(And(Or(A, B), Not(C))) + A ^ B ^ (A & B) ^ (A & C) ^ (B & C) ^ (A & B & C) + >>> to_anf(Implies(Not(A), Equivalent(B, C)), deep=False) + True ^ ~A ^ (~A & (Equivalent(B, C))) + + """ + expr = sympify(expr) + + if is_anf(expr): + return expr + return expr.to_anf(deep=deep) + + +def to_nnf(expr, simplify=True): + """ + Converts ``expr`` to Negation Normal Form (NNF). + + A logical expression is in NNF if it + contains only :py:class:`~.And`, :py:class:`~.Or` and :py:class:`~.Not`, + and :py:class:`~.Not` is applied only to literals. + If ``simplify`` is ``True``, the result contains no redundant clauses. + + Examples + ======== + + >>> from sympy.abc import A, B, C, D + >>> from sympy.logic.boolalg import Not, Equivalent, to_nnf + >>> to_nnf(Not((~A & ~B) | (C & D))) + (A | B) & (~C | ~D) + >>> to_nnf(Equivalent(A >> B, B >> A)) + (A | ~B | (A & ~B)) & (B | ~A | (B & ~A)) + + """ + if is_nnf(expr, simplify): + return expr + return expr.to_nnf(simplify) + + +def to_cnf(expr, simplify=False, force=False): + """ + Convert a propositional logical sentence ``expr`` to conjunctive normal + form: ``((A | ~B | ...) & (B | C | ...) & ...)``. + If ``simplify`` is ``True``, ``expr`` is evaluated to its simplest CNF + form using the Quine-McCluskey algorithm; this may take a long + time. If there are more than 8 variables the ``force`` flag must be set + to ``True`` to simplify (default is ``False``). + + Examples + ======== + + >>> from sympy.logic.boolalg import to_cnf + >>> from sympy.abc import A, B, D + >>> to_cnf(~(A | B) | D) + (D | ~A) & (D | ~B) + >>> to_cnf((A | B) & (A | ~A), True) + A | B + + """ + expr = sympify(expr) + if not isinstance(expr, BooleanFunction): + return expr + + if simplify: + if not force and len(_find_predicates(expr)) > 8: + raise ValueError(filldedent(''' + To simplify a logical expression with more + than 8 variables may take a long time and requires + the use of `force=True`.''')) + return simplify_logic(expr, 'cnf', True, force=force) + + # Don't convert unless we have to + if is_cnf(expr): + return expr + + expr = eliminate_implications(expr) + res = distribute_and_over_or(expr) + + return res + + +def to_dnf(expr, simplify=False, force=False): + """ + Convert a propositional logical sentence ``expr`` to disjunctive normal + form: ``((A & ~B & ...) | (B & C & ...) | ...)``. + If ``simplify`` is ``True``, ``expr`` is evaluated to its simplest DNF form using + the Quine-McCluskey algorithm; this may take a long + time. If there are more than 8 variables, the ``force`` flag must be set to + ``True`` to simplify (default is ``False``). + + Examples + ======== + + >>> from sympy.logic.boolalg import to_dnf + >>> from sympy.abc import A, B, C + >>> to_dnf(B & (A | C)) + (A & B) | (B & C) + >>> to_dnf((A & B) | (A & ~B) | (B & C) | (~B & C), True) + A | C + + """ + expr = sympify(expr) + if not isinstance(expr, BooleanFunction): + return expr + + if simplify: + if not force and len(_find_predicates(expr)) > 8: + raise ValueError(filldedent(''' + To simplify a logical expression with more + than 8 variables may take a long time and requires + the use of `force=True`.''')) + return simplify_logic(expr, 'dnf', True, force=force) + + # Don't convert unless we have to + if is_dnf(expr): + return expr + + expr = eliminate_implications(expr) + return distribute_or_over_and(expr) + + +def is_anf(expr): + r""" + Checks if ``expr`` is in Algebraic Normal Form (ANF). + + A logical expression is in ANF if it has the form + + .. math:: 1 \oplus a \oplus b \oplus ab \oplus abc + + i.e. it is purely true, purely false, conjunction of + variables or exclusive disjunction. The exclusive + disjunction can only contain true, variables or + conjunction of variables. No negations are permitted. + + Examples + ======== + + >>> from sympy.logic.boolalg import And, Not, Xor, true, is_anf + >>> from sympy.abc import A, B, C + >>> is_anf(true) + True + >>> is_anf(A) + True + >>> is_anf(And(A, B, C)) + True + >>> is_anf(Xor(A, Not(B))) + False + + """ + expr = sympify(expr) + + if is_literal(expr) and not isinstance(expr, Not): + return True + + if isinstance(expr, And): + for arg in expr.args: + if not arg.is_Symbol: + return False + return True + + elif isinstance(expr, Xor): + for arg in expr.args: + if isinstance(arg, And): + for a in arg.args: + if not a.is_Symbol: + return False + elif is_literal(arg): + if isinstance(arg, Not): + return False + else: + return False + return True + + else: + return False + + +def is_nnf(expr, simplified=True): + """ + Checks if ``expr`` is in Negation Normal Form (NNF). + + A logical expression is in NNF if it + contains only :py:class:`~.And`, :py:class:`~.Or` and :py:class:`~.Not`, + and :py:class:`~.Not` is applied only to literals. + If ``simplified`` is ``True``, checks if result contains no redundant clauses. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy.logic.boolalg import Not, is_nnf + >>> is_nnf(A & B | ~C) + True + >>> is_nnf((A | ~A) & (B | C)) + False + >>> is_nnf((A | ~A) & (B | C), False) + True + >>> is_nnf(Not(A & B) | C) + False + >>> is_nnf((A >> B) & (B >> A)) + False + + """ + + expr = sympify(expr) + if is_literal(expr): + return True + + stack = [expr] + + while stack: + expr = stack.pop() + if expr.func in (And, Or): + if simplified: + args = expr.args + for arg in args: + if Not(arg) in args: + return False + stack.extend(expr.args) + + elif not is_literal(expr): + return False + + return True + + +def is_cnf(expr): + """ + Test whether or not an expression is in conjunctive normal form. + + Examples + ======== + + >>> from sympy.logic.boolalg import is_cnf + >>> from sympy.abc import A, B, C + >>> is_cnf(A | B | C) + True + >>> is_cnf(A & B & C) + True + >>> is_cnf((A & B) | C) + False + + """ + return _is_form(expr, And, Or) + + +def is_dnf(expr): + """ + Test whether or not an expression is in disjunctive normal form. + + Examples + ======== + + >>> from sympy.logic.boolalg import is_dnf + >>> from sympy.abc import A, B, C + >>> is_dnf(A | B | C) + True + >>> is_dnf(A & B & C) + True + >>> is_dnf((A & B) | C) + True + >>> is_dnf(A & (B | C)) + False + + """ + return _is_form(expr, Or, And) + + +def _is_form(expr, function1, function2): + """ + Test whether or not an expression is of the required form. + + """ + expr = sympify(expr) + + vals = function1.make_args(expr) if isinstance(expr, function1) else [expr] + for lit in vals: + if isinstance(lit, function2): + vals2 = function2.make_args(lit) if isinstance(lit, function2) else [lit] + for l in vals2: + if is_literal(l) is False: + return False + elif is_literal(lit) is False: + return False + + return True + + +def eliminate_implications(expr): + """ + Change :py:class:`~.Implies` and :py:class:`~.Equivalent` into + :py:class:`~.And`, :py:class:`~.Or`, and :py:class:`~.Not`. + That is, return an expression that is equivalent to ``expr``, but has only + ``&``, ``|``, and ``~`` as logical + operators. + + Examples + ======== + + >>> from sympy.logic.boolalg import Implies, Equivalent, \ + eliminate_implications + >>> from sympy.abc import A, B, C + >>> eliminate_implications(Implies(A, B)) + B | ~A + >>> eliminate_implications(Equivalent(A, B)) + (A | ~B) & (B | ~A) + >>> eliminate_implications(Equivalent(A, B, C)) + (A | ~C) & (B | ~A) & (C | ~B) + + """ + return to_nnf(expr, simplify=False) + + +def is_literal(expr): + """ + Returns True if expr is a literal, else False. + + Examples + ======== + + >>> from sympy import Or, Q + >>> from sympy.abc import A, B + >>> from sympy.logic.boolalg import is_literal + >>> is_literal(A) + True + >>> is_literal(~A) + True + >>> is_literal(Q.zero(A)) + True + >>> is_literal(A + B) + True + >>> is_literal(Or(A, B)) + False + + """ + from sympy.assumptions import AppliedPredicate + + if isinstance(expr, Not): + return is_literal(expr.args[0]) + elif expr in (True, False) or isinstance(expr, AppliedPredicate) or expr.is_Atom: + return True + elif not isinstance(expr, BooleanFunction) and all( + (isinstance(expr, AppliedPredicate) or a.is_Atom) for a in expr.args): + return True + return False + + +def to_int_repr(clauses, symbols): + """ + Takes clauses in CNF format and puts them into an integer representation. + + Examples + ======== + + >>> from sympy.logic.boolalg import to_int_repr + >>> from sympy.abc import x, y + >>> to_int_repr([x | y, y], [x, y]) == [{1, 2}, {2}] + True + + """ + + # Convert the symbol list into a dict + symbols = dict(zip(symbols, range(1, len(symbols) + 1))) + + def append_symbol(arg, symbols): + if isinstance(arg, Not): + return -symbols[arg.args[0]] + else: + return symbols[arg] + + return [{append_symbol(arg, symbols) for arg in Or.make_args(c)} + for c in clauses] + + +def term_to_integer(term): + """ + Return an integer corresponding to the base-2 digits given by *term*. + + Parameters + ========== + + term : a string or list of ones and zeros + + Examples + ======== + + >>> from sympy.logic.boolalg import term_to_integer + >>> term_to_integer([1, 0, 0]) + 4 + >>> term_to_integer('100') + 4 + + """ + + return int(''.join(list(map(str, list(term)))), 2) + + +integer_to_term = ibin # XXX could delete? + + +def truth_table(expr, variables, input=True): + """ + Return a generator of all possible configurations of the input variables, + and the result of the boolean expression for those values. + + Parameters + ========== + + expr : Boolean expression + + variables : list of variables + + input : bool (default ``True``) + Indicates whether to return the input combinations. + + Examples + ======== + + >>> from sympy.logic.boolalg import truth_table + >>> from sympy.abc import x,y + >>> table = truth_table(x >> y, [x, y]) + >>> for t in table: + ... print('{0} -> {1}'.format(*t)) + [0, 0] -> True + [0, 1] -> True + [1, 0] -> False + [1, 1] -> True + + >>> table = truth_table(x | y, [x, y]) + >>> list(table) + [([0, 0], False), ([0, 1], True), ([1, 0], True), ([1, 1], True)] + + If ``input`` is ``False``, ``truth_table`` returns only a list of truth values. + In this case, the corresponding input values of variables can be + deduced from the index of a given output. + + >>> from sympy.utilities.iterables import ibin + >>> vars = [y, x] + >>> values = truth_table(x >> y, vars, input=False) + >>> values = list(values) + >>> values + [True, False, True, True] + + >>> for i, value in enumerate(values): + ... print('{0} -> {1}'.format(list(zip( + ... vars, ibin(i, len(vars)))), value)) + [(y, 0), (x, 0)] -> True + [(y, 0), (x, 1)] -> False + [(y, 1), (x, 0)] -> True + [(y, 1), (x, 1)] -> True + + """ + variables = [sympify(v) for v in variables] + + expr = sympify(expr) + if not isinstance(expr, BooleanFunction) and not is_literal(expr): + return + + table = product((0, 1), repeat=len(variables)) + for term in table: + value = expr.xreplace(dict(zip(variables, term))) + + if input: + yield list(term), value + else: + yield value + + +def _check_pair(minterm1, minterm2): + """ + Checks if a pair of minterms differs by only one bit. If yes, returns + index, else returns `-1`. + """ + # Early termination seems to be faster than list comprehension, + # at least for large examples. + index = -1 + for x, i in enumerate(minterm1): # zip(minterm1, minterm2) is slower + if i != minterm2[x]: + if index == -1: + index = x + else: + return -1 + return index + + +def _convert_to_varsSOP(minterm, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for SOP). + """ + temp = [variables[n] if val == 1 else Not(variables[n]) + for n, val in enumerate(minterm) if val != 3] + return And(*temp) + + +def _convert_to_varsPOS(maxterm, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for POS). + """ + temp = [variables[n] if val == 0 else Not(variables[n]) + for n, val in enumerate(maxterm) if val != 3] + return Or(*temp) + + +def _convert_to_varsANF(term, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for ANF). + + Parameters + ========== + + term : list of 1's and 0's (complementation pattern) + variables : list of variables + + """ + temp = [variables[n] for n, t in enumerate(term) if t == 1] + + if not temp: + return true + + return And(*temp) + + +def _get_odd_parity_terms(n): + """ + Returns a list of lists, with all possible combinations of n zeros and ones + with an odd number of ones. + """ + return [e for e in [ibin(i, n) for i in range(2**n)] if sum(e) % 2 == 1] + + +def _get_even_parity_terms(n): + """ + Returns a list of lists, with all possible combinations of n zeros and ones + with an even number of ones. + """ + return [e for e in [ibin(i, n) for i in range(2**n)] if sum(e) % 2 == 0] + + +def _simplified_pairs(terms): + """ + Reduces a set of minterms, if possible, to a simplified set of minterms + with one less variable in the terms using QM method. + """ + if not terms: + return [] + + simplified_terms = [] + todo = list(range(len(terms))) + + # Count number of ones as _check_pair can only potentially match if there + # is at most a difference of a single one + termdict = defaultdict(list) + for n, term in enumerate(terms): + ones = sum(1 for t in term if t == 1) + termdict[ones].append(n) + + variables = len(terms[0]) + for k in range(variables): + for i in termdict[k]: + for j in termdict[k+1]: + index = _check_pair(terms[i], terms[j]) + if index != -1: + # Mark terms handled + todo[i] = todo[j] = None + # Copy old term + newterm = terms[i][:] + # Set differing position to don't care + newterm[index] = 3 + # Add if not already there + if newterm not in simplified_terms: + simplified_terms.append(newterm) + + if simplified_terms: + # Further simplifications only among the new terms + simplified_terms = _simplified_pairs(simplified_terms) + + # Add remaining, non-simplified, terms + simplified_terms.extend([terms[i] for i in todo if i is not None]) + return simplified_terms + + +def _rem_redundancy(l1, terms): + """ + After the truth table has been sufficiently simplified, use the prime + implicant table method to recognize and eliminate redundant pairs, + and return the essential arguments. + """ + + if not terms: + return [] + + nterms = len(terms) + nl1 = len(l1) + + # Create dominating matrix + dommatrix = [[0]*nl1 for n in range(nterms)] + colcount = [0]*nl1 + rowcount = [0]*nterms + for primei, prime in enumerate(l1): + for termi, term in enumerate(terms): + # Check prime implicant covering term + if all(t == 3 or t == mt for t, mt in zip(prime, term)): + dommatrix[termi][primei] = 1 + colcount[primei] += 1 + rowcount[termi] += 1 + + # Keep track if anything changed + anythingchanged = True + # Then, go again + while anythingchanged: + anythingchanged = False + + for rowi in range(nterms): + # Still non-dominated? + if rowcount[rowi]: + row = dommatrix[rowi] + for row2i in range(nterms): + # Still non-dominated? + if rowi != row2i and rowcount[rowi] and (rowcount[rowi] <= rowcount[row2i]): + row2 = dommatrix[row2i] + if all(row2[n] >= row[n] for n in range(nl1)): + # row2 dominating row, remove row2 + rowcount[row2i] = 0 + anythingchanged = True + for primei, prime in enumerate(row2): + if prime: + # Make corresponding entry 0 + dommatrix[row2i][primei] = 0 + colcount[primei] -= 1 + + colcache = {} + + for coli in range(nl1): + # Still non-dominated? + if colcount[coli]: + if coli in colcache: + col = colcache[coli] + else: + col = [dommatrix[i][coli] for i in range(nterms)] + colcache[coli] = col + for col2i in range(nl1): + # Still non-dominated? + if coli != col2i and colcount[col2i] and (colcount[coli] >= colcount[col2i]): + if col2i in colcache: + col2 = colcache[col2i] + else: + col2 = [dommatrix[i][col2i] for i in range(nterms)] + colcache[col2i] = col2 + if all(col[n] >= col2[n] for n in range(nterms)): + # col dominating col2, remove col2 + colcount[col2i] = 0 + anythingchanged = True + for termi, term in enumerate(col2): + if term and dommatrix[termi][col2i]: + # Make corresponding entry 0 + dommatrix[termi][col2i] = 0 + rowcount[termi] -= 1 + + if not anythingchanged: + # Heuristically select the prime implicant covering most terms + maxterms = 0 + bestcolidx = -1 + for coli in range(nl1): + s = colcount[coli] + if s > maxterms: + bestcolidx = coli + maxterms = s + + # In case we found a prime implicant covering at least two terms + if bestcolidx != -1 and maxterms > 1: + for primei, prime in enumerate(l1): + if primei != bestcolidx: + for termi, term in enumerate(colcache[bestcolidx]): + if term and dommatrix[termi][primei]: + # Make corresponding entry 0 + dommatrix[termi][primei] = 0 + anythingchanged = True + rowcount[termi] -= 1 + colcount[primei] -= 1 + + return [l1[i] for i in range(nl1) if colcount[i]] + + +def _input_to_binlist(inputlist, variables): + binlist = [] + bits = len(variables) + for val in inputlist: + if isinstance(val, int): + binlist.append(ibin(val, bits)) + elif isinstance(val, dict): + nonspecvars = list(variables) + for key in val.keys(): + nonspecvars.remove(key) + for t in product((0, 1), repeat=len(nonspecvars)): + d = dict(zip(nonspecvars, t)) + d.update(val) + binlist.append([d[v] for v in variables]) + elif isinstance(val, (list, tuple)): + if len(val) != bits: + raise ValueError("Each term must contain {bits} bits as there are" + "\n{bits} variables (or be an integer)." + "".format(bits=bits)) + binlist.append(list(val)) + else: + raise TypeError("A term list can only contain lists," + " ints or dicts.") + return binlist + + +def SOPform(variables, minterms, dontcares=None): + """ + The SOPform function uses simplified_pairs and a redundant group- + eliminating algorithm to convert the list of all input combos that + generate '1' (the minterms) into the smallest sum-of-products form. + + The variables must be given as the first argument. + + Return a logical :py:class:`~.Or` function (i.e., the "sum of products" or + "SOP" form) that gives the desired outcome. If there are inputs that can + be ignored, pass them as a list, too. + + The result will be one of the (perhaps many) functions that satisfy + the conditions. + + Examples + ======== + + >>> from sympy.logic import SOPform + >>> from sympy import symbols + >>> w, x, y, z = symbols('w x y z') + >>> minterms = [[0, 0, 0, 1], [0, 0, 1, 1], + ... [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1]] + >>> dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + >>> SOPform([w, x, y, z], minterms, dontcares) + (y & z) | (~w & ~x) + + The terms can also be represented as integers: + + >>> minterms = [1, 3, 7, 11, 15] + >>> dontcares = [0, 2, 5] + >>> SOPform([w, x, y, z], minterms, dontcares) + (y & z) | (~w & ~x) + + They can also be specified using dicts, which does not have to be fully + specified: + + >>> minterms = [{w: 0, x: 1}, {y: 1, z: 1, x: 0}] + >>> SOPform([w, x, y, z], minterms) + (x & ~w) | (y & z & ~x) + + Or a combination: + + >>> minterms = [4, 7, 11, [1, 1, 1, 1]] + >>> dontcares = [{w : 0, x : 0, y: 0}, 5] + >>> SOPform([w, x, y, z], minterms, dontcares) + (w & y & z) | (~w & ~y) | (x & z & ~w) + + See also + ======== + + POSform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Quine-McCluskey_algorithm + .. [2] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + if not minterms: + return false + + variables = tuple(map(sympify, variables)) + + + minterms = _input_to_binlist(minterms, variables) + dontcares = _input_to_binlist((dontcares or []), variables) + for d in dontcares: + if d in minterms: + raise ValueError('%s in minterms is also in dontcares' % d) + + return _sop_form(variables, minterms, dontcares) + + +def _sop_form(variables, minterms, dontcares): + new = _simplified_pairs(minterms + dontcares) + essential = _rem_redundancy(new, minterms) + return Or(*[_convert_to_varsSOP(x, variables) for x in essential]) + + +def POSform(variables, minterms, dontcares=None): + """ + The POSform function uses simplified_pairs and a redundant-group + eliminating algorithm to convert the list of all input combinations + that generate '1' (the minterms) into the smallest product-of-sums form. + + The variables must be given as the first argument. + + Return a logical :py:class:`~.And` function (i.e., the "product of sums" + or "POS" form) that gives the desired outcome. If there are inputs that can + be ignored, pass them as a list, too. + + The result will be one of the (perhaps many) functions that satisfy + the conditions. + + Examples + ======== + + >>> from sympy.logic import POSform + >>> from sympy import symbols + >>> w, x, y, z = symbols('w x y z') + >>> minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], + ... [1, 0, 1, 1], [1, 1, 1, 1]] + >>> dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + >>> POSform([w, x, y, z], minterms, dontcares) + z & (y | ~w) + + The terms can also be represented as integers: + + >>> minterms = [1, 3, 7, 11, 15] + >>> dontcares = [0, 2, 5] + >>> POSform([w, x, y, z], minterms, dontcares) + z & (y | ~w) + + They can also be specified using dicts, which does not have to be fully + specified: + + >>> minterms = [{w: 0, x: 1}, {y: 1, z: 1, x: 0}] + >>> POSform([w, x, y, z], minterms) + (x | y) & (x | z) & (~w | ~x) + + Or a combination: + + >>> minterms = [4, 7, 11, [1, 1, 1, 1]] + >>> dontcares = [{w : 0, x : 0, y: 0}, 5] + >>> POSform([w, x, y, z], minterms, dontcares) + (w | x) & (y | ~w) & (z | ~y) + + See also + ======== + + SOPform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Quine-McCluskey_algorithm + .. [2] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + if not minterms: + return false + + variables = tuple(map(sympify, variables)) + minterms = _input_to_binlist(minterms, variables) + dontcares = _input_to_binlist((dontcares or []), variables) + for d in dontcares: + if d in minterms: + raise ValueError('%s in minterms is also in dontcares' % d) + + maxterms = [] + for t in product((0, 1), repeat=len(variables)): + t = list(t) + if (t not in minterms) and (t not in dontcares): + maxterms.append(t) + + new = _simplified_pairs(maxterms + dontcares) + essential = _rem_redundancy(new, maxterms) + return And(*[_convert_to_varsPOS(x, variables) for x in essential]) + + +def ANFform(variables, truthvalues): + """ + The ANFform function converts the list of truth values to + Algebraic Normal Form (ANF). + + The variables must be given as the first argument. + + Return True, False, logical :py:class:`~.And` function (i.e., the + "Zhegalkin monomial") or logical :py:class:`~.Xor` function (i.e., + the "Zhegalkin polynomial"). When True and False + are represented by 1 and 0, respectively, then + :py:class:`~.And` is multiplication and :py:class:`~.Xor` is addition. + + Formally a "Zhegalkin monomial" is the product (logical + And) of a finite set of distinct variables, including + the empty set whose product is denoted 1 (True). + A "Zhegalkin polynomial" is the sum (logical Xor) of a + set of Zhegalkin monomials, with the empty set denoted + by 0 (False). + + Parameters + ========== + + variables : list of variables + truthvalues : list of 1's and 0's (result column of truth table) + + Examples + ======== + >>> from sympy.logic.boolalg import ANFform + >>> from sympy.abc import x, y + >>> ANFform([x], [1, 0]) + x ^ True + >>> ANFform([x, y], [0, 1, 1, 1]) + x ^ y ^ (x & y) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Zhegalkin_polynomial + + """ + + n_vars = len(variables) + n_values = len(truthvalues) + + if n_values != 2 ** n_vars: + raise ValueError("The number of truth values must be equal to 2^%d, " + "got %d" % (n_vars, n_values)) + + variables = tuple(map(sympify, variables)) + + coeffs = anf_coeffs(truthvalues) + terms = [] + + for i, t in enumerate(product((0, 1), repeat=n_vars)): + if coeffs[i] == 1: + terms.append(t) + + return Xor(*[_convert_to_varsANF(x, variables) for x in terms], + remove_true=False) + + +def anf_coeffs(truthvalues): + """ + Convert a list of truth values of some boolean expression + to the list of coefficients of the polynomial mod 2 (exclusive + disjunction) representing the boolean expression in ANF + (i.e., the "Zhegalkin polynomial"). + + There are `2^n` possible Zhegalkin monomials in `n` variables, since + each monomial is fully specified by the presence or absence of + each variable. + + We can enumerate all the monomials. For example, boolean + function with four variables ``(a, b, c, d)`` can contain + up to `2^4 = 16` monomials. The 13-th monomial is the + product ``a & b & d``, because 13 in binary is 1, 1, 0, 1. + + A given monomial's presence or absence in a polynomial corresponds + to that monomial's coefficient being 1 or 0 respectively. + + Examples + ======== + >>> from sympy.logic.boolalg import anf_coeffs, bool_monomial, Xor + >>> from sympy.abc import a, b, c + >>> truthvalues = [0, 1, 1, 0, 0, 1, 0, 1] + >>> coeffs = anf_coeffs(truthvalues) + >>> coeffs + [0, 1, 1, 0, 0, 0, 1, 0] + >>> polynomial = Xor(*[ + ... bool_monomial(k, [a, b, c]) + ... for k, coeff in enumerate(coeffs) if coeff == 1 + ... ]) + >>> polynomial + b ^ c ^ (a & b) + + """ + + s = '{:b}'.format(len(truthvalues)) + n = len(s) - 1 + + if len(truthvalues) != 2**n: + raise ValueError("The number of truth values must be a power of two, " + "got %d" % len(truthvalues)) + + coeffs = [[v] for v in truthvalues] + + for i in range(n): + tmp = [] + for j in range(2 ** (n-i-1)): + tmp.append(coeffs[2*j] + + list(map(lambda x, y: x^y, coeffs[2*j], coeffs[2*j+1]))) + coeffs = tmp + + return coeffs[0] + + +def bool_minterm(k, variables): + """ + Return the k-th minterm. + + Minterms are numbered by a binary encoding of the complementation + pattern of the variables. This convention assigns the value 1 to + the direct form and 0 to the complemented form. + + Parameters + ========== + + k : int or list of 1's and 0's (complementation pattern) + variables : list of variables + + Examples + ======== + + >>> from sympy.logic.boolalg import bool_minterm + >>> from sympy.abc import x, y, z + >>> bool_minterm([1, 0, 1], [x, y, z]) + x & z & ~y + >>> bool_minterm(6, [x, y, z]) + x & y & ~z + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Canonical_normal_form#Indexing_minterms + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsSOP(k, variables) + + +def bool_maxterm(k, variables): + """ + Return the k-th maxterm. + + Each maxterm is assigned an index based on the opposite + conventional binary encoding used for minterms. The maxterm + convention assigns the value 0 to the direct form and 1 to + the complemented form. + + Parameters + ========== + + k : int or list of 1's and 0's (complementation pattern) + variables : list of variables + + Examples + ======== + >>> from sympy.logic.boolalg import bool_maxterm + >>> from sympy.abc import x, y, z + >>> bool_maxterm([1, 0, 1], [x, y, z]) + y | ~x | ~z + >>> bool_maxterm(6, [x, y, z]) + z | ~x | ~y + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Canonical_normal_form#Indexing_maxterms + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsPOS(k, variables) + + +def bool_monomial(k, variables): + """ + Return the k-th monomial. + + Monomials are numbered by a binary encoding of the presence and + absences of the variables. This convention assigns the value + 1 to the presence of variable and 0 to the absence of variable. + + Each boolean function can be uniquely represented by a + Zhegalkin Polynomial (Algebraic Normal Form). The Zhegalkin + Polynomial of the boolean function with `n` variables can contain + up to `2^n` monomials. We can enumerate all the monomials. + Each monomial is fully specified by the presence or absence + of each variable. + + For example, boolean function with four variables ``(a, b, c, d)`` + can contain up to `2^4 = 16` monomials. The 13-th monomial is the + product ``a & b & d``, because 13 in binary is 1, 1, 0, 1. + + Parameters + ========== + + k : int or list of 1's and 0's + variables : list of variables + + Examples + ======== + >>> from sympy.logic.boolalg import bool_monomial + >>> from sympy.abc import x, y, z + >>> bool_monomial([1, 0, 1], [x, y, z]) + x & z + >>> bool_monomial(6, [x, y, z]) + x & y + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsANF(k, variables) + + +def _find_predicates(expr): + """Helper to find logical predicates in BooleanFunctions. + + A logical predicate is defined here as anything within a BooleanFunction + that is not a BooleanFunction itself. + + """ + if not isinstance(expr, BooleanFunction): + return {expr} + return set().union(*(map(_find_predicates, expr.args))) + + +def simplify_logic(expr, form=None, deep=True, force=False, dontcare=None): + """ + This function simplifies a boolean function to its simplified version + in SOP or POS form. The return type is an :py:class:`~.Or` or + :py:class:`~.And` object in SymPy. + + Parameters + ========== + + expr : Boolean + + form : string (``'cnf'`` or ``'dnf'``) or ``None`` (default). + If ``'cnf'`` or ``'dnf'``, the simplest expression in the corresponding + normal form is returned; if ``None``, the answer is returned + according to the form with fewest args (in CNF by default). + + deep : bool (default ``True``) + Indicates whether to recursively simplify any + non-boolean functions contained within the input. + + force : bool (default ``False``) + As the simplifications require exponential time in the number + of variables, there is by default a limit on expressions with + 8 variables. When the expression has more than 8 variables + only symbolical simplification (controlled by ``deep``) is + made. By setting ``force`` to ``True``, this limit is removed. Be + aware that this can lead to very long simplification times. + + dontcare : Boolean + Optimize expression under the assumption that inputs where this + expression is true are don't care. This is useful in e.g. Piecewise + conditions, where later conditions do not need to consider inputs that + are converted by previous conditions. For example, if a previous + condition is ``And(A, B)``, the simplification of expr can be made + with don't cares for ``And(A, B)``. + + Examples + ======== + + >>> from sympy.logic import simplify_logic + >>> from sympy.abc import x, y, z + >>> b = (~x & ~y & ~z) | ( ~x & ~y & z) + >>> simplify_logic(b) + ~x & ~y + >>> simplify_logic(x | y, dontcare=y) + x + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + + if form not in (None, 'cnf', 'dnf'): + raise ValueError("form can be cnf or dnf only") + expr = sympify(expr) + # check for quick exit if form is given: right form and all args are + # literal and do not involve Not + if form: + form_ok = False + if form == 'cnf': + form_ok = is_cnf(expr) + elif form == 'dnf': + form_ok = is_dnf(expr) + + if form_ok and all(is_literal(a) + for a in expr.args): + return expr + from sympy.core.relational import Relational + if deep: + variables = expr.atoms(Relational) + from sympy.simplify.simplify import simplify + s = tuple(map(simplify, variables)) + expr = expr.xreplace(dict(zip(variables, s))) + if not isinstance(expr, BooleanFunction): + return expr + # Replace Relationals with Dummys to possibly + # reduce the number of variables + repl = {} + undo = {} + from sympy.core.symbol import Dummy + variables = expr.atoms(Relational) + if dontcare is not None: + dontcare = sympify(dontcare) + variables.update(dontcare.atoms(Relational)) + while variables: + var = variables.pop() + if var.is_Relational: + d = Dummy() + undo[d] = var + repl[var] = d + nvar = var.negated + if nvar in variables: + repl[nvar] = Not(d) + variables.remove(nvar) + + expr = expr.xreplace(repl) + + if dontcare is not None: + dontcare = dontcare.xreplace(repl) + + # Get new variables after replacing + variables = _find_predicates(expr) + if not force and len(variables) > 8: + return expr.xreplace(undo) + if dontcare is not None: + # Add variables from dontcare + dcvariables = _find_predicates(dontcare) + variables.update(dcvariables) + # if too many restore to variables only + if not force and len(variables) > 8: + variables = _find_predicates(expr) + dontcare = None + # group into constants and variable values + c, v = sift(ordered(variables), lambda x: x in (True, False), binary=True) + variables = c + v + # standardize constants to be 1 or 0 in keeping with truthtable + c = [1 if i == True else 0 for i in c] + truthtable = _get_truthtable(v, expr, c) + if dontcare is not None: + dctruthtable = _get_truthtable(v, dontcare, c) + truthtable = [t for t in truthtable if t not in dctruthtable] + else: + dctruthtable = [] + big = len(truthtable) >= (2 ** (len(variables) - 1)) + if form == 'dnf' or form is None and big: + return _sop_form(variables, truthtable, dctruthtable).xreplace(undo) + return POSform(variables, truthtable, dctruthtable).xreplace(undo) + + +def _get_truthtable(variables, expr, const): + """ Return a list of all combinations leading to a True result for ``expr``. + """ + _variables = variables.copy() + def _get_tt(inputs): + if _variables: + v = _variables.pop() + tab = [[i[0].xreplace({v: false}), [0] + i[1]] for i in inputs if i[0] is not false] + tab.extend([[i[0].xreplace({v: true}), [1] + i[1]] for i in inputs if i[0] is not false]) + return _get_tt(tab) + return inputs + res = [const + k[1] for k in _get_tt([[expr, []]]) if k[0]] + if res == [[]]: + return [] + else: + return res + + +def _finger(eq): + """ + Assign a 5-item fingerprint to each symbol in the equation: + [ + # of times it appeared as a Symbol; + # of times it appeared as a Not(symbol); + # of times it appeared as a Symbol in an And or Or; + # of times it appeared as a Not(Symbol) in an And or Or; + a sorted tuple of tuples, (i, j, k), where i is the number of arguments + in an And or Or with which it appeared as a Symbol, and j is + the number of arguments that were Not(Symbol); k is the number + of times that (i, j) was seen. + ] + + Examples + ======== + + >>> from sympy.logic.boolalg import _finger as finger + >>> from sympy import And, Or, Not, Xor, to_cnf, symbols + >>> from sympy.abc import a, b, x, y + >>> eq = Or(And(Not(y), a), And(Not(y), b), And(x, y)) + >>> dict(finger(eq)) + {(0, 0, 1, 0, ((2, 0, 1),)): [x], + (0, 0, 1, 0, ((2, 1, 1),)): [a, b], + (0, 0, 1, 2, ((2, 0, 1),)): [y]} + >>> dict(finger(x & ~y)) + {(0, 1, 0, 0, ()): [y], (1, 0, 0, 0, ()): [x]} + + In the following, the (5, 2, 6) means that there were 6 Or + functions in which a symbol appeared as itself amongst 5 arguments in + which there were also 2 negated symbols, e.g. ``(a0 | a1 | a2 | ~a3 | ~a4)`` + is counted once for a0, a1 and a2. + + >>> dict(finger(to_cnf(Xor(*symbols('a:5'))))) + {(0, 0, 8, 8, ((5, 0, 1), (5, 2, 6), (5, 4, 1))): [a0, a1, a2, a3, a4]} + + The equation must not have more than one level of nesting: + + >>> dict(finger(And(Or(x, y), y))) + {(0, 0, 1, 0, ((2, 0, 1),)): [x], (1, 0, 1, 0, ((2, 0, 1),)): [y]} + >>> dict(finger(And(Or(x, And(a, x)), y))) + Traceback (most recent call last): + ... + NotImplementedError: unexpected level of nesting + + So y and x have unique fingerprints, but a and b do not. + """ + f = eq.free_symbols + d = dict(list(zip(f, [[0]*4 + [defaultdict(int)] for fi in f]))) + for a in eq.args: + if a.is_Symbol: + d[a][0] += 1 + elif a.is_Not: + d[a.args[0]][1] += 1 + else: + o = len(a.args), sum(isinstance(ai, Not) for ai in a.args) + for ai in a.args: + if ai.is_Symbol: + d[ai][2] += 1 + d[ai][-1][o] += 1 + elif ai.is_Not: + d[ai.args[0]][3] += 1 + else: + raise NotImplementedError('unexpected level of nesting') + inv = defaultdict(list) + for k, v in ordered(iter(d.items())): + v[-1] = tuple(sorted([i + (j,) for i, j in v[-1].items()])) + inv[tuple(v)].append(k) + return inv + + +def bool_map(bool1, bool2): + """ + Return the simplified version of *bool1*, and the mapping of variables + that makes the two expressions *bool1* and *bool2* represent the same + logical behaviour for some correspondence between the variables + of each. + If more than one mappings of this sort exist, one of them + is returned. + + For example, ``And(x, y)`` is logically equivalent to ``And(a, b)`` for + the mapping ``{x: a, y: b}`` or ``{x: b, y: a}``. + If no such mapping exists, return ``False``. + + Examples + ======== + + >>> from sympy import SOPform, bool_map, Or, And, Not, Xor + >>> from sympy.abc import w, x, y, z, a, b, c, d + >>> function1 = SOPform([x, z, y],[[1, 0, 1], [0, 0, 1]]) + >>> function2 = SOPform([a, b, c],[[1, 0, 1], [1, 0, 0]]) + >>> bool_map(function1, function2) + (y & ~z, {y: a, z: b}) + + The results are not necessarily unique, but they are canonical. Here, + ``(w, z)`` could be ``(a, d)`` or ``(d, a)``: + + >>> eq = Or(And(Not(y), w), And(Not(y), z), And(x, y)) + >>> eq2 = Or(And(Not(c), a), And(Not(c), d), And(b, c)) + >>> bool_map(eq, eq2) + ((x & y) | (w & ~y) | (z & ~y), {w: a, x: b, y: c, z: d}) + >>> eq = And(Xor(a, b), c, And(c,d)) + >>> bool_map(eq, eq.subs(c, x)) + (c & d & (a | b) & (~a | ~b), {a: a, b: b, c: d, d: x}) + + """ + + def match(function1, function2): + """Return the mapping that equates variables between two + simplified boolean expressions if possible. + + By "simplified" we mean that a function has been denested + and is either an And (or an Or) whose arguments are either + symbols (x), negated symbols (Not(x)), or Or (or an And) whose + arguments are only symbols or negated symbols. For example, + ``And(x, Not(y), Or(w, Not(z)))``. + + Basic.match is not robust enough (see issue 4835) so this is + a workaround that is valid for simplified boolean expressions + """ + + # do some quick checks + if function1.__class__ != function2.__class__: + return None # maybe simplification makes them the same? + if len(function1.args) != len(function2.args): + return None # maybe simplification makes them the same? + if function1.is_Symbol: + return {function1: function2} + + # get the fingerprint dictionaries + f1 = _finger(function1) + f2 = _finger(function2) + + # more quick checks + if len(f1) != len(f2): + return False + + # assemble the match dictionary if possible + matchdict = {} + for k in f1.keys(): + if k not in f2: + return False + if len(f1[k]) != len(f2[k]): + return False + for i, x in enumerate(f1[k]): + matchdict[x] = f2[k][i] + return matchdict + + a = simplify_logic(bool1) + b = simplify_logic(bool2) + m = match(a, b) + if m: + return a, m + return m + + +def _apply_patternbased_simplification(rv, patterns, measure, + dominatingvalue, + replacementvalue=None, + threeterm_patterns=None): + """ + Replace patterns of Relational + + Parameters + ========== + + rv : Expr + Boolean expression + + patterns : tuple + Tuple of tuples, with (pattern to simplify, simplified pattern) with + two terms. + + measure : function + Simplification measure. + + dominatingvalue : Boolean or ``None`` + The dominating value for the function of consideration. + For example, for :py:class:`~.And` ``S.false`` is dominating. + As soon as one expression is ``S.false`` in :py:class:`~.And`, + the whole expression is ``S.false``. + + replacementvalue : Boolean or ``None``, optional + The resulting value for the whole expression if one argument + evaluates to ``dominatingvalue``. + For example, for :py:class:`~.Nand` ``S.false`` is dominating, but + in this case the resulting value is ``S.true``. Default is ``None``. + If ``replacementvalue`` is ``None`` and ``dominatingvalue`` is not + ``None``, ``replacementvalue = dominatingvalue``. + + threeterm_patterns : tuple, optional + Tuple of tuples, with (pattern to simplify, simplified pattern) with + three terms. + + """ + from sympy.core.relational import Relational, _canonical + + if replacementvalue is None and dominatingvalue is not None: + replacementvalue = dominatingvalue + # Use replacement patterns for Relationals + Rel, nonRel = sift(rv.args, lambda i: isinstance(i, Relational), + binary=True) + if len(Rel) <= 1: + return rv + Rel, nonRealRel = sift(Rel, lambda i: not any(s.is_real is False + for s in i.free_symbols), + binary=True) + Rel = [i.canonical for i in Rel] + + if threeterm_patterns and len(Rel) >= 3: + Rel = _apply_patternbased_threeterm_simplification(Rel, + threeterm_patterns, rv.func, dominatingvalue, + replacementvalue, measure) + + Rel = _apply_patternbased_twoterm_simplification(Rel, patterns, + rv.func, dominatingvalue, replacementvalue, measure) + + rv = rv.func(*([_canonical(i) for i in ordered(Rel)] + + nonRel + nonRealRel)) + return rv + + +def _apply_patternbased_twoterm_simplification(Rel, patterns, func, + dominatingvalue, + replacementvalue, + measure): + """ Apply pattern-based two-term simplification.""" + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core.relational import Ge, Gt, _Inequality + changed = True + while changed and len(Rel) >= 2: + changed = False + # Use only < or <= + Rel = [r.reversed if isinstance(r, (Ge, Gt)) else r for r in Rel] + # Sort based on ordered + Rel = list(ordered(Rel)) + # Eq and Ne must be tested reversed as well + rtmp = [(r, ) if isinstance(r, _Inequality) else (r, r.reversed) for r in Rel] + # Create a list of possible replacements + results = [] + # Try all combinations of possibly reversed relational + for ((i, pi), (j, pj)) in combinations(enumerate(rtmp), 2): + for pattern, simp in patterns: + res = [] + for p1, p2 in product(pi, pj): + # use SymPy matching + oldexpr = Tuple(p1, p2) + tmpres = oldexpr.match(pattern) + if tmpres: + res.append((tmpres, oldexpr)) + + if res: + for tmpres, oldexpr in res: + # we have a matching, compute replacement + np = simp.xreplace(tmpres) + if np == dominatingvalue: + # if dominatingvalue, the whole expression + # will be replacementvalue + return [replacementvalue] + # add replacement + if not isinstance(np, ITE) and not np.has(Min, Max): + # We only want to use ITE and Min/Max replacements if + # they simplify to a relational + costsaving = measure(func(*oldexpr.args)) - measure(np) + if costsaving > 0: + results.append((costsaving, ([i, j], np))) + if results: + # Sort results based on complexity + results = sorted(results, + key=lambda pair: pair[0], reverse=True) + # Replace the one providing most simplification + replacement = results[0][1] + idx, newrel = replacement + idx.sort() + # Remove the old relationals + for index in reversed(idx): + del Rel[index] + if dominatingvalue is None or newrel != Not(dominatingvalue): + # Insert the new one (no need to insert a value that will + # not affect the result) + if newrel.func == func: + for a in newrel.args: + Rel.append(a) + else: + Rel.append(newrel) + # We did change something so try again + changed = True + return Rel + + +def _apply_patternbased_threeterm_simplification(Rel, patterns, func, + dominatingvalue, + replacementvalue, + measure): + """ Apply pattern-based three-term simplification.""" + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core.relational import Le, Lt, _Inequality + changed = True + while changed and len(Rel) >= 3: + changed = False + # Use only > or >= + Rel = [r.reversed if isinstance(r, (Le, Lt)) else r for r in Rel] + # Sort based on ordered + Rel = list(ordered(Rel)) + # Create a list of possible replacements + results = [] + # Eq and Ne must be tested reversed as well + rtmp = [(r, ) if isinstance(r, _Inequality) else (r, r.reversed) for r in Rel] + # Try all combinations of possibly reversed relational + for ((i, pi), (j, pj), (k, pk)) in permutations(enumerate(rtmp), 3): + for pattern, simp in patterns: + res = [] + for p1, p2, p3 in product(pi, pj, pk): + # use SymPy matching + oldexpr = Tuple(p1, p2, p3) + tmpres = oldexpr.match(pattern) + if tmpres: + res.append((tmpres, oldexpr)) + + if res: + for tmpres, oldexpr in res: + # we have a matching, compute replacement + np = simp.xreplace(tmpres) + if np == dominatingvalue: + # if dominatingvalue, the whole expression + # will be replacementvalue + return [replacementvalue] + # add replacement + if not isinstance(np, ITE) and not np.has(Min, Max): + # We only want to use ITE and Min/Max replacements if + # they simplify to a relational + costsaving = measure(func(*oldexpr.args)) - measure(np) + if costsaving > 0: + results.append((costsaving, ([i, j, k], np))) + if results: + # Sort results based on complexity + results = sorted(results, + key=lambda pair: pair[0], reverse=True) + # Replace the one providing most simplification + replacement = results[0][1] + idx, newrel = replacement + idx.sort() + # Remove the old relationals + for index in reversed(idx): + del Rel[index] + if dominatingvalue is None or newrel != Not(dominatingvalue): + # Insert the new one (no need to insert a value that will + # not affect the result) + if newrel.func == func: + for a in newrel.args: + Rel.append(a) + else: + Rel.append(newrel) + # We did change something so try again + changed = True + return Rel + + +@cacheit +def _simplify_patterns_and(): + """ Two-term patterns for And.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.miscellaneous import Min, Max + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_and = ((Tuple(Eq(a, b), Lt(a, b)), false), + #(Tuple(Eq(a, b), Lt(b, a)), S.false), + #(Tuple(Le(b, a), Lt(a, b)), S.false), + #(Tuple(Lt(b, a), Le(a, b)), S.false), + (Tuple(Lt(b, a), Lt(a, b)), false), + (Tuple(Eq(a, b), Le(b, a)), Eq(a, b)), + #(Tuple(Eq(a, b), Le(a, b)), Eq(a, b)), + #(Tuple(Le(b, a), Lt(b, a)), Gt(a, b)), + (Tuple(Le(b, a), Le(a, b)), Eq(a, b)), + #(Tuple(Le(b, a), Ne(a, b)), Gt(a, b)), + #(Tuple(Lt(b, a), Ne(a, b)), Gt(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Lt(a, b)), + (Tuple(Le(a, b), Ne(a, b)), Lt(a, b)), + (Tuple(Lt(a, b), Ne(a, b)), Lt(a, b)), + # Sign + (Tuple(Eq(a, b), Eq(a, -b)), And(Eq(a, S.Zero), Eq(b, S.Zero))), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), Ge(a, Max(b, c))), + (Tuple(Le(b, a), Lt(c, a)), ITE(b > c, Ge(a, b), Gt(a, c))), + (Tuple(Lt(b, a), Lt(c, a)), Gt(a, Max(b, c))), + (Tuple(Le(a, b), Le(a, c)), Le(a, Min(b, c))), + (Tuple(Le(a, b), Lt(a, c)), ITE(b < c, Le(a, b), Lt(a, c))), + (Tuple(Lt(a, b), Lt(a, c)), Lt(a, Min(b, c))), + (Tuple(Le(a, b), Le(c, a)), ITE(Eq(b, c), Eq(a, b), ITE(b < c, false, And(Le(a, b), Ge(a, c))))), + (Tuple(Le(c, a), Le(a, b)), ITE(Eq(b, c), Eq(a, b), ITE(b < c, false, And(Le(a, b), Ge(a, c))))), + (Tuple(Lt(a, b), Lt(c, a)), ITE(b < c, false, And(Lt(a, b), Gt(a, c)))), + (Tuple(Lt(c, a), Lt(a, b)), ITE(b < c, false, And(Lt(a, b), Gt(a, c)))), + (Tuple(Le(a, b), Lt(c, a)), ITE(b <= c, false, And(Le(a, b), Gt(a, c)))), + (Tuple(Le(c, a), Lt(a, b)), ITE(b <= c, false, And(Lt(a, b), Ge(a, c)))), + (Tuple(Eq(a, b), Eq(a, c)), ITE(Eq(b, c), Eq(a, b), false)), + (Tuple(Lt(a, b), Lt(-b, a)), ITE(b > 0, Lt(Abs(a), b), false)), + (Tuple(Le(a, b), Le(-b, a)), ITE(b >= 0, Le(Abs(a), b), false)), + ) + return _matchers_and + + +@cacheit +def _simplify_patterns_and3(): + """ Three-term patterns for And.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ge, Gt + + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, pattern3, simplified) + # Do not use Le, Lt + _matchers_and = ((Tuple(Ge(a, b), Ge(b, c), Gt(c, a)), false), + (Tuple(Ge(a, b), Gt(b, c), Gt(c, a)), false), + (Tuple(Gt(a, b), Gt(b, c), Gt(c, a)), false), + # (Tuple(Ge(c, a), Gt(a, b), Gt(b, c)), S.false), + # Lower bound relations + # Commented out combinations that does not simplify + (Tuple(Ge(a, b), Ge(a, c), Ge(b, c)), And(Ge(a, b), Ge(b, c))), + (Tuple(Ge(a, b), Ge(a, c), Gt(b, c)), And(Ge(a, b), Gt(b, c))), + # (Tuple(Ge(a, b), Gt(a, c), Ge(b, c)), And(Ge(a, b), Ge(b, c))), + (Tuple(Ge(a, b), Gt(a, c), Gt(b, c)), And(Ge(a, b), Gt(b, c))), + # (Tuple(Gt(a, b), Ge(a, c), Ge(b, c)), And(Gt(a, b), Ge(b, c))), + (Tuple(Ge(a, c), Gt(a, b), Gt(b, c)), And(Gt(a, b), Gt(b, c))), + (Tuple(Ge(b, c), Gt(a, b), Gt(a, c)), And(Gt(a, b), Ge(b, c))), + (Tuple(Gt(a, b), Gt(a, c), Gt(b, c)), And(Gt(a, b), Gt(b, c))), + # Upper bound relations + # Commented out combinations that does not simplify + (Tuple(Ge(b, a), Ge(c, a), Ge(b, c)), And(Ge(c, a), Ge(b, c))), + (Tuple(Ge(b, a), Ge(c, a), Gt(b, c)), And(Ge(c, a), Gt(b, c))), + # (Tuple(Ge(b, a), Gt(c, a), Ge(b, c)), And(Gt(c, a), Ge(b, c))), + (Tuple(Ge(b, a), Gt(c, a), Gt(b, c)), And(Gt(c, a), Gt(b, c))), + # (Tuple(Gt(b, a), Ge(c, a), Ge(b, c)), And(Ge(c, a), Ge(b, c))), + (Tuple(Ge(c, a), Gt(b, a), Gt(b, c)), And(Ge(c, a), Gt(b, c))), + (Tuple(Ge(b, c), Gt(b, a), Gt(c, a)), And(Gt(c, a), Ge(b, c))), + (Tuple(Gt(b, a), Gt(c, a), Gt(b, c)), And(Gt(c, a), Gt(b, c))), + # Circular relation + (Tuple(Ge(a, b), Ge(b, c), Ge(c, a)), And(Eq(a, b), Eq(b, c))), + ) + return _matchers_and + + +@cacheit +def _simplify_patterns_or(): + """ Two-term patterns for Or.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.miscellaneous import Min, Max + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_or = ((Tuple(Le(b, a), Le(a, b)), true), + #(Tuple(Le(b, a), Lt(a, b)), true), + (Tuple(Le(b, a), Ne(a, b)), true), + #(Tuple(Le(a, b), Lt(b, a)), true), + #(Tuple(Le(a, b), Ne(a, b)), true), + #(Tuple(Eq(a, b), Le(b, a)), Ge(a, b)), + #(Tuple(Eq(a, b), Lt(b, a)), Ge(a, b)), + (Tuple(Eq(a, b), Le(a, b)), Le(a, b)), + (Tuple(Eq(a, b), Lt(a, b)), Le(a, b)), + #(Tuple(Le(b, a), Lt(b, a)), Ge(a, b)), + (Tuple(Lt(b, a), Lt(a, b)), Ne(a, b)), + (Tuple(Lt(b, a), Ne(a, b)), Ne(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Le(a, b)), + #(Tuple(Lt(a, b), Ne(a, b)), Ne(a, b)), + (Tuple(Eq(a, b), Ne(a, c)), ITE(Eq(b, c), true, Ne(a, c))), + (Tuple(Ne(a, b), Ne(a, c)), ITE(Eq(b, c), Ne(a, b), true)), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), Ge(a, Min(b, c))), + #(Tuple(Ge(b, a), Ge(c, a)), Ge(Min(b, c), a)), + (Tuple(Le(b, a), Lt(c, a)), ITE(b > c, Lt(c, a), Le(b, a))), + (Tuple(Lt(b, a), Lt(c, a)), Gt(a, Min(b, c))), + #(Tuple(Gt(b, a), Gt(c, a)), Gt(Min(b, c), a)), + (Tuple(Le(a, b), Le(a, c)), Le(a, Max(b, c))), + #(Tuple(Le(b, a), Le(c, a)), Le(Max(b, c), a)), + (Tuple(Le(a, b), Lt(a, c)), ITE(b >= c, Le(a, b), Lt(a, c))), + (Tuple(Lt(a, b), Lt(a, c)), Lt(a, Max(b, c))), + #(Tuple(Lt(b, a), Lt(c, a)), Lt(Max(b, c), a)), + (Tuple(Le(a, b), Le(c, a)), ITE(b >= c, true, Or(Le(a, b), Ge(a, c)))), + (Tuple(Le(c, a), Le(a, b)), ITE(b >= c, true, Or(Le(a, b), Ge(a, c)))), + (Tuple(Lt(a, b), Lt(c, a)), ITE(b > c, true, Or(Lt(a, b), Gt(a, c)))), + (Tuple(Lt(c, a), Lt(a, b)), ITE(b > c, true, Or(Lt(a, b), Gt(a, c)))), + (Tuple(Le(a, b), Lt(c, a)), ITE(b >= c, true, Or(Le(a, b), Gt(a, c)))), + (Tuple(Le(c, a), Lt(a, b)), ITE(b >= c, true, Or(Lt(a, b), Ge(a, c)))), + (Tuple(Lt(b, a), Lt(a, -b)), ITE(b >= 0, Gt(Abs(a), b), true)), + (Tuple(Le(b, a), Le(a, -b)), ITE(b > 0, Ge(Abs(a), b), true)), + ) + return _matchers_or + + +@cacheit +def _simplify_patterns_xor(): + """ Two-term patterns for Xor.""" + + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_xor = (#(Tuple(Le(b, a), Lt(a, b)), true), + #(Tuple(Lt(b, a), Le(a, b)), true), + #(Tuple(Eq(a, b), Le(b, a)), Gt(a, b)), + #(Tuple(Eq(a, b), Lt(b, a)), Ge(a, b)), + (Tuple(Eq(a, b), Le(a, b)), Lt(a, b)), + (Tuple(Eq(a, b), Lt(a, b)), Le(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Eq(a, b)), + (Tuple(Le(a, b), Le(b, a)), Ne(a, b)), + (Tuple(Le(b, a), Ne(a, b)), Le(a, b)), + # (Tuple(Lt(b, a), Lt(a, b)), Ne(a, b)), + (Tuple(Lt(b, a), Ne(a, b)), Lt(a, b)), + # (Tuple(Le(a, b), Lt(a, b)), Eq(a, b)), + # (Tuple(Le(a, b), Ne(a, b)), Ge(a, b)), + # (Tuple(Lt(a, b), Ne(a, b)), Gt(a, b)), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), + And(Ge(a, Min(b, c)), Lt(a, Max(b, c)))), + (Tuple(Le(b, a), Lt(c, a)), + ITE(b > c, And(Gt(a, c), Lt(a, b)), + And(Ge(a, b), Le(a, c)))), + (Tuple(Lt(b, a), Lt(c, a)), + And(Gt(a, Min(b, c)), Le(a, Max(b, c)))), + (Tuple(Le(a, b), Le(a, c)), + And(Le(a, Max(b, c)), Gt(a, Min(b, c)))), + (Tuple(Le(a, b), Lt(a, c)), + ITE(b < c, And(Lt(a, c), Gt(a, b)), + And(Le(a, b), Ge(a, c)))), + (Tuple(Lt(a, b), Lt(a, c)), + And(Lt(a, Max(b, c)), Ge(a, Min(b, c)))), + ) + return _matchers_xor + + +def simplify_univariate(expr): + """return a simplified version of univariate boolean expression, else ``expr``""" + from sympy.functions.elementary.piecewise import Piecewise + from sympy.core.relational import Eq, Ne + if not isinstance(expr, BooleanFunction): + return expr + if expr.atoms(Eq, Ne): + return expr + c = expr + free = c.free_symbols + if len(free) != 1: + return c + x = free.pop() + ok, i = Piecewise((0, c), evaluate=False + )._intervals(x, err_on_Eq=True) + if not ok: + return c + if not i: + return false + args = [] + for a, b, _, _ in i: + if a is S.NegativeInfinity: + if b is S.Infinity: + c = true + else: + if c.subs(x, b) == True: + c = (x <= b) + else: + c = (x < b) + else: + incl_a = (c.subs(x, a) == True) + incl_b = (c.subs(x, b) == True) + if incl_a and incl_b: + if b.is_infinite: + c = (x >= a) + else: + c = And(a <= x, x <= b) + elif incl_a: + c = And(a <= x, x < b) + elif incl_b: + if b.is_infinite: + c = (x > a) + else: + c = And(a < x, x <= b) + else: + c = And(a < x, x < b) + args.append(c) + return Or(*args) + + +# Classes corresponding to logic gates +# Used in gateinputcount method +BooleanGates = (And, Or, Xor, Nand, Nor, Not, Xnor, ITE) + +def gateinputcount(expr): + """ + Return the total number of inputs for the logic gates realizing the + Boolean expression. + + Returns + ======= + + int + Number of gate inputs + + Note + ==== + + Not all Boolean functions count as gate here, only those that are + considered to be standard gates. These are: :py:class:`~.And`, + :py:class:`~.Or`, :py:class:`~.Xor`, :py:class:`~.Not`, and + :py:class:`~.ITE` (multiplexer). :py:class:`~.Nand`, :py:class:`~.Nor`, + and :py:class:`~.Xnor` will be evaluated to ``Not(And())`` etc. + + Examples + ======== + + >>> from sympy.logic import And, Or, Nand, Not, gateinputcount + >>> from sympy.abc import x, y, z + >>> expr = And(x, y) + >>> gateinputcount(expr) + 2 + >>> gateinputcount(Or(expr, z)) + 4 + + Note that ``Nand`` is automatically evaluated to ``Not(And())`` so + + >>> gateinputcount(Nand(x, y, z)) + 4 + >>> gateinputcount(Not(And(x, y, z))) + 4 + + Although this can be avoided by using ``evaluate=False`` + + >>> gateinputcount(Nand(x, y, z, evaluate=False)) + 3 + + Also note that a comparison will count as a Boolean variable: + + >>> gateinputcount(And(x > z, y >= 2)) + 2 + + As will a symbol: + >>> gateinputcount(x) + 0 + + """ + if not isinstance(expr, Boolean): + raise TypeError("Expression must be Boolean") + if isinstance(expr, BooleanGates): + return len(expr.args) + sum(gateinputcount(x) for x in expr.args) + return 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/inference.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c3798231c09ae351ea7e7026d622b834fea3e3fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/inference.py @@ -0,0 +1,340 @@ +"""Inference in propositional logic""" + +from sympy.logic.boolalg import And, Not, conjuncts, to_cnf, BooleanFunction +from sympy.core.sorting import ordered +from sympy.core.sympify import sympify +from sympy.external.importtools import import_module + + +def literal_symbol(literal): + """ + The symbol in this literal (without the negation). + + Examples + ======== + + >>> from sympy.abc import A + >>> from sympy.logic.inference import literal_symbol + >>> literal_symbol(A) + A + >>> literal_symbol(~A) + A + + """ + + if literal is True or literal is False: + return literal + elif literal.is_Symbol: + return literal + elif literal.is_Not: + return literal_symbol(literal.args[0]) + else: + raise ValueError("Argument must be a boolean literal.") + + +def satisfiable(expr, algorithm=None, all_models=False, minimal=False, use_lra_theory=False): + """ + Check satisfiability of a propositional sentence. + Returns a model when it succeeds. + Returns {true: true} for trivially true expressions. + + On setting all_models to True, if given expr is satisfiable then + returns a generator of models. However, if expr is unsatisfiable + then returns a generator containing the single element False. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import satisfiable + >>> satisfiable(A & ~B) + {A: True, B: False} + >>> satisfiable(A & ~A) + False + >>> satisfiable(True) + {True: True} + >>> next(satisfiable(A & ~A, all_models=True)) + False + >>> models = satisfiable((A >> B) & B, all_models=True) + >>> next(models) + {A: False, B: True} + >>> next(models) + {A: True, B: True} + >>> def use_models(models): + ... for model in models: + ... if model: + ... # Do something with the model. + ... print(model) + ... else: + ... # Given expr is unsatisfiable. + ... print("UNSAT") + >>> use_models(satisfiable(A >> ~A, all_models=True)) + {A: False} + >>> use_models(satisfiable(A ^ A, all_models=True)) + UNSAT + + """ + if use_lra_theory: + if algorithm is not None and algorithm != "dpll2": + raise ValueError(f"Currently only dpll2 can handle using lra theory. {algorithm} is not handled.") + algorithm = "dpll2" + + if algorithm is None or algorithm == "pycosat": + pycosat = import_module('pycosat') + if pycosat is not None: + algorithm = "pycosat" + else: + if algorithm == "pycosat": + raise ImportError("pycosat module is not present") + # Silently fall back to dpll2 if pycosat + # is not installed + algorithm = "dpll2" + + if algorithm=="minisat22": + pysat = import_module('pysat') + if pysat is None: + algorithm = "dpll2" + + if algorithm=="z3": + z3 = import_module('z3') + if z3 is None: + algorithm = "dpll2" + + if algorithm == "dpll": + from sympy.logic.algorithms.dpll import dpll_satisfiable + return dpll_satisfiable(expr) + elif algorithm == "dpll2": + from sympy.logic.algorithms.dpll2 import dpll_satisfiable + return dpll_satisfiable(expr, all_models, use_lra_theory=use_lra_theory) + elif algorithm == "pycosat": + from sympy.logic.algorithms.pycosat_wrapper import pycosat_satisfiable + return pycosat_satisfiable(expr, all_models) + elif algorithm == "minisat22": + from sympy.logic.algorithms.minisat22_wrapper import minisat22_satisfiable + return minisat22_satisfiable(expr, all_models, minimal) + elif algorithm == "z3": + from sympy.logic.algorithms.z3_wrapper import z3_satisfiable + return z3_satisfiable(expr, all_models) + + raise NotImplementedError + + +def valid(expr): + """ + Check validity of a propositional sentence. + A valid propositional sentence is True under every assignment. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import valid + >>> valid(A | ~A) + True + >>> valid(A | B) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Validity + + """ + return not satisfiable(Not(expr)) + + +def pl_true(expr, model=None, deep=False): + """ + Returns whether the given assignment is a model or not. + + If the assignment does not specify the value for every proposition, + this may return None to indicate 'not obvious'. + + Parameters + ========== + + model : dict, optional, default: {} + Mapping of symbols to boolean values to indicate assignment. + deep: boolean, optional, default: False + Gives the value of the expression under partial assignments + correctly. May still return None to indicate 'not obvious'. + + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import pl_true + >>> pl_true( A & B, {A: True, B: True}) + True + >>> pl_true(A & B, {A: False}) + False + >>> pl_true(A & B, {A: True}) + >>> pl_true(A & B, {A: True}, deep=True) + >>> pl_true(A >> (B >> A)) + >>> pl_true(A >> (B >> A), deep=True) + True + >>> pl_true(A & ~A) + >>> pl_true(A & ~A, deep=True) + False + >>> pl_true(A & B & (~A | ~B), {A: True}) + >>> pl_true(A & B & (~A | ~B), {A: True}, deep=True) + False + + """ + + from sympy.core.symbol import Symbol + + boolean = (True, False) + + def _validate(expr): + if isinstance(expr, Symbol) or expr in boolean: + return True + if not isinstance(expr, BooleanFunction): + return False + return all(_validate(arg) for arg in expr.args) + + if expr in boolean: + return expr + expr = sympify(expr) + if not _validate(expr): + raise ValueError("%s is not a valid boolean expression" % expr) + if not model: + model = {} + model = {k: v for k, v in model.items() if v in boolean} + result = expr.subs(model) + if result in boolean: + return bool(result) + if deep: + model = dict.fromkeys(result.atoms(), True) + if pl_true(result, model): + if valid(result): + return True + else: + if not satisfiable(result): + return False + return None + + +def entails(expr, formula_set=None): + """ + Check whether the given expr_set entail an expr. + If formula_set is empty then it returns the validity of expr. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy.logic.inference import entails + >>> entails(A, [A >> B, B >> C]) + False + >>> entails(C, [A >> B, B >> C, A]) + True + >>> entails(A >> B) + False + >>> entails(A >> (B >> A)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logical_consequence + + """ + if formula_set: + formula_set = list(formula_set) + else: + formula_set = [] + formula_set.append(Not(expr)) + return not satisfiable(And(*formula_set)) + + +class KB: + """Base class for all knowledge bases""" + def __init__(self, sentence=None): + self.clauses_ = set() + if sentence: + self.tell(sentence) + + def tell(self, sentence): + raise NotImplementedError + + def ask(self, query): + raise NotImplementedError + + def retract(self, sentence): + raise NotImplementedError + + @property + def clauses(self): + return list(ordered(self.clauses_)) + + +class PropKB(KB): + """A KB for Propositional Logic. Inefficient, with no indexing.""" + + def tell(self, sentence): + """Add the sentence's clauses to the KB + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.clauses + [] + + >>> l.tell(x | y) + >>> l.clauses + [x | y] + + >>> l.tell(y) + >>> l.clauses + [y, x | y] + + """ + for c in conjuncts(to_cnf(sentence)): + self.clauses_.add(c) + + def ask(self, query): + """Checks if the query is true given the set of clauses. + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.tell(x & ~y) + >>> l.ask(x) + True + >>> l.ask(y) + False + + """ + return entails(query, self.clauses_) + + def retract(self, sentence): + """Remove the sentence's clauses from the KB + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.clauses + [] + + >>> l.tell(x | y) + >>> l.clauses + [x | y] + + >>> l.retract(x | y) + >>> l.clauses + [] + + """ + for c in conjuncts(to_cnf(sentence)): + self.clauses_.discard(c) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1ed4d80ea06359d596fd394eaec4ae3cb1d0842 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f67d488fb4f51a3f3467055f258a3e58f335b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a29b26656f1cf37a1de8f442dd06e7a39ad330 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639582c9f8309de92afccb6446f08bd2a3c623e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_boolalg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_boolalg.py new file mode 100644 index 0000000000000000000000000000000000000000..88cdd647fdcc723faee328f71df96030841a3edb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_boolalg.py @@ -0,0 +1,1367 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.numbers import oo +from sympy.core.relational import Equality, Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.sets.sets import Interval, Union +from sympy.sets.contains import Contains +from sympy.simplify.simplify import simplify +from sympy.logic.boolalg import ( + And, Boolean, Equivalent, ITE, Implies, Nand, Nor, Not, Or, + POSform, SOPform, Xor, Xnor, conjuncts, disjuncts, + distribute_or_over_and, distribute_and_over_or, + eliminate_implications, is_nnf, is_cnf, is_dnf, simplify_logic, + to_nnf, to_cnf, to_dnf, to_int_repr, bool_map, true, false, + BooleanAtom, is_literal, term_to_integer, + truth_table, as_Boolean, to_anf, is_anf, distribute_xor_over_and, + anf_coeffs, ANFform, bool_minterm, bool_maxterm, bool_monomial, + _check_pair, _convert_to_varsSOP, _convert_to_varsPOS, Exclusive, + gateinputcount) +from sympy.assumptions.cnf import CNF + +from sympy.testing.pytest import raises, XFAIL, slow + +from itertools import combinations, permutations, product + +A, B, C, D = symbols('A:D') +a, b, c, d, e, w, x, y, z = symbols('a:e w:z') + + +def test_overloading(): + """Test that |, & are overloaded as expected""" + + assert A & B == And(A, B) + assert A | B == Or(A, B) + assert (A & B) | C == Or(And(A, B), C) + assert A >> B == Implies(A, B) + assert A << B == Implies(B, A) + assert ~A == Not(A) + assert A ^ B == Xor(A, B) + + +def test_And(): + assert And() is true + assert And(A) == A + assert And(True) is true + assert And(False) is false + assert And(True, True) is true + assert And(True, False) is false + assert And(False, False) is false + assert And(True, A) == A + assert And(False, A) is false + assert And(True, True, True) is true + assert And(True, True, A) == A + assert And(True, False, A) is false + assert And(1, A) == A + raises(TypeError, lambda: And(2, A)) + assert And(A < 1, A >= 1) is false + e = A > 1 + assert And(e, e.canonical) == e.canonical + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert And(g, l, ge, le) == And(ge, g) + assert {And(*i) for i in permutations((l, g, le, ge))} == {And(ge, g)} + assert And(And(Eq(a, 0), Eq(b, 0)), And(Ne(a, 0), Eq(c, 0))) is false + + +def test_Or(): + assert Or() is false + assert Or(A) == A + assert Or(True) is true + assert Or(False) is false + assert Or(True, True) is true + assert Or(True, False) is true + assert Or(False, False) is false + assert Or(True, A) is true + assert Or(False, A) == A + assert Or(True, False, False) is true + assert Or(True, False, A) is true + assert Or(False, False, A) == A + assert Or(1, A) is true + raises(TypeError, lambda: Or(2, A)) + assert Or(A < 1, A >= 1) is true + e = A > 1 + assert Or(e, e.canonical) == e + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert Or(g, l, ge, le) == Or(g, ge) + + +def test_Xor(): + assert Xor() is false + assert Xor(A) == A + assert Xor(A, A) is false + assert Xor(True, A, A) is true + assert Xor(A, A, A, A, A) == A + assert Xor(True, False, False, A, B) == ~Xor(A, B) + assert Xor(True) is true + assert Xor(False) is false + assert Xor(True, True) is false + assert Xor(True, False) is true + assert Xor(False, False) is false + assert Xor(True, A) == ~A + assert Xor(False, A) == A + assert Xor(True, False, False) is true + assert Xor(True, False, A) == ~A + assert Xor(False, False, A) == A + assert isinstance(Xor(A, B), Xor) + assert Xor(A, B, Xor(C, D)) == Xor(A, B, C, D) + assert Xor(A, B, Xor(B, C)) == Xor(A, C) + assert Xor(A < 1, A >= 1, B) == Xor(0, 1, B) == Xor(1, 0, B) + e = A > 1 + assert Xor(e, e.canonical) == Xor(0, 0) == Xor(1, 1) + + +def test_rewrite_as_And(): + expr = x ^ y + assert expr.rewrite(And) == (x | y) & (~x | ~y) + + +def test_rewrite_as_Or(): + expr = x ^ y + assert expr.rewrite(Or) == (x & ~y) | (y & ~x) + + +def test_rewrite_as_Nand(): + expr = (y & z) | (z & ~w) + assert expr.rewrite(Nand) == ~(~(y & z) & ~(z & ~w)) + + +def test_rewrite_as_Nor(): + expr = z & (y | ~w) + assert expr.rewrite(Nor) == ~(~z | ~(y | ~w)) + + +def test_Not(): + raises(TypeError, lambda: Not(True, False)) + assert Not(True) is false + assert Not(False) is true + assert Not(0) is true + assert Not(1) is false + assert Not(2) is false + + +def test_Nand(): + assert Nand() is false + assert Nand(A) == ~A + assert Nand(True) is false + assert Nand(False) is true + assert Nand(True, True) is false + assert Nand(True, False) is true + assert Nand(False, False) is true + assert Nand(True, A) == ~A + assert Nand(False, A) is true + assert Nand(True, True, True) is false + assert Nand(True, True, A) == ~A + assert Nand(True, False, A) is true + + +def test_Nor(): + assert Nor() is true + assert Nor(A) == ~A + assert Nor(True) is false + assert Nor(False) is true + assert Nor(True, True) is false + assert Nor(True, False) is false + assert Nor(False, False) is true + assert Nor(True, A) is false + assert Nor(False, A) == ~A + assert Nor(True, True, True) is false + assert Nor(True, True, A) is false + assert Nor(True, False, A) is false + + +def test_Xnor(): + assert Xnor() is true + assert Xnor(A) == ~A + assert Xnor(A, A) is true + assert Xnor(True, A, A) is false + assert Xnor(A, A, A, A, A) == ~A + assert Xnor(True) is false + assert Xnor(False) is true + assert Xnor(True, True) is true + assert Xnor(True, False) is false + assert Xnor(False, False) is true + assert Xnor(True, A) == A + assert Xnor(False, A) == ~A + assert Xnor(True, False, False) is false + assert Xnor(True, False, A) == A + assert Xnor(False, False, A) == ~A + + +def test_Implies(): + raises(ValueError, lambda: Implies(A, B, C)) + assert Implies(True, True) is true + assert Implies(True, False) is false + assert Implies(False, True) is true + assert Implies(False, False) is true + assert Implies(0, A) is true + assert Implies(1, 1) is true + assert Implies(1, 0) is false + assert A >> B == B << A + assert (A < 1) >> (A >= 1) == (A >= 1) + assert (A < 1) >> (S.One > A) is true + assert A >> A is true + + +def test_Equivalent(): + assert Equivalent(A, B) == Equivalent(B, A) == Equivalent(A, B, A) + assert Equivalent() is true + assert Equivalent(A, A) == Equivalent(A) is true + assert Equivalent(True, True) == Equivalent(False, False) is true + assert Equivalent(True, False) == Equivalent(False, True) is false + assert Equivalent(A, True) == A + assert Equivalent(A, False) == Not(A) + assert Equivalent(A, B, True) == A & B + assert Equivalent(A, B, False) == ~A & ~B + assert Equivalent(1, A) == A + assert Equivalent(0, A) == Not(A) + assert Equivalent(A, Equivalent(B, C)) != Equivalent(Equivalent(A, B), C) + assert Equivalent(A < 1, A >= 1) is false + assert Equivalent(A < 1, A >= 1, 0) is false + assert Equivalent(A < 1, A >= 1, 1) is false + assert Equivalent(A < 1, S.One > A) == Equivalent(1, 1) == Equivalent(0, 0) + assert Equivalent(Equality(A, B), Equality(B, A)) is true + + +def test_Exclusive(): + assert Exclusive(False, False, False) is true + assert Exclusive(True, False, False) is true + assert Exclusive(True, True, False) is false + assert Exclusive(True, True, True) is false + + +def test_equals(): + assert Not(Or(A, B)).equals(And(Not(A), Not(B))) is True + assert Equivalent(A, B).equals((A >> B) & (B >> A)) is True + assert ((A | ~B) & (~A | B)).equals((~A & ~B) | (A & B)) is True + assert (A >> B).equals(~A >> ~B) is False + assert (A >> (B >> A)).equals(A >> (C >> A)) is False + raises(NotImplementedError, lambda: (A & B).equals(A > B)) + + +def test_simplification_boolalg(): + """ + Test working of simplification methods. + """ + set1 = [[0, 0, 1], [0, 1, 1], [1, 0, 0], [1, 1, 0]] + set2 = [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]] + assert SOPform([x, y, z], set1) == Or(And(Not(x), z), And(Not(z), x)) + assert Not(SOPform([x, y, z], set2)) == \ + Not(Or(And(Not(x), Not(z)), And(x, z))) + assert POSform([x, y, z], set1 + set2) is true + assert SOPform([x, y, z], set1 + set2) is true + assert SOPform([Dummy(), Dummy(), Dummy()], set1 + set2) is true + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, 3, 7, 11, 15] + dontcares = [0, 2, 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, [0, 0, 1, 1], 7, [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, {y: 1, z: 1}] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [{y: 1, z: 1}, 1] + dontcares = [[0, 0, 0, 0]] + + minterms = [[0, 0, 0]] + raises(ValueError, lambda: SOPform([w, x, y, z], minterms)) + raises(ValueError, lambda: POSform([w, x, y, z], minterms)) + + raises(TypeError, lambda: POSform([w, x, y, z], ["abcdefg"])) + + # test simplification + ans = And(A, Or(B, C)) + assert simplify_logic(A & (B | C)) == ans + assert simplify_logic((A & B) | (A & C)) == ans + assert simplify_logic(Implies(A, B)) == Or(Not(A), B) + assert simplify_logic(Equivalent(A, B)) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert simplify_logic(And(Equality(A, 2), C)) == And(Equality(A, 2), C) + assert simplify_logic(And(Equality(A, 2), A)) == And(Equality(A, 2), A) + assert simplify_logic(And(Equality(A, B), C)) == And(Equality(A, B), C) + assert simplify_logic(Or(And(Equality(A, 3), B), And(Equality(A, 3), C))) \ + == And(Equality(A, 3), Or(B, C)) + b = (~x & ~y & ~z) | (~x & ~y & z) + e = And(A, b) + assert simplify_logic(e) == A & ~x & ~y + raises(ValueError, lambda: simplify_logic(A & (B | C), form='blabla')) + assert simplify(Or(x <= y, And(x < y, z))) == (x <= y) + assert simplify(Or(x <= y, And(y > x, z))) == (x <= y) + assert simplify(Or(x >= y, And(y < x, z))) == (x >= y) + + # Check that expressions with nine variables or more are not simplified + # (without the force-flag) + a, b, c, d, e, f, g, h, j = symbols('a b c d e f g h j') + expr = a & b & c & d & e & f & g & h & j | \ + a & b & c & d & e & f & g & h & ~j + # This expression can be simplified to get rid of the j variables + assert simplify_logic(expr) == expr + + # Test dontcare + assert simplify_logic((a & b) | c | d, dontcare=(a & b)) == c | d + + # check input + ans = SOPform([x, y], [[1, 0]]) + assert SOPform([x, y], [[1, 0]]) == ans + assert POSform([x, y], [[1, 0]]) == ans + + raises(ValueError, lambda: SOPform([x], [[1]], [[1]])) + assert SOPform([x], [[1]], [[0]]) is true + assert SOPform([x], [[0]], [[1]]) is true + assert SOPform([x], [], []) is false + + raises(ValueError, lambda: POSform([x], [[1]], [[1]])) + assert POSform([x], [[1]], [[0]]) is true + assert POSform([x], [[0]], [[1]]) is true + assert POSform([x], [], []) is false + + # check working of simplify + assert simplify((A & B) | (A & C)) == And(A, Or(B, C)) + assert simplify(And(x, Not(x))) == False + assert simplify(Or(x, Not(x))) == True + assert simplify(And(Eq(x, 0), Eq(x, y))) == And(Eq(x, 0), Eq(y, 0)) + assert And(Eq(x - 1, 0), Eq(x, y)).simplify() == And(Eq(x, 1), Eq(y, 1)) + assert And(Ne(x - 1, 0), Ne(x, y)).simplify() == And(Ne(x, 1), Ne(x, y)) + assert And(Eq(x - 1, 0), Ne(x, y)).simplify() == And(Eq(x, 1), Ne(y, 1)) + assert And(Eq(x - 1, 0), Eq(x, z + y), Eq(y + x, 0)).simplify( + ) == And(Eq(x, 1), Eq(y, -1), Eq(z, 2)) + assert And(Eq(x - 1, 0), Eq(x + 2, 3)).simplify() == Eq(x, 1) + assert And(Ne(x - 1, 0), Ne(x + 2, 3)).simplify() == Ne(x, 1) + assert And(Eq(x - 1, 0), Eq(x + 2, 2)).simplify() == False + assert And(Ne(x - 1, 0), Ne(x + 2, 2)).simplify( + ) == And(Ne(x, 1), Ne(x, 0)) + assert simplify(Xor(x, ~x)) == True + + +def test_bool_map(): + """ + Test working of bool_map function. + """ + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + assert bool_map(Not(Not(a)), a) == (a, {a: a}) + assert bool_map(SOPform([w, x, y, z], minterms), + POSform([w, x, y, z], minterms)) == \ + (And(Or(Not(w), y), Or(Not(x), y), z), {x: x, w: w, z: z, y: y}) + assert bool_map(SOPform([x, z, y], [[1, 0, 1]]), + SOPform([a, b, c], [[1, 0, 1]])) != False + function1 = SOPform([x, z, y], [[1, 0, 1], [0, 0, 1]]) + function2 = SOPform([a, b, c], [[1, 0, 1], [1, 0, 0]]) + assert bool_map(function1, function2) == \ + (function1, {y: a, z: b}) + assert bool_map(Xor(x, y), ~Xor(x, y)) == False + assert bool_map(And(x, y), Or(x, y)) is None + assert bool_map(And(x, y), And(x, y, z)) is None + # issue 16179 + assert bool_map(Xor(x, y, z), ~Xor(x, y, z)) == False + assert bool_map(Xor(a, x, y, z), ~Xor(a, x, y, z)) == False + + +def test_bool_symbol(): + """Test that mixing symbols with boolean values + works as expected""" + + assert And(A, True) == A + assert And(A, True, True) == A + assert And(A, False) is false + assert And(A, True, False) is false + assert Or(A, True) is true + assert Or(A, False) == A + + +def test_is_boolean(): + assert isinstance(True, Boolean) is False + assert isinstance(true, Boolean) is True + assert 1 == True + assert 1 != true + assert (1 == true) is False + assert 0 == False + assert 0 != false + assert (0 == false) is False + assert true.is_Boolean is True + assert (A & B).is_Boolean + assert (A | B).is_Boolean + assert (~A).is_Boolean + assert (A ^ B).is_Boolean + assert A.is_Boolean != isinstance(A, Boolean) + assert isinstance(A, Boolean) + + +def test_subs(): + assert (A & B).subs(A, True) == B + assert (A & B).subs(A, False) is false + assert (A & B).subs(B, True) == A + assert (A & B).subs(B, False) is false + assert (A & B).subs({A: True, B: True}) is true + assert (A | B).subs(A, True) is true + assert (A | B).subs(A, False) == B + assert (A | B).subs(B, True) is true + assert (A | B).subs(B, False) == A + assert (A | B).subs({A: True, B: True}) is true + + +""" +we test for axioms of boolean algebra +see https://en.wikipedia.org/wiki/Boolean_algebra_(structure) +""" + + +def test_commutative(): + """Test for commutativity of And and Or""" + A, B = map(Boolean, symbols('A,B')) + + assert A & B == B & A + assert A | B == B | A + + +def test_and_associativity(): + """Test for associativity of And""" + + assert (A & B) & C == A & (B & C) + + +def test_or_assicativity(): + assert ((A | B) | C) == (A | (B | C)) + + +def test_double_negation(): + a = Boolean() + assert ~(~a) == a + + +# test methods + +def test_eliminate_implications(): + assert eliminate_implications(Implies(A, B, evaluate=False)) == (~A) | B + assert eliminate_implications( + A >> (C >> Not(B))) == Or(Or(Not(B), Not(C)), Not(A)) + assert eliminate_implications(Equivalent(A, B, C, D)) == \ + (~A | B) & (~B | C) & (~C | D) & (~D | A) + + +def test_conjuncts(): + assert conjuncts(A & B & C) == {A, B, C} + assert conjuncts((A | B) & C) == {A | B, C} + assert conjuncts(A) == {A} + assert conjuncts(True) == {True} + assert conjuncts(False) == {False} + + +def test_disjuncts(): + assert disjuncts(A | B | C) == {A, B, C} + assert disjuncts((A | B) & C) == {(A | B) & C} + assert disjuncts(A) == {A} + assert disjuncts(True) == {True} + assert disjuncts(False) == {False} + + +def test_distribute(): + assert distribute_and_over_or(Or(And(A, B), C)) == And(Or(A, C), Or(B, C)) + assert distribute_or_over_and(And(A, Or(B, C))) == Or(And(A, B), And(A, C)) + assert distribute_xor_over_and(And(A, Xor(B, C))) == Xor(And(A, B), And(A, C)) + + +def test_to_anf(): + x, y, z = symbols('x,y,z') + assert to_anf(And(x, y)) == And(x, y) + assert to_anf(Or(x, y)) == Xor(x, y, And(x, y)) + assert to_anf(Or(Implies(x, y), And(x, y), y)) == \ + Xor(x, True, x & y, remove_true=False) + assert to_anf(Or(Nand(x, y), Nor(x, y), Xnor(x, y), Implies(x, y))) == True + assert to_anf(Or(x, Not(y), Nor(x, z), And(x, y), Nand(y, z))) == \ + Xor(True, And(y, z), And(x, y, z), remove_true=False) + assert to_anf(Xor(x, y)) == Xor(x, y) + assert to_anf(Not(x)) == Xor(x, True, remove_true=False) + assert to_anf(Nand(x, y)) == Xor(True, And(x, y), remove_true=False) + assert to_anf(Nor(x, y)) == Xor(x, y, True, And(x, y), remove_true=False) + assert to_anf(Implies(x, y)) == Xor(x, True, And(x, y), remove_true=False) + assert to_anf(Equivalent(x, y)) == Xor(x, y, True, remove_true=False) + assert to_anf(Nand(x | y, x >> y), deep=False) == \ + Xor(True, And(Or(x, y), Implies(x, y)), remove_true=False) + assert to_anf(Nor(x ^ y, x & y), deep=False) == \ + Xor(True, Or(Xor(x, y), And(x, y)), remove_true=False) + # issue 25218 + assert to_anf(x ^ ~(x ^ y ^ ~y)) == False + + +def test_to_nnf(): + assert to_nnf(true) is true + assert to_nnf(false) is false + assert to_nnf(A) == A + assert to_nnf(A | ~A | B) is true + assert to_nnf(A & ~A & B) is false + assert to_nnf(A >> B) == ~A | B + assert to_nnf(Equivalent(A, B, C)) == (~A | B) & (~B | C) & (~C | A) + assert to_nnf(A ^ B ^ C) == \ + (A | B | C) & (~A | ~B | C) & (A | ~B | ~C) & (~A | B | ~C) + assert to_nnf(ITE(A, B, C)) == (~A | B) & (A | C) + assert to_nnf(Not(A | B | C)) == ~A & ~B & ~C + assert to_nnf(Not(A & B & C)) == ~A | ~B | ~C + assert to_nnf(Not(A >> B)) == A & ~B + assert to_nnf(Not(Equivalent(A, B, C))) == And(Or(A, B, C), Or(~A, ~B, ~C)) + assert to_nnf(Not(A ^ B ^ C)) == \ + (~A | B | C) & (A | ~B | C) & (A | B | ~C) & (~A | ~B | ~C) + assert to_nnf(Not(ITE(A, B, C))) == (~A | ~B) & (A | ~C) + assert to_nnf((A >> B) ^ (B >> A)) == (A & ~B) | (~A & B) + assert to_nnf((A >> B) ^ (B >> A), False) == \ + (~A | ~B | A | B) & ((A & ~B) | (~A & B)) + assert ITE(A, 1, 0).to_nnf() == A + assert ITE(A, 0, 1).to_nnf() == ~A + # although ITE can hold non-Boolean, it will complain if + # an attempt is made to convert the ITE to Boolean nnf + raises(TypeError, lambda: ITE(A < 1, [1], B).to_nnf()) + + +def test_to_cnf(): + assert to_cnf(~(B | C)) == And(Not(B), Not(C)) + assert to_cnf((A & B) | C) == And(Or(A, C), Or(B, C)) + assert to_cnf(A >> B) == (~A) | B + assert to_cnf(A >> (B & C)) == (~A | B) & (~A | C) + assert to_cnf(A & (B | C) | ~A & (B | C), True) == B | C + assert to_cnf(A & B) == And(A, B) + + assert to_cnf(Equivalent(A, B)) == And(Or(A, Not(B)), Or(B, Not(A))) + assert to_cnf(Equivalent(A, B & C)) == \ + (~A | B) & (~A | C) & (~B | ~C | A) + assert to_cnf(Equivalent(A, B | C), True) == \ + And(Or(Not(B), A), Or(Not(C), A), Or(B, C, Not(A))) + assert to_cnf(A + 1) == A + 1 + + +def test_issue_18904(): + x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 = symbols('x1:16') + eq = ((x1 & x2 & x3 & x4 & x5 & x6 & x7 & x8 & x9) | + (x1 & x2 & x3 & x4 & x5 & x6 & x7 & x10 & x9) | + (x1 & x11 & x3 & x12 & x5 & x13 & x14 & x15 & x9)) + assert is_cnf(to_cnf(eq)) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + for f, t in zip((And, Or), (to_cnf, to_dnf)): + eq = f(x1, x2, x3, x4, x5, x6, x7, x8, x9) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + assert t(eq, simplify=True, force=True) == eq + + +def test_issue_9949(): + assert is_cnf(to_cnf((b > -5) | (a > 2) & (a < 4))) + + +def test_to_CNF(): + assert CNF.CNF_to_cnf(CNF.to_CNF(~(B | C))) == to_cnf(~(B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF((A & B) | C)) == to_cnf((A & B) | C) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> B)) == to_cnf(A >> B) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> (B & C))) == to_cnf(A >> (B & C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & (B | C) | ~A & (B | C))) == to_cnf(A & (B | C) | ~A & (B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & B)) == to_cnf(A & B) + + +def test_to_dnf(): + assert to_dnf(~(B | C)) == And(Not(B), Not(C)) + assert to_dnf(A & (B | C)) == Or(And(A, B), And(A, C)) + assert to_dnf(A >> B) == (~A) | B + assert to_dnf(A >> (B & C)) == (~A) | (B & C) + assert to_dnf(A | B) == A | B + + assert to_dnf(Equivalent(A, B), True) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert to_dnf(Equivalent(A, B & C), True) == \ + Or(And(A, B, C), And(Not(A), Not(B)), And(Not(A), Not(C))) + assert to_dnf(A + 1) == A + 1 + + +def test_to_int_repr(): + x, y, z = map(Boolean, symbols('x,y,z')) + + def sorted_recursive(arg): + try: + return sorted(sorted_recursive(x) for x in arg) + except TypeError: # arg is not a sequence + return arg + + assert sorted_recursive(to_int_repr([x | y, z | x], [x, y, z])) == \ + sorted_recursive([[1, 2], [1, 3]]) + assert sorted_recursive(to_int_repr([x | y, z | ~x], [x, y, z])) == \ + sorted_recursive([[1, 2], [3, -1]]) + + +def test_is_anf(): + x, y = symbols('x,y') + assert is_anf(true) is True + assert is_anf(false) is True + assert is_anf(x) is True + assert is_anf(And(x, y)) is True + assert is_anf(Xor(x, y, And(x, y))) is True + assert is_anf(Xor(x, y, Or(x, y))) is False + assert is_anf(Xor(Not(x), y)) is False + + +def test_is_nnf(): + assert is_nnf(true) is True + assert is_nnf(A) is True + assert is_nnf(~A) is True + assert is_nnf(A & B) is True + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), False) is True + assert is_nnf((A | B) & (~A | ~B)) is True + assert is_nnf(Not(Or(A, B))) is False + assert is_nnf(A ^ B) is False + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), True) is False + + +def test_is_cnf(): + assert is_cnf(x) is True + assert is_cnf(x | y | z) is True + assert is_cnf(x & y & z) is True + assert is_cnf((x | y) & z) is True + assert is_cnf((x & y) | z) is False + assert is_cnf(~(x & y) | z) is False + + +def test_is_dnf(): + assert is_dnf(x) is True + assert is_dnf(x | y | z) is True + assert is_dnf(x & y & z) is True + assert is_dnf((x & y) | z) is True + assert is_dnf((x | y) & z) is False + assert is_dnf(~(x | y) & z) is False + + +def test_ITE(): + A, B, C = symbols('A:C') + assert ITE(True, False, True) is false + assert ITE(True, True, False) is true + assert ITE(False, True, False) is false + assert ITE(False, False, True) is true + assert isinstance(ITE(A, B, C), ITE) + + A = True + assert ITE(A, B, C) == B + A = False + assert ITE(A, B, C) == C + B = True + assert ITE(And(A, B), B, C) == C + assert ITE(Or(A, False), And(B, True), False) is false + assert ITE(x, A, B) == Not(x) + assert ITE(x, B, A) == x + assert ITE(1, x, y) == x + assert ITE(0, x, y) == y + raises(TypeError, lambda: ITE(2, x, y)) + raises(TypeError, lambda: ITE(1, [], y)) + raises(TypeError, lambda: ITE(1, (), y)) + raises(TypeError, lambda: ITE(1, y, [])) + assert ITE(1, 1, 1) is S.true + assert isinstance(ITE(1, 1, 1, evaluate=False), ITE) + + assert ITE(Eq(x, True), y, x) == ITE(x, y, x) + assert ITE(Eq(x, False), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, True), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, False), y, x) == ITE(x, y, x) + assert ITE(Eq(S.true, x), y, x) == ITE(x, y, x) + assert ITE(Eq(S.false, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.true, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.false, x), y, x) == ITE(x, y, x) + # 0 and 1 in the context are not treated as True/False + # so the equality must always be False since dissimilar + # objects cannot be equal + assert ITE(Eq(x, 0), y, x) == x + assert ITE(Eq(x, 1), y, x) == x + assert ITE(Ne(x, 0), y, x) == y + assert ITE(Ne(x, 1), y, x) == y + assert ITE(Eq(x, 0), y, z).subs(x, 0) == y + assert ITE(Eq(x, 0), y, z).subs(x, 1) == z + raises(ValueError, lambda: ITE(x > 1, y, x, z)) + + +def test_is_literal(): + assert is_literal(True) is True + assert is_literal(False) is True + assert is_literal(A) is True + assert is_literal(~A) is True + assert is_literal(Or(A, B)) is False + assert is_literal(Q.zero(A)) is True + assert is_literal(Not(Q.zero(A))) is True + assert is_literal(Or(A, B)) is False + assert is_literal(And(Q.zero(A), Q.zero(B))) is False + assert is_literal(x < 3) + assert not is_literal(x + y < 3) + + +def test_operators(): + # Mostly test __and__, __rand__, and so on + assert True & A == A & True == A + assert False & A == A & False == False + assert A & B == And(A, B) + assert True | A == A | True == True + assert False | A == A | False == A + assert A | B == Or(A, B) + assert ~A == Not(A) + assert True >> A == A << True == A + assert False >> A == A << False == True + assert A >> True == True << A == True + assert A >> False == False << A == ~A + assert A >> B == B << A == Implies(A, B) + assert True ^ A == A ^ True == ~A + assert False ^ A == A ^ False == A + assert A ^ B == Xor(A, B) + + +def test_true_false(): + assert true is S.true + assert false is S.false + assert true is not True + assert false is not False + assert true + assert not false + assert true == True + assert false == False + assert not (true == False) + assert not (false == True) + assert not (true == false) + + assert hash(true) == hash(True) + assert hash(false) == hash(False) + assert len({true, True}) == len({false, False}) == 1 + + assert isinstance(true, BooleanAtom) + assert isinstance(false, BooleanAtom) + # We don't want to subclass from bool, because bool subclasses from + # int. But operators like &, |, ^, <<, >>, and ~ act differently on 0 and + # 1 then we want them to on true and false. See the docstrings of the + # various And, Or, etc. functions for examples. + assert not isinstance(true, bool) + assert not isinstance(false, bool) + + # Note: using 'is' comparison is important here. We want these to return + # true and false, not True and False + + assert Not(true) is false + assert Not(True) is false + assert Not(false) is true + assert Not(False) is true + assert ~true is false + assert ~false is true + + for T, F in product((True, true), (False, false)): + assert And(T, F) is false + assert And(F, T) is false + assert And(F, F) is false + assert And(T, T) is true + assert And(T, x) == x + assert And(F, x) is false + if not (T is True and F is False): + assert T & F is false + assert F & T is false + if F is not False: + assert F & F is false + if T is not True: + assert T & T is true + + assert Or(T, F) is true + assert Or(F, T) is true + assert Or(F, F) is false + assert Or(T, T) is true + assert Or(T, x) is true + assert Or(F, x) == x + if not (T is True and F is False): + assert T | F is true + assert F | T is true + if F is not False: + assert F | F is false + if T is not True: + assert T | T is true + + assert Xor(T, F) is true + assert Xor(F, T) is true + assert Xor(F, F) is false + assert Xor(T, T) is false + assert Xor(T, x) == ~x + assert Xor(F, x) == x + if not (T is True and F is False): + assert T ^ F is true + assert F ^ T is true + if F is not False: + assert F ^ F is false + if T is not True: + assert T ^ T is false + + assert Nand(T, F) is true + assert Nand(F, T) is true + assert Nand(F, F) is true + assert Nand(T, T) is false + assert Nand(T, x) == ~x + assert Nand(F, x) is true + + assert Nor(T, F) is false + assert Nor(F, T) is false + assert Nor(F, F) is true + assert Nor(T, T) is false + assert Nor(T, x) is false + assert Nor(F, x) == ~x + + assert Implies(T, F) is false + assert Implies(F, T) is true + assert Implies(F, F) is true + assert Implies(T, T) is true + assert Implies(T, x) == x + assert Implies(F, x) is true + assert Implies(x, T) is true + assert Implies(x, F) == ~x + if not (T is True and F is False): + assert T >> F is false + assert F << T is false + assert F >> T is true + assert T << F is true + if F is not False: + assert F >> F is true + assert F << F is true + if T is not True: + assert T >> T is true + assert T << T is true + + assert Equivalent(T, F) is false + assert Equivalent(F, T) is false + assert Equivalent(F, F) is true + assert Equivalent(T, T) is true + assert Equivalent(T, x) == x + assert Equivalent(F, x) == ~x + assert Equivalent(x, T) == x + assert Equivalent(x, F) == ~x + + assert ITE(T, T, T) is true + assert ITE(T, T, F) is true + assert ITE(T, F, T) is false + assert ITE(T, F, F) is false + assert ITE(F, T, T) is true + assert ITE(F, T, F) is false + assert ITE(F, F, T) is true + assert ITE(F, F, F) is false + + assert all(i.simplify(1, 2) is i for i in (S.true, S.false)) + + +def test_bool_as_set(): + assert ITE(y <= 0, False, y >= 1).as_set() == Interval(1, oo) + assert And(x <= 2, x >= -2).as_set() == Interval(-2, 2) + assert Or(x >= 2, x <= -2).as_set() == Interval(-oo, -2) + Interval(2, oo) + assert Not(x > 2).as_set() == Interval(-oo, 2) + # issue 10240 + assert Not(And(x > 2, x < 3)).as_set() == \ + Union(Interval(-oo, 2), Interval(3, oo)) + assert true.as_set() == S.UniversalSet + assert false.as_set() is S.EmptySet + assert x.as_set() == S.UniversalSet + assert And(Or(x < 1, x > 3), x < 2).as_set() == Interval.open(-oo, 1) + assert And(x < 1, sin(x) < 3).as_set() == (x < 1).as_set() + raises(NotImplementedError, lambda: (sin(x) < 1).as_set()) + # watch for object morph in as_set + assert Eq(-1, cos(2 * x) ** 2 / sin(2 * x) ** 2).as_set() is S.EmptySet + + +@XFAIL +def test_multivariate_bool_as_set(): + x, y = symbols('x,y') + + assert And(x >= 0, y >= 0).as_set() == Interval(0, oo) * Interval(0, oo) + assert Or(x >= 0, y >= 0).as_set() == S.Reals * S.Reals - \ + Interval(-oo, 0, True, True) * Interval(-oo, 0, True, True) + + +def test_all_or_nothing(): + x = symbols('x', extended_real=True) + args = x >= -oo, x <= oo + v = And(*args) + if v.func is And: + assert len(v.args) == len(args) - args.count(S.true) + else: + assert v == True + v = Or(*args) + if v.func is Or: + assert len(v.args) == 2 + else: + assert v == True + + +def test_canonical_atoms(): + assert true.canonical == true + assert false.canonical == false + + +def test_negated_atoms(): + assert true.negated == false + assert false.negated == true + + +def test_issue_8777(): + assert And(x > 2, x < oo).as_set() == Interval(2, oo, left_open=True) + assert And(x >= 1, x < oo).as_set() == Interval(1, oo) + assert (x < oo).as_set() == Interval(-oo, oo) + assert (x > -oo).as_set() == Interval(-oo, oo) + + +def test_issue_8975(): + assert Or(And(-oo < x, x <= -2), And(2 <= x, x < oo)).as_set() == \ + Interval(-oo, -2) + Interval(2, oo) + + +def test_term_to_integer(): + assert term_to_integer([1, 0, 1, 0, 0, 1, 0]) == 82 + assert term_to_integer('0010101000111001') == 10809 + + +def test_issue_21971(): + a, b, c, d = symbols('a b c d') + f = a & b & c | a & c + assert f.subs(a & c, d) == b & d | d + assert f.subs(a & b & c, d) == a & c | d + + f = (a | b | c) & (a | c) + assert f.subs(a | c, d) == (b | d) & d + assert f.subs(a | b | c, d) == (a | c) & d + + f = (a ^ b ^ c) & (a ^ c) + assert f.subs(a ^ c, d) == (b ^ d) & d + assert f.subs(a ^ b ^ c, d) == (a ^ c) & d + + +def test_truth_table(): + assert list(truth_table(And(x, y), [x, y], input=False)) == \ + [False, False, False, True] + assert list(truth_table(x | y, [x, y], input=False)) == \ + [False, True, True, True] + assert list(truth_table(x >> y, [x, y], input=False)) == \ + [True, True, False, True] + assert list(truth_table(And(x, y), [x, y])) == \ + [([0, 0], False), ([0, 1], False), ([1, 0], False), ([1, 1], True)] + + +def test_issue_8571(): + for t in (S.true, S.false): + raises(TypeError, lambda: +t) + raises(TypeError, lambda: -t) + raises(TypeError, lambda: abs(t)) + # use int(bool(t)) to get 0 or 1 + raises(TypeError, lambda: int(t)) + + for o in [S.Zero, S.One, x]: + for _ in range(2): + raises(TypeError, lambda: o + t) + raises(TypeError, lambda: o - t) + raises(TypeError, lambda: o % t) + raises(TypeError, lambda: o * t) + raises(TypeError, lambda: o / t) + raises(TypeError, lambda: o ** t) + o, t = t, o # do again in reversed order + + +def test_expand_relational(): + n = symbols('n', negative=True) + p, q = symbols('p q', positive=True) + r = ((n + q * (-n / q + 1)) / (q * (-n / q + 1)) < 0) + assert r is not S.false + assert r.expand() is S.false + assert (q > 0).expand() is S.true + + +def test_issue_12717(): + assert S.true.is_Atom == True + assert S.false.is_Atom == True + + +def test_as_Boolean(): + nz = symbols('nz', nonzero=True) + assert all(as_Boolean(i) is S.true for i in (True, S.true, 1, nz)) + z = symbols('z', zero=True) + assert all(as_Boolean(i) is S.false for i in (False, S.false, 0, z)) + assert all(as_Boolean(i) == i for i in (x, x < 0)) + for i in (2, S(2), x + 1, []): + raises(TypeError, lambda: as_Boolean(i)) + + +def test_binary_symbols(): + assert ITE(x < 1, y, z).binary_symbols == {y, z} + for f in (Eq, Ne): + assert f(x, 1).binary_symbols == set() + assert f(x, True).binary_symbols == {x} + assert f(x, False).binary_symbols == {x} + assert S.true.binary_symbols == set() + assert S.false.binary_symbols == set() + assert x.binary_symbols == {x} + assert And(x, Eq(y, False), Eq(z, 1)).binary_symbols == {x, y} + assert Q.prime(x).binary_symbols == set() + assert Q.lt(x, 1).binary_symbols == set() + assert Q.is_true(x).binary_symbols == {x} + assert Q.eq(x, True).binary_symbols == {x} + assert Q.prime(x).binary_symbols == set() + + +def test_BooleanFunction_diff(): + assert And(x, y).diff(x) == Piecewise((0, Eq(y, False)), (1, True)) + + +def test_issue_14700(): + A, B, C, D, E, F, G, H = symbols('A B C D E F G H') + q = ((B & D & H & ~F) | (B & H & ~C & ~D) | (B & H & ~C & ~F) | + (B & H & ~D & ~G) | (B & H & ~F & ~G) | (C & G & ~B & ~D) | + (C & G & ~D & ~H) | (C & G & ~F & ~H) | (D & F & H & ~B) | + (D & F & ~G & ~H) | (B & D & F & ~C & ~H) | (D & E & F & ~B & ~C) | + (D & F & ~A & ~B & ~C) | (D & F & ~A & ~C & ~H) | + (A & B & D & F & ~E & ~H)) + soldnf = ((B & D & H & ~F) | (D & F & H & ~B) | (B & H & ~C & ~D) | + (B & H & ~D & ~G) | (C & G & ~B & ~D) | (C & G & ~D & ~H) | + (C & G & ~F & ~H) | (D & F & ~G & ~H) | (D & E & F & ~C & ~H) | + (D & F & ~A & ~C & ~H) | (A & B & D & F & ~E & ~H)) + solcnf = ((B | C | D) & (B | D | G) & (C | D | H) & (C | F | H) & + (D | G | H) & (F | G | H) & (B | F | ~D | ~H) & + (~B | ~D | ~F | ~H) & (D | ~B | ~C | ~G | ~H) & + (A | H | ~C | ~D | ~F | ~G) & (H | ~C | ~D | ~E | ~F | ~G) & + (B | E | H | ~A | ~D | ~F | ~G)) + assert simplify_logic(q, "dnf") == soldnf + assert simplify_logic(q, "cnf") == solcnf + + minterms = [[0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 1, 0], [0, 1, 1, 1], + [0, 0, 1, 1], [1, 0, 1, 1]] + dontcares = [[1, 0, 0, 0], [1, 0, 0, 1], [1, 1, 0, 0], [1, 1, 0, 1]] + assert SOPform([w, x, y, z], minterms) == (x & ~w) | (y & z & ~x) + # Should not be more complicated with don't cares + assert SOPform([w, x, y, z], minterms, dontcares) == \ + (x & ~w) | (y & z & ~x) + + +def test_issue_25115(): + cond = Contains(x, S.Integers) + # Previously this raised an exception: + assert simplify_logic(cond) == cond + + +def test_relational_simplification(): + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + # Test all combinations or sign and order + assert Or(x >= y, x < y).simplify() == S.true + assert Or(x >= y, y > x).simplify() == S.true + assert Or(x >= y, -x > -y).simplify() == S.true + assert Or(x >= y, -y < -x).simplify() == S.true + assert Or(-x <= -y, x < y).simplify() == S.true + assert Or(-x <= -y, -x > -y).simplify() == S.true + assert Or(-x <= -y, y > x).simplify() == S.true + assert Or(-x <= -y, -y < -x).simplify() == S.true + assert Or(y <= x, x < y).simplify() == S.true + assert Or(y <= x, y > x).simplify() == S.true + assert Or(y <= x, -x > -y).simplify() == S.true + assert Or(y <= x, -y < -x).simplify() == S.true + assert Or(-y >= -x, x < y).simplify() == S.true + assert Or(-y >= -x, y > x).simplify() == S.true + assert Or(-y >= -x, -x > -y).simplify() == S.true + assert Or(-y >= -x, -y < -x).simplify() == S.true + + assert Or(x < y, x >= y).simplify() == S.true + assert Or(y > x, x >= y).simplify() == S.true + assert Or(-x > -y, x >= y).simplify() == S.true + assert Or(-y < -x, x >= y).simplify() == S.true + assert Or(x < y, -x <= -y).simplify() == S.true + assert Or(-x > -y, -x <= -y).simplify() == S.true + assert Or(y > x, -x <= -y).simplify() == S.true + assert Or(-y < -x, -x <= -y).simplify() == S.true + assert Or(x < y, y <= x).simplify() == S.true + assert Or(y > x, y <= x).simplify() == S.true + assert Or(-x > -y, y <= x).simplify() == S.true + assert Or(-y < -x, y <= x).simplify() == S.true + assert Or(x < y, -y >= -x).simplify() == S.true + assert Or(y > x, -y >= -x).simplify() == S.true + assert Or(-x > -y, -y >= -x).simplify() == S.true + assert Or(-y < -x, -y >= -x).simplify() == S.true + + # Some other tests + assert Or(x >= y, w < z, x <= y).simplify() == S.true + assert And(x >= y, x < y).simplify() == S.false + assert Or(x >= y, Eq(y, x)).simplify() == (x >= y) + assert And(x >= y, Eq(y, x)).simplify() == Eq(x, y) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert Or(Eq(x, y), x >= y, w < y, z < y).simplify() == \ + (x >= y) | (y > z) | (w < y) + assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify() == \ + Eq(x, y) & (y > z) & (w < y) + # assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify(relational_minmax=True) == \ + # And(Eq(x, y), y > Max(w, z)) + # assert Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify(relational_minmax=True) == \ + # (Eq(x, y) | (x >= 1) | (y > Min(2, z))) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)).simplify() == \ + (Eq(x, y) & Eq(d, e) & (d >= e)) + assert And(Eq(x, y), Eq(x, -y)).simplify() == And(Eq(x, 0), Eq(y, 0)) + assert Xor(x >= y, x <= y).simplify() == Ne(x, y) + assert And(x > 1, x < -1, Eq(x, y)).simplify() == S.false + # From #16690 + assert And(x >= y, Eq(y, 0)).simplify() == And(x >= 0, Eq(y, 0)) + assert Or(Ne(x, 1), Ne(x, 2)).simplify() == S.true + assert And(Eq(x, 1), Ne(2, x)).simplify() == Eq(x, 1) + assert Or(Eq(x, 1), Ne(2, x)).simplify() == Ne(x, 2) + + +def test_issue_8373(): + x = symbols('x', real=True) + assert Or(x < 1, x > -1).simplify() == S.true + assert Or(x < 1, x >= 1).simplify() == S.true + assert And(x < 1, x >= 1).simplify() == S.false + assert Or(x <= 1, x >= 1).simplify() == S.true + + +def test_issue_7950(): + x = symbols('x', real=True) + assert And(Eq(x, 1), Eq(x, 2)).simplify() == S.false + + +@slow +def test_relational_simplification_numerically(): + def test_simplification_numerically_function(original, simplified): + symb = original.free_symbols + n = len(symb) + valuelist = list(set(combinations(list(range(-(n - 1), n)) * n, n))) + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.subs(sublist) + simplifiedvalue = simplified.subs(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for {}" \ + "".format(original, simplified, sublist) + + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + + expressions = (And(Eq(x, y), x >= y, w < y, y >= z, z < y), + And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + And(x >= y, Eq(y, x)), + Or(And(Eq(x, y), x >= y, w < y, Or(y >= z, z < y)), + And(Eq(x, y), x >= 1, 2 < y, y >= -1, z < y)), + (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)), + ) + + for expression in expressions: + test_simplification_numerically_function(expression, + expression.simplify()) + + +def test_relational_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and, \ + _simplify_patterns_or, _simplify_patterns_xor + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and()], + [Or, _simplify_patterns_or()], + [Xor, _simplify_patterns_xor()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_16803(): + n = symbols('n') + # No simplification done, but should not raise an exception + assert ((n > 3) | (n < 0) | ((n > 0) & (n < 3))).simplify() == \ + (n > 3) | (n < 0) | ((n > 0) & (n < 3)) + + +def test_issue_17530(): + r = {x: oo, y: oo} + assert Or(x + y > 0, x - y < 0).subs(r) + assert not And(x + y < 0, x - y < 0).subs(r) + raises(TypeError, lambda: Or(x + y < 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + + +def test_anf_coeffs(): + assert anf_coeffs([1, 0]) == [1, 1] + assert anf_coeffs([0, 0, 0, 1]) == [0, 0, 0, 1] + assert anf_coeffs([0, 1, 1, 1]) == [0, 1, 1, 1] + assert anf_coeffs([1, 1, 1, 0]) == [1, 0, 0, 1] + assert anf_coeffs([1, 0, 0, 0]) == [1, 1, 1, 1] + assert anf_coeffs([1, 0, 0, 1]) == [1, 1, 1, 0] + assert anf_coeffs([1, 1, 0, 1]) == [1, 0, 1, 1] + + +def test_ANFform(): + x, y = symbols('x,y') + assert ANFform([x], [1, 1]) == True + assert ANFform([x], [0, 0]) == False + assert ANFform([x], [1, 0]) == Xor(x, True, remove_true=False) + assert ANFform([x, y], [1, 1, 1, 0]) == \ + Xor(True, And(x, y), remove_true=False) + + +def test_bool_minterm(): + x, y = symbols('x,y') + assert bool_minterm(3, [x, y]) == And(x, y) + assert bool_minterm([1, 0], [x, y]) == And(Not(y), x) + + +def test_bool_maxterm(): + x, y = symbols('x,y') + assert bool_maxterm(2, [x, y]) == Or(Not(x), y) + assert bool_maxterm([0, 1], [x, y]) == Or(Not(y), x) + + +def test_bool_monomial(): + x, y = symbols('x,y') + assert bool_monomial(1, [x, y]) == y + assert bool_monomial([1, 1], [x, y]) == And(x, y) + + +def test_check_pair(): + assert _check_pair([0, 1, 0], [0, 1, 1]) == 2 + assert _check_pair([0, 1, 0], [1, 1, 1]) == -1 + + +def test_issue_19114(): + expr = (B & C) | (A & ~C) | (~A & ~B) + # Expression is minimal, but there are multiple minimal forms possible + res1 = (A & B) | (C & ~A) | (~B & ~C) + result = to_dnf(expr, simplify=True) + assert result in (expr, res1) + + +def test_issue_20870(): + result = SOPform([a, b, c, d], [1, 2, 3, 4, 5, 6, 8, 9, 11, 12, 14, 15]) + expected = ((d & ~b) | (a & b & c) | (a & ~c & ~d) | + (b & ~a & ~c) | (c & ~a & ~d)) + assert result == expected + + +def test_convert_to_varsSOP(): + assert _convert_to_varsSOP([0, 1, 0], [x, y, z]) == And(Not(x), y, Not(z)) + assert _convert_to_varsSOP([3, 1, 0], [x, y, z]) == And(y, Not(z)) + + +def test_convert_to_varsPOS(): + assert _convert_to_varsPOS([0, 1, 0], [x, y, z]) == Or(x, Not(y), z) + assert _convert_to_varsPOS([3, 1, 0], [x, y, z]) == Or(Not(y), z) + + +def test_gateinputcount(): + a, b, c, d, e = symbols('a:e') + assert gateinputcount(And(a, b)) == 2 + assert gateinputcount(a | b & c & d ^ (e | a)) == 9 + assert gateinputcount(And(a, True)) == 0 + raises(TypeError, lambda: gateinputcount(a * b)) + + +def test_refine(): + # relational + assert not refine(x < 0, ~(x < 0)) + assert refine(x < 0, (x < 0)) + assert refine(x < 0, (0 > x)) is S.true + assert refine(x < 0, (y < 0)) == (x < 0) + assert not refine(x <= 0, ~(x <= 0)) + assert refine(x <= 0, (x <= 0)) + assert refine(x <= 0, (0 >= x)) is S.true + assert refine(x <= 0, (y <= 0)) == (x <= 0) + assert not refine(x > 0, ~(x > 0)) + assert refine(x > 0, (x > 0)) + assert refine(x > 0, (0 < x)) is S.true + assert refine(x > 0, (y > 0)) == (x > 0) + assert not refine(x >= 0, ~(x >= 0)) + assert refine(x >= 0, (x >= 0)) + assert refine(x >= 0, (0 <= x)) is S.true + assert refine(x >= 0, (y >= 0)) == (x >= 0) + assert not refine(Eq(x, 0), ~(Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(0, x))) is S.true + assert refine(Eq(x, 0), (Eq(y, 0))) == Eq(x, 0) + assert not refine(Ne(x, 0), ~(Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(0, x))) is S.true + assert refine(Ne(x, 0), (Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(y, 0))) == (Ne(x, 0)) + + # boolean functions + assert refine(And(x > 0, y > 0), (x > 0)) == (y > 0) + assert refine(And(x > 0, y > 0), (x > 0) & (y > 0)) is S.true + + # predicates + assert refine(Q.positive(x), Q.positive(x)) is S.true + assert refine(Q.positive(x), Q.negative(x)) is S.false + assert refine(Q.positive(x), Q.real(x)) == Q.positive(x) + + +def test_relational_threeterm_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and3 + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and3()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_25451(): + x = Or(And(a, c), Eq(a, b)) + assert isinstance(x, Or) + assert set(x.args) == {And(a, c), Eq(a, b)} + + +def test_issue_26985(): + a, b, c, d = symbols('a b c d') + + # Expression before applying to_anf + x = Xor(c, And(a, b), And(a, c)) + y = Xor(a, b, And(a, c)) + + # Applying to_anf + result = Xor(Xor(d, And(x, y)), And(x, y)) + result_anf = to_anf(Xor(to_anf(Xor(d, And(x, y))), And(x, y))) + + assert result_anf == d + assert result == d diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_dimacs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9a51a39d33fb807688614cb5809b621ce21a2c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_dimacs.py @@ -0,0 +1,234 @@ +"""Various tests on satisfiability using dimacs cnf file syntax +You can find lots of cnf files in +ftp://dimacs.rutgers.edu/pub/challenge/satisfiability/benchmarks/cnf/ +""" + +from sympy.logic.utilities.dimacs import load +from sympy.logic.algorithms.dpll import dpll_satisfiable + + +def test_f1(): + assert bool(dpll_satisfiable(load(f1))) + + +def test_f2(): + assert bool(dpll_satisfiable(load(f2))) + + +def test_f3(): + assert bool(dpll_satisfiable(load(f3))) + + +def test_f4(): + assert not bool(dpll_satisfiable(load(f4))) + + +def test_f5(): + assert bool(dpll_satisfiable(load(f5))) + +f1 = """c simple example +c Resolution: SATISFIABLE +c +p cnf 3 2 +1 -3 0 +2 3 -1 0 +""" + + +f2 = """c an example from Quinn's text, 16 variables and 18 clauses. +c Resolution: SATISFIABLE +c +p cnf 16 18 + 1 2 0 + -2 -4 0 + 3 4 0 + -4 -5 0 + 5 -6 0 + 6 -7 0 + 6 7 0 + 7 -16 0 + 8 -9 0 + -8 -14 0 + 9 10 0 + 9 -10 0 +-10 -11 0 + 10 12 0 + 11 12 0 + 13 14 0 + 14 -15 0 + 15 16 0 +""" + +f3 = """c +p cnf 6 9 +-1 0 +-3 0 +2 -1 0 +2 -4 0 +5 -4 0 +-1 -3 0 +-4 -6 0 +1 3 -2 0 +4 6 -2 -5 0 +""" + +f4 = """c +c file: hole6.cnf [http://people.sc.fsu.edu/~jburkardt/data/cnf/hole6.cnf] +c +c SOURCE: John Hooker (jh38+@andrew.cmu.edu) +c +c DESCRIPTION: Pigeon hole problem of placing n (for file 'holen.cnf') pigeons +c in n+1 holes without placing 2 pigeons in the same hole +c +c NOTE: Part of the collection at the Forschungsinstitut fuer +c anwendungsorientierte Wissensverarbeitung in Ulm Germany. +c +c NOTE: Not satisfiable +c +p cnf 42 133 +-1 -7 0 +-1 -13 0 +-1 -19 0 +-1 -25 0 +-1 -31 0 +-1 -37 0 +-7 -13 0 +-7 -19 0 +-7 -25 0 +-7 -31 0 +-7 -37 0 +-13 -19 0 +-13 -25 0 +-13 -31 0 +-13 -37 0 +-19 -25 0 +-19 -31 0 +-19 -37 0 +-25 -31 0 +-25 -37 0 +-31 -37 0 +-2 -8 0 +-2 -14 0 +-2 -20 0 +-2 -26 0 +-2 -32 0 +-2 -38 0 +-8 -14 0 +-8 -20 0 +-8 -26 0 +-8 -32 0 +-8 -38 0 +-14 -20 0 +-14 -26 0 +-14 -32 0 +-14 -38 0 +-20 -26 0 +-20 -32 0 +-20 -38 0 +-26 -32 0 +-26 -38 0 +-32 -38 0 +-3 -9 0 +-3 -15 0 +-3 -21 0 +-3 -27 0 +-3 -33 0 +-3 -39 0 +-9 -15 0 +-9 -21 0 +-9 -27 0 +-9 -33 0 +-9 -39 0 +-15 -21 0 +-15 -27 0 +-15 -33 0 +-15 -39 0 +-21 -27 0 +-21 -33 0 +-21 -39 0 +-27 -33 0 +-27 -39 0 +-33 -39 0 +-4 -10 0 +-4 -16 0 +-4 -22 0 +-4 -28 0 +-4 -34 0 +-4 -40 0 +-10 -16 0 +-10 -22 0 +-10 -28 0 +-10 -34 0 +-10 -40 0 +-16 -22 0 +-16 -28 0 +-16 -34 0 +-16 -40 0 +-22 -28 0 +-22 -34 0 +-22 -40 0 +-28 -34 0 +-28 -40 0 +-34 -40 0 +-5 -11 0 +-5 -17 0 +-5 -23 0 +-5 -29 0 +-5 -35 0 +-5 -41 0 +-11 -17 0 +-11 -23 0 +-11 -29 0 +-11 -35 0 +-11 -41 0 +-17 -23 0 +-17 -29 0 +-17 -35 0 +-17 -41 0 +-23 -29 0 +-23 -35 0 +-23 -41 0 +-29 -35 0 +-29 -41 0 +-35 -41 0 +-6 -12 0 +-6 -18 0 +-6 -24 0 +-6 -30 0 +-6 -36 0 +-6 -42 0 +-12 -18 0 +-12 -24 0 +-12 -30 0 +-12 -36 0 +-12 -42 0 +-18 -24 0 +-18 -30 0 +-18 -36 0 +-18 -42 0 +-24 -30 0 +-24 -36 0 +-24 -42 0 +-30 -36 0 +-30 -42 0 +-36 -42 0 + 6 5 4 3 2 1 0 + 12 11 10 9 8 7 0 + 18 17 16 15 14 13 0 + 24 23 22 21 20 19 0 + 30 29 28 27 26 25 0 + 36 35 34 33 32 31 0 + 42 41 40 39 38 37 0 +""" + +f5 = """c simple example requiring variable selection +c +c NOTE: Satisfiable +c +p cnf 5 5 +1 2 3 0 +1 -2 3 0 +4 5 -3 0 +1 -4 -3 0 +-1 -5 0 +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_inference.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ff37b1b104f6f106ec5df7809fd34959bce35917 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_inference.py @@ -0,0 +1,396 @@ +"""For more tests on satisfiability, see test_dimacs""" + +from sympy.assumptions.ask import Q +from sympy.core.symbol import symbols +from sympy.core.relational import Unequality +from sympy.logic.boolalg import And, Or, Implies, Equivalent, true, false +from sympy.logic.inference import literal_symbol, \ + pl_true, satisfiable, valid, entails, PropKB +from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \ + find_pure_symbol, find_unit_clause, unit_propagate, \ + find_pure_symbol_int_repr, find_unit_clause_int_repr, \ + unit_propagate_int_repr +from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable + +from sympy.logic.algorithms.z3_wrapper import z3_satisfiable +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.logic.tests.test_lra_theory import make_random_problem +from sympy.core.random import randint + +from sympy.testing.pytest import raises, skip +from sympy.external import import_module + + +def test_literal(): + A, B = symbols('A,B') + assert literal_symbol(True) is True + assert literal_symbol(False) is False + assert literal_symbol(A) is A + assert literal_symbol(~A) is A + + +def test_find_pure_symbol(): + A, B, C = symbols('A,B,C') + assert find_pure_symbol([A], [A]) == (A, True) + assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None) + assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True) + assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True) + assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False) + assert find_pure_symbol( + [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None) + + +def test_find_pure_symbol_int_repr(): + assert find_pure_symbol_int_repr([1], [{1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2], + [{-1, 2}, {-2, 1}]) == (None, None) + assert find_pure_symbol_int_repr([1, 2, 3], + [{1, -2}, {-2, -3}, {3, 1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {2, -3}, {3, 1}]) == (2, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None) + + +def test_unit_clause(): + A, B, C = symbols('A,B,C') + assert find_unit_clause([A], {}) == (A, True) + assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ?? + assert find_unit_clause([A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B], {B: True}) == (A, True) + assert find_unit_clause( + [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False) + assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_clause_int_repr(): + assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True) + assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True) + assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True) + assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True) + + A, B, C = symbols('A,B,C') + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_propagate(): + A, B, C = symbols('A,B,C') + assert unit_propagate([A | B], A) == [] + assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A] + + +def test_unit_propagate_int_repr(): + assert unit_propagate_int_repr([{1, 2}], 1) == [] + assert unit_propagate_int_repr(map(set, + [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}] + + +def test_dpll(): + """This is also tested in test_dimacs""" + A, B, C = symbols('A,B,C') + assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True} + + +def test_dpll_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll_satisfiable( A & ~A ) is False + assert dpll_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False}, + {A: True, C: True}, {B: True, C: True}) + assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True} + assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_dpll2_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll2_satisfiable( A & ~A ) is False + assert dpll2_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll2_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll2_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}) + assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_minisat22_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22") + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + +def test_minisat22_minimal_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True) + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True) + sol = next(g) + first_solution = {key for key, value in sol.items() if value} + sol=next(g) + second_solution = {key for key, value in sol.items() if value} + sol=next(g) + third_solution = {key for key, value in sol.items() if value} + assert not first_solution <= second_solution + assert not second_solution <= third_solution + assert not first_solution <= third_solution + +def test_satisfiable(): + A, B, C = symbols('A,B,C') + assert satisfiable(A & (A >> B) & ~B) is False + + +def test_valid(): + A, B, C = symbols('A,B,C') + assert valid(A >> (B >> A)) is True + assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True + assert valid((~B >> ~A) >> (A >> B)) is True + assert valid(A | B | C) is False + assert valid(A >> B) is False + + +def test_pl_true(): + A, B, C = symbols('A,B,C') + assert pl_true(True) is True + assert pl_true( A & B, {A: True, B: True}) is True + assert pl_true( A | B, {A: True}) is True + assert pl_true( A | B, {B: True}) is True + assert pl_true( A | B, {A: None, B: True}) is True + assert pl_true( A >> B, {A: False}) is True + assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True + assert pl_true(Equivalent(A, B), {A: False, B: False}) is True + + # test for false + assert pl_true(False) is False + assert pl_true( A & B, {A: False, B: False}) is False + assert pl_true( A & B, {A: False}) is False + assert pl_true( A & B, {B: False}) is False + assert pl_true( A | B, {A: False, B: False}) is False + + #test for None + assert pl_true(B, {B: None}) is None + assert pl_true( A & B, {A: True, B: None}) is None + assert pl_true( A >> B, {A: True, B: None}) is None + assert pl_true(Equivalent(A, B), {A: None}) is None + assert pl_true(Equivalent(A, B), {A: True, B: None}) is None + + # Test for deep + assert pl_true(A | B, {A: False}, deep=True) is None + assert pl_true(~A & ~B, {A: False}, deep=True) is None + assert pl_true(A | B, {A: False, B: False}, deep=True) is False + assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False + assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True + + +def test_pl_true_wrong_input(): + from sympy.core.numbers import pi + raises(ValueError, lambda: pl_true('John Cleese')) + raises(ValueError, lambda: pl_true(42 + pi + pi ** 2)) + raises(ValueError, lambda: pl_true(42)) + + +def test_entails(): + A, B, C = symbols('A, B, C') + assert entails(A, [A >> B, ~B]) is False + assert entails(B, [Equivalent(A, B), A]) is True + assert entails((A >> B) >> (~A >> ~B)) is False + assert entails((A >> B) >> (~B >> ~A)) is True + + +def test_PropKB(): + A, B, C = symbols('A,B,C') + kb = PropKB() + assert kb.ask(A >> B) is False + assert kb.ask(A >> (B >> A)) is True + kb.tell(A >> B) + kb.tell(B >> C) + assert kb.ask(A) is False + assert kb.ask(B) is False + assert kb.ask(C) is False + assert kb.ask(~A) is False + assert kb.ask(~B) is False + assert kb.ask(~C) is False + assert kb.ask(A >> C) is True + kb.tell(A) + assert kb.ask(A) is True + assert kb.ask(B) is True + assert kb.ask(C) is True + assert kb.ask(~C) is False + kb.retract(A) + assert kb.ask(C) is False + + +def test_propKB_tolerant(): + """"tolerant to bad input""" + kb = PropKB() + A, B, C = symbols('A,B,C') + assert kb.ask(B) is False + +def test_satisfiable_non_symbols(): + x, y = symbols('x y') + assumptions = Q.zero(x*y) + facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y)) + query = ~Q.zero(x) & ~Q.zero(y) + refutations = [ + {Q.zero(x): True, Q.zero(x*y): True}, + {Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True}, + {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}] + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations + +def test_satisfiable_bool(): + from sympy.core.singleton import S + assert satisfiable(true) == {true: true} + assert satisfiable(S.true) == {true: true} + assert satisfiable(false) is False + assert satisfiable(S.false) is False + + +def test_satisfiable_all_models(): + from sympy.abc import A, B + assert next(satisfiable(False, all_models=True)) is False + assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False] + assert list(satisfiable(True, all_models=True)) == [{true: true}] + + models = [{A: True, B: False}, {A: False, B: True}] + result = satisfiable(A ^ B, all_models=True) + models.remove(next(result)) + models.remove(next(result)) + raises(StopIteration, lambda: next(result)) + assert not models + + assert list(satisfiable(Equivalent(A, B), all_models=True)) == \ + [{A: False, B: False}, {A: True, B: True}] + + models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}] + for model in satisfiable(A >> B, all_models=True): + models.remove(model) + assert not models + + # This is a santiy test to check that only the required number + # of solutions are generated. The expr below has 2**100 - 1 models + # which would time out the test if all are generated at once. + from sympy.utilities.iterables import numbered_symbols + from sympy.logic.boolalg import Or + sym = numbered_symbols() + X = [next(sym) for i in range(100)] + result = satisfiable(Or(*X), all_models=True) + for i in range(10): + assert next(result) + + +def test_z3(): + z3 = import_module("z3") + + if not z3: + skip("z3 not installed.") + A, B, C = symbols('A,B,C') + x, y, z = symbols('x,y,z') + assert z3_satisfiable((x >= 2) & (x < 1)) is False + assert z3_satisfiable( A & ~A ) is False + + model = z3_satisfiable(A & (~A | B | C)) + assert bool(model) is True + assert model[A] is True + + # test nonlinear function + assert z3_satisfiable((x ** 2 >= 2) & (x < 1) & (x > -1)) is False + + +def test_z3_vs_lra_dpll2(): + z3 = import_module("z3") + if z3 is None: + skip("z3 not installed.") + + def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + def make_random_cnf(num_clauses=5, num_constraints=10, num_var=2): + assert num_clauses <= num_constraints + constraints = make_random_problem(num_variables=num_var, num_constraints=num_constraints, rational=False) + clauses = [[cons] for cons in constraints[:num_clauses]] + for cons in constraints[num_clauses:]: + if isinstance(cons, Unequality): + cons = ~cons + i = randint(0, num_clauses-1) + clauses[i].append(cons) + + clauses = [Or(*clause) for clause in clauses] + cnf = And(*clauses) + return boolean_formula_to_encoded_cnf(cnf) + + lra_dpll2_satisfiable = lambda x: dpll2_satisfiable(x, use_lra_theory=True) + + for _ in range(50): + cnf = make_random_cnf(num_clauses=10, num_constraints=15, num_var=2) + + try: + z3_sat = z3_satisfiable(cnf) + except z3.z3types.Z3Exception: + continue + + lra_dpll2_sat = lra_dpll2_satisfiable(cnf) is not False + + assert z3_sat == lra_dpll2_sat + +def test_issue_27733(): + x, y = symbols('x,y') + clauses = [[1, -3, -2], [5, 7, -8, -6, -4], [-10, -9, 10, 11, -4], [-12, 13, 14], [-10, 9, -6, 11, -4], + [16, -15, 18, -19, -17], [11, -6, 10, -9], [9, 11, -10, -9], [2, -3, -1], [-13, 12], [-15, 3, -17], + [-16, -15, 19, -17], [-6, -9, 10, 11, -4], [20, -1, -2], [-23, -22, -21], [10, 11, -10, -9], + [9, 11, -4, -10], [24, -6, -4], [-14, 12], [-10, -9, 9, -6, 11], [25, -27, -26], [-15, 19, -18, -17], + [5, 8, -7, -6, -4], [-30, -29, 28], [12], [14]] + + encoding = {Q.gt(y, i): i for i in range(1, 31) if i != 11 and i != 12} + encoding[Q.gt(x, 0)] = 11 + encoding[Q.lt(x, 0)] = 12 + + cnf = EncodedCNF(clauses, encoding) + assert satisfiable(cnf, use_lra_theory=True) is False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_lra_theory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..207a3c5ba2c1b16ee5323382deee0863a5dfb595 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/tests/test_lra_theory.py @@ -0,0 +1,440 @@ +from sympy.core.numbers import Rational, I, oo +from sympy.core.relational import Eq +from sympy.core.symbol import symbols +from sympy.core.singleton import S +from sympy.matrices.dense import Matrix +from sympy.matrices.dense import randMatrix +from sympy.assumptions.ask import Q +from sympy.logic.boolalg import And +from sympy.abc import x, y, z +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.functions.elementary.trigonometric import cos +from sympy.external import import_module + +from sympy.logic.algorithms.lra_theory import LRASolver, UnhandledInput, LRARational, HANDLE_NEGATION +from sympy.core.random import random, choice, randint +from sympy.core.sympify import sympify +from sympy.ntheory.generate import randprime +from sympy.core.relational import StrictLessThan, StrictGreaterThan +import itertools + +from sympy.testing.pytest import raises, XFAIL, skip + +def make_random_problem(num_variables=2, num_constraints=2, sparsity=.1, rational=True, + disable_strict = False, disable_nonstrict=False, disable_equality=False): + def rand(sparsity=sparsity): + if random() < sparsity: + return sympify(0) + if rational: + int1, int2 = [randprime(0, 50) for _ in range(2)] + return Rational(int1, int2) * choice([-1, 1]) + else: + return randint(1, 10) * choice([-1, 1]) + + variables = symbols('x1:%s' % (num_variables + 1)) + constraints = [] + for _ in range(num_constraints): + lhs, rhs = sum(rand() * x for x in variables), rand(sparsity=0) # sparsity=0 bc of bug with smtlib_code + options = [] + if not disable_equality: + options += [Eq(lhs, rhs)] + if not disable_nonstrict: + options += [lhs <= rhs, lhs >= rhs] + if not disable_strict: + options += [lhs < rhs, lhs > rhs] + + constraints.append(choice(options)) + + return constraints + +def check_if_satisfiable_with_z3(constraints): + from sympy.external.importtools import import_module + from sympy.printing.smtlib import smtlib_code + from sympy.logic.boolalg import And + boolean_formula = And(*constraints) + z3 = import_module("z3") + if z3: + smtlib_string = smtlib_code(boolean_formula) + s = z3.Solver() + s.from_string(smtlib_string) + res = str(s.check()) + if res == 'sat': + return True + elif res == 'unsat': + return False + else: + raise ValueError(f"z3 was not able to check the satisfiability of {boolean_formula}") + +def find_rational_assignment(constr, assignment, iter=20): + eps = sympify(1) + + for _ in range(iter): + assign = {key: val[0] + val[1]*eps for key, val in assignment.items()} + try: + for cons in constr: + assert cons.subs(assign) == True + return assign + except AssertionError: + eps = eps/2 + + return None + +def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + +def test_from_encoded_cnf(): + s1, s2 = symbols("s1 s2") + + # Test preprocessing + # Example is from section 3 of paper. + phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + enc = boolean_formula_to_encoded_cnf(phi) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert lra.A.shape == (2, 5) + assert str(lra.slack) == '[_s1, _s2]' + assert str(lra.nonslack) == '[x, y, z]' + assert lra.A == Matrix([[ 1, 1, 0, -1, 0], + [-1, -2, 1, 0, -1]]) + assert {(str(b.var), b.bound, b.upper, b.equality, b.strict) for b in lra.enc_to_boundary.values()} == {('_s1', 2, None, True, False), + ('_s1', 2, True, False, False), + ('_s2', -4, True, False, True), + ('_s2', -6, True, False, False), + ('x', 0, False, False, False)} + + +def test_problem(): + from sympy.logic.algorithms.lra_theory import LRASolver + from sympy.assumptions.cnf import CNF, EncodedCNF + cons = [-2 * x - 2 * y >= 7, -9 * y >= 7, -6 * y >= 5] + cnf = CNF().from_prop(And(*cons)) + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, _ = LRASolver.from_encoded_cnf(enc) + lra.assert_lit(1) + lra.assert_lit(2) + lra.assert_lit(3) + is_sat, assignment = lra.check() + assert is_sat is True + + +def test_random_problems(): + z3 = import_module("z3") + if z3 is None: + skip("z3 is not installed") + + special_cases = []; x1, x2, x3 = symbols("x1 x2 x3") + special_cases.append([x1 - 3 * x2 <= -5, 6 * x1 + 4 * x2 <= 0, -7 * x1 + 3 * x2 <= 3]) + special_cases.append([-3 * x1 >= 3, Eq(4 * x1, -1)]) + special_cases.append([-4 * x1 < 4, 6 * x1 <= -6]) + special_cases.append([-3 * x2 >= 7, 6 * x1 <= -5, -3 * x2 <= -4]) + special_cases.append([x + y >= 2, x + y <= 1]) + special_cases.append([x >= 0, x + y <= 2, x + 2 * y - z >= 6]) # from paper example + special_cases.append([-2 * x1 - 2 * x2 >= 7, -9 * x1 >= 7, -6 * x1 >= 5]) + special_cases.append([2 * x1 > -3, -9 * x1 < -6, 9 * x1 <= 6]) + special_cases.append([-2*x1 < -4, 9*x1 > -9]) + special_cases.append([-6*x1 >= -1, -8*x1 + x2 >= 5, -8*x1 + 7*x2 < 4, x1 > 7]) + special_cases.append([Eq(x1, 2), Eq(5*x1, -2), Eq(-7*x2, -6), Eq(9*x1 + 10*x2, 9)]) + special_cases.append([Eq(3*x1, 6), Eq(x1 - 8*x2, -9), Eq(-7*x1 + 5*x2, 3), Eq(3*x2, 7)]) + special_cases.append([-4*x1 < 4, 6*x1 <= -6]) + special_cases.append([-3*x1 + 8*x2 >= -8, -10*x2 > 9, 8*x1 - 4*x2 < 8, 10*x1 - 9*x2 >= -9]) + special_cases.append([x1 + 5*x2 >= -6, 9*x1 - 3*x2 >= -9, 6*x1 + 6*x2 < -10, -3*x1 + 3*x2 < -7]) + special_cases.append([-9*x1 < 7, -5*x1 - 7*x2 < -1, 3*x1 + 7*x2 > 1, -6*x1 - 6*x2 > 9]) + special_cases.append([9*x1 - 6*x2 >= -7, 9*x1 + 4*x2 < -8, -7*x2 <= 1, 10*x2 <= -7]) + + feasible_count = 0 + for i in range(50): + if i % 8 == 0: + constraints = make_random_problem(num_variables=1, num_constraints=2, rational=False) + elif i % 8 == 1: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_equality=True, + disable_nonstrict=True) + elif i % 8 == 2: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_strict=True) + elif i % 8 == 3: + constraints = make_random_problem(num_variables=3, num_constraints=12, rational=False) + else: + constraints = make_random_problem(num_variables=3, num_constraints=6, rational=False) + + if i < len(special_cases): + constraints = special_cases[i] + + if False in constraints or True in constraints: + continue + + phi = And(*constraints) + if phi == False: + continue + cnf = CNF.from_prop(phi); enc = EncodedCNF() + enc.from_cnf(cnf) + assert all(0 not in clause for clause in enc.data) + + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + s_subs = lra.s_subs + + lra.run_checks = True + s_subs_rev = {value: key for key, value in s_subs.items()} + lits = {lit for clause in enc.data for lit in clause} + + bounds = [(lra.enc_to_boundary[l], l) for l in lits if l in lra.enc_to_boundary] + bounds = sorted(bounds, key=lambda x: (str(x[0].var), x[0].bound, str(x[0].upper))) # to remove nondeterminism + + for b, l in bounds: + if lra.result and lra.result[0] == False: + break + lra.assert_lit(l) + + feasible = lra.check() + + if feasible[0] == True: + feasible_count += 1 + assert check_if_satisfiable_with_z3(constraints) is True + cons_funcs = [cons.func for cons in constraints] + assignment = feasible[1] + assignment = {key.var : value for key, value in assignment.items()} + if not (StrictLessThan in cons_funcs or StrictGreaterThan in cons_funcs): + assignment = {key: value[0] for key, value in assignment.items()} + for cons in constraints: + assert cons.subs(assignment) == True + + else: + rat_assignment = find_rational_assignment(constraints, assignment) + assert rat_assignment is not None + else: + assert check_if_satisfiable_with_z3(constraints) is False + + conflict = feasible[1] + assert len(conflict) >= 2 + conflict = {lra.enc_to_boundary[-l].get_inequality() for l in conflict} + conflict = {clause.subs(s_subs_rev) for clause in conflict} + assert check_if_satisfiable_with_z3(conflict) is False + + # check that conflict clause is probably minimal + for subset in itertools.combinations(conflict, len(conflict)-1): + assert check_if_satisfiable_with_z3(subset) is True + + +@XFAIL +def test_pos_neg_zero(): + bf = Q.positive(x) & Q.negative(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.lt(x, -1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + +@XFAIL +def test_pos_neg_infinite(): + bf = Q.positive_infinite(x) & Q.lt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive_infinite(x) & Q.gt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + bf = Q.positive_infinite(x) & Q.negative_infinite(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + +def test_binrel_evaluation(): + bf = Q.gt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[1]] + + bf = Q.lt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[-1]] + + +def test_negation(): + assert HANDLE_NEGATION is True + bf = Q.gt(x, 1) & ~Q.gt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + assert sorted(lra.check()[1]) in [[-1, 2], [-2, 1]] + + bf = ~Q.gt(x, 1) & ~Q.lt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + bf = ~Q.gt(x, 0) & ~Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.gt(x, 0) & ~Q.le(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.le(x+y, 2) & ~Q.ge(x-y, 2) & ~Q.ge(y, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + assert len(lra.check()[1]) == 3 + assert all(i > 0 for i in lra.check()[1]) + + +def test_unhandled_input(): + nan = S.NaN + bf = Q.gt(3, nan) & Q.gt(x, nan) + enc = boolean_formula_to_encoded_cnf(bf) + raises(ValueError, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, I) & Q.gt(x, I) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, float("inf")) & Q.gt(x, float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, oo) & Q.gt(x, oo) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + # test non-linearity + bf = Q.gt(x**2 + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(cos(x) + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + +@XFAIL +def test_infinite_strict_inequalities(): + # Extensive testing of the interaction between strict inequalities + # and constraints containing infinity is needed because + # the paper's rule for strict inequalities don't work when + # infinite numbers are allowed. Using the paper's rules you + # can end up with situations where oo + delta > oo is considered + # True when oo + delta should be equal to oo. + # See https://math.stackexchange.com/questions/4757069/can-this-method-of-converting-strict-inequalities-to-equisatisfiable-nonstrict-i + bf = (-x - y >= -float("inf")) & (x > 0) & (y >= float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in sorted(enc.encoding.values()): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + +def test_pivot(): + for _ in range(10): + m = randMatrix(5) + rref = m.rref() + for _ in range(5): + i, j = randint(0, 4), randint(0, 4) + if m[i, j] != 0: + assert LRASolver._pivot(m, i, j).rref() == rref + + +def test_reset_bounds(): + bf = Q.ge(x, 1) & Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + lra.reset_bounds() + assert lra.check()[0] == True + for var in lra.all_var: + assert var.upper == LRARational(float("inf"), 0) + assert var.upper_from_eq == False + assert var.upper_from_neg == False + assert var.lower == LRARational(-float("inf"), 0) + assert var.lower_from_eq == False + assert var.lower_from_neg == False + assert var.assign == LRARational(0, 0) + assert var.var is not None + assert var.col_idx is not None + + +def test_empty_cnf(): + cnf = CNF() + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, conflict = LRASolver.from_encoded_cnf(enc) + assert len(conflict) == 0 + assert lra.check() == (True, {}) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3526c3e53d624bc70afe2df05f123c835781364c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__init__.py @@ -0,0 +1,3 @@ +from .dimacs import load_file + +__all__ = ['load_file'] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a387ae14bf1e6e935a6beddcf717a7fc77b9f5ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81ec21e6641fada3a71acf60c0e67c32fab0616b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/dimacs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..51302d8052c8ed8443239c1e21a2f063cf34e4ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/logic/utilities/dimacs.py @@ -0,0 +1,69 @@ +"""For reading in DIMACS file format + +www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/satformat.ps + +""" + +from sympy.core import Symbol +from sympy.logic.boolalg import And, Or +import re +from pathlib import Path + + +def load(s): + """Loads a boolean expression from a string. + + Examples + ======== + + >>> from sympy.logic.utilities.dimacs import load + >>> load('1') + cnf_1 + >>> load('1 2') + cnf_1 | cnf_2 + >>> load('1 \\n 2') + cnf_1 & cnf_2 + >>> load('1 2 \\n 3') + cnf_3 & (cnf_1 | cnf_2) + """ + clauses = [] + + lines = s.split('\n') + + pComment = re.compile(r'c.*') + pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)') + + while len(lines) > 0: + line = lines.pop(0) + + # Only deal with lines that aren't comments + if not pComment.match(line): + m = pStats.match(line) + + if not m: + nums = line.rstrip('\n').split(' ') + list = [] + for lit in nums: + if lit != '': + if int(lit) == 0: + continue + num = abs(int(lit)) + sign = True + if int(lit) < 0: + sign = False + + if sign: + list.append(Symbol("cnf_%s" % num)) + else: + list.append(~Symbol("cnf_%s" % num)) + + if len(list) > 0: + clauses.append(Or(*list)) + + return And(*clauses) + + +def load_file(location): + """Loads a boolean expression from a file.""" + s = Path(location).read_text() + return load(s) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c40052e900b900819c3ea75537324003c679945 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b29023bc2ea9594c43893f2f21b4c51a54239a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94cb2a5c6a34f1550e24e005a066b2203152f1b7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d52d69c67faeeb6ede777476cf69084851ca6bf9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d472405803130c254b2b9feb4f4fca0e900413c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f8f5a9dc666e46f37c4c5f10cbfd5084947d087 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c94e1b02bdd1c3537983b04774aab2a65a765c5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf7208c7f322e0d7672146af276a52a3fac3a64c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2288356c807618fb69d2afafcad7b3d7cfeabe55 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3d23c844c2000ffa4bfa360e40a74fc372324b4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccbe86c8dca5d1f18c27231a793ac2dc1154f526 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba400d4985f627807eeb618372e7c9de7fda38b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d89ef2646040362d25aa01c29d0a361377a822a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..495e534cf37906afac80a0dc67c3478fc88cdf96 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba741412856c8012dda7cbd58022f81df798ef41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc1dc24cd09fc04f6f70cd8424fd859e6f4c818c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58bf95dc820956c87b615883173fd1d5812e7653 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b82e3275b52b1670fa34d19b288caa40e329098 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af0c5d91b04dfee4aab7ea79c2c523077cc0c485 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed3a113d1ecd7ee8144a32d5b2e19f5c4fdf955 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d603cfb6ba1460d0553cf31760389c06caba6854 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05e17e6827bb05420698d78ec7e507250a01a5af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef3d509780073932d0e7c75895a80f87b3b985e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..541364d97c3bb25aa16988fe83faed47aa45a173 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0e599cfba8a12789e49e9fc2ca4d7a51b81f56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85e7a9acdf425658fb5a384956931e24a1b7dcf8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ee31d7a09a19d17b2885e9d6e8cad2479ef6202 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f36831c62a6fa3595885ef3c986595b440d3e30 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c73d87fb60b8f751b74b45eee5eb7ff2af14251 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5941c7945c91d6e0cb728cb4ee8beb5ad568ff9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bc5cb7b9a53a03112eca7544ea583959265d2ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e6f32aa22ff731f925c7037931f789c7ec47f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a6d33bcb0d98775c60df0a3a34dcc0d26f6600 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6d40b67877bf36633263680c76800d8dac3d8a3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf1c7594a0695185f7f1fb73a7914b8bb84a829b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7025922afa8d23178f31fd5c49b29d829e2c9b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c9f75dc317e5975edc79ddf1963f7e7c4fbd613 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796d47405fde3dacfb1ccd5e62f19073d5613f82 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aff0c01ba863a3c1e1444aebaf83c41caf13052d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8418a9399a28fe41d5f848c7c6823b6bc63f4609 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccdcc58cd70caf0c2c2931f511bf0019ecc70694 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00da9610b5d2c706bc542f1253a57136166ea266 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8d0d6b7ed1f2259508bf3b182f7d7f9c9d9cfc6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884b5299af44b818dd14ccc929c062566e7f0f68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8ea1208ce935c990860bcd8e1ee4a1abc01eb3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79eca50dde270741e715f466742323c5b6788506 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e408598fc72bc18a6ad9b006cd748642953f64c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e808215f263eed0f9c61721c4066dfe853b31538 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5873e4ae1f0b2cbd0cec009b679cc901da3376a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b3b9796bd5dda3a5e02c628fc7e3b883eaded0b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd03f8c9291ee472e9ada54f4feccdf06558e1ce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db0d6fcd9129ec81cca7f6b70c88989c9aca4db2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fb8ed09f68819d9fa817baad888a627b3ba1ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_companion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_companion.py new file mode 100644 index 0000000000000000000000000000000000000000..edc592c29098eddce0c6352806aa73d5d889e999 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_companion.py @@ -0,0 +1,48 @@ +from sympy.core.expr import unchanged +from sympy.core.symbol import Symbol, symbols +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.matrices.expressions.companion import CompanionMatrix +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + + +def test_creation(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: CompanionMatrix(1)) + raises(ValueError, lambda: CompanionMatrix(Poly([1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly([2, 1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly(x*y, [x, y]))) + assert unchanged(CompanionMatrix, Poly([1, 2, 3], x)) + + +def test_shape(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).shape == (1, 1) + assert CompanionMatrix(Poly([1, c1, c0], x)).shape == (2, 2) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).shape == (3, 3) + + +def test_entry(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + A = CompanionMatrix(Poly([1, c2, c1, c0], x)) + assert A[0, 0] == 0 + assert A[1, 0] == 1 + assert A[1, 1] == 0 + assert A[2, 1] == 1 + assert A[0, 2] == -c0 + assert A[1, 2] == -c1 + assert A[2, 2] == -c2 + + +def test_as_explicit(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([-c0]) + assert CompanionMatrix(Poly([1, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, -c0], [1, -c1]]) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..abf8ab8e935cbd3039f25f83d3603ac444e5a7bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py @@ -0,0 +1,35 @@ +from sympy.core.expr import unchanged +from sympy.core.mul import Mul +from sympy.matrices import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.testing.pytest import raises + + +A = Matrix(3, 1, [1, 2, 3]) +B = Matrix(3, 1, [1, 3, 5]) +C = Matrix(4, 1, [1, 2, 4, 5]) +D = Matrix(2, 2, [1, 2, 3, 4]) + +def test_docproduct(): + assert DotProduct(A, B).doit() == 22 + assert DotProduct(A.T, B).doit() == 22 + assert DotProduct(A, B.T).doit() == 22 + assert DotProduct(A.T, B.T).doit() == 22 + + raises(TypeError, lambda: DotProduct(1, A)) + raises(TypeError, lambda: DotProduct(A, 1)) + raises(TypeError, lambda: DotProduct(A, D)) + raises(TypeError, lambda: DotProduct(D, A)) + + raises(TypeError, lambda: DotProduct(B, C).doit()) + +def test_dotproduct_symbolic(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 3, 1) + + dot = DotProduct(A, B) + assert dot.is_scalar == True + assert unchanged(Mul, 2, dot) + # XXX Fix forced evaluation for arithmetics with matrix expressions + assert dot * A == (A[0, 0]*B[0, 0] + A[1, 0]*B[1, 0] + A[2, 0]*B[2, 0])*A diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_factorizations.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_factorizations.py new file mode 100644 index 0000000000000000000000000000000000000000..a0319acabbb7409dfa5c24ceca39e25ff0240618 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_factorizations.py @@ -0,0 +1,29 @@ +from sympy.matrices.expressions.factorizations import lu, LofCholesky, qr, svd +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol + +n = Symbol('n') +X = MatrixSymbol('X', n, n) + +def test_LU(): + L, U = lu(X) + assert L.shape == U.shape == X.shape + assert ask(Q.lower_triangular(L)) + assert ask(Q.upper_triangular(U)) + +def test_Cholesky(): + LofCholesky(X) + +def test_QR(): + Q_, R = qr(X) + assert Q_.shape == R.shape == X.shape + assert ask(Q.orthogonal(Q_)) + assert ask(Q.upper_triangular(R)) + +def test_svd(): + U, S, V = svd(X) + assert U.shape == S.shape == V.shape == X.shape + assert ask(Q.orthogonal(U)) + assert ask(Q.orthogonal(V)) + assert ask(Q.diagonal(S)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_inverse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcc7d4de2b2bee4c4922bda8bc48a52aa205961 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_inverse.py @@ -0,0 +1,69 @@ +from sympy.core import symbols, S +from sympy.matrices.expressions import MatrixSymbol, Inverse, MatPow, ZeroMatrix, OneMatrix +from sympy.matrices.exceptions import NonInvertibleMatrixError, NonSquareMatrixError +from sympy.matrices import eye, Identity +from sympy.testing.pytest import raises +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine + +n, m, l = symbols('n m l', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + + +def test_inverse(): + assert Inverse(C).args == (C, S.NegativeOne) + assert Inverse(C).shape == (n, n) + assert Inverse(A*E).shape == (n, n) + assert Inverse(E*A).shape == (m, m) + assert Inverse(C).inverse() == C + assert Inverse(Inverse(C)).doit() == C + assert isinstance(Inverse(Inverse(C)), Inverse) + + assert Inverse(*Inverse(E*A).args) == Inverse(E*A) + + assert C.inverse().inverse() == C + + assert C.inverse()*C == Identity(C.rows) + + assert Identity(n).inverse() == Identity(n) + assert (3*Identity(n)).inverse() == Identity(n)/3 + + # Simplifies Muls if possible (i.e. submatrices are square) + assert (C*D).inverse() == D.I*C.I + # But still works when not possible + assert isinstance((A*E).inverse(), Inverse) + assert Inverse(C*D).doit(inv_expand=False) == Inverse(C*D) + + assert Inverse(eye(3)).doit() == eye(3) + assert Inverse(eye(3)).doit(deep=False) == eye(3) + + assert OneMatrix(1, 1).I == Identity(1) + assert isinstance(OneMatrix(n, n).I, Inverse) + +def test_inverse_non_invertible(): + raises(NonInvertibleMatrixError, lambda: ZeroMatrix(n, n).I) + raises(NonInvertibleMatrixError, lambda: OneMatrix(2, 2).I) + +def test_refine(): + assert refine(C.I, Q.orthogonal(C)) == C.T + + +def test_inverse_matpow_canonicalization(): + A = MatrixSymbol('A', 3, 3) + assert Inverse(MatPow(A, 3)).doit() == MatPow(Inverse(A), 3).doit() + + +def test_nonsquare_error(): + A = MatrixSymbol('A', 3, 4) + raises(NonSquareMatrixError, lambda: Inverse(A)) + + +def test_adjoint_trnaspose_conjugate(): + A = MatrixSymbol('A', n, n) + assert A.transpose().inverse() == A.inverse().transpose() + assert A.conjugate().inverse() == A.inverse().conjugate() + assert A.adjoint().inverse() == A.inverse().adjoint() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_matexpr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_matexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..f2319e8d8097c2ad3519eab783c4665623c55b80 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_matexpr.py @@ -0,0 +1,592 @@ +from sympy.concrete.summations import Sum +from sympy.core.exprtools import gcd_terms +from sympy.core.function import (diff, expand) +from sympy.core.relational import Eq +from sympy.core.symbol import (Dummy, Symbol, Str) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import zeros +from sympy.polys.polytools import factor + +from sympy.core import (S, symbols, Add, Mul, SympifyError, Rational, + Function) +from sympy.functions import sin, cos, tan, sqrt, cbrt, exp +from sympy.simplify import simplify +from sympy.matrices import (ImmutableMatrix, Inverse, MatAdd, MatMul, + MatPow, Matrix, MatrixExpr, MatrixSymbol, + SparseMatrix, Transpose, Adjoint, MatrixSet) +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions.determinant import Determinant, det +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.matrices.expressions.special import ZeroMatrix, Identity +from sympy.testing.pytest import raises, XFAIL, skip +from importlib.metadata import version + +n, m, l, k, p = symbols('n m l k p', integer=True) +x = symbols('x') +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) +w = MatrixSymbol('w', n, 1) + + +def test_matrix_symbol_creation(): + assert MatrixSymbol('A', 2, 2) + assert MatrixSymbol('A', 0, 0) + raises(ValueError, lambda: MatrixSymbol('A', -1, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2.0, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2j, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2, -1)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2.0)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2j)) + + n = symbols('n') + assert MatrixSymbol('A', n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + + +def test_matexpr_properties(): + assert A.shape == (n, m) + assert (A * B).shape == (n, l) + assert A[0, 1].indices == (0, 1) + assert A[0, 0].symbol == A + assert A[0, 0].symbol.name == 'A' + + +def test_matexpr(): + assert (x*A).shape == A.shape + assert (x*A).__class__ == MatMul + assert 2*A - A - A == ZeroMatrix(*A.shape) + assert (A*B).shape == (n, l) + + +def test_matexpr_subs(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', m, l) + + assert A.subs(n, m).shape == (m, m) + assert (A*B).subs(B, C) == A*C + assert (A*B).subs(l, n).is_square + + W = MatrixSymbol("W", 3, 3) + X = MatrixSymbol("X", 2, 2) + Y = MatrixSymbol("Y", 1, 2) + Z = MatrixSymbol("Z", n, 2) + # no restrictions on Symbol replacement + assert X.subs(X, Y) == Y + # it might be better to just change the name + y = Str('y') + assert X.subs(Str("X"), y).args == (y, 2, 2) + # it's ok to introduce a wider matrix + assert X[1, 1].subs(X, W) == W[1, 1] + # but for a given MatrixExpression, only change + # name if indexing on the new shape is valid. + # Here, X is 2,2; Y is 1,2 and Y[1, 1] is out + # of range so an error is raised + raises(IndexError, lambda: X[1, 1].subs(X, Y)) + # here, [0, 1] is in range so the subs succeeds + assert X[0, 1].subs(X, Y) == Y[0, 1] + # and here the size of n will accept any index + # in the first position + assert W[2, 1].subs(W, Z) == Z[2, 1] + # but not in the second position + raises(IndexError, lambda: W[2, 2].subs(W, Z)) + # any matrix should raise if invalid + raises(IndexError, lambda: W[2, 2].subs(W, zeros(2))) + + A = SparseMatrix([[1, 2], [3, 4]]) + B = Matrix([[1, 2], [3, 4]]) + C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2) + + assert (C*D).subs({C: A, D: B}) == MatMul(A, B) + + +def test_addition(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, m) + + assert isinstance(A + B, MatAdd) + assert (A + B).shape == A.shape + assert isinstance(A - A + 2*B, MatMul) + + raises(TypeError, lambda: A + 1) + raises(TypeError, lambda: 5 + A) + raises(TypeError, lambda: 5 - A) + + assert A + ZeroMatrix(n, m) - A == ZeroMatrix(n, m) + raises(TypeError, lambda: ZeroMatrix(n, m) + S.Zero) + + +def test_multiplication(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', n, n) + + assert (2*A*B).shape == (n, l) + assert (A*0*B) == ZeroMatrix(n, l) + assert (2*A).shape == A.shape + + assert A * ZeroMatrix(m, m) * B == ZeroMatrix(n, l) + + assert C * Identity(n) * C.I == Identity(n) + + assert B/2 == S.Half*B + raises(NotImplementedError, lambda: 2/B) + + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert Identity(n) * (A + B) == A + B + + assert A**2*A == A**3 + assert A**2*(A.I)**3 == A.I + assert A**3*(A.I)**2 == A + + +def test_MatPow(): + A = MatrixSymbol('A', n, n) + + AA = MatPow(A, 2) + assert AA.exp == 2 + assert AA.base == A + assert (A**n).exp == n + + assert A**0 == Identity(n) + assert A**1 == A + assert A**2 == AA + assert A**-1 == Inverse(A) + assert (A**-1)**-1 == A + assert (A**2)**3 == A**6 + assert A**S.Half == sqrt(A) + assert A**Rational(1, 3) == cbrt(A) + raises(NonSquareMatrixError, lambda: MatrixSymbol('B', 3, 2)**2) + + +def test_MatrixSymbol(): + n, m, t = symbols('n,m,t') + X = MatrixSymbol('X', n, m) + assert X.shape == (n, m) + raises(TypeError, lambda: MatrixSymbol('X', n, m)(t)) # issue 5855 + assert X.doit() == X + + +def test_dense_conversion(): + X = MatrixSymbol('X', 2, 2) + assert ImmutableMatrix(X) == ImmutableMatrix(2, 2, lambda i, j: X[i, j]) + assert Matrix(X) == Matrix(2, 2, lambda i, j: X[i, j]) + + +def test_free_symbols(): + assert (C*D).free_symbols == {C, D} + + +def test_zero_matmul(): + assert isinstance(S.Zero * MatrixSymbol('X', 2, 2), MatrixExpr) + + +def test_matadd_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatAdd(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatAdd(A, Matrix([[1]])) + + +def test_matmul_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatMul(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatMul(A, Matrix([[1]])) + + +def test_invariants(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + X = MatrixSymbol('X', n, n) + objs = [Identity(n), ZeroMatrix(m, n), A, MatMul(A, B), MatAdd(A, A), + Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1), + MatPow(X, 0)] + for obj in objs: + assert obj == obj.__class__(*obj.args) + + +def test_matexpr_indexing(): + A = MatrixSymbol('A', n, m) + A[1, 2] + A[l, k] + A[l + 1, k + 1] + A = MatrixSymbol('A', 2, 1) + for i in range(-2, 2): + for j in range(-1, 1): + A[i, j] + + +def test_single_indexing(): + A = MatrixSymbol('A', 2, 3) + assert A[1] == A[0, 1] + assert A[int(1)] == A[0, 1] + assert A[3] == A[1, 0] + assert list(A[:2, :2]) == [A[0, 0], A[0, 1], A[1, 0], A[1, 1]] + raises(IndexError, lambda: A[6]) + raises(IndexError, lambda: A[n]) + B = MatrixSymbol('B', n, m) + raises(IndexError, lambda: B[1]) + B = MatrixSymbol('B', n, 3) + assert B[3] == B[1, 0] + + +def test_MatrixElement_commutative(): + assert A[0, 1]*A[1, 0] == A[1, 0]*A[0, 1] + + +def test_MatrixSymbol_determinant(): + A = MatrixSymbol('A', 4, 4) + assert A.as_explicit().det() == A[0, 0]*A[1, 1]*A[2, 2]*A[3, 3] - \ + A[0, 0]*A[1, 1]*A[2, 3]*A[3, 2] - A[0, 0]*A[1, 2]*A[2, 1]*A[3, 3] + \ + A[0, 0]*A[1, 2]*A[2, 3]*A[3, 1] + A[0, 0]*A[1, 3]*A[2, 1]*A[3, 2] - \ + A[0, 0]*A[1, 3]*A[2, 2]*A[3, 1] - A[0, 1]*A[1, 0]*A[2, 2]*A[3, 3] + \ + A[0, 1]*A[1, 0]*A[2, 3]*A[3, 2] + A[0, 1]*A[1, 2]*A[2, 0]*A[3, 3] - \ + A[0, 1]*A[1, 2]*A[2, 3]*A[3, 0] - A[0, 1]*A[1, 3]*A[2, 0]*A[3, 2] + \ + A[0, 1]*A[1, 3]*A[2, 2]*A[3, 0] + A[0, 2]*A[1, 0]*A[2, 1]*A[3, 3] - \ + A[0, 2]*A[1, 0]*A[2, 3]*A[3, 1] - A[0, 2]*A[1, 1]*A[2, 0]*A[3, 3] + \ + A[0, 2]*A[1, 1]*A[2, 3]*A[3, 0] + A[0, 2]*A[1, 3]*A[2, 0]*A[3, 1] - \ + A[0, 2]*A[1, 3]*A[2, 1]*A[3, 0] - A[0, 3]*A[1, 0]*A[2, 1]*A[3, 2] + \ + A[0, 3]*A[1, 0]*A[2, 2]*A[3, 1] + A[0, 3]*A[1, 1]*A[2, 0]*A[3, 2] - \ + A[0, 3]*A[1, 1]*A[2, 2]*A[3, 0] - A[0, 3]*A[1, 2]*A[2, 0]*A[3, 1] + \ + A[0, 3]*A[1, 2]*A[2, 1]*A[3, 0] + + B = MatrixSymbol('B', 4, 4) + assert Determinant(A + B).doit() == det(A + B) == (A + B).det() + + +def test_MatrixElement_diff(): + assert (A[3, 0]*A[0, 0]).diff(A[0, 0]) == A[3, 0] + + +def test_MatrixElement_doit(): + u = MatrixSymbol('u', 2, 1) + v = ImmutableMatrix([3, 5]) + assert u[0, 0].subs(u, v).doit() == v[0, 0] + + +def test_identity_powers(): + M = Identity(n) + assert MatPow(M, 3).doit() == M**3 + assert M**n == M + assert MatPow(M, 0).doit() == M**2 + assert M**-2 == M + assert MatPow(M, -2).doit() == M**0 + N = Identity(3) + assert MatPow(N, 2).doit() == N**n + assert MatPow(N, 3).doit() == N + assert MatPow(N, -2).doit() == N**4 + assert MatPow(N, 2).doit() == N**0 + + +def test_Zero_power(): + z1 = ZeroMatrix(n, n) + assert z1**4 == z1 + raises(ValueError, lambda:z1**-2) + assert z1**0 == Identity(n) + assert MatPow(z1, 2).doit() == z1**2 + raises(ValueError, lambda:MatPow(z1, -2).doit()) + z2 = ZeroMatrix(3, 3) + assert MatPow(z2, 4).doit() == z2**4 + raises(ValueError, lambda:z2**-3) + assert z2**3 == MatPow(z2, 3).doit() + assert z2**0 == Identity(3) + raises(ValueError, lambda:MatPow(z2, -1).doit()) + + +def test_matrixelement_diff(): + dexpr = diff((D*w)[k,0], w[p,0]) + + assert w[k, p].diff(w[k, p]) == 1 + assert w[k, p].diff(w[0, 0]) == KroneckerDelta(0, k, (0, n-1))*KroneckerDelta(0, p, (0, 0)) + _i_1 = Dummy("_i_1") + assert dexpr.dummy_eq(Sum(KroneckerDelta(_i_1, p, (0, n-1))*D[k, _i_1], (_i_1, 0, n - 1))) + assert dexpr.doit() == D[k, p] + + +def test_MatrixElement_with_values(): + x, y, z, w = symbols("x y z w") + M = Matrix([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, MatrixElement) + Ms = SparseMatrix([[2, 3], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, MatrixElement) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = MatrixSymbol("A", 2, 2) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3*i - 2, j], MatrixElement) + assert M[3*i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], MatrixElement) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == Matrix([[1, 0], [0, 0]])[i, j] + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + +def test_inv(): + B = MatrixSymbol('B', 3, 3) + assert B.inv() == B**-1 + + # https://github.com/sympy/sympy/issues/19162 + X = MatrixSymbol('X', 1, 1).as_explicit() + assert X.inv() == Matrix([[1/X[0, 0]]]) + + X = MatrixSymbol('X', 2, 2).as_explicit() + detX = X[0, 0]*X[1, 1] - X[0, 1]*X[1, 0] + invX = Matrix([[ X[1, 1], -X[0, 1]], + [-X[1, 0], X[0, 0]]]) / detX + assert X.inv() == invX + + +@XFAIL +def test_factor_expand(): + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + expr1 = (A + B)*(C + D) + expr2 = A*C + B*C + A*D + B*D + assert expr1 != expr2 + assert expand(expr1) == expr2 + assert factor(expr2) == expr1 + + expr = B**(-1)*(A**(-1)*B**(-1) - A**(-1)*C*B**(-1))**(-1)*A**(-1) + I = Identity(n) + # Ideally we get the first, but we at least don't want a wrong answer + assert factor(expr) in [I - C, B**-1*(A**-1*(I - C)*B**-1)**-1*A**-1] + +def test_numpy_conversion(): + try: + from numpy import array, array_equal + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + A = MatrixSymbol('A', 2, 2) + np_array = array([[MatrixElement(A, 0, 0), MatrixElement(A, 0, 1)], + [MatrixElement(A, 1, 0), MatrixElement(A, 1, 1)]]) + assert array_equal(array(A), np_array) + assert array_equal(array(A, copy=True), np_array) + if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly. + raises(TypeError, lambda: array(A, copy=False)) + +def test_issue_2749(): + A = MatrixSymbol("A", 5, 2) + assert (A.T * A).I.as_explicit() == Matrix([[(A.T * A).I[0, 0], (A.T * A).I[0, 1]], \ + [(A.T * A).I[1, 0], (A.T * A).I[1, 1]]]) + + +def test_issue_2750(): + x = MatrixSymbol('x', 1, 1) + assert (x.T*x).as_explicit()**-1 == Matrix([[x[0, 0]**(-2)]]) + + +def test_issue_7842(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 2, 1) + assert Eq(A, B) == False + assert Eq(A[1,0], B[1, 0]).func is Eq + A = ZeroMatrix(2, 3) + B = ZeroMatrix(2, 3) + assert Eq(A, B) == True + + +def test_issue_21195(): + t = symbols('t') + x = Function('x')(t) + dx = x.diff(t) + exp1 = cos(x) + cos(x)*dx + exp2 = sin(x) + tan(x)*(dx.diff(t)) + exp3 = sin(x)*sin(t)*(dx.diff(t)).diff(t) + A = Matrix([[exp1], [exp2], [exp3]]) + B = Matrix([[exp1.diff(x)], [exp2.diff(x)], [exp3.diff(x)]]) + assert A.diff(x) == B + + +def test_issue_24859(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 2) + J = A*B + Jinv = Matrix(J).adjugate() + u = MatrixSymbol('u', 2, 3) + Jk = Jinv.subs(A, A + x*u) + + expected = B[0, 1]*u[1, 0] + B[1, 1]*u[1, 1] + B[2, 1]*u[1, 2] + assert Jk[0, 0].diff(x) == expected + assert diff(Jk[0, 0], x).doit() == expected + + +def test_MatMul_postprocessor(): + z = zeros(2) + z1 = ZeroMatrix(2, 2) + assert Mul(0, z) == Mul(z, 0) in [z, z1] + + M = Matrix([[1, 2], [3, 4]]) + Mx = Matrix([[x, 2*x], [3*x, 4*x]]) + assert Mul(x, M) == Mul(M, x) == Mx + + A = MatrixSymbol("A", 2, 2) + assert Mul(A, M) == MatMul(A, M) + assert Mul(M, A) == MatMul(M, A) + # Scalars should be absorbed into constant matrices + a = Mul(x, M, A) + b = Mul(M, x, A) + c = Mul(M, A, x) + assert a == b == c == MatMul(Mx, A) + a = Mul(x, A, M) + b = Mul(A, x, M) + c = Mul(A, M, x) + assert a == b == c == MatMul(A, Mx) + assert Mul(M, M) == M**2 + assert Mul(A, M, M) == MatMul(A, M**2) + assert Mul(M, M, A) == MatMul(M**2, A) + assert Mul(M, A, M) == MatMul(M, A, M) + + assert Mul(A, x, M, M, x) == MatMul(A, Mx**2) + + +@XFAIL +def test_MatAdd_postprocessor_xfail(): + # This is difficult to get working because of the way that Add processes + # its args. + z = zeros(2) + assert Add(z, S.NaN) == Add(S.NaN, z) + + +def test_MatAdd_postprocessor(): + # Some of these are nonsensical, but we do not raise errors for Add + # because that breaks algorithms that want to replace matrices with dummy + # symbols. + + z = zeros(2) + + assert Add(0, z) == Add(z, 0) == z + + a = Add(S.Infinity, z) + assert a == Add(z, S.Infinity) + assert isinstance(a, Add) + assert a.args == (S.Infinity, z) + + a = Add(S.ComplexInfinity, z) + assert a == Add(z, S.ComplexInfinity) + assert isinstance(a, Add) + assert a.args == (S.ComplexInfinity, z) + + a = Add(z, S.NaN) + # assert a == Add(S.NaN, z) # See the XFAIL above + assert isinstance(a, Add) + assert a.args == (S.NaN, z) + + M = Matrix([[1, 2], [3, 4]]) + a = Add(x, M) + assert a == Add(M, x) + assert isinstance(a, Add) + assert a.args == (x, M) + + A = MatrixSymbol("A", 2, 2) + assert Add(A, M) == Add(M, A) == A + M + + # Scalars should be absorbed into constant matrices (producing an error) + a = Add(x, M, A) + assert a == Add(M, x, A) == Add(M, A, x) == Add(x, A, M) == Add(A, x, M) == Add(A, M, x) + assert isinstance(a, Add) + assert a.args == (x, A + M) + + assert Add(M, M) == 2*M + assert Add(M, A, M) == Add(M, M, A) == Add(A, M, M) == A + 2*M + + a = Add(A, x, M, M, x) + assert isinstance(a, Add) + assert a.args == (2*x, A + 2*M) + + +def test_simplify_matrix_expressions(): + # Various simplification functions + assert type(gcd_terms(C*D + D*C)) == MatAdd + a = gcd_terms(2*C*D + 4*D*C) + assert type(a) == MatAdd + assert a.args == (2*C*D, 4*D*C) + + +def test_exp(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + expr1 = exp(A)*exp(B) + expr2 = exp(B)*exp(A) + assert expr1 != expr2 + assert expr1 - expr2 != 0 + assert not isinstance(expr1, exp) + assert not isinstance(expr2, exp) + + +def test_invalid_args(): + raises(SympifyError, lambda: MatrixSymbol(1, 2, 'A')) + + +def test_matrixsymbol_from_symbol(): + # The label should be preserved during doit and subs + A_label = Symbol('A', complex=True) + A = MatrixSymbol(A_label, 2, 2) + + A_1 = A.doit() + A_2 = A.subs(2, 3) + assert A_1.args == A.args + assert A_2.args[0] == A.args[0] + + +def test_as_explicit(): + Z = MatrixSymbol('Z', 2, 3) + assert Z.as_explicit() == ImmutableMatrix([ + [Z[0, 0], Z[0, 1], Z[0, 2]], + [Z[1, 0], Z[1, 1], Z[1, 2]], + ]) + raises(ValueError, lambda: A.as_explicit()) + + +def test_MatrixSet(): + M = MatrixSet(2, 2, set=S.Reals) + assert M.shape == (2, 2) + assert M.set == S.Reals + X = Matrix([[1, 2], [3, 4]]) + assert X in M + X = ZeroMatrix(2, 2) + assert X in M + raises(TypeError, lambda: A in M) + raises(TypeError, lambda: 1 in M) + M = MatrixSet(n, m, set=S.Reals) + assert A in M + raises(TypeError, lambda: C in M) + raises(TypeError, lambda: X in M) + M = MatrixSet(2, 2, set={1, 2, 3}) + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[1, 2]]) + assert (X in M) == S.false + assert (Y in M) == S.false + raises(ValueError, lambda: MatrixSet(2, -2, S.Reals)) + raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals)) + raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3))) + + +def test_matrixsymbol_solving(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + Z = ZeroMatrix(2, 2) + assert -(-A + B) - A + B == Z + assert (-(-A + B) - A + B).simplify() == Z + assert (-(-A + B) - A + B).expand() == Z + assert (-(-A + B) - A + B - Z).simplify() == Z + assert (-(-A + B) - A + B - Z).expand() == Z + assert (A*(A + B) + B*(A.T + B.T)).expand() == A**2 + A*B + B*A.T + B*B.T diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_slice.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..36490719e26908b9e913ed99b7673d602647c492 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/matrices/expressions/tests/test_slice.py @@ -0,0 +1,65 @@ +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions import MatrixSymbol +from sympy.abc import a, b, c, d, k, l, m, n +from sympy.testing.pytest import raises, XFAIL +from sympy.functions.elementary.integers import floor +from sympy.assumptions import assuming, Q + + +X = MatrixSymbol('X', n, m) +Y = MatrixSymbol('Y', m, k) + +def test_shape(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B.shape == (b - a, d - c) + +def test_entry(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B[0,0] == X[a, c] + assert B[k,l] == X[a+k, c+l] + raises(IndexError, lambda : MatrixSlice(X, 1, (2, 5))[1, 0]) + + assert X[1::2, :][1, 3] == X[1+2, 3] + assert X[:, 1::2][3, 1] == X[3, 1+2] + +def test_on_diag(): + assert not MatrixSlice(X, (a, b), (c, d)).on_diag + assert MatrixSlice(X, (a, b), (a, b)).on_diag + +def test_inputs(): + assert MatrixSlice(X, 1, (2, 5)) == MatrixSlice(X, (1, 2), (2, 5)) + assert MatrixSlice(X, 1, (2, 5)).shape == (1, 3) + +def test_slicing(): + assert X[1:5, 2:4] == MatrixSlice(X, (1, 5), (2, 4)) + assert X[1, 2:4] == MatrixSlice(X, 1, (2, 4)) + assert X[1:5, :].shape == (4, X.shape[1]) + assert X[:, 1:5].shape == (X.shape[0], 4) + + assert X[::2, ::2].shape == (floor(n/2), floor(m/2)) + assert X[2, :] == MatrixSlice(X, 2, (0, m)) + assert X[k, :] == MatrixSlice(X, k, (0, m)) + +def test_exceptions(): + X = MatrixSymbol('x', 10, 20) + raises(IndexError, lambda: X[0:12, 2]) + raises(IndexError, lambda: X[0:9, 22]) + raises(IndexError, lambda: X[-1:5, 2]) + +@XFAIL +def test_symmetry(): + X = MatrixSymbol('x', 10, 10) + Y = X[:5, 5:] + with assuming(Q.symmetric(X)): + assert Y.T == X[5:, :5] + +def test_slice_of_slice(): + X = MatrixSymbol('x', 10, 10) + assert X[2, :][:, 3][0, 0] == X[2, 3] + assert X[:5, :5][:4, :4] == X[:4, :4] + assert X[1:5, 2:6][1:3, 2] == X[2:4, 4] + assert X[1:9:2, 2:6][1:3, 2] == X[3:7:2, 4] + +def test_negative_index(): + X = MatrixSymbol('x', 10, 10) + assert X[-1, :] == X[9, :] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19576c8935da455743d27f0a263caecca94f59f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__init__.py @@ -0,0 +1,67 @@ +""" +Number theory module (primes, etc) +""" + +from .generate import nextprime, prevprime, prime, primepi, primerange, \ + randprime, Sieve, sieve, primorial, cycle_length, composite, compositepi +from .primetest import isprime, is_gaussian_prime, is_mersenne_prime +from .factor_ import divisors, proper_divisors, factorint, multiplicity, \ + multiplicity_in_factorial, perfect_power, factor_cache, pollard_pm1, \ + pollard_rho, primefactors, totient, \ + divisor_count, proper_divisor_count, divisor_sigma, factorrat, \ + reduced_totient, primenu, primeomega, mersenne_prime_exponent, \ + is_perfect, is_abundant, is_deficient, is_amicable, is_carmichael, \ + abundance, dra, drm + +from .partitions_ import npartitions +from .residue_ntheory import is_primitive_root, is_quad_residue, \ + legendre_symbol, jacobi_symbol, n_order, sqrt_mod, quadratic_residues, \ + primitive_root, nthroot_mod, is_nthpow_residue, sqrt_mod_iter, mobius, \ + discrete_log, quadratic_congruence, polynomial_congruence +from .multinomial import binomial_coefficients, binomial_coefficients_list, \ + multinomial_coefficients +from .continued_fraction import continued_fraction_periodic, \ + continued_fraction_iterator, continued_fraction_reduce, \ + continued_fraction_convergents, continued_fraction +from .digits import count_digits, digits, is_palindromic +from .egyptian_fraction import egyptian_fraction +from .ecm import ecm +from .qs import qs, qs_factor +__all__ = [ + 'nextprime', 'prevprime', 'prime', 'primepi', 'primerange', 'randprime', + 'Sieve', 'sieve', 'primorial', 'cycle_length', 'composite', 'compositepi', + + 'isprime', 'is_gaussian_prime', 'is_mersenne_prime', + + + 'divisors', 'proper_divisors', 'factorint', 'multiplicity', 'perfect_power', + 'pollard_pm1', 'factor_cache', 'pollard_rho', 'primefactors', 'totient', + 'divisor_count', 'proper_divisor_count', 'divisor_sigma', 'factorrat', + 'reduced_totient', 'primenu', 'primeomega', 'mersenne_prime_exponent', + 'is_perfect', 'is_abundant', 'is_deficient', 'is_amicable', + 'is_carmichael', 'abundance', 'dra', 'drm', 'multiplicity_in_factorial', + + 'npartitions', + + 'is_primitive_root', 'is_quad_residue', 'legendre_symbol', + 'jacobi_symbol', 'n_order', 'sqrt_mod', 'quadratic_residues', + 'primitive_root', 'nthroot_mod', 'is_nthpow_residue', 'sqrt_mod_iter', + 'mobius', 'discrete_log', 'quadratic_congruence', 'polynomial_congruence', + + 'binomial_coefficients', 'binomial_coefficients_list', + 'multinomial_coefficients', + + 'continued_fraction_periodic', 'continued_fraction_iterator', + 'continued_fraction_reduce', 'continued_fraction_convergents', + 'continued_fraction', + + 'digits', + 'count_digits', + 'is_palindromic', + + 'egyptian_fraction', + + 'ecm', + + 'qs', 'qs_factor', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca6ddf05ef2a80370650675a29b9237e67d43e92 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8b9b10c9b815cf4df27302b7accb75345a4a39 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705e957629d87c0bedabf825efeddb23e4a0ecd2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/digits.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/digits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb696e356a60b79ed031ddb1dc2ee4587c6fc5fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/digits.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/ecm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/ecm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9043a35b9b2009f205494c4d343d06a891ce864 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/ecm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..523e806ea9022408491d24c4e2d8d0e14ea3f84a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab51c3c83b513822297e2a1bf32b183adff9790 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/factor_.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/factor_.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a18f45fc6ea4bcc9e3928d13b33ff7748380c276 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/factor_.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/generate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/generate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..617d4fa7a965eb5ecf94393ff2d8e8c4f32e8cd7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/generate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/modular.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/modular.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f22406da925be1273041edc51b8800490809300f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/modular.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d6191de7bacf9e855a19ce464a3677c20f31af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4996a02498c3b7f0fdd520882aaf61d3ed5bde Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/primetest.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/primetest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6025f8a5600c2c2955f85f97d9e14ff020cf6d67 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/primetest.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/qs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/qs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da9cacf5566a78c175c5fed7b749d72bd7b387c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/qs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e02007244b543c7c5a8fb3611b77b811d8333ec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/bbp_pi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/bbp_pi.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ff4b755d74d4e075ac7195f991c8182d175693 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/bbp_pi.py @@ -0,0 +1,190 @@ +''' +This implementation is a heavily modified fixed point implementation of +BBP_formula for calculating the nth position of pi. The original hosted +at: https://web.archive.org/web/20151116045029/http://en.literateprograms.org/Pi_with_the_BBP_formula_(Python) + +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sub-license, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +Modifications: + +1.Once the nth digit and desired number of digits is selected, the +number of digits of working precision is calculated to ensure that +the hexadecimal digits returned are accurate. This is calculated as + + int(math.log(start + prec)/math.log(16) + prec + 3) + --------------------------------------- -------- + / / + number of hex digits additional digits + +This was checked by the following code which completed without +errors (and dig are the digits included in the test_bbp.py file): + + for i in range(0,1000): + for j in range(1,1000): + a, b = pi_hex_digits(i, j), dig[i:i+j] + if a != b: + print('%s\n%s'%(a,b)) + +Deceasing the additional digits by 1 generated errors, so '3' is +the smallest additional precision needed to calculate the above +loop without errors. The following trailing 10 digits were also +checked to be accurate (and the times were slightly faster with +some of the constant modifications that were made): + + >> from time import time + >> t=time();pi_hex_digits(10**2-10 + 1, 10), time()-t + ('e90c6cc0ac', 0.0) + >> t=time();pi_hex_digits(10**4-10 + 1, 10), time()-t + ('26aab49ec6', 0.17100000381469727) + >> t=time();pi_hex_digits(10**5-10 + 1, 10), time()-t + ('a22673c1a5', 4.7109999656677246) + >> t=time();pi_hex_digits(10**6-10 + 1, 10), time()-t + ('9ffd342362', 59.985999822616577) + >> t=time();pi_hex_digits(10**7-10 + 1, 10), time()-t + ('c1a42e06a1', 689.51800012588501) + +2. The while loop to evaluate whether the series has converged quits +when the addition amount `dt` has dropped to zero. + +3. the formatting string to convert the decimal to hexadecimal is +calculated for the given precision. + +4. pi_hex_digits(n) changed to have coefficient to the formula in an +array (perhaps just a matter of preference). + +''' + +from sympy.utilities.misc import as_int + + +def _series(j, n, prec=14): + + # Left sum from the bbp algorithm + s = 0 + D = _dn(n, prec) + D4 = 4 * D + d = j + for k in range(n + 1): + s += (pow(16, n - k, d) << D4) // d + d += 8 + + # Right sum iterates to infinity for full precision, but we + # stop at the point where one iteration is beyond the precision + # specified. + + t = 0 + k = n + 1 + e = D4 - 4 # 4*(D + n - k) + d = 8 * k + j + while True: + dt = (1 << e) // d + if not dt: + break + t += dt + # k += 1 + e -= 4 + d += 8 + total = s + t + + return total + + +def pi_hex_digits(n, prec=14): + """Returns a string containing ``prec`` (default 14) digits + starting at the nth digit of pi in hex. Counting of digits + starts at 0 and the decimal is not counted, so for n = 0 the + returned value starts with 3; n = 1 corresponds to the first + digit past the decimal point (which in hex is 2). + + Parameters + ========== + + n : non-negative integer + prec : non-negative integer. default = 14 + + Returns + ======= + + str : Returns a string containing ``prec`` digits + starting at the nth digit of pi in hex. + If ``prec`` = 0, returns empty string. + + Raises + ====== + + ValueError + If ``n`` < 0 or ``prec`` < 0. + Or ``n`` or ``prec`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory.bbp_pi import pi_hex_digits + >>> pi_hex_digits(0) + '3243f6a8885a30' + >>> pi_hex_digits(0, 3) + '324' + + These are consistent with the following results + + >>> import math + >>> hex(int(math.pi * 2**((14-1)*4))) + '0x3243f6a8885a30' + >>> hex(int(math.pi * 2**((3-1)*4))) + '0x324' + + References + ========== + + .. [1] http://www.numberworld.org/digits/Pi/ + """ + n, prec = as_int(n), as_int(prec) + if n < 0: + raise ValueError('n cannot be negative') + if prec < 0: + raise ValueError('prec cannot be negative') + if prec == 0: + return '' + + # main of implementation arrays holding formulae coefficients + n -= 1 + a = [4, 2, 1, 1] + j = [1, 4, 5, 6] + + #formulae + D = _dn(n, prec) + x = + (a[0]*_series(j[0], n, prec) + - a[1]*_series(j[1], n, prec) + - a[2]*_series(j[2], n, prec) + - a[3]*_series(j[3], n, prec)) & (16**D - 1) + + s = ("%0" + "%ix" % prec) % (x // 16**(D - prec)) + return s + + +def _dn(n, prec): + # controller for n dependence on precision + # n = starting digit index + # prec = the number of total digits to compute + n += 1 # because we subtract 1 for _series + + # assert int(math.log(n + prec)/math.log(16)) ==\ + # ((n + prec).bit_length() - 1) // 4 + return ((n + prec).bit_length() - 1) // 4 + prec + 3 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/continued_fraction.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/continued_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..62f8e2d729ada3414a87d6f0583e06bee2a2b220 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/continued_fraction.py @@ -0,0 +1,369 @@ +from __future__ import annotations +import itertools +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import Integer, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify +from sympy.utilities.misc import as_int + + +def continued_fraction(a) -> list: + """Return the continued fraction representation of a Rational or + quadratic irrational. + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction + >>> from sympy import sqrt + >>> continued_fraction((1 + 2*sqrt(3))/5) + [0, 1, [8, 3, 34, 3]] + + See Also + ======== + continued_fraction_periodic, continued_fraction_reduce, continued_fraction_convergents + """ + e = _sympify(a) + if all(i.is_Rational for i in e.atoms()): + if e.is_Integer: + return continued_fraction_periodic(e, 1, 0) + elif e.is_Rational: + return continued_fraction_periodic(e.p, e.q, 0) + elif e.is_Pow and e.exp is S.Half and e.base.is_Integer: + return continued_fraction_periodic(0, 1, e.base) + elif e.is_Mul and len(e.args) == 2 and ( + e.args[0].is_Rational and + e.args[1].is_Pow and + e.args[1].base.is_Integer and + e.args[1].exp is S.Half): + a, b = e.args + return continued_fraction_periodic(0, a.q, b.base, a.p) + else: + # this should not have to work very hard- no + # simplification, cancel, etc... which should be + # done by the user. e.g. This is a fancy 1 but + # the user should simplify it first: + # sqrt(2)*(1 + sqrt(2))/(sqrt(2) + 2) + p, d = e.expand().as_numer_denom() + if d.is_Integer: + if p.is_Rational: + return continued_fraction_periodic(p, d) + # look for a + b*c + # with c = sqrt(s) + if p.is_Add and len(p.args) == 2: + a, bc = p.args + else: + a = S.Zero + bc = p + if a.is_Integer: + b = S.NaN + if bc.is_Mul and len(bc.args) == 2: + b, c = bc.args + elif bc.is_Pow: + b = Integer(1) + c = bc + if b.is_Integer and ( + c.is_Pow and c.exp is S.Half and + c.base.is_Integer): + # (a + b*sqrt(c))/d + c = c.base + return continued_fraction_periodic(a, d, c, b) + raise ValueError( + 'expecting a rational or quadratic irrational, not %s' % e) + + +def continued_fraction_periodic(p, q, d=0, s=1) -> list: + r""" + Find the periodic continued fraction expansion of a quadratic irrational. + + Compute the continued fraction expansion of a rational or a + quadratic irrational number, i.e. `\frac{p + s\sqrt{d}}{q}`, where + `p`, `q \ne 0` and `d \ge 0` are integers. + + Returns the continued fraction representation (canonical form) as + a list of integers, optionally ending (for quadratic irrationals) + with list of integers representing the repeating digits. + + Parameters + ========== + + p : int + the rational part of the number's numerator + q : int + the denominator of the number + d : int, optional + the irrational part (discriminator) of the number's numerator + s : int, optional + the coefficient of the irrational part + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction_periodic + >>> continued_fraction_periodic(3, 2, 7) + [2, [1, 4, 1, 1]] + + Golden ratio has the simplest continued fraction expansion: + + >>> continued_fraction_periodic(1, 2, 5) + [[1]] + + If the discriminator is zero or a perfect square then the number will be a + rational number: + + >>> continued_fraction_periodic(4, 3, 0) + [1, 3] + >>> continued_fraction_periodic(4, 3, 49) + [3, 1, 2] + + See Also + ======== + + continued_fraction_iterator, continued_fraction_reduce + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Periodic_continued_fraction + .. [2] K. Rosen. Elementary Number theory and its applications. + Addison-Wesley, 3 Sub edition, pages 379-381, January 1992. + + """ + from sympy.functions import sqrt, floor + + p, q, d, s = list(map(as_int, [p, q, d, s])) + + if d < 0: + raise ValueError("expected non-negative for `d` but got %s" % d) + + if q == 0: + raise ValueError("The denominator cannot be 0.") + + if not s: + d = 0 + + # check for rational case + sd = sqrt(d) + if sd.is_Integer: + return list(continued_fraction_iterator(Rational(p + s*sd, q))) + + # irrational case with sd != Integer + if q < 0: + p, q, s = -p, -q, -s + + n = (p + s*sd)/q + if n < 0: + w = floor(-n) + f = -n - w + one_f = continued_fraction(1 - f) # 1-f < 1 so cf is [0 ... [...]] + one_f[0] -= w + 1 + return one_f + + d *= s**2 + sd *= s + + if (d - p**2)%q: + d *= q**2 + sd *= q + p *= q + q *= q + + terms: list[int] = [] + pq = {} + + while (p, q) not in pq: + pq[(p, q)] = len(terms) + terms.append((p + sd)//q) + p = terms[-1]*q - p + q = (d - p**2)//q + + i = pq[(p, q)] + return terms[:i] + [terms[i:]] # type: ignore + + +def continued_fraction_reduce(cf): + """ + Reduce a continued fraction to a rational or quadratic irrational. + + Compute the rational or quadratic irrational number from its + terminating or periodic continued fraction expansion. The + continued fraction expansion (cf) should be supplied as a + terminating iterator supplying the terms of the expansion. For + terminating continued fractions, this is equivalent to + ``list(continued_fraction_convergents(cf))[-1]``, only a little more + efficient. If the expansion has a repeating part, a list of the + repeating terms should be returned as the last element from the + iterator. This is the format returned by + continued_fraction_periodic. + + For quadratic irrationals, returns the largest solution found, + which is generally the one sought, if the fraction is in canonical + form (all terms positive except possibly the first). + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction_reduce + >>> continued_fraction_reduce([1, 2, 3, 4, 5]) + 225/157 + >>> continued_fraction_reduce([-2, 1, 9, 7, 1, 2]) + -256/233 + >>> continued_fraction_reduce([2, 1, 2, 1, 1, 4, 1, 1, 6, 1, 1, 8]).n(10) + 2.718281835 + >>> continued_fraction_reduce([1, 4, 2, [3, 1]]) + (sqrt(21) + 287)/238 + >>> continued_fraction_reduce([[1]]) + (1 + sqrt(5))/2 + >>> from sympy.ntheory.continued_fraction import continued_fraction_periodic + >>> continued_fraction_reduce(continued_fraction_periodic(8, 5, 13)) + (sqrt(13) + 8)/5 + + See Also + ======== + + continued_fraction_periodic + + """ + from sympy.solvers import solve + + period = [] + x = Dummy('x') + + def untillist(cf): + for nxt in cf: + if isinstance(nxt, list): + period.extend(nxt) + yield x + break + yield nxt + + a = S.Zero + for a in continued_fraction_convergents(untillist(cf)): + pass + + if period: + y = Dummy('y') + solns = solve(continued_fraction_reduce(period + [y]) - y, y) + solns.sort() + pure = solns[-1] + rv = a.subs(x, pure).radsimp() + else: + rv = a + if rv.is_Add: + rv = factor_terms(rv) + if rv.is_Mul and rv.args[0] == -1: + rv = rv.func(*rv.args) + return rv + + +def continued_fraction_iterator(x): + """ + Return continued fraction expansion of x as iterator. + + Examples + ======== + + >>> from sympy import Rational, pi + >>> from sympy.ntheory.continued_fraction import continued_fraction_iterator + + >>> list(continued_fraction_iterator(Rational(3, 8))) + [0, 2, 1, 2] + >>> list(continued_fraction_iterator(Rational(-3, 8))) + [-1, 1, 1, 1, 2] + + >>> for i, v in enumerate(continued_fraction_iterator(pi)): + ... if i > 7: + ... break + ... print(v) + 3 + 7 + 15 + 1 + 292 + 1 + 1 + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Continued_fraction + + """ + from sympy.functions import floor + while True: + i = floor(x) + yield i + x -= i + if not x: + break + x = 1/x + + +def continued_fraction_convergents(cf): + """ + Return an iterator over the convergents of a continued fraction (cf). + + The parameter should be in either of the following to forms: + - A list of partial quotients, possibly with the last element being a list + of repeating partial quotients, such as might be returned by + continued_fraction and continued_fraction_periodic. + - An iterable returning successive partial quotients of the continued + fraction, such as might be returned by continued_fraction_iterator. + + In computing the convergents, the continued fraction need not be strictly + in canonical form (all integers, all but the first positive). + Rational and negative elements may be present in the expansion. + + Examples + ======== + + >>> from sympy.core import pi + >>> from sympy import S + >>> from sympy.ntheory.continued_fraction import \ + continued_fraction_convergents, continued_fraction_iterator + + >>> list(continued_fraction_convergents([0, 2, 1, 2])) + [0, 1/2, 1/3, 3/8] + + >>> list(continued_fraction_convergents([1, S('1/2'), -7, S('1/4')])) + [1, 3, 19/5, 7] + + >>> it = continued_fraction_convergents(continued_fraction_iterator(pi)) + >>> for n in range(7): + ... print(next(it)) + 3 + 22/7 + 333/106 + 355/113 + 103993/33102 + 104348/33215 + 208341/66317 + + >>> it = continued_fraction_convergents([1, [1, 2]]) # sqrt(3) + >>> for n in range(7): + ... print(next(it)) + 1 + 2 + 5/3 + 7/4 + 19/11 + 26/15 + 71/41 + + See Also + ======== + + continued_fraction_iterator, continued_fraction, continued_fraction_periodic + + """ + if isinstance(cf, list) and isinstance(cf[-1], list): + cf = itertools.chain(cf[:-1], itertools.cycle(cf[-1])) + p_2, q_2 = S.Zero, S.One + p_1, q_1 = S.One, S.Zero + for a in cf: + p, q = a*p_1 + p_2, a*q_1 + q_2 + p_2, q_2 = p_1, q_1 + p_1, q_1 = p, q + yield p/q diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/digits.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/digits.py new file mode 100644 index 0000000000000000000000000000000000000000..a0414815871f6f888ccd2823546ab2b0c2c9f515 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/digits.py @@ -0,0 +1,150 @@ +from collections import defaultdict + +from sympy.utilities.iterables import multiset, is_palindromic as _palindromic +from sympy.utilities.misc import as_int + + +def digits(n, b=10, digits=None): + """ + Return a list of the digits of ``n`` in base ``b``. The first + element in the list is ``b`` (or ``-b`` if ``n`` is negative). + + Examples + ======== + + >>> from sympy.ntheory.digits import digits + >>> digits(35) + [10, 3, 5] + + If the number is negative, the negative sign will be placed on the + base (which is the first element in the returned list): + + >>> digits(-35) + [-10, 3, 5] + + Bases other than 10 (and greater than 1) can be selected with ``b``: + + >>> digits(27, b=2) + [2, 1, 1, 0, 1, 1] + + Use the ``digits`` keyword if a certain number of digits is desired: + + >>> digits(35, digits=4) + [10, 0, 0, 3, 5] + + Parameters + ========== + + n: integer + The number whose digits are returned. + + b: integer + The base in which digits are computed. + + digits: integer (or None for all digits) + The number of digits to be returned (padded with zeros, if + necessary). + + See Also + ======== + sympy.core.intfunc.num_digits, count_digits + """ + + b = as_int(b) + n = as_int(n) + if b < 2: + raise ValueError("b must be greater than 1") + else: + x, y = abs(n), [] + while x >= b: + x, r = divmod(x, b) + y.append(r) + y.append(x) + y.append(-b if n < 0 else b) + y.reverse() + ndig = len(y) - 1 + if digits is not None: + if ndig > digits: + raise ValueError( + "For %s, at least %s digits are needed." % (n, ndig)) + elif ndig < digits: + y[1:1] = [0]*(digits - ndig) + return y + + +def count_digits(n, b=10): + """ + Return a dictionary whose keys are the digits of ``n`` in the + given base, ``b``, with keys indicating the digits appearing in the + number and values indicating how many times that digit appeared. + + Examples + ======== + + >>> from sympy.ntheory import count_digits + + >>> count_digits(1111339) + {1: 4, 3: 2, 9: 1} + + The digits returned are always represented in base-10 + but the number itself can be entered in any format that is + understood by Python; the base of the number can also be + given if it is different than 10: + + >>> n = 0xFA; n + 250 + >>> count_digits(_) + {0: 1, 2: 1, 5: 1} + >>> count_digits(n, 16) + {10: 1, 15: 1} + + The default dictionary will return a 0 for any digit that did + not appear in the number. For example, which digits appear 7 + times in ``77!``: + + >>> from sympy import factorial + >>> c77 = count_digits(factorial(77)) + >>> [i for i in range(10) if c77[i] == 7] + [1, 3, 7, 9] + + See Also + ======== + sympy.core.intfunc.num_digits, digits + """ + rv = defaultdict(int, multiset(digits(n, b)).items()) + rv.pop(b) if b in rv else rv.pop(-b) # b or -b is there + return rv + + +def is_palindromic(n, b=10): + """return True if ``n`` is the same when read from left to right + or right to left in the given base, ``b``. + + Examples + ======== + + >>> from sympy.ntheory import is_palindromic + + >>> all(is_palindromic(i) for i in (-11, 1, 22, 121)) + True + + The second argument allows you to test numbers in other + bases. For example, 88 is palindromic in base-10 but not + in base-8: + + >>> is_palindromic(88, 8) + False + + On the other hand, a number can be palindromic in base-8 but + not in base-10: + + >>> 0o121, is_palindromic(0o121) + (81, False) + + Or it might be palindromic in both bases: + + >>> oct(121), is_palindromic(121, 8) and is_palindromic(121) + ('0o171', True) + + """ + return _palindromic(digits(n, b), 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/ecm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/ecm.py new file mode 100644 index 0000000000000000000000000000000000000000..498c0c8fdf8478688465c4bae307818e9685b686 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/ecm.py @@ -0,0 +1,348 @@ +from math import log + +from sympy.core.random import _randint +from sympy.external.gmpy import gcd, invert, sqrt +from sympy.utilities.misc import as_int +from .generate import sieve, primerange +from .primetest import isprime + + +#----------------------------------------------------------------------------# +# # +# Lenstra's Elliptic Curve Factorization # +# # +#----------------------------------------------------------------------------# + + +class Point: + """Montgomery form of Points in an elliptic curve. + In this form, the addition and doubling of points + does not need any y-coordinate information thus + decreasing the number of operations. + Using Montgomery form we try to perform point addition + and doubling in least amount of multiplications. + + The elliptic curve used here is of the form + (E : b*y**2*z = x**3 + a*x**2*z + x*z**2). + The a_24 parameter is equal to (a + 2)/4. + + References + ========== + + .. [1] Kris Gaj, Soonhak Kwon, Patrick Baier, Paul Kohlbrenner, Hoang Le, Mohammed Khaleeluddin, Ramakrishna Bachimanchi, + Implementing the Elliptic Curve Method of Factoring in Reconfigurable Hardware, + Cryptographic Hardware and Embedded Systems - CHES 2006 (2006), pp. 119-133, + https://doi.org/10.1007/11894063_10 + https://www.hyperelliptic.org/tanja/SHARCS/talks06/Gaj.pdf + + """ + + def __init__(self, x_cord, z_cord, a_24, mod): + """ + Initial parameters for the Point class. + + Parameters + ========== + + x_cord : X coordinate of the Point + z_cord : Z coordinate of the Point + a_24 : Parameter of the elliptic curve in Montgomery form + mod : modulus + """ + self.x_cord = x_cord + self.z_cord = z_cord + self.a_24 = a_24 + self.mod = mod + + def __eq__(self, other): + """Two points are equal if X/Z of both points are equal + """ + if self.a_24 != other.a_24 or self.mod != other.mod: + return False + return self.x_cord * other.z_cord % self.mod ==\ + other.x_cord * self.z_cord % self.mod + + def add(self, Q, diff): + """ + Add two points self and Q where diff = self - Q. Moreover the assumption + is self.x_cord*Q.x_cord*(self.x_cord - Q.x_cord) != 0. This algorithm + requires 6 multiplications. Here the difference between the points + is already known and using this algorithm speeds up the addition + by reducing the number of multiplication required. Also in the + mont_ladder algorithm is constructed in a way so that the difference + between intermediate points is always equal to the initial point. + So, we always know what the difference between the point is. + + + Parameters + ========== + + Q : point on the curve in Montgomery form + diff : self - Q + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p2 = Point(13, 10, 7, 29) + >>> p3 = p2.add(p1, p1) + >>> p3.x_cord + 23 + >>> p3.z_cord + 17 + """ + u = (self.x_cord - self.z_cord)*(Q.x_cord + Q.z_cord) + v = (self.x_cord + self.z_cord)*(Q.x_cord - Q.z_cord) + add, subt = u + v, u - v + x_cord = diff.z_cord * add * add % self.mod + z_cord = diff.x_cord * subt * subt % self.mod + return Point(x_cord, z_cord, self.a_24, self.mod) + + def double(self): + """ + Doubles a point in an elliptic curve in Montgomery form. + This algorithm requires 5 multiplications. + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p2 = p1.double() + >>> p2.x_cord + 13 + >>> p2.z_cord + 10 + """ + u = pow(self.x_cord + self.z_cord, 2, self.mod) + v = pow(self.x_cord - self.z_cord, 2, self.mod) + diff = u - v + x_cord = u*v % self.mod + z_cord = diff*(v + self.a_24*diff) % self.mod + return Point(x_cord, z_cord, self.a_24, self.mod) + + def mont_ladder(self, k): + """ + Scalar multiplication of a point in Montgomery form + using Montgomery Ladder Algorithm. + A total of 11 multiplications are required in each step of this + algorithm. + + Parameters + ========== + + k : The positive integer multiplier + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p3 = p1.mont_ladder(3) + >>> p3.x_cord + 23 + >>> p3.z_cord + 17 + """ + Q = self + R = self.double() + for i in bin(k)[3:]: + if i == '1': + Q = R.add(Q, self) + R = R.double() + else: + R = Q.add(R, self) + Q = Q.double() + return Q + + +def _ecm_one_factor(n, B1=10000, B2=100000, max_curve=200, seed=None): + """Returns one factor of n using + Lenstra's 2 Stage Elliptic curve Factorization + with Suyama's Parameterization. Here Montgomery + arithmetic is used for fast computation of addition + and doubling of points in elliptic curve. + + Explanation + =========== + + This ECM method considers elliptic curves in Montgomery + form (E : b*y**2*z = x**3 + a*x**2*z + x*z**2) and involves + elliptic curve operations (mod N), where the elements in + Z are reduced (mod N). Since N is not a prime, E over FF(N) + is not really an elliptic curve but we can still do point additions + and doubling as if FF(N) was a field. + + Stage 1 : The basic algorithm involves taking a random point (P) on an + elliptic curve in FF(N). The compute k*P using Montgomery ladder algorithm. + Let q be an unknown factor of N. Then the order of the curve E, |E(FF(q))|, + might be a smooth number that divides k. Then we have k = l * |E(FF(q))| + for some l. For any point belonging to the curve E, |E(FF(q))|*P = O, + hence k*P = l*|E(FF(q))|*P. Thus kP.z_cord = 0 (mod q), and the unknownn + factor of N (q) can be recovered by taking gcd(kP.z_cord, N). + + Stage 2 : This is a continuation of Stage 1 if k*P != O. The idea utilize + the fact that even if kP != 0, the value of k might miss just one large + prime divisor of |E(FF(q))|. In this case we only need to compute the + scalar multiplication by p to get p*k*P = O. Here a second bound B2 + restrict the size of possible values of p. + + Parameters + ========== + + n : Number to be Factored. Assume that it is a composite number. + B1 : Stage 1 Bound. Must be an even number. + B2 : Stage 2 Bound. Must be an even number. + max_curve : Maximum number of curves generated + + Returns + ======= + + integer | None : a non-trivial divisor of ``n``. ``None`` if not found + + References + ========== + + .. [1] Carl Pomerance, Richard Crandall, Prime Numbers: A Computational Perspective, + 2nd Edition (2005), page 344, ISBN:978-0387252827 + """ + randint = _randint(seed) + + # When calculating T, if (B1 - 2*D) is negative, it cannot be calculated. + D = min(sqrt(B2), B1 // 2 - 1) + sieve.extend(D) + beta = [0] * D + S = [0] * D + k = 1 + for p in primerange(2, B1 + 1): + k *= pow(p, int(log(B1, p))) + + # Pre-calculate the prime numbers to be used in stage 2. + # Using the fact that the x-coordinates of point P and its + # inverse -P coincide, the number of primes to be checked + # in stage 2 can be reduced. + deltas_list = [] + for r in range(B1 + 2*D, B2 + 2*D, 4*D): + # d in deltas iff r+(2d+1) and/or r-(2d+1) is prime + deltas = {abs(q - r) >> 1 for q in primerange(r - 2*D, r + 2*D)} + deltas_list.append(list(deltas)) + + for _ in range(max_curve): + #Suyama's Parametrization + sigma = randint(6, n - 1) + u = (sigma**2 - 5) % n + v = (4*sigma) % n + u_3 = pow(u, 3, n) + + try: + # We use the elliptic curve y**2 = x**3 + a*x**2 + x + # where a = pow(v - u, 3, n)*(3*u + v)*invert(4*u_3*v, n) - 2 + # However, we do not declare a because it is more convenient + # to use a24 = (a + 2)*invert(4, n) in the calculation. + a24 = pow(v - u, 3, n)*(3*u + v)*invert(16*u_3*v, n) % n + except ZeroDivisionError: + #If the invert(16*u_3*v, n) doesn't exist (i.e., g != 1) + g = gcd(2*u_3*v, n) + #If g = n, try another curve + if g == n: + continue + return g + + Q = Point(u_3, pow(v, 3, n), a24, n) + Q = Q.mont_ladder(k) + g = gcd(Q.z_cord, n) + + #Stage 1 factor + if g != 1 and g != n: + return g + #Stage 1 failure. Q.z = 0, Try another curve + elif g == n: + continue + + #Stage 2 - Improved Standard Continuation + S[0] = Q + Q2 = Q.double() + S[1] = Q2.add(Q, Q) + beta[0] = (S[0].x_cord*S[0].z_cord) % n + beta[1] = (S[1].x_cord*S[1].z_cord) % n + for d in range(2, D): + S[d] = S[d - 1].add(Q2, S[d - 2]) + beta[d] = (S[d].x_cord*S[d].z_cord) % n + # i.e., S[i] = Q.mont_ladder(2*i + 1) + + g = 1 + W = Q.mont_ladder(4*D) + T = Q.mont_ladder(B1 - 2*D) + R = Q.mont_ladder(B1 + 2*D) + for deltas in deltas_list: + # R = Q.mont_ladder(r) where r in range(B1 + 2*D, B2 + 2*D, 4*D) + alpha = (R.x_cord*R.z_cord) % n + for delta in deltas: + # We want to calculate + # f = R.x_cord * S[delta].z_cord - S[delta].x_cord * R.z_cord + f = (R.x_cord - S[delta].x_cord)*\ + (R.z_cord + S[delta].z_cord) - alpha + beta[delta] + g = (g*f) % n + T, R = R, R.add(W, T) + g = gcd(n, g) + + #Stage 2 Factor found + if g != 1 and g != n: + return g + + +def ecm(n, B1=10000, B2=100000, max_curve=200, seed=1234): + """Performs factorization using Lenstra's Elliptic curve method. + + This function repeatedly calls ``_ecm_one_factor`` to compute the factors + of n. First all the small factors are taken out using trial division. + Then ``_ecm_one_factor`` is used to compute one factor at a time. + + Parameters + ========== + + n : Number to be Factored + B1 : Stage 1 Bound. Must be an even number. + B2 : Stage 2 Bound. Must be an even number. + max_curve : Maximum number of curves generated + seed : Initialize pseudorandom generator + + Examples + ======== + + >>> from sympy.ntheory import ecm + >>> ecm(25645121643901801) + {5394769, 4753701529} + >>> ecm(9804659461513846513) + {4641991, 2112166839943} + """ + from .factor_ import _perfect_power + n = as_int(n) + if B1 % 2 != 0 or B2 % 2 != 0: + raise ValueError("both bounds must be even") + TF_LIMIT = 100000 + factors = set() + for prime in sieve.primerange(2, TF_LIMIT): + if n % prime == 0: + factors.add(prime) + while(n % prime == 0): + n //= prime + + queue = [] + def check(m): + if isprime(m): + factors.add(m) + return + if result := _perfect_power(m, TF_LIMIT): + return check(result[0]) + queue.append(m) + check(n) + while queue: + n = queue.pop() + factor = _ecm_one_factor(n, B1, B2, max_curve, seed) + if factor is None: + raise ValueError("Increase the bounds") + check(factor) + check(n // factor) + return factors diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/egyptian_fraction.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/egyptian_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..8a42540b372042f596808684fef8e3fc57935b74 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/egyptian_fraction.py @@ -0,0 +1,223 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import (Integer, Rational) +from sympy.core.singleton import S +import sympy.polys + +from math import gcd + + +def egyptian_fraction(r, algorithm="Greedy"): + """ + Return the list of denominators of an Egyptian fraction + expansion [1]_ of the said rational `r`. + + Parameters + ========== + + r : Rational or (p, q) + a positive rational number, ``p/q``. + algorithm : { "Greedy", "Graham Jewett", "Takenouchi", "Golomb" }, optional + Denotes the algorithm to be used (the default is "Greedy"). + + Examples + ======== + + >>> from sympy import Rational + >>> from sympy.ntheory.egyptian_fraction import egyptian_fraction + >>> egyptian_fraction(Rational(3, 7)) + [3, 11, 231] + >>> egyptian_fraction((3, 7), "Graham Jewett") + [7, 8, 9, 56, 57, 72, 3192] + >>> egyptian_fraction((3, 7), "Takenouchi") + [4, 7, 28] + >>> egyptian_fraction((3, 7), "Golomb") + [3, 15, 35] + >>> egyptian_fraction((11, 5), "Golomb") + [1, 2, 3, 4, 9, 234, 1118, 2580] + + See Also + ======== + + sympy.core.numbers.Rational + + Notes + ===== + + Currently the following algorithms are supported: + + 1) Greedy Algorithm + + Also called the Fibonacci-Sylvester algorithm [2]_. + At each step, extract the largest unit fraction less + than the target and replace the target with the remainder. + + It has some distinct properties: + + a) Given `p/q` in lowest terms, generates an expansion of maximum + length `p`. Even as the numerators get large, the number of + terms is seldom more than a handful. + + b) Uses minimal memory. + + c) The terms can blow up (standard examples of this are 5/121 and + 31/311). The denominator is at most squared at each step + (doubly-exponential growth) and typically exhibits + singly-exponential growth. + + 2) Graham Jewett Algorithm + + The algorithm suggested by the result of Graham and Jewett. + Note that this has a tendency to blow up: the length of the + resulting expansion is always ``2**(x/gcd(x, y)) - 1``. See [3]_. + + 3) Takenouchi Algorithm + + The algorithm suggested by Takenouchi (1921). + Differs from the Graham-Jewett algorithm only in the handling + of duplicates. See [3]_. + + 4) Golomb's Algorithm + + A method given by Golumb (1962), using modular arithmetic and + inverses. It yields the same results as a method using continued + fractions proposed by Bleicher (1972). See [4]_. + + If the given rational is greater than or equal to 1, a greedy algorithm + of summing the harmonic sequence 1/1 + 1/2 + 1/3 + ... is used, taking + all the unit fractions of this sequence until adding one more would be + greater than the given number. This list of denominators is prefixed + to the result from the requested algorithm used on the remainder. For + example, if r is 8/3, using the Greedy algorithm, we get [1, 2, 3, 4, + 5, 6, 7, 14, 420], where the beginning of the sequence, [1, 2, 3, 4, 5, + 6, 7] is part of the harmonic sequence summing to 363/140, leaving a + remainder of 31/420, which yields [14, 420] by the Greedy algorithm. + The result of egyptian_fraction(Rational(8, 3), "Golomb") is [1, 2, 3, + 4, 5, 6, 7, 14, 574, 2788, 6460, 11590, 33062, 113820], and so on. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Egyptian_fraction + .. [2] https://en.wikipedia.org/wiki/Greedy_algorithm_for_Egyptian_fractions + .. [3] https://www.ics.uci.edu/~eppstein/numth/egypt/conflict.html + .. [4] https://web.archive.org/web/20180413004012/https://ami.ektf.hu/uploads/papers/finalpdf/AMI_42_from129to134.pdf + + """ + + if not isinstance(r, Rational): + if isinstance(r, (Tuple, tuple)) and len(r) == 2: + r = Rational(*r) + else: + raise ValueError("Value must be a Rational or tuple of ints") + if r <= 0: + raise ValueError("Value must be positive") + + # common cases that all methods agree on + x, y = r.as_numer_denom() + if y == 1 and x == 2: + return [Integer(i) for i in [1, 2, 3, 6]] + if x == y + 1: + return [S.One, y] + + prefix, rem = egypt_harmonic(r) + if rem == 0: + return prefix + # work in Python ints + x, y = rem.p, rem.q + # assert x < y and gcd(x, y) = 1 + + if algorithm == "Greedy": + postfix = egypt_greedy(x, y) + elif algorithm == "Graham Jewett": + postfix = egypt_graham_jewett(x, y) + elif algorithm == "Takenouchi": + postfix = egypt_takenouchi(x, y) + elif algorithm == "Golomb": + postfix = egypt_golomb(x, y) + else: + raise ValueError("Entered invalid algorithm") + return prefix + [Integer(i) for i in postfix] + + +def egypt_greedy(x, y): + # assumes gcd(x, y) == 1 + if x == 1: + return [y] + else: + a = (-y) % x + b = y*(y//x + 1) + c = gcd(a, b) + if c > 1: + num, denom = a//c, b//c + else: + num, denom = a, b + return [y//x + 1] + egypt_greedy(num, denom) + + +def egypt_graham_jewett(x, y): + # assumes gcd(x, y) == 1 + l = [y] * x + + # l is now a list of integers whose reciprocals sum to x/y. + # we shall now proceed to manipulate the elements of l without + # changing the reciprocated sum until all elements are unique. + + while len(l) != len(set(l)): + l.sort() # so the list has duplicates. find a smallest pair + for i in range(len(l) - 1): + if l[i] == l[i + 1]: + break + # we have now identified a pair of identical + # elements: l[i] and l[i + 1]. + # now comes the application of the result of graham and jewett: + l[i + 1] = l[i] + 1 + # and we just iterate that until the list has no duplicates. + l.append(l[i]*(l[i] + 1)) + return sorted(l) + + +def egypt_takenouchi(x, y): + # assumes gcd(x, y) == 1 + # special cases for 3/y + if x == 3: + if y % 2 == 0: + return [y//2, y] + i = (y - 1)//2 + j = i + 1 + k = j + i + return [j, k, j*k] + l = [y] * x + while len(l) != len(set(l)): + l.sort() + for i in range(len(l) - 1): + if l[i] == l[i + 1]: + break + k = l[i] + if k % 2 == 0: + l[i] = l[i] // 2 + del l[i + 1] + else: + l[i], l[i + 1] = (k + 1)//2, k*(k + 1)//2 + return sorted(l) + + +def egypt_golomb(x, y): + # assumes x < y and gcd(x, y) == 1 + if x == 1: + return [y] + xp = sympy.polys.ZZ.invert(int(x), int(y)) + rv = [xp*y] + rv.extend(egypt_golomb((x*xp - 1)//y, xp)) + return sorted(rv) + + +def egypt_harmonic(r): + # assumes r is Rational + rv = [] + d = S.One + acc = S.Zero + while acc + 1/d <= r: + acc += 1/d + rv.append(d) + d += 1 + return (rv, r - acc) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/elliptic_curve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/elliptic_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..c969470a6c19a3d17e637529b6615eeba326e84a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/elliptic_curve.py @@ -0,0 +1,397 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.polys.domains import FiniteField, QQ, RationalField, FF +from sympy.polys.polytools import Poly +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int +from .factor_ import divisors +from .residue_ntheory import polynomial_congruence + + +class EllipticCurve: + """ + Create the following Elliptic Curve over domain. + + `y^{2} + a_{1} x y + a_{3} y = x^{3} + a_{2} x^{2} + a_{4} x + a_{6}` + + The default domain is ``QQ``. If no coefficient ``a1``, ``a2``, ``a3``, + is given then it creates a curve with the following form: + + `y^{2} = x^{3} + a_{4} x + a_{6}` + + Examples + ======== + + References + ========== + + .. [1] J. Silverman "A Friendly Introduction to Number Theory" Third Edition + .. [2] https://mathworld.wolfram.com/EllipticDiscriminant.html + .. [3] G. Hardy, E. Wright "An Introduction to the Theory of Numbers" Sixth Edition + + """ + + def __init__(self, a4, a6, a1=0, a2=0, a3=0, modulus=0): + if modulus == 0: + domain = QQ + else: + domain = FF(modulus) + a1, a2, a3, a4, a6 = map(domain.convert, (a1, a2, a3, a4, a6)) + self._domain = domain + self.modulus = modulus + # Calculate discriminant + b2 = a1**2 + 4 * a2 + b4 = 2 * a4 + a1 * a3 + b6 = a3**2 + 4 * a6 + b8 = a1**2 * a6 + 4 * a2 * a6 - a1 * a3 * a4 + a2 * a3**2 - a4**2 + self._b2, self._b4, self._b6, self._b8 = b2, b4, b6, b8 + self._discrim = -b2**2 * b8 - 8 * b4**3 - 27 * b6**2 + 9 * b2 * b4 * b6 + self._a1 = a1 + self._a2 = a2 + self._a3 = a3 + self._a4 = a4 + self._a6 = a6 + x, y, z = symbols('x y z') + self.x, self.y, self.z = x, y, z + self._poly = Poly(y**2*z + a1*x*y*z + a3*y*z**2 - x**3 - a2*x**2*z - a4*x*z**2 - a6*z**3, domain=domain) + if isinstance(self._domain, FiniteField): + self._rank = 0 + elif isinstance(self._domain, RationalField): + self._rank = None + + def __call__(self, x, y, z=1): + return EllipticCurvePoint(x, y, z, self) + + def __contains__(self, point): + if is_sequence(point): + if len(point) == 2: + z1 = 1 + else: + z1 = point[2] + x1, y1 = point[:2] + elif isinstance(point, EllipticCurvePoint): + x1, y1, z1 = point.x, point.y, point.z + else: + raise ValueError('Invalid point.') + if self.characteristic == 0 and z1 == 0: + return True + return self._poly.subs({self.x: x1, self.y: y1, self.z: z1}) == 0 + + def __repr__(self): + return self._poly.__repr__() + + def minimal(self): + """ + Return minimal Weierstrass equation. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + + >>> e1 = EllipticCurve(-10, -20, 0, -1, 1) + >>> e1.minimal() + Poly(-x**3 + 13392*x*z**2 + y**2*z + 1080432*z**3, x, y, z, domain='QQ') + + """ + char = self.characteristic + if char == 2: + return self + if char == 3: + return EllipticCurve(self._b4/2, self._b6/4, a2=self._b2/4, modulus=self.modulus) + c4 = self._b2**2 - 24*self._b4 + c6 = -self._b2**3 + 36*self._b2*self._b4 - 216*self._b6 + return EllipticCurve(-27*c4, -54*c6, modulus=self.modulus) + + def points(self): + """ + Return points of curve over Finite Field. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(1, 1, 1, 1, 1, modulus=5) + >>> e2.points() + {(0, 2), (1, 4), (2, 0), (2, 2), (3, 0), (3, 1), (4, 0)} + + """ + + char = self.characteristic + all_pt = set() + if char >= 1: + for i in range(char): + congruence_eq = self._poly.subs({self.x: i, self.z: 1}).expr + sol = polynomial_congruence(congruence_eq, char) + all_pt.update((i, num) for num in sol) + return all_pt + else: + raise ValueError("Infinitely many points") + + def points_x(self, x): + """Returns points on the curve for the given x-coordinate.""" + pt = [] + if self._domain == QQ: + for y in solve(self._poly.subs(self.x, x)): + pt.append((x, y)) + else: + congruence_eq = self._poly.subs({self.x: x, self.z: 1}).expr + for y in polynomial_congruence(congruence_eq, self.characteristic): + pt.append((x, y)) + return pt + + def torsion_points(self): + """ + Return torsion points of curve over Rational number. + + Return point objects those are finite order. + According to Nagell-Lutz theorem, torsion point p(x, y) + x and y are integers, either y = 0 or y**2 is divisor + of discriminent. According to Mazur's theorem, there are + at most 15 points in torsion collection. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(-43, 166) + >>> sorted(e2.torsion_points()) + [(-5, -16), (-5, 16), O, (3, -8), (3, 8), (11, -32), (11, 32)] + + """ + if self.characteristic > 0: + raise ValueError("No torsion point for Finite Field.") + l = [EllipticCurvePoint.point_at_infinity(self)] + for xx in solve(self._poly.subs({self.y: 0, self.z: 1})): + if xx.is_rational: + l.append(self(xx, 0)) + for i in divisors(self.discriminant, generator=True): + j = int(i**.5) + if j**2 == i: + for xx in solve(self._poly.subs({self.y: j, self.z: 1})): + if not xx.is_rational: + continue + p = self(xx, j) + if p.order() != oo: + l.extend([p, -p]) + return l + + @property + def characteristic(self): + """ + Return domain characteristic. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(-43, 166) + >>> e2.characteristic + 0 + + """ + return self._domain.characteristic() + + @property + def discriminant(self): + """ + Return curve discriminant. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(0, 17) + >>> e2.discriminant + -124848 + + """ + return int(self._discrim) + + @property + def is_singular(self): + """ + Return True if curve discriminant is equal to zero. + """ + return self.discriminant == 0 + + @property + def j_invariant(self): + """ + Return curve j-invariant. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e1 = EllipticCurve(-2, 0, 0, 1, 1) + >>> e1.j_invariant + 1404928/389 + + """ + c4 = self._b2**2 - 24*self._b4 + return self._domain.to_sympy(c4**3 / self._discrim) + + @property + def order(self): + """ + Number of points in Finite field. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(1, 0, modulus=19) + >>> e2.order + 19 + + """ + if self.characteristic == 0: + raise NotImplementedError("Still not implemented") + return len(self.points()) + + @property + def rank(self): + """ + Number of independent points of infinite order. + + For Finite field, it must be 0. + """ + if self._rank is not None: + return self._rank + raise NotImplementedError("Still not implemented") + + +class EllipticCurvePoint: + """ + Point of Elliptic Curve + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e1 = EllipticCurve(-17, 16) + >>> p1 = e1(0, -4, 1) + >>> p2 = e1(1, 0) + >>> p1 + p2 + (15, -56) + >>> e3 = EllipticCurve(-1, 9) + >>> e3(1, -3) * 3 + (664/169, 17811/2197) + >>> (e3(1, -3) * 3).order() + oo + >>> e2 = EllipticCurve(-2, 0, 0, 1, 1) + >>> p = e2(-1,1) + >>> q = e2(0, -1) + >>> p+q + (4, 8) + >>> p-q + (1, 0) + >>> 3*p-5*q + (328/361, -2800/6859) + """ + + @staticmethod + def point_at_infinity(curve): + return EllipticCurvePoint(0, 1, 0, curve) + + def __init__(self, x, y, z, curve): + dom = curve._domain.convert + self.x = dom(x) + self.y = dom(y) + self.z = dom(z) + self._curve = curve + self._domain = self._curve._domain + if not self._curve.__contains__(self): + raise ValueError("The curve does not contain this point") + + def __add__(self, p): + if self.z == 0: + return p + if p.z == 0: + return self + x1, y1 = self.x/self.z, self.y/self.z + x2, y2 = p.x/p.z, p.y/p.z + a1 = self._curve._a1 + a2 = self._curve._a2 + a3 = self._curve._a3 + a4 = self._curve._a4 + a6 = self._curve._a6 + if x1 != x2: + slope = (y1 - y2) / (x1 - x2) + yint = (y1 * x2 - y2 * x1) / (x2 - x1) + else: + if (y1 + y2) == 0: + return self.point_at_infinity(self._curve) + slope = (3 * x1**2 + 2*a2*x1 + a4 - a1*y1) / (a1 * x1 + a3 + 2 * y1) + yint = (-x1**3 + a4*x1 + 2*a6 - a3*y1) / (a1*x1 + a3 + 2*y1) + x3 = slope**2 + a1*slope - a2 - x1 - x2 + y3 = -(slope + a1) * x3 - yint - a3 + return self._curve(x3, y3, 1) + + def __lt__(self, other): + return (self.x, self.y, self.z) < (other.x, other.y, other.z) + + def __mul__(self, n): + n = as_int(n) + r = self.point_at_infinity(self._curve) + if n == 0: + return r + if n < 0: + return -self * -n + p = self + while n: + if n & 1: + r = r + p + n >>= 1 + p = p + p + return r + + def __rmul__(self, n): + return self * n + + def __neg__(self): + return EllipticCurvePoint(self.x, -self.y - self._curve._a1*self.x - self._curve._a3, self.z, self._curve) + + def __repr__(self): + if self.z == 0: + return 'O' + dom = self._curve._domain + try: + return '({}, {})'.format(dom.to_sympy(self.x), dom.to_sympy(self.y)) + except TypeError: + pass + return '({}, {})'.format(self.x, self.y) + + def __sub__(self, other): + return self.__add__(-other) + + def order(self): + """ + Return point order n where nP = 0. + + """ + if self.z == 0: + return 1 + if self.y == 0: # P = -P + return 2 + p = self * 2 + if p.y == -self.y: # 2P = -P + return 3 + i = 2 + if self._domain != QQ: + while int(p.x) == p.x and int(p.y) == p.y: + p = self + p + i += 1 + if p.z == 0: + return i + return oo + while p.x.numerator == p.x and p.y.numerator == p.y: + p = self + p + i += 1 + if i > 12: + return oo + if p.z == 0: + return i + return oo diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/factor_.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/factor_.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc6ac81c237f000e55014f5e170b27b41335786 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/factor_.py @@ -0,0 +1,2841 @@ +""" +Integer factorization +""" +from __future__ import annotations + +from bisect import bisect_left +from collections import defaultdict, OrderedDict +from collections.abc import MutableMapping +import math + +from sympy.core.containers import Dict +from sympy.core.mul import Mul +from sympy.core.numbers import Rational, Integer +from sympy.core.intfunc import num_digits +from sympy.core.power import Pow +from sympy.core.random import _randint +from sympy.core.singleton import S +from sympy.external.gmpy import (SYMPY_INTS, gcd, sqrt as isqrt, + sqrtrem, iroot, bit_scan1, remove) +from .primetest import isprime, MERSENNE_PRIME_EXPONENTS, is_mersenne_prime +from .generate import sieve, primerange, nextprime +from .digits import digits +from sympy.utilities.decorator import deprecated +from sympy.utilities.iterables import flatten +from sympy.utilities.misc import as_int, filldedent +from .ecm import _ecm_one_factor + + +def smoothness(n): + """ + Return the B-smooth and B-power smooth values of n. + + The smoothness of n is the largest prime factor of n; the power- + smoothness is the largest divisor raised to its multiplicity. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import smoothness + >>> smoothness(2**7*3**2) + (3, 128) + >>> smoothness(2**4*13) + (13, 16) + >>> smoothness(2) + (2, 2) + + See Also + ======== + + factorint, smoothness_p + """ + + if n == 1: + return (1, 1) # not prime, but otherwise this causes headaches + facs = factorint(n) + return max(facs), max(m**facs[m] for m in facs) + + +def smoothness_p(n, m=-1, power=0, visual=None): + """ + Return a list of [m, (p, (M, sm(p + m), psm(p + m)))...] + where: + + 1. p**M is the base-p divisor of n + 2. sm(p + m) is the smoothness of p + m (m = -1 by default) + 3. psm(p + m) is the power smoothness of p + m + + The list is sorted according to smoothness (default) or by power smoothness + if power=1. + + The smoothness of the numbers to the left (m = -1) or right (m = 1) of a + factor govern the results that are obtained from the p +/- 1 type factoring + methods. + + >>> from sympy.ntheory.factor_ import smoothness_p, factorint + >>> smoothness_p(10431, m=1) + (1, [(3, (2, 2, 4)), (19, (1, 5, 5)), (61, (1, 31, 31))]) + >>> smoothness_p(10431) + (-1, [(3, (2, 2, 2)), (19, (1, 3, 9)), (61, (1, 5, 5))]) + >>> smoothness_p(10431, power=1) + (-1, [(3, (2, 2, 2)), (61, (1, 5, 5)), (19, (1, 3, 9))]) + + If visual=True then an annotated string will be returned: + + >>> print(smoothness_p(21477639576571, visual=1)) + p**i=4410317**1 has p-1 B=1787, B-pow=1787 + p**i=4869863**1 has p-1 B=2434931, B-pow=2434931 + + This string can also be generated directly from a factorization dictionary + and vice versa: + + >>> factorint(17*9) + {3: 2, 17: 1} + >>> smoothness_p(_) + 'p**i=3**2 has p-1 B=2, B-pow=2\\np**i=17**1 has p-1 B=2, B-pow=16' + >>> smoothness_p(_) + {3: 2, 17: 1} + + The table of the output logic is: + + ====== ====== ======= ======= + | Visual + ------ ---------------------- + Input True False other + ====== ====== ======= ======= + dict str tuple str + str str tuple dict + tuple str tuple str + n str tuple tuple + mul str tuple tuple + ====== ====== ======= ======= + + See Also + ======== + + factorint, smoothness + """ + + # visual must be True, False or other (stored as None) + if visual in (1, 0): + visual = bool(visual) + elif visual not in (True, False): + visual = None + + if isinstance(n, str): + if visual: + return n + d = {} + for li in n.splitlines(): + k, v = [int(i) for i in + li.split('has')[0].split('=')[1].split('**')] + d[k] = v + if visual is not True and visual is not False: + return d + return smoothness_p(d, visual=False) + elif not isinstance(n, tuple): + facs = factorint(n, visual=False) + + if power: + k = -1 + else: + k = 1 + if isinstance(n, tuple): + rv = n + else: + rv = (m, sorted([(f, + tuple([M] + list(smoothness(f + m)))) + for f, M in list(facs.items())], + key=lambda x: (x[1][k], x[0]))) + + if visual is False or (visual is not True) and (type(n) in [int, Mul]): + return rv + lines = [] + for dat in rv[1]: + dat = flatten(dat) + dat.insert(2, m) + lines.append('p**i=%i**%i has p%+i B=%i, B-pow=%i' % tuple(dat)) + return '\n'.join(lines) + + +def multiplicity(p, n): + """ + Find the greatest integer m such that p**m divides n. + + Examples + ======== + + >>> from sympy import multiplicity, Rational + >>> [multiplicity(5, n) for n in [8, 5, 25, 125, 250]] + [0, 1, 2, 3, 3] + >>> multiplicity(3, Rational(1, 9)) + -2 + + Note: when checking for the multiplicity of a number in a + large factorial it is most efficient to send it as an unevaluated + factorial or to call ``multiplicity_in_factorial`` directly: + + >>> from sympy.ntheory import multiplicity_in_factorial + >>> from sympy import factorial + >>> p = factorial(25) + >>> n = 2**100 + >>> nfac = factorial(n, evaluate=False) + >>> multiplicity(p, nfac) + 52818775009509558395695966887 + >>> _ == multiplicity_in_factorial(p, n) + True + + See Also + ======== + + trailing + + """ + try: + p, n = as_int(p), as_int(n) + except ValueError: + from sympy.functions.combinatorial.factorials import factorial + if all(isinstance(i, (SYMPY_INTS, Rational)) for i in (p, n)): + p = Rational(p) + n = Rational(n) + if p.q == 1: + if n.p == 1: + return -multiplicity(p.p, n.q) + return multiplicity(p.p, n.p) - multiplicity(p.p, n.q) + elif p.p == 1: + return multiplicity(p.q, n.q) + else: + like = min( + multiplicity(p.p, n.p), + multiplicity(p.q, n.q)) + cross = min( + multiplicity(p.q, n.p), + multiplicity(p.p, n.q)) + return like - cross + elif (isinstance(p, (SYMPY_INTS, Integer)) and + isinstance(n, factorial) and + isinstance(n.args[0], Integer) and + n.args[0] >= 0): + return multiplicity_in_factorial(p, n.args[0]) + raise ValueError('expecting ints or fractions, got %s and %s' % (p, n)) + + if n == 0: + raise ValueError('no such integer exists: multiplicity of %s is not-defined' %(n)) + return remove(n, p)[1] + + +def multiplicity_in_factorial(p, n): + """return the largest integer ``m`` such that ``p**m`` divides ``n!`` + without calculating the factorial of ``n``. + + Parameters + ========== + + p : Integer + positive integer + n : Integer + non-negative integer + + Examples + ======== + + >>> from sympy.ntheory import multiplicity_in_factorial + >>> from sympy import factorial + + >>> multiplicity_in_factorial(2, 3) + 1 + + An instructive use of this is to tell how many trailing zeros + a given factorial has. For example, there are 6 in 25!: + + >>> factorial(25) + 15511210043330985984000000 + >>> multiplicity_in_factorial(10, 25) + 6 + + For large factorials, it is much faster/feasible to use + this function rather than computing the actual factorial: + + >>> multiplicity_in_factorial(factorial(25), 2**100) + 52818775009509558395695966887 + + See Also + ======== + + multiplicity + + """ + + p, n = as_int(p), as_int(n) + + if p <= 0: + raise ValueError('expecting positive integer got %s' % p ) + + if n < 0: + raise ValueError('expecting non-negative integer got %s' % n ) + + # keep only the largest of a given multiplicity since those + # of a given multiplicity will be goverened by the behavior + # of the largest factor + f = defaultdict(int) + for k, v in factorint(p).items(): + f[v] = max(k, f[v]) + # multiplicity of p in n! depends on multiplicity + # of prime `k` in p, so we floor divide by `v` + # and keep it if smaller than the multiplicity of p + # seen so far + return min((n + k - sum(digits(n, k)))//(k - 1)//v for v, k in f.items()) + + +def _perfect_power(n, next_p=2): + """ Return integers ``(b, e)`` such that ``n == b**e`` if ``n`` is a unique + perfect power with ``e > 1``, else ``False`` (e.g. 1 is not a perfect power). + + Explanation + =========== + + This is a low-level helper for ``perfect_power``, for internal use. + + Parameters + ========== + + n : int + assume that n is a nonnegative integer + next_p : int + Assume that n has no factor less than next_p. + i.e., all(n % p for p in range(2, next_p)) is True + + Examples + ======== + >>> from sympy.ntheory.factor_ import _perfect_power + >>> _perfect_power(16) + (2, 4) + >>> _perfect_power(17) + False + + """ + if n <= 3: + return False + + factors = {} + g = 0 + multi = 1 + + def done(n, factors, g, multi): + g = gcd(g, multi) + if g == 1: + return False + factors[n] = multi + return math.prod(p**(e//g) for p, e in factors.items()), g + + # If n is small, only trial factoring is faster + if n <= 1_000_000: + n = _factorint_small(factors, n, 1_000, 1_000, next_p)[0] + if n > 1: + return False + g = gcd(*factors.values()) + if g == 1: + return False + return math.prod(p**(e//g) for p, e in factors.items()), g + + # divide by 2 + if next_p < 3: + g = bit_scan1(n) + if g: + if g == 1: + return False + n >>= g + factors[2] = g + if n == 1: + return 2, g + else: + # If `m**g`, then we have found perfect power. + # Otherwise, there is no possibility of perfect power, especially if `g` is prime. + m, _exact = iroot(n, g) + if _exact: + return 2*m, g + elif isprime(g): + return False + next_p = 3 + + # square number? + while n & 7 == 1: # n % 8 == 1: + m, _exact = iroot(n, 2) + if _exact: + n = m + multi <<= 1 + else: + break + if n < next_p**3: + return done(n, factors, g, multi) + + # trial factoring + # Since the maximum value an exponent can take is `log_{next_p}(n)`, + # the number of exponents to be checked can be reduced by performing a trial factoring. + # The value of `tf_max` needs more consideration. + tf_max = n.bit_length()//27 + 24 + if next_p < tf_max: + for p in primerange(next_p, tf_max): + m, t = remove(n, p) + if t: + n = m + t *= multi + _g = gcd(g, t) + if _g == 1: + return False + factors[p] = t + if n == 1: + return math.prod(p**(e//_g) + for p, e in factors.items()), _g + elif g == 0 or _g < g: # If g is updated + g = _g + m, _exact = iroot(n**multi, g) + if _exact: + return m * math.prod(p**(e//g) + for p, e in factors.items()), g + elif isprime(g): + return False + next_p = tf_max + if n < next_p**3: + return done(n, factors, g, multi) + + # check iroot + if g: + # If g is non-zero, the exponent is a divisor of g. + # 2 can be omitted since it has already been checked. + prime_iter = sorted(factorint(g >> bit_scan1(g)).keys()) + else: + # The maximum possible value of the exponent is `log_{next_p}(n)`. + # To compensate for the presence of computational error, 2 is added. + prime_iter = primerange(3, int(math.log(n, next_p)) + 2) + logn = math.log2(n) + threshold = logn / 40 # Threshold for direct calculation + for p in prime_iter: + if threshold < p: + # If p is large, find the power root p directly without `iroot`. + while True: + b = pow(2, logn / p) + rb = int(b + 0.5) + if abs(rb - b) < 0.01 and rb**p == n: + n = rb + multi *= p + logn = math.log2(n) + else: + break + else: + while True: + m, _exact = iroot(n, p) + if _exact: + n = m + multi *= p + logn = math.log2(n) + else: + break + if n < next_p**(p + 2): + break + return done(n, factors, g, multi) + + +def perfect_power(n, candidates=None, big=True, factor=True): + """ + Return ``(b, e)`` such that ``n`` == ``b**e`` if ``n`` is a unique + perfect power with ``e > 1``, else ``False`` (e.g. 1 is not a + perfect power). A ValueError is raised if ``n`` is not Rational. + + By default, the base is recursively decomposed and the exponents + collected so the largest possible ``e`` is sought. If ``big=False`` + then the smallest possible ``e`` (thus prime) will be chosen. + + If ``factor=True`` then simultaneous factorization of ``n`` is + attempted since finding a factor indicates the only possible root + for ``n``. This is True by default since only a few small factors will + be tested in the course of searching for the perfect power. + + The use of ``candidates`` is primarily for internal use; if provided, + False will be returned if ``n`` cannot be written as a power with one + of the candidates as an exponent and factoring (beyond testing for + a factor of 2) will not be attempted. + + Examples + ======== + + >>> from sympy import perfect_power, Rational + >>> perfect_power(16) + (2, 4) + >>> perfect_power(16, big=False) + (4, 2) + + Negative numbers can only have odd perfect powers: + + >>> perfect_power(-4) + False + >>> perfect_power(-8) + (-2, 3) + + Rationals are also recognized: + + >>> perfect_power(Rational(1, 2)**3) + (1/2, 3) + >>> perfect_power(Rational(-3, 2)**3) + (-3/2, 3) + + Notes + ===== + + To know whether an integer is a perfect power of 2 use + + >>> is2pow = lambda n: bool(n and not n & (n - 1)) + >>> [(i, is2pow(i)) for i in range(5)] + [(0, False), (1, True), (2, True), (3, False), (4, True)] + + It is not necessary to provide ``candidates``. When provided + it will be assumed that they are ints. The first one that is + larger than the computed maximum possible exponent will signal + failure for the routine. + + >>> perfect_power(3**8, [9]) + False + >>> perfect_power(3**8, [2, 4, 8]) + (3, 8) + >>> perfect_power(3**8, [4, 8], big=False) + (9, 4) + + See Also + ======== + sympy.core.intfunc.integer_nthroot + sympy.ntheory.primetest.is_square + """ + # negative handling + if n < 0: + if candidates is None: + pp = perfect_power(-n, big=True, factor=factor) + if not pp: + return False + + b, e = pp + e2 = e & (-e) + b, e = b ** e2, e // e2 + + if e <= 1: + return False + + if big or isprime(e): + return -b, e + + for p in primerange(3, e + 1): + if e % p == 0: + return - b ** (e // p), p + + odd_candidates = {i for i in candidates if i % 2} + if not odd_candidates: + return False + + pp = perfect_power(-n, odd_candidates, big, factor) + if pp: + return -pp[0], pp[1] + + return False + + # non-integer handling + if isinstance(n, Rational) and not isinstance(n, Integer): + p, q = n.p, n.q + + if p == 1: + qq = perfect_power(q, candidates, big, factor) + return (S.One / qq[0], qq[1]) if qq is not False else False + + if not (pp:=perfect_power(p, factor=factor)): + return False + if not (qq:=perfect_power(q, factor=factor)): + return False + (num_base, num_exp), (den_base, den_exp) = pp, qq + + def compute_tuple(exponent): + """Helper to compute final result given an exponent""" + new_num = num_base ** (num_exp // exponent) + new_den = den_base ** (den_exp // exponent) + return n.func(new_num, new_den), exponent + + if candidates: + valid_candidates = [i for i in candidates + if num_exp % i == 0 and den_exp % i == 0] + if not valid_candidates: + return False + + e = max(valid_candidates) if big else min(valid_candidates) + return compute_tuple(e) + + g = math.gcd(num_exp, den_exp) + if g == 1: + return False + + if big: + return compute_tuple(g) + + e = next(p for p in primerange(2, g + 1) if g % p == 0) + return compute_tuple(e) + + if candidates is not None: + candidates = set(candidates) + + # positive integer handling + n = as_int(n) + + if candidates is None and big: + return _perfect_power(n) + + if n <= 3: + # no unique exponent for 0, 1 + # 2 and 3 have exponents of 1 + return False + logn = math.log2(n) + max_possible = int(logn) + 2 # only check values less than this + not_square = n % 10 in [2, 3, 7, 8] # squares cannot end in 2, 3, 7, 8 + min_possible = 2 + not_square + if not candidates: + candidates = primerange(min_possible, max_possible) + else: + candidates = sorted([i for i in candidates + if min_possible <= i < max_possible]) + if n%2 == 0: + e = bit_scan1(n) + candidates = [i for i in candidates if e%i == 0] + if big: + candidates = reversed(candidates) + for e in candidates: + r, ok = iroot(n, e) + if ok: + return int(r), e + return False + + def _factors(): + rv = 2 + n % 2 + while True: + yield rv + rv = nextprime(rv) + + for fac, e in zip(_factors(), candidates): + # see if there is a factor present + if factor and n % fac == 0: + # find what the potential power is + e = remove(n, fac)[1] + # if it's a trivial power we are done + if e == 1: + return False + + # maybe the e-th root of n is exact + r, exact = iroot(n, e) + if not exact: + # Having a factor, we know that e is the maximal + # possible value for a root of n. + # If n = fac**e*m can be written as a perfect + # power then see if m can be written as r**E where + # gcd(e, E) != 1 so n = (fac**(e//E)*r)**E + m = n//fac**e + rE = perfect_power(m, candidates=divisors(e, generator=True)) + if not rE: + return False + else: + r, E = rE + r, e = fac**(e//E)*r, E + if not big: + e0 = primefactors(e) + if e0[0] != e: + r, e = r**(e//e0[0]), e0[0] + return int(r), e + + # Weed out downright impossible candidates + if logn/e < 40: + b = 2.0**(logn/e) + if abs(int(b + 0.5) - b) > 0.01: + continue + + # now see if the plausible e makes a perfect power + r, exact = iroot(n, e) + if exact: + if big: + m = perfect_power(r, big=big, factor=factor) + if m: + r, e = m[0], e*m[1] + return int(r), e + + return False + + +class FactorCache(MutableMapping): + """ Provides a cache for prime factors. + ``factor_cache`` is pre-prepared as an instance of ``FactorCache``, + and ``factorint`` internally references it to speed up + the factorization of prime factors. + + While cache is automatically added during the execution of ``factorint``, + users can also manually add prime factors independently. + + >>> from sympy import factor_cache + >>> factor_cache[15] = 5 + + Furthermore, by customizing ``get_external``, + it is also possible to use external databases. + The following is an example using http://factordb.com . + + .. code-block:: python + + import requests + from sympy import factor_cache + + def get_external(self, n: int) -> list[int] | None: + res = requests.get("http://factordb.com/api", params={"query": str(n)}) + if res.status_code != requests.codes.ok: + return None + j = res.json() + if j.get("status") in ["FF", "P"]: + return list(int(p) for p, _ in j.get("factors")) + + factor_cache.get_external = get_external + + Be aware that writing this code will trigger internet access + to factordb.com when calling ``factorint``. + + """ + def __init__(self, maxsize: int | None = None): + self._cache: OrderedDict[int, int] = OrderedDict() + self.maxsize = maxsize + + def __len__(self) -> int: + return len(self._cache) + + def __contains__(self, n) -> bool: + return n in self._cache + + def __getitem__(self, n: int) -> int: + factor = self.get(n) + if factor is None: + raise KeyError(f"{n} does not exist.") + return factor + + def __setitem__(self, n: int, factor: int): + if not (1 < factor <= n and n % factor == 0 and isprime(factor)): + raise ValueError(f"{factor} is not a prime factor of {n}") + self._cache[n] = max(self._cache.get(n, 0), factor) + if self.maxsize is not None and len(self._cache) > self.maxsize: + self._cache.popitem(False) + + def __delitem__(self, n: int): + if n not in self._cache: + raise KeyError(f"{n} does not exist.") + del self._cache[n] + + def __iter__(self): + return self._cache.__iter__() + + def cache_clear(self) -> None: + """ Clear the cache """ + self._cache = OrderedDict() + + @property + def maxsize(self) -> int | None: + """ Returns the maximum cache size; if ``None``, it is unlimited. """ + return self._maxsize + + @maxsize.setter + def maxsize(self, value: int | None) -> None: + if value is not None and value <= 0: + raise ValueError("maxsize must be None or a non-negative integer.") + self._maxsize = value + if value is not None: + while len(self._cache) > value: + self._cache.popitem(False) + + def get(self, n: int, default=None): + """ Return the prime factor of ``n``. + If it does not exist in the cache, return the value of ``default``. + """ + if n <= sieve._list[-1]: + if sieve._list[bisect_left(sieve._list, n)] == n: + return n + if n in self._cache: + self._cache.move_to_end(n) + return self._cache[n] + if factors := self.get_external(n): + self.add(n, factors) + return self._cache[n] + return default + + def add(self, n: int, factors: list[int]) -> None: + for p in sorted(factors, reverse=True): + self[n] = p + n, _ = remove(n, p) + + def get_external(self, n: int) -> list[int] | None: + return None + + +factor_cache = FactorCache(maxsize=1000) + + +def pollard_rho(n, s=2, a=1, retries=5, seed=1234, max_steps=None, F=None): + r""" + Use Pollard's rho method to try to extract a nontrivial factor + of ``n``. The returned factor may be a composite number. If no + factor is found, ``None`` is returned. + + The algorithm generates pseudo-random values of x with a generator + function, replacing x with F(x). If F is not supplied then the + function x**2 + ``a`` is used. The first value supplied to F(x) is ``s``. + Upon failure (if ``retries`` is > 0) a new ``a`` and ``s`` will be + supplied; the ``a`` will be ignored if F was supplied. + + The sequence of numbers generated by such functions generally have a + a lead-up to some number and then loop around back to that number and + begin to repeat the sequence, e.g. 1, 2, 3, 4, 5, 3, 4, 5 -- this leader + and loop look a bit like the Greek letter rho, and thus the name, 'rho'. + + For a given function, very different leader-loop values can be obtained + so it is a good idea to allow for retries: + + >>> from sympy.ntheory.generate import cycle_length + >>> n = 16843009 + >>> F = lambda x:(2048*pow(x, 2, n) + 32767) % n + >>> for s in range(5): + ... print('loop length = %4i; leader length = %3i' % next(cycle_length(F, s))) + ... + loop length = 2489; leader length = 43 + loop length = 78; leader length = 121 + loop length = 1482; leader length = 100 + loop length = 1482; leader length = 286 + loop length = 1482; leader length = 101 + + Here is an explicit example where there is a three element leadup to + a sequence of 3 numbers (11, 14, 4) that then repeat: + + >>> x=2 + >>> for i in range(9): + ... print(x) + ... x=(x**2+12)%17 + ... + 2 + 16 + 13 + 11 + 14 + 4 + 11 + 14 + 4 + >>> next(cycle_length(lambda x: (x**2+12)%17, 2)) + (3, 3) + >>> list(cycle_length(lambda x: (x**2+12)%17, 2, values=True)) + [2, 16, 13, 11, 14, 4] + + Instead of checking the differences of all generated values for a gcd + with n, only the kth and 2*kth numbers are checked, e.g. 1st and 2nd, + 2nd and 4th, 3rd and 6th until it has been detected that the loop has been + traversed. Loops may be many thousands of steps long before rho finds a + factor or reports failure. If ``max_steps`` is specified, the iteration + is cancelled with a failure after the specified number of steps. + + Examples + ======== + + >>> from sympy import pollard_rho + >>> n=16843009 + >>> F=lambda x:(2048*pow(x,2,n) + 32767) % n + >>> pollard_rho(n, F=F) + 257 + + Use the default setting with a bad value of ``a`` and no retries: + + >>> pollard_rho(n, a=n-2, retries=0) + + If retries is > 0 then perhaps the problem will correct itself when + new values are generated for a: + + >>> pollard_rho(n, a=n-2, retries=1) + 257 + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 229-231 + + """ + n = int(n) + if n < 5: + raise ValueError('pollard_rho should receive n > 4') + randint = _randint(seed + retries) + V = s + for i in range(retries + 1): + U = V + if not F: + F = lambda x: (pow(x, 2, n) + a) % n + j = 0 + while 1: + if max_steps and (j > max_steps): + break + j += 1 + U = F(U) + V = F(F(V)) # V is 2x further along than U + g = gcd(U - V, n) + if g == 1: + continue + if g == n: + break + return int(g) + V = randint(0, n - 1) + a = randint(1, n - 3) # for x**2 + a, a%n should not be 0 or -2 + F = None + return None + + +def pollard_pm1(n, B=10, a=2, retries=0, seed=1234): + """ + Use Pollard's p-1 method to try to extract a nontrivial factor + of ``n``. Either a divisor (perhaps composite) or ``None`` is returned. + + The value of ``a`` is the base that is used in the test gcd(a**M - 1, n). + The default is 2. If ``retries`` > 0 then if no factor is found after the + first attempt, a new ``a`` will be generated randomly (using the ``seed``) + and the process repeated. + + Note: the value of M is lcm(1..B) = reduce(ilcm, range(2, B + 1)). + + A search is made for factors next to even numbers having a power smoothness + less than ``B``. Choosing a larger B increases the likelihood of finding a + larger factor but takes longer. Whether a factor of n is found or not + depends on ``a`` and the power smoothness of the even number just less than + the factor p (hence the name p - 1). + + Although some discussion of what constitutes a good ``a`` some + descriptions are hard to interpret. At the modular.math site referenced + below it is stated that if gcd(a**M - 1, n) = N then a**M % q**r is 1 + for every prime power divisor of N. But consider the following: + + >>> from sympy.ntheory.factor_ import smoothness_p, pollard_pm1 + >>> n=257*1009 + >>> smoothness_p(n) + (-1, [(257, (1, 2, 256)), (1009, (1, 7, 16))]) + + So we should (and can) find a root with B=16: + + >>> pollard_pm1(n, B=16, a=3) + 1009 + + If we attempt to increase B to 256 we find that it does not work: + + >>> pollard_pm1(n, B=256) + >>> + + But if the value of ``a`` is changed we find that only multiples of + 257 work, e.g.: + + >>> pollard_pm1(n, B=256, a=257) + 1009 + + Checking different ``a`` values shows that all the ones that did not + work had a gcd value not equal to ``n`` but equal to one of the + factors: + + >>> from sympy import ilcm, igcd, factorint, Pow + >>> M = 1 + >>> for i in range(2, 256): + ... M = ilcm(M, i) + ... + >>> set([igcd(pow(a, M, n) - 1, n) for a in range(2, 256) if + ... igcd(pow(a, M, n) - 1, n) != n]) + {1009} + + But does aM % d for every divisor of n give 1? + + >>> aM = pow(255, M, n) + >>> [(d, aM%Pow(*d.args)) for d in factorint(n, visual=True).args] + [(257**1, 1), (1009**1, 1)] + + No, only one of them. So perhaps the principle is that a root will + be found for a given value of B provided that: + + 1) the power smoothness of the p - 1 value next to the root + does not exceed B + 2) a**M % p != 1 for any of the divisors of n. + + By trying more than one ``a`` it is possible that one of them + will yield a factor. + + Examples + ======== + + With the default smoothness bound, this number cannot be cracked: + + >>> from sympy.ntheory import pollard_pm1 + >>> pollard_pm1(21477639576571) + + Increasing the smoothness bound helps: + + >>> pollard_pm1(21477639576571, B=2000) + 4410317 + + Looking at the smoothness of the factors of this number we find: + + >>> from sympy.ntheory.factor_ import smoothness_p, factorint + >>> print(smoothness_p(21477639576571, visual=1)) + p**i=4410317**1 has p-1 B=1787, B-pow=1787 + p**i=4869863**1 has p-1 B=2434931, B-pow=2434931 + + The B and B-pow are the same for the p - 1 factorizations of the divisors + because those factorizations had a very large prime factor: + + >>> factorint(4410317 - 1) + {2: 2, 617: 1, 1787: 1} + >>> factorint(4869863-1) + {2: 1, 2434931: 1} + + Note that until B reaches the B-pow value of 1787, the number is not cracked; + + >>> pollard_pm1(21477639576571, B=1786) + >>> pollard_pm1(21477639576571, B=1787) + 4410317 + + The B value has to do with the factors of the number next to the divisor, + not the divisors themselves. A worst case scenario is that the number next + to the factor p has a large prime divisisor or is a perfect power. If these + conditions apply then the power-smoothness will be about p/2 or p. The more + realistic is that there will be a large prime factor next to p requiring + a B value on the order of p/2. Although primes may have been searched for + up to this level, the p/2 is a factor of p - 1, something that we do not + know. The modular.math reference below states that 15% of numbers in the + range of 10**15 to 15**15 + 10**4 are 10**6 power smooth so a B of 10**6 + will fail 85% of the time in that range. From 10**8 to 10**8 + 10**3 the + percentages are nearly reversed...but in that range the simple trial + division is quite fast. + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 236-238 + .. [2] https://web.archive.org/web/20150716201437/http://modular.math.washington.edu/edu/2007/spring/ent/ent-html/node81.html + .. [3] https://www.cs.toronto.edu/~yuvalf/Factorization.pdf + """ + + n = int(n) + if n < 4 or B < 3: + raise ValueError('pollard_pm1 should receive n > 3 and B > 2') + randint = _randint(seed + B) + + # computing a**lcm(1,2,3,..B) % n for B > 2 + # it looks weird, but it's right: primes run [2, B] + # and the answer's not right until the loop is done. + for i in range(retries + 1): + aM = a + for p in sieve.primerange(2, B + 1): + e = int(math.log(B, p)) + aM = pow(aM, pow(p, e), n) + g = gcd(aM - 1, n) + if 1 < g < n: + return int(g) + + # get a new a: + # since the exponent, lcm(1..B), is even, if we allow 'a' to be 'n-1' + # then (n - 1)**even % n will be 1 which will give a g of 0 and 1 will + # give a zero, too, so we set the range as [2, n-2]. Some references + # say 'a' should be coprime to n, but either will detect factors. + a = randint(2, n - 2) + + +def _trial(factors, n, candidates, verbose=False): + """ + Helper function for integer factorization. Trial factors ``n` + against all integers given in the sequence ``candidates`` + and updates the dict ``factors`` in-place. Returns the reduced + value of ``n`` and a flag indicating whether any factors were found. + """ + if verbose: + factors0 = list(factors.keys()) + nfactors = len(factors) + for d in candidates: + if n % d == 0: + if n != d: + factor_cache[n] = d + n, m = remove(n // d, d) + factors[d] = m + 1 + if verbose: + for k in sorted(set(factors).difference(set(factors0))): + print(factor_msg % (k, factors[k])) + return int(n), len(factors) != nfactors + + +def _check_termination(factors, n, limit, use_trial, use_rho, use_pm1, + verbose, next_p): + """ + Helper function for integer factorization. Checks if ``n`` + is a prime or a perfect power, and in those cases updates the factorization. + """ + if verbose: + print('Check for termination') + if n == 1: + if verbose: + print(complete_msg) + return True + if n < next_p**2 or isprime(n): + factor_cache[n] = n + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return True + + # since we've already been factoring there is no need to do + # simultaneous factoring with the power check + p = _perfect_power(n, next_p) + if not p: + return False + base, exp = p + if base < next_p**2 or isprime(base): + factor_cache[n] = base + factors[base] = exp + else: + facs = factorint(base, limit, use_trial, use_rho, use_pm1, + verbose=False) + for b, e in facs.items(): + if verbose: + print(factor_msg % (b, e)) + factors[b] = exp*e + if verbose: + print(complete_msg) + return True + + +trial_int_msg = "Trial division with ints [%i ... %i] and fail_max=%i" +trial_msg = "Trial division with primes [%i ... %i]" +rho_msg = "Pollard's rho with retries %i, max_steps %i and seed %i" +pm1_msg = "Pollard's p-1 with smoothness bound %i and seed %i" +ecm_msg = "Elliptic Curve with B1 bound %i, B2 bound %i, num_curves %i" +factor_msg = '\t%i ** %i' +fermat_msg = 'Close factors satisfying Fermat condition found.' +complete_msg = 'Factorization is complete.' + + +def _factorint_small(factors, n, limit, fail_max, next_p=2): + """ + Return the value of n and either a 0 (indicating that factorization up + to the limit was complete) or else the next near-prime that would have + been tested. + + Factoring stops if there are fail_max unsuccessful tests in a row. + + If factors of n were found they will be in the factors dictionary as + {factor: multiplicity} and the returned value of n will have had those + factors removed. The factors dictionary is modified in-place. + + """ + + def done(n, d): + """return n, d if the sqrt(n) was not reached yet, else + n, 0 indicating that factoring is done. + """ + if d*d <= n: + return n, d + return n, 0 + + limit2 = limit**2 + threshold2 = min(n, limit2) + + if next_p < 3: + if not n & 1: + m = bit_scan1(n) + factors[2] = m + n >>= m + threshold2 = min(n, limit2) + next_p = 3 + if threshold2 < 9: # next_p**2 = 9 + return done(n, next_p) + + if next_p < 5: + if not n % 3: + n //= 3 + m = 1 + while not n % 3: + n //= 3 + m += 1 + if m == 20: + n, mm = remove(n, 3) + m += mm + break + factors[3] = m + threshold2 = min(n, limit2) + next_p = 5 + if threshold2 < 25: # next_p**2 = 25 + return done(n, next_p) + + # Because of the order of checks, starting from `min_p = 6k+5`, + # useless checks are caused. + # We want to calculate + # next_p += [-1, -2, 3, 2, 1, 0][next_p % 6] + p6 = next_p % 6 + next_p += (-1 if p6 < 2 else 5) - p6 + + fails = 0 + while fails < fail_max: + # next_p % 6 == 5 + if n % next_p: + fails += 1 + else: + n //= next_p + m = 1 + while not n % next_p: + n //= next_p + m += 1 + if m == 20: + n, mm = remove(n, next_p) + m += mm + break + factors[next_p] = m + fails = 0 + threshold2 = min(n, limit2) + next_p += 2 + if threshold2 < next_p**2: + return done(n, next_p) + + # next_p % 6 == 1 + if n % next_p: + fails += 1 + else: + n //= next_p + m = 1 + while not n % next_p: + n //= next_p + m += 1 + if m == 20: + n, mm = remove(n, next_p) + m += mm + break + factors[next_p] = m + fails = 0 + threshold2 = min(n, limit2) + next_p += 4 + if threshold2 < next_p**2: + return done(n, next_p) + return done(n, next_p) + + +def factorint(n, limit=None, use_trial=True, use_rho=True, use_pm1=True, + use_ecm=True, verbose=False, visual=None, multiple=False): + r""" + Given a positive integer ``n``, ``factorint(n)`` returns a dict containing + the prime factors of ``n`` as keys and their respective multiplicities + as values. For example: + + >>> from sympy.ntheory import factorint + >>> factorint(2000) # 2000 = (2**4) * (5**3) + {2: 4, 5: 3} + >>> factorint(65537) # This number is prime + {65537: 1} + + For input less than 2, factorint behaves as follows: + + - ``factorint(1)`` returns the empty factorization, ``{}`` + - ``factorint(0)`` returns ``{0:1}`` + - ``factorint(-n)`` adds ``-1:1`` to the factors and then factors ``n`` + + Partial Factorization: + + If ``limit`` (> 3) is specified, the search is stopped after performing + trial division up to (and including) the limit (or taking a + corresponding number of rho/p-1 steps). This is useful if one has + a large number and only is interested in finding small factors (if + any). Note that setting a limit does not prevent larger factors + from being found early; it simply means that the largest factor may + be composite. Since checking for perfect power is relatively cheap, it is + done regardless of the limit setting. + + This number, for example, has two small factors and a huge + semi-prime factor that cannot be reduced easily: + + >>> from sympy.ntheory import isprime + >>> a = 1407633717262338957430697921446883 + >>> f = factorint(a, limit=10000) + >>> f == {991: 1, int(202916782076162456022877024859): 1, 7: 1} + True + >>> isprime(max(f)) + False + + This number has a small factor and a residual perfect power whose + base is greater than the limit: + + >>> factorint(3*101**7, limit=5) + {3: 1, 101: 7} + + List of Factors: + + If ``multiple`` is set to ``True`` then a list containing the + prime factors including multiplicities is returned. + + >>> factorint(24, multiple=True) + [2, 2, 2, 3] + + Visual Factorization: + + If ``visual`` is set to ``True``, then it will return a visual + factorization of the integer. For example: + + >>> from sympy import pprint + >>> pprint(factorint(4200, visual=True)) + 3 1 2 1 + 2 *3 *5 *7 + + Note that this is achieved by using the evaluate=False flag in Mul + and Pow. If you do other manipulations with an expression where + evaluate=False, it may evaluate. Therefore, you should use the + visual option only for visualization, and use the normal dictionary + returned by visual=False if you want to perform operations on the + factors. + + You can easily switch between the two forms by sending them back to + factorint: + + >>> from sympy import Mul + >>> regular = factorint(1764); regular + {2: 2, 3: 2, 7: 2} + >>> pprint(factorint(regular)) + 2 2 2 + 2 *3 *7 + + >>> visual = factorint(1764, visual=True); pprint(visual) + 2 2 2 + 2 *3 *7 + >>> print(factorint(visual)) + {2: 2, 3: 2, 7: 2} + + If you want to send a number to be factored in a partially factored form + you can do so with a dictionary or unevaluated expression: + + >>> factorint(factorint({4: 2, 12: 3})) # twice to toggle to dict form + {2: 10, 3: 3} + >>> factorint(Mul(4, 12, evaluate=False)) + {2: 4, 3: 1} + + The table of the output logic is: + + ====== ====== ======= ======= + Visual + ------ ---------------------- + Input True False other + ====== ====== ======= ======= + dict mul dict mul + n mul dict dict + mul mul dict dict + ====== ====== ======= ======= + + Notes + ===== + + Algorithm: + + The function switches between multiple algorithms. Trial division + quickly finds small factors (of the order 1-5 digits), and finds + all large factors if given enough time. The Pollard rho and p-1 + algorithms are used to find large factors ahead of time; they + will often find factors of the order of 10 digits within a few + seconds: + + >>> factors = factorint(12345678910111213141516) + >>> for base, exp in sorted(factors.items()): + ... print('%s %s' % (base, exp)) + ... + 2 2 + 2507191691 1 + 1231026625769 1 + + Any of these methods can optionally be disabled with the following + boolean parameters: + + - ``use_trial``: Toggle use of trial division + - ``use_rho``: Toggle use of Pollard's rho method + - ``use_pm1``: Toggle use of Pollard's p-1 method + + ``factorint`` also periodically checks if the remaining part is + a prime number or a perfect power, and in those cases stops. + + For unevaluated factorial, it uses Legendre's formula(theorem). + + + If ``verbose`` is set to ``True``, detailed progress is printed. + + See Also + ======== + + smoothness, smoothness_p, divisors + + """ + if isinstance(n, Dict): + n = dict(n) + if multiple: + fac = factorint(n, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False, multiple=False) + factorlist = sum(([p] * fac[p] if fac[p] > 0 else [S.One/p]*(-fac[p]) + for p in sorted(fac)), []) + return factorlist + + factordict = {} + if visual and not isinstance(n, (Mul, dict)): + factordict = factorint(n, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False) + elif isinstance(n, Mul): + factordict = {int(k): int(v) for k, v in + n.as_powers_dict().items()} + elif isinstance(n, dict): + factordict = n + if factordict and isinstance(n, (Mul, dict)): + # check it + for key in list(factordict.keys()): + if isprime(key): + continue + e = factordict.pop(key) + d = factorint(key, limit=limit, use_trial=use_trial, use_rho=use_rho, + use_pm1=use_pm1, verbose=verbose, visual=False) + for k, v in d.items(): + if k in factordict: + factordict[k] += v*e + else: + factordict[k] = v*e + if visual or (type(n) is dict and + visual is not True and + visual is not False): + if factordict == {}: + return S.One + if -1 in factordict: + factordict.pop(-1) + args = [S.NegativeOne] + else: + args = [] + args.extend([Pow(*i, evaluate=False) + for i in sorted(factordict.items())]) + return Mul(*args, evaluate=False) + elif isinstance(n, (dict, Mul)): + return factordict + + assert use_trial or use_rho or use_pm1 or use_ecm + + from sympy.functions.combinatorial.factorials import factorial + if isinstance(n, factorial): + x = as_int(n.args[0]) + if x >= 20: + factors = {} + m = 2 # to initialize the if condition below + for p in sieve.primerange(2, x + 1): + if m > 1: + m, q = 0, x // p + while q != 0: + m += q + q //= p + factors[p] = m + if factors and verbose: + for k in sorted(factors): + print(factor_msg % (k, factors[k])) + if verbose: + print(complete_msg) + return factors + else: + # if n < 20!, direct computation is faster + # since it uses a lookup table + n = n.func(x) + + n = as_int(n) + if limit: + limit = int(limit) + use_ecm = False + + # special cases + if n < 0: + factors = factorint( + -n, limit=limit, use_trial=use_trial, use_rho=use_rho, + use_pm1=use_pm1, verbose=verbose, visual=False) + factors[-1] = 1 + return factors + + if limit and limit < 2: + if n == 1: + return {} + return {n: 1} + elif n < 10: + # doing this we are assured of getting a limit > 2 + # when we have to compute it later + return [{0: 1}, {}, {2: 1}, {3: 1}, {2: 2}, {5: 1}, + {2: 1, 3: 1}, {7: 1}, {2: 3}, {3: 2}][n] + + factors = {} + + # do simplistic factorization + if verbose: + sn = str(n) + if len(sn) > 50: + print('Factoring %s' % sn[:5] + \ + '..(%i other digits)..' % (len(sn) - 10) + sn[-5:]) + else: + print('Factoring', n) + + # this is the preliminary factorization for small factors + # We want to guarantee that there are no small prime factors, + # so we run even if `use_trial` is False. + small = 2**15 + fail_max = 600 + small = min(small, limit or small) + if verbose: + print(trial_int_msg % (2, small, fail_max)) + n, next_p = _factorint_small(factors, n, small, fail_max) + if factors and verbose: + for k in sorted(factors): + print(factor_msg % (k, factors[k])) + if next_p == 0: + if n > 1: + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return factors + # Check if it exists in the cache + while p := factor_cache.get(n): + n, e = remove(n, p) + factors[int(p)] = int(e) + # first check if the simplistic run didn't finish + # because of the limit and check for a perfect + # power before exiting + if limit and next_p > limit: + if verbose: + print('Exceeded limit:', limit) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + if n > 1: + factors[int(n)] = 1 + return factors + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + + # continue with more advanced factorization methods + # ...do a Fermat test since it's so easy and we need the + # square root anyway. Finding 2 factors is easy if they are + # "close enough." This is the big root equivalent of dividing by + # 2, 3, 5. + sqrt_n = isqrt(n) + a = sqrt_n + 1 + # If `n % 4 == 1`, `a` must be odd for `a**2 - n` to be a square number. + if (n % 4 == 1) ^ (a & 1): + a += 1 + a2 = a**2 + b2 = a2 - n + for _ in range(3): + b, fermat = sqrtrem(b2) + if not fermat: + if verbose: + print(fermat_msg) + for r in [a - b, a + b]: + facs = factorint(r, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose) + for k, v in facs.items(): + factors[k] = factors.get(k, 0) + v + factor_cache.add(n, facs) + if verbose: + print(complete_msg) + return factors + b2 += (a + 1) << 2 # equiv to (a + 2)**2 - n + a += 2 + + # these are the limits for trial division which will + # be attempted in parallel with pollard methods + low, high = next_p, 2*next_p + + # add 1 to make sure limit is reached in primerange calls + _limit = (limit or sqrt_n) + 1 + iteration = 0 + while 1: + high_ = min(high, _limit) + + # Trial division + if use_trial: + if verbose: + print(trial_msg % (low, high_)) + ps = sieve.primerange(low, high_) + n, found_trial = _trial(factors, n, ps, verbose) + next_p = high_ + if found_trial and _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + else: + found_trial = False + + if high > _limit: + if verbose: + print('Exceeded limit:', _limit) + if n > 1: + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return factors + + # Only used advanced methods when no small factors were found + if not found_trial: + # Pollard p-1 + if use_pm1: + if verbose: + print(pm1_msg % (low, high_)) + c = pollard_pm1(n, B=low, seed=high_) + if c: + if c < next_p**2 or isprime(c): + ps = [c] + else: + ps = factorint(c, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + + # Pollard rho + if use_rho: + if verbose: + print(rho_msg % (1, low, high_)) + c = pollard_rho(n, retries=1, max_steps=low, seed=high_) + if c: + if c < next_p**2 or isprime(c): + ps = [c] + else: + ps = factorint(c, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + # Use subexponential algorithms if use_ecm + # Use pollard algorithms for finding small factors for 3 iterations + # if after small factors the number of digits of n >= 25 then use ecm + iteration += 1 + if use_ecm and iteration >= 3 and num_digits(n) >= 24: + break + low, high = high, high*2 + + B1 = 10000 + B2 = 100*B1 + num_curves = 50 + while(1): + if verbose: + print(ecm_msg % (B1, B2, num_curves)) + factor = _ecm_one_factor(n, B1, B2, num_curves, seed=B1) + if factor: + if factor < next_p**2 or isprime(factor): + ps = [factor] + else: + ps = factorint(factor, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + B1 *= 5 + B2 = 100*B1 + num_curves *= 4 + + +def factorrat(rat, limit=None, use_trial=True, use_rho=True, use_pm1=True, + verbose=False, visual=None, multiple=False): + r""" + Given a Rational ``r``, ``factorrat(r)`` returns a dict containing + the prime factors of ``r`` as keys and their respective multiplicities + as values. For example: + + >>> from sympy import factorrat, S + >>> factorrat(S(8)/9) # 8/9 = (2**3) * (3**-2) + {2: 3, 3: -2} + >>> factorrat(S(-1)/987) # -1/789 = -1 * (3**-1) * (7**-1) * (47**-1) + {-1: 1, 3: -1, 7: -1, 47: -1} + + Please see the docstring for ``factorint`` for detailed explanations + and examples of the following keywords: + + - ``limit``: Integer limit up to which trial division is done + - ``use_trial``: Toggle use of trial division + - ``use_rho``: Toggle use of Pollard's rho method + - ``use_pm1``: Toggle use of Pollard's p-1 method + - ``verbose``: Toggle detailed printing of progress + - ``multiple``: Toggle returning a list of factors or dict + - ``visual``: Toggle product form of output + """ + if multiple: + fac = factorrat(rat, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False, multiple=False) + factorlist = sum(([p] * fac[p] if fac[p] > 0 else [S.One/p]*(-fac[p]) + for p, _ in sorted(fac.items(), + key=lambda elem: elem[0] + if elem[1] > 0 + else 1/elem[0])), []) + return factorlist + + f = factorint(rat.p, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose).copy() + f = defaultdict(int, f) + for p, e in factorint(rat.q, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + verbose=verbose).items(): + f[p] += -e + + if len(f) > 1 and 1 in f: + del f[1] + if not visual: + return dict(f) + else: + if -1 in f: + f.pop(-1) + args = [S.NegativeOne] + else: + args = [] + args.extend([Pow(*i, evaluate=False) + for i in sorted(f.items())]) + return Mul(*args, evaluate=False) + + +def primefactors(n, limit=None, verbose=False, **kwargs): + """Return a sorted list of n's prime factors, ignoring multiplicity + and any composite factor that remains if the limit was set too low + for complete factorization. Unlike factorint(), primefactors() does + not return -1 or 0. + + Parameters + ========== + + n : integer + limit, verbose, **kwargs : + Additional keyword arguments to be passed to ``factorint``. + Since ``kwargs`` is new in version 1.13, + ``limit`` and ``verbose`` are retained for compatibility purposes. + + Returns + ======= + + list(int) : List of prime numbers dividing ``n`` + + Examples + ======== + + >>> from sympy.ntheory import primefactors, factorint, isprime + >>> primefactors(6) + [2, 3] + >>> primefactors(-5) + [5] + + >>> sorted(factorint(123456).items()) + [(2, 6), (3, 1), (643, 1)] + >>> primefactors(123456) + [2, 3, 643] + + >>> sorted(factorint(10000000001, limit=200).items()) + [(101, 1), (99009901, 1)] + >>> isprime(99009901) + False + >>> primefactors(10000000001, limit=300) + [101] + + See Also + ======== + + factorint, divisors + + """ + n = int(n) + kwargs.update({"visual": None, "multiple": False, + "limit": limit, "verbose": verbose}) + factors = sorted(factorint(n=n, **kwargs).keys()) + # We want to calculate + # s = [f for f in factors if isprime(f)] + s = [f for f in factors[:-1:] if f not in [-1, 0, 1]] + if factors and isprime(factors[-1]): + s += [factors[-1]] + return s + + +def _divisors(n, proper=False): + """Helper function for divisors which generates the divisors. + + Parameters + ========== + + n : int + a nonnegative integer + proper: bool + If `True`, returns the generator that outputs only the proper divisor (i.e., excluding n). + + """ + if n <= 1: + if not proper and n: + yield 1 + return + + factordict = factorint(n) + ps = sorted(factordict.keys()) + + def rec_gen(n=0): + if n == len(ps): + yield 1 + else: + pows = [1] + for _ in range(factordict[ps[n]]): + pows.append(pows[-1] * ps[n]) + yield from (p * q for q in rec_gen(n + 1) for p in pows) + + if proper: + yield from (p for p in rec_gen() if p != n) + else: + yield from rec_gen() + + +def divisors(n, generator=False, proper=False): + r""" + Return all divisors of n sorted from 1..n by default. + If generator is ``True`` an unordered generator is returned. + + The number of divisors of n can be quite large if there are many + prime factors (counting repeated factors). If only the number of + factors is desired use divisor_count(n). + + Examples + ======== + + >>> from sympy import divisors, divisor_count + >>> divisors(24) + [1, 2, 3, 4, 6, 8, 12, 24] + >>> divisor_count(24) + 8 + + >>> list(divisors(120, generator=True)) + [1, 2, 4, 8, 3, 6, 12, 24, 5, 10, 20, 40, 15, 30, 60, 120] + + Notes + ===== + + This is a slightly modified version of Tim Peters referenced at: + https://stackoverflow.com/questions/1010381/python-factorization + + See Also + ======== + + primefactors, factorint, divisor_count + """ + rv = _divisors(as_int(abs(n)), proper) + return rv if generator else sorted(rv) + + +def divisor_count(n, modulus=1, proper=False): + """ + Return the number of divisors of ``n``. If ``modulus`` is not 1 then only + those that are divisible by ``modulus`` are counted. If ``proper`` is True + then the divisor of ``n`` will not be counted. + + Examples + ======== + + >>> from sympy import divisor_count + >>> divisor_count(6) + 4 + >>> divisor_count(6, 2) + 2 + >>> divisor_count(6, proper=True) + 3 + + See Also + ======== + + factorint, divisors, totient, proper_divisor_count + + """ + + if not modulus: + return 0 + elif modulus != 1: + n, r = divmod(n, modulus) + if r: + return 0 + if n == 0: + return 0 + n = Mul(*[v + 1 for k, v in factorint(n).items() if k > 1]) + if n and proper: + n -= 1 + return n + + +def proper_divisors(n, generator=False): + """ + Return all divisors of n except n, sorted by default. + If generator is ``True`` an unordered generator is returned. + + Examples + ======== + + >>> from sympy import proper_divisors, proper_divisor_count + >>> proper_divisors(24) + [1, 2, 3, 4, 6, 8, 12] + >>> proper_divisor_count(24) + 7 + >>> list(proper_divisors(120, generator=True)) + [1, 2, 4, 8, 3, 6, 12, 24, 5, 10, 20, 40, 15, 30, 60] + + See Also + ======== + + factorint, divisors, proper_divisor_count + + """ + return divisors(n, generator=generator, proper=True) + + +def proper_divisor_count(n, modulus=1): + """ + Return the number of proper divisors of ``n``. + + Examples + ======== + + >>> from sympy import proper_divisor_count + >>> proper_divisor_count(6) + 3 + >>> proper_divisor_count(6, modulus=2) + 1 + + See Also + ======== + + divisors, proper_divisors, divisor_count + + """ + return divisor_count(n, modulus=modulus, proper=True) + + +def _udivisors(n): + """Helper function for udivisors which generates the unitary divisors. + + Parameters + ========== + + n : int + a nonnegative integer + + """ + if n <= 1: + if n == 1: + yield 1 + return + + factorpows = [p**e for p, e in factorint(n).items()] + # We want to calculate + # yield from (math.prod(s) for s in powersets(factorpows)) + for i in range(2**len(factorpows)): + d = 1 + for k in range(i.bit_length()): + if i & 1: + d *= factorpows[k] + i >>= 1 + yield d + + +def udivisors(n, generator=False): + r""" + Return all unitary divisors of n sorted from 1..n by default. + If generator is ``True`` an unordered generator is returned. + + The number of unitary divisors of n can be quite large if there are many + prime factors. If only the number of unitary divisors is desired use + udivisor_count(n). + + Examples + ======== + + >>> from sympy.ntheory.factor_ import udivisors, udivisor_count + >>> udivisors(15) + [1, 3, 5, 15] + >>> udivisor_count(15) + 4 + + >>> sorted(udivisors(120, generator=True)) + [1, 3, 5, 8, 15, 24, 40, 120] + + See Also + ======== + + primefactors, factorint, divisors, divisor_count, udivisor_count + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Unitary_divisor + .. [2] https://mathworld.wolfram.com/UnitaryDivisor.html + + """ + rv = _udivisors(as_int(abs(n))) + return rv if generator else sorted(rv) + + +def udivisor_count(n): + """ + Return the number of unitary divisors of ``n``. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.ntheory.factor_ import udivisor_count + >>> udivisor_count(120) + 8 + + See Also + ======== + + factorint, divisors, udivisors, divisor_count, totient + + References + ========== + + .. [1] https://mathworld.wolfram.com/UnitaryDivisorFunction.html + + """ + + if n == 0: + return 0 + return 2**len([p for p in factorint(n) if p > 1]) + + +def _antidivisors(n): + """Helper function for antidivisors which generates the antidivisors. + + Parameters + ========== + + n : int + a nonnegative integer + + """ + if n <= 2: + return + for d in _divisors(n): + y = 2*d + if n > y and n % y: + yield y + for d in _divisors(2*n-1): + if n > d >= 2 and n % d: + yield d + for d in _divisors(2*n+1): + if n > d >= 2 and n % d: + yield d + + +def antidivisors(n, generator=False): + r""" + Return all antidivisors of n sorted from 1..n by default. + + Antidivisors [1]_ of n are numbers that do not divide n by the largest + possible margin. If generator is True an unordered generator is returned. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import antidivisors + >>> antidivisors(24) + [7, 16] + + >>> sorted(antidivisors(128, generator=True)) + [3, 5, 15, 17, 51, 85] + + See Also + ======== + + primefactors, factorint, divisors, divisor_count, antidivisor_count + + References + ========== + + .. [1] definition is described in https://oeis.org/A066272/a066272a.html + + """ + rv = _antidivisors(as_int(abs(n))) + return rv if generator else sorted(rv) + + +def antidivisor_count(n): + """ + Return the number of antidivisors [1]_ of ``n``. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.ntheory.factor_ import antidivisor_count + >>> antidivisor_count(13) + 4 + >>> antidivisor_count(27) + 5 + + See Also + ======== + + factorint, divisors, antidivisors, divisor_count, totient + + References + ========== + + .. [1] formula from https://oeis.org/A066272 + + """ + + n = as_int(abs(n)) + if n <= 2: + return 0 + return divisor_count(2*n - 1) + divisor_count(2*n + 1) + \ + divisor_count(n) - divisor_count(n, 2) - 5 + +@deprecated("""\ +The `sympy.ntheory.factor_.totient` has been moved to `sympy.functions.combinatorial.numbers.totient`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def totient(n): + r""" + Calculate the Euler totient function phi(n) + + .. deprecated:: 1.13 + + The ``totient`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.totient` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``totient(n)`` or `\phi(n)` is the number of positive integers `\leq` n + that are relatively prime to n. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import totient + >>> totient(1) + 1 + >>> totient(25) + 20 + >>> totient(45) == totient(5)*totient(9) + True + + See Also + ======== + + divisor_count + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%27s_totient_function + .. [2] https://mathworld.wolfram.com/TotientFunction.html + + """ + from sympy.functions.combinatorial.numbers import totient as _totient + return _totient(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.reduced_totient` has been moved to `sympy.functions.combinatorial.numbers.reduced_totient`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def reduced_totient(n): + r""" + Calculate the Carmichael reduced totient function lambda(n) + + .. deprecated:: 1.13 + + The ``reduced_totient`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.reduced_totient` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``reduced_totient(n)`` or `\lambda(n)` is the smallest m > 0 such that + `k^m \equiv 1 \mod n` for all k relatively prime to n. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import reduced_totient + >>> reduced_totient(1) + 1 + >>> reduced_totient(8) + 2 + >>> reduced_totient(30) + 4 + + See Also + ======== + + totient + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_function + .. [2] https://mathworld.wolfram.com/CarmichaelFunction.html + + """ + from sympy.functions.combinatorial.numbers import reduced_totient as _reduced_totient + return _reduced_totient(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.divisor_sigma` has been moved to `sympy.functions.combinatorial.numbers.divisor_sigma`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def divisor_sigma(n, k=1): + r""" + Calculate the divisor function `\sigma_k(n)` for positive integer n + + .. deprecated:: 1.13 + + The ``divisor_sigma`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.divisor_sigma` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``divisor_sigma(n, k)`` is equal to ``sum([x**k for x in divisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k(n) = \prod_{i=1}^\omega (1+p_i^k+p_i^{2k}+\cdots + + p_i^{m_ik}). + + Parameters + ========== + + n : integer + + k : integer, optional + power of divisors in the sum + + for k = 0, 1: + ``divisor_sigma(n, 0)`` is equal to ``divisor_count(n)`` + ``divisor_sigma(n, 1)`` is equal to ``sum(divisors(n))`` + + Default for k is 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> divisor_sigma(18, 0) + 6 + >>> divisor_sigma(39, 1) + 56 + >>> divisor_sigma(12, 2) + 210 + >>> divisor_sigma(37) + 38 + + See Also + ======== + + divisor_count, totient, divisors, factorint + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Divisor_function + + """ + from sympy.functions.combinatorial.numbers import divisor_sigma as func_divisor_sigma + return func_divisor_sigma(n, k) + + +def _divisor_sigma(n:int, k:int=1) -> int: + r""" Calculate the divisor function `\sigma_k(n)` for positive integer n + + Parameters + ========== + + n : int + positive integer + k : int + nonnegative integer + + See Also + ======== + + sympy.functions.combinatorial.numbers.divisor_sigma + + """ + if k == 0: + return math.prod(e + 1 for e in factorint(n).values()) + return math.prod((p**(k*(e + 1)) - 1)//(p**k - 1) for p, e in factorint(n).items()) + + +def core(n, t=2): + r""" + Calculate core(n, t) = `core_t(n)` of a positive integer n + + ``core_2(n)`` is equal to the squarefree part of n + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + core_t(n) = \prod_{i=1}^\omega p_i^{m_i \mod t}. + + Parameters + ========== + + n : integer + + t : integer + core(n, t) calculates the t-th power free part of n + + ``core(n, 2)`` is the squarefree part of ``n`` + ``core(n, 3)`` is the cubefree part of ``n`` + + Default for t is 2. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import core + >>> core(24, 2) + 6 + >>> core(9424, 3) + 1178 + >>> core(379238) + 379238 + >>> core(15**11, 10) + 15 + + See Also + ======== + + factorint, sympy.solvers.diophantine.diophantine.square_factor + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square-free_integer#Squarefree_core + + """ + + n = as_int(n) + t = as_int(t) + if n <= 0: + raise ValueError("n must be a positive integer") + elif t <= 1: + raise ValueError("t must be >= 2") + else: + y = 1 + for p, e in factorint(n).items(): + y *= p**(e % t) + return y + + +@deprecated("""\ +The `sympy.ntheory.factor_.udivisor_sigma` has been moved to `sympy.functions.combinatorial.numbers.udivisor_sigma`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def udivisor_sigma(n, k=1): + r""" + Calculate the unitary divisor function `\sigma_k^*(n)` for positive integer n + + .. deprecated:: 1.13 + + The ``udivisor_sigma`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.udivisor_sigma` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``udivisor_sigma(n, k)`` is equal to ``sum([x**k for x in udivisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k^*(n) = \prod_{i=1}^\omega (1+ p_i^{m_ik}). + + Parameters + ========== + + k : power of divisors in the sum + + for k = 0, 1: + ``udivisor_sigma(n, 0)`` is equal to ``udivisor_count(n)`` + ``udivisor_sigma(n, 1)`` is equal to ``sum(udivisors(n))`` + + Default for k is 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import udivisor_sigma + >>> udivisor_sigma(18, 0) + 4 + >>> udivisor_sigma(74, 1) + 114 + >>> udivisor_sigma(36, 3) + 47450 + >>> udivisor_sigma(111) + 152 + + See Also + ======== + + divisor_count, totient, divisors, udivisors, udivisor_count, divisor_sigma, + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/UnitaryDivisorFunction.html + + """ + from sympy.functions.combinatorial.numbers import udivisor_sigma as _udivisor_sigma + return _udivisor_sigma(n, k) + + +@deprecated("""\ +The `sympy.ntheory.factor_.primenu` has been moved to `sympy.functions.combinatorial.numbers.primenu`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primenu(n): + r""" + Calculate the number of distinct prime factors for a positive integer n. + + .. deprecated:: 1.13 + + The ``primenu`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primenu` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primenu(n)`` or `\nu(n)` is: + + .. math :: + \nu(n) = k. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primenu + >>> primenu(1) + 0 + >>> primenu(30) + 3 + + See Also + ======== + + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + + """ + from sympy.functions.combinatorial.numbers import primenu as _primenu + return _primenu(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.primeomega` has been moved to `sympy.functions.combinatorial.numbers.primeomega`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primeomega(n): + r""" + Calculate the number of prime factors counting multiplicities for a + positive integer n. + + .. deprecated:: 1.13 + + The ``primeomega`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primeomega` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primeomega(n)`` or `\Omega(n)` is: + + .. math :: + \Omega(n) = \sum_{i=1}^k m_i. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primeomega + >>> primeomega(1) + 0 + >>> primeomega(20) + 3 + + See Also + ======== + + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + + """ + from sympy.functions.combinatorial.numbers import primeomega as _primeomega + return _primeomega(n) + + +def mersenne_prime_exponent(nth): + """Returns the exponent ``i`` for the nth Mersenne prime (which + has the form `2^i - 1`). + + Examples + ======== + + >>> from sympy.ntheory.factor_ import mersenne_prime_exponent + >>> mersenne_prime_exponent(1) + 2 + >>> mersenne_prime_exponent(20) + 4423 + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; mersenne_prime_exponent(1) == 2") + if n > 51: + raise ValueError("There are only 51 perfect numbers; nth must be less than or equal to 51") + return MERSENNE_PRIME_EXPONENTS[n - 1] + + +def is_perfect(n): + """Returns True if ``n`` is a perfect number, else False. + + A perfect number is equal to the sum of its positive, proper divisors. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> from sympy.ntheory.factor_ import is_perfect, divisors + >>> is_perfect(20) + False + >>> is_perfect(6) + True + >>> 6 == divisor_sigma(6) - 6 == sum(divisors(6)[:-1]) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/PerfectNumber.html + .. [2] https://en.wikipedia.org/wiki/Perfect_number + + """ + n = as_int(n) + if n < 1: + return False + if n % 2 == 0: + m = (n.bit_length() + 1) >> 1 + if (1 << (m - 1)) * ((1 << m) - 1) != n: + # Even perfect numbers must be of the form `2^{m-1}(2^m-1)` + return False + return m in MERSENNE_PRIME_EXPONENTS or is_mersenne_prime(2**m - 1) + + # n is an odd integer + if n < 10**2000: # https://www.lirmm.fr/~ochem/opn/ + return False + if n % 105 == 0: # not divis by 105 + return False + if all(n % m != r for m, r in [(12, 1), (468, 117), (324, 81)]): + return False + # there are many criteria that the factor structure of n + # must meet; since we will have to factor it to test the + # structure we will have the factors and can then check + # to see whether it is a perfect number or not. So we + # skip the structure checks and go straight to the final + # test below. + result = abundance(n) == 0 + if result: + raise ValueError(filldedent('''In 1888, Sylvester stated: " + ...a prolonged meditation on the subject has satisfied + me that the existence of any one such [odd perfect number] + -- its escape, so to say, from the complex web of conditions + which hem it in on all sides -- would be little short of a + miracle." I guess SymPy just found that miracle and it + factors like this: %s''' % factorint(n))) + return result + + +def abundance(n): + """Returns the difference between the sum of the positive + proper divisors of a number and the number. + + Examples + ======== + + >>> from sympy.ntheory import abundance, is_perfect, is_abundant + >>> abundance(6) + 0 + >>> is_perfect(6) + True + >>> abundance(10) + -2 + >>> is_abundant(10) + False + """ + return _divisor_sigma(n) - 2 * n + + +def is_abundant(n): + """Returns True if ``n`` is an abundant number, else False. + + A abundant number is smaller than the sum of its positive proper divisors. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_abundant + >>> is_abundant(20) + True + >>> is_abundant(15) + False + + References + ========== + + .. [1] https://mathworld.wolfram.com/AbundantNumber.html + + """ + n = as_int(n) + if is_perfect(n): + return False + return n % 6 == 0 or bool(abundance(n) > 0) + + +def is_deficient(n): + """Returns True if ``n`` is a deficient number, else False. + + A deficient number is greater than the sum of its positive proper divisors. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_deficient + >>> is_deficient(20) + False + >>> is_deficient(15) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/DeficientNumber.html + + """ + n = as_int(n) + if is_perfect(n): + return False + return bool(abundance(n) < 0) + + +def is_amicable(m, n): + """Returns True if the numbers `m` and `n` are "amicable", else False. + + Amicable numbers are two different numbers so related that the sum + of the proper divisors of each is equal to that of the other. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> from sympy.ntheory.factor_ import is_amicable + >>> is_amicable(220, 284) + True + >>> divisor_sigma(220) == divisor_sigma(284) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Amicable_numbers + + """ + return m != n and m + n == _divisor_sigma(m) == _divisor_sigma(n) + + +def is_carmichael(n): + """ Returns True if the numbers `n` is Carmichael number, else False. + + Parameters + ========== + + n : Integer + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_number + .. [2] https://oeis.org/A002997 + + """ + if n < 561: + return False + return n % 2 and not isprime(n) and \ + all(e == 1 and (n - 1) % (p - 1) == 0 for p, e in factorint(n).items()) + + +def find_carmichael_numbers_in_range(x, y): + """ Returns a list of the number of Carmichael in the range + + See Also + ======== + + is_carmichael + + """ + if 0 <= x <= y: + if x % 2 == 0: + return [i for i in range(x + 1, y, 2) if is_carmichael(i)] + else: + return [i for i in range(x, y, 2) if is_carmichael(i)] + else: + raise ValueError('The provided range is not valid. x and y must be non-negative integers and x <= y') + + +def find_first_n_carmichaels(n): + """ Returns the first n Carmichael numbers. + + Parameters + ========== + + n : Integer + + See Also + ======== + + is_carmichael + + """ + i = 561 + carmichaels = [] + + while len(carmichaels) < n: + if is_carmichael(i): + carmichaels.append(i) + i += 2 + + return carmichaels + + +def dra(n, b): + """ + Returns the additive digital root of a natural number ``n`` in base ``b`` + which is a single digit value obtained by an iterative process of summing + digits, on each iteration using the result from the previous iteration to + compute a digit sum. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import dra + >>> dra(3110, 12) + 8 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Digital_root + + """ + + num = abs(as_int(n)) + b = as_int(b) + if b <= 1: + raise ValueError("Base should be an integer greater than 1") + + if num == 0: + return 0 + + return (1 + (num - 1) % (b - 1)) + + +def drm(n, b): + """ + Returns the multiplicative digital root of a natural number ``n`` in a given + base ``b`` which is a single digit value obtained by an iterative process of + multiplying digits, on each iteration using the result from the previous + iteration to compute the digit multiplication. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import drm + >>> drm(9876, 10) + 0 + + >>> drm(49, 10) + 8 + + References + ========== + + .. [1] https://mathworld.wolfram.com/MultiplicativeDigitalRoot.html + + """ + + n = abs(as_int(n)) + b = as_int(b) + if b <= 1: + raise ValueError("Base should be an integer greater than 1") + while n > b: + mul = 1 + while n > 1: + n, r = divmod(n, b) + if r == 0: + return 0 + mul *= r + n = mul + return n diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/generate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..855bb44acfcb6241e6b0bcb81e7a2cfc8ced861f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/generate.py @@ -0,0 +1,1157 @@ +""" +Generating and counting primes. + +""" + +from bisect import bisect, bisect_left +from itertools import count +# Using arrays for sieving instead of lists greatly reduces +# memory consumption +from array import array as _array + +from sympy.core.random import randint +from sympy.external.gmpy import sqrt +from .primetest import isprime +from sympy.utilities.decorator import deprecated +from sympy.utilities.misc import as_int + + +def _as_int_ceiling(a): + """ Wrapping ceiling in as_int will raise an error if there was a problem + determining whether the expression was exactly an integer or not.""" + from sympy.functions.elementary.integers import ceiling + return as_int(ceiling(a)) + + +class Sieve: + """A list of prime numbers, implemented as a dynamically + growing sieve of Eratosthenes. When a lookup is requested involving + an odd number that has not been sieved, the sieve is automatically + extended up to that number. Implementation details limit the number of + primes to ``2^32-1``. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> 25 in sieve + False + >>> sieve._list + array('L', [2, 3, 5, 7, 11, 13, 17, 19, 23]) + """ + + # data shared (and updated) by all Sieve instances + def __init__(self, sieve_interval=1_000_000): + """ Initial parameters for the Sieve class. + + Parameters + ========== + + sieve_interval (int): Amount of memory to be used + + Raises + ====== + + ValueError + If ``sieve_interval`` is not positive. + + """ + self._n = 6 + self._list = _array('L', [2, 3, 5, 7, 11, 13]) # primes + self._tlist = _array('L', [0, 1, 1, 2, 2, 4]) # totient + self._mlist = _array('i', [0, 1, -1, -1, 0, -1]) # mobius + if sieve_interval <= 0: + raise ValueError("sieve_interval should be a positive integer") + self.sieve_interval = sieve_interval + assert all(len(i) == self._n for i in (self._list, self._tlist, self._mlist)) + + def __repr__(self): + return ("<%s sieve (%i): %i, %i, %i, ... %i, %i\n" + "%s sieve (%i): %i, %i, %i, ... %i, %i\n" + "%s sieve (%i): %i, %i, %i, ... %i, %i>") % ( + 'prime', len(self._list), + self._list[0], self._list[1], self._list[2], + self._list[-2], self._list[-1], + 'totient', len(self._tlist), + self._tlist[0], self._tlist[1], + self._tlist[2], self._tlist[-2], self._tlist[-1], + 'mobius', len(self._mlist), + self._mlist[0], self._mlist[1], + self._mlist[2], self._mlist[-2], self._mlist[-1]) + + def _reset(self, prime=None, totient=None, mobius=None): + """Reset all caches (default). To reset one or more set the + desired keyword to True.""" + if all(i is None for i in (prime, totient, mobius)): + prime = totient = mobius = True + if prime: + self._list = self._list[:self._n] + if totient: + self._tlist = self._tlist[:self._n] + if mobius: + self._mlist = self._mlist[:self._n] + + def extend(self, n): + """Grow the sieve to cover all primes <= n. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> sieve.extend(30) + >>> sieve[10] == 29 + True + """ + n = int(n) + # `num` is even at any point in the function. + # This satisfies the condition required by `self._primerange`. + num = self._list[-1] + 1 + if n < num: + return + num2 = num**2 + while num2 <= n: + self._list += _array('L', self._primerange(num, num2)) + num, num2 = num2, num2**2 + # Merge the sieves + self._list += _array('L', self._primerange(num, n + 1)) + + def _primerange(self, a, b): + """ Generate all prime numbers in the range (a, b). + + Parameters + ========== + + a, b : positive integers assuming the following conditions + * a is an even number + * 2 < self._list[-1] < a < b < nextprime(self._list[-1])**2 + + Yields + ====== + + p (int): prime numbers such that ``a < p < b`` + + Examples + ======== + + >>> from sympy.ntheory.generate import Sieve + >>> s = Sieve() + >>> s._list[-1] + 13 + >>> list(s._primerange(18, 31)) + [19, 23, 29] + + """ + if b % 2: + b -= 1 + while a < b: + block_size = min(self.sieve_interval, (b - a) // 2) + # Create the list such that block[x] iff (a + 2x + 1) is prime. + # Note that even numbers are not considered here. + block = [True] * block_size + for p in self._list[1:bisect(self._list, sqrt(a + 2 * block_size + 1))]: + for t in range((-(a + 1 + p) // 2) % p, block_size, p): + block[t] = False + for idx, p in enumerate(block): + if p: + yield a + 2 * idx + 1 + a += 2 * block_size + + def extend_to_no(self, i): + """Extend to include the ith prime number. + + Parameters + ========== + + i : integer + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> sieve.extend_to_no(9) + >>> sieve._list + array('L', [2, 3, 5, 7, 11, 13, 17, 19, 23]) + + Notes + ===== + + The list is extended by 50% if it is too short, so it is + likely that it will be longer than requested. + """ + i = as_int(i) + while len(self._list) < i: + self.extend(int(self._list[-1] * 1.5)) + + def primerange(self, a, b=None): + """Generate all prime numbers in the range [2, a) or [a, b). + + Examples + ======== + + >>> from sympy import sieve, prime + + All primes less than 19: + + >>> print([i for i in sieve.primerange(19)]) + [2, 3, 5, 7, 11, 13, 17] + + All primes greater than or equal to 7 and less than 19: + + >>> print([i for i in sieve.primerange(7, 19)]) + [7, 11, 13, 17] + + All primes through the 10th prime + + >>> list(sieve.primerange(prime(10) + 1)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + """ + if b is None: + b = _as_int_ceiling(a) + a = 2 + else: + a = max(2, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + if a >= b: + return + self.extend(b) + yield from self._list[bisect_left(self._list, a): + bisect_left(self._list, b)] + + def totientrange(self, a, b): + """Generate all totient numbers for the range [a, b). + + Examples + ======== + + >>> from sympy import sieve + >>> print([i for i in sieve.totientrange(7, 18)]) + [6, 4, 6, 4, 10, 4, 12, 6, 8, 8, 16] + """ + a = max(1, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + n = len(self._tlist) + if a >= b: + return + elif b <= n: + for i in range(a, b): + yield self._tlist[i] + else: + self._tlist += _array('L', range(n, b)) + for i in range(1, n): + ti = self._tlist[i] + if ti == i - 1: + startindex = (n + i - 1) // i * i + for j in range(startindex, b, i): + self._tlist[j] -= self._tlist[j] // i + if i >= a: + yield ti + + for i in range(n, b): + ti = self._tlist[i] + if ti == i: + for j in range(i, b, i): + self._tlist[j] -= self._tlist[j] // i + if i >= a: + yield self._tlist[i] + + def mobiusrange(self, a, b): + """Generate all mobius numbers for the range [a, b). + + Parameters + ========== + + a : integer + First number in range + + b : integer + First number outside of range + + Examples + ======== + + >>> from sympy import sieve + >>> print([i for i in sieve.mobiusrange(7, 18)]) + [-1, 0, 0, 1, -1, 0, -1, 1, 1, 0, -1] + """ + a = max(1, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + n = len(self._mlist) + if a >= b: + return + elif b <= n: + for i in range(a, b): + yield self._mlist[i] + else: + self._mlist += _array('i', [0]*(b - n)) + for i in range(1, n): + mi = self._mlist[i] + startindex = (n + i - 1) // i * i + for j in range(startindex, b, i): + self._mlist[j] -= mi + if i >= a: + yield mi + + for i in range(n, b): + mi = self._mlist[i] + for j in range(2 * i, b, i): + self._mlist[j] -= mi + if i >= a: + yield mi + + def search(self, n): + """Return the indices i, j of the primes that bound n. + + If n is prime then i == j. + + Although n can be an expression, if ceiling cannot convert + it to an integer then an n error will be raised. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve.search(25) + (9, 10) + >>> sieve.search(23) + (9, 9) + """ + test = _as_int_ceiling(n) + n = as_int(n) + if n < 2: + raise ValueError("n should be >= 2 but got: %s" % n) + if n > self._list[-1]: + self.extend(n) + b = bisect(self._list, n) + if self._list[b - 1] == test: + return b, b + else: + return b, b + 1 + + def __contains__(self, n): + try: + n = as_int(n) + assert n >= 2 + except (ValueError, AssertionError): + return False + if n % 2 == 0: + return n == 2 + a, b = self.search(n) + return a == b + + def __iter__(self): + for n in count(1): + yield self[n] + + def __getitem__(self, n): + """Return the nth prime number""" + if isinstance(n, slice): + self.extend_to_no(n.stop) + start = n.start if n.start is not None else 0 + if start < 1: + # sieve[:5] would be empty (starting at -1), let's + # just be explicit and raise. + raise IndexError("Sieve indices start at 1.") + return self._list[start - 1:n.stop - 1:n.step] + else: + if n < 1: + # offset is one, so forbid explicit access to sieve[0] + # (would surprisingly return the last one). + raise IndexError("Sieve indices start at 1.") + n = as_int(n) + self.extend_to_no(n) + return self._list[n - 1] + +# Generate a global object for repeated use in trial division etc +sieve = Sieve() + +def prime(nth): + r""" + Return the nth prime number, where primes are indexed starting from 1: + prime(1) = 2, prime(2) = 3, etc. + + Parameters + ========== + + nth : int + The position of the prime number to return (must be a positive integer). + + Returns + ======= + + int + The nth prime number. + + Examples + ======== + + >>> from sympy import prime + >>> prime(10) + 29 + >>> prime(1) + 2 + >>> prime(100000) + 1299709 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if a number is prime. + primerange : Generate all primes in a given range. + primepi : Return the number of primes less than or equal to a given number. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Prime_number_theorem + .. [2] https://en.wikipedia.org/wiki/Logarithmic_integral_function + .. [3] https://en.wikipedia.org/wiki/Skewes%27_number + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; prime(1) == 2") + + # Check if n is within the sieve range + if n <= len(sieve._list): + return sieve[n] + + from sympy.functions.elementary.exponential import log + from sympy.functions.special.error_functions import li + + if n < 1000: + # Extend sieve up to 8*n as this is empirically sufficient + sieve.extend(8 * n) + return sieve[n] + + a = 2 + # Estimate an upper bound for the nth prime using the prime number theorem + b = int(n * (log(n).evalf() + log(log(n)).evalf())) + + # Binary search for the least m such that li(m) > n + while a < b: + mid = (a + b) >> 1 + if li(mid).evalf() > n: + b = mid + else: + a = mid + 1 + + return nextprime(a - 1, n - _primepi(a - 1)) + + +@deprecated("""\ +The `sympy.ntheory.generate.primepi` has been moved to `sympy.functions.combinatorial.numbers.primepi`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primepi(n): + r""" Represents the prime counting function pi(n) = the number + of prime numbers less than or equal to n. + + .. deprecated:: 1.13 + + The ``primepi`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primepi` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + Algorithm Description: + + In sieve method, we remove all multiples of prime p + except p itself. + + Let phi(i,j) be the number of integers 2 <= k <= i + which remain after sieving from primes less than + or equal to j. + Clearly, pi(n) = phi(n, sqrt(n)) + + If j is not a prime, + phi(i,j) = phi(i, j - 1) + + if j is a prime, + We remove all numbers(except j) whose + smallest prime factor is j. + + Let $x= j \times a$ be such a number, where $2 \le a \le i / j$ + Now, after sieving from primes $\le j - 1$, + a must remain + (because x, and hence a has no prime factor $\le j - 1$) + Clearly, there are phi(i / j, j - 1) such a + which remain on sieving from primes $\le j - 1$ + + Now, if a is a prime less than equal to j - 1, + $x= j \times a$ has smallest prime factor = a, and + has already been removed(by sieving from a). + So, we do not need to remove it again. + (Note: there will be pi(j - 1) such x) + + Thus, number of x, that will be removed are: + phi(i / j, j - 1) - phi(j - 1, j - 1) + (Note that pi(j - 1) = phi(j - 1, j - 1)) + + $\Rightarrow$ phi(i,j) = phi(i, j - 1) - phi(i / j, j - 1) + phi(j - 1, j - 1) + + So,following recursion is used and implemented as dp: + + phi(a, b) = phi(a, b - 1), if b is not a prime + phi(a, b) = phi(a, b-1)-phi(a / b, b-1) + phi(b-1, b-1), if b is prime + + Clearly a is always of the form floor(n / k), + which can take at most $2\sqrt{n}$ values. + Two arrays arr1,arr2 are maintained + arr1[i] = phi(i, j), + arr2[i] = phi(n // i, j) + + Finally the answer is arr2[1] + + Examples + ======== + + >>> from sympy import primepi, prime, prevprime, isprime + >>> primepi(25) + 9 + + So there are 9 primes less than or equal to 25. Is 25 prime? + + >>> isprime(25) + False + + It is not. So the first prime less than 25 must be the + 9th prime: + + >>> prevprime(25) == prime(9) + True + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + prime : Return the nth prime + """ + from sympy.functions.combinatorial.numbers import primepi as func_primepi + return func_primepi(n) + + +def _primepi(n:int) -> int: + r""" Represents the prime counting function pi(n) = the number + of prime numbers less than or equal to n. + + Explanation + =========== + + In sieve method, we remove all multiples of prime p + except p itself. + + Let phi(i,j) be the number of integers 2 <= k <= i + which remain after sieving from primes less than + or equal to j. + Clearly, pi(n) = phi(n, sqrt(n)) + + If j is not a prime, + phi(i,j) = phi(i, j - 1) + + if j is a prime, + We remove all numbers(except j) whose + smallest prime factor is j. + + Let $x= j \times a$ be such a number, where $2 \le a \le i / j$ + Now, after sieving from primes $\le j - 1$, + a must remain + (because x, and hence a has no prime factor $\le j - 1$) + Clearly, there are phi(i / j, j - 1) such a + which remain on sieving from primes $\le j - 1$ + + Now, if a is a prime less than equal to j - 1, + $x= j \times a$ has smallest prime factor = a, and + has already been removed(by sieving from a). + So, we do not need to remove it again. + (Note: there will be pi(j - 1) such x) + + Thus, number of x, that will be removed are: + phi(i / j, j - 1) - phi(j - 1, j - 1) + (Note that pi(j - 1) = phi(j - 1, j - 1)) + + $\Rightarrow$ phi(i,j) = phi(i, j - 1) - phi(i / j, j - 1) + phi(j - 1, j - 1) + + So,following recursion is used and implemented as dp: + + phi(a, b) = phi(a, b - 1), if b is not a prime + phi(a, b) = phi(a, b-1)-phi(a / b, b-1) + phi(b-1, b-1), if b is prime + + Clearly a is always of the form floor(n / k), + which can take at most $2\sqrt{n}$ values. + Two arrays arr1,arr2 are maintained + arr1[i] = phi(i, j), + arr2[i] = phi(n // i, j) + + Finally the answer is arr2[1] + + Parameters + ========== + + n : int + + """ + if n < 2: + return 0 + if n <= sieve._list[-1]: + return sieve.search(n)[0] + lim = sqrt(n) + arr1 = [(i + 1) >> 1 for i in range(lim + 1)] + arr2 = [0] + [(n//i + 1) >> 1 for i in range(1, lim + 1)] + skip = [False] * (lim + 1) + for i in range(3, lim + 1, 2): + # Presently, arr1[k]=phi(k,i - 1), + # arr2[k] = phi(n // k,i - 1) # not all k's do this + if skip[i]: + # skip if i is a composite number + continue + p = arr1[i - 1] + for j in range(i, lim + 1, i): + skip[j] = True + # update arr2 + # phi(n/j, i) = phi(n/j, i-1) - phi(n/(i*j), i-1) + phi(i-1, i-1) + for j in range(1, min(n // (i * i), lim) + 1, 2): + # No need for arr2[j] in j such that skip[j] is True to + # compute the final required arr2[1]. + if skip[j]: + continue + st = i * j + if st <= lim: + arr2[j] -= arr2[st] - p + else: + arr2[j] -= arr1[n // st] - p + # update arr1 + # phi(j, i) = phi(j, i-1) - phi(j/i, i-1) + phi(i-1, i-1) + # where the range below i**2 is fixed and + # does not need to be calculated. + for j in range(lim, min(lim, i*i - 1), -1): + arr1[j] -= arr1[j // i] - p + return arr2[1] + + +def nextprime(n, ith=1): + """ Return the ith prime greater than n. + + Parameters + ========== + + n : integer + ith : positive integer + + Returns + ======= + + int : Return the ith prime greater than n + + Raises + ====== + + ValueError + If ``ith <= 0``. + If ``n`` or ``ith`` is not an integer. + + Notes + ===== + + Potential primes are located at 6*j +/- 1. This + property is used during searching. + + >>> from sympy import nextprime + >>> [(i, nextprime(i)) for i in range(10, 15)] + [(10, 11), (11, 13), (12, 13), (13, 17), (14, 17)] + >>> nextprime(2, ith=2) # the 2nd prime after 2 + 5 + + See Also + ======== + + prevprime : Return the largest prime smaller than n + primerange : Generate all primes in a given range + + """ + n = int(n) + i = as_int(ith) + if i <= 0: + raise ValueError("ith should be positive") + if n < 2: + n = 2 + i -= 1 + if n <= sieve._list[-2]: + l, _ = sieve.search(n) + if l + i - 1 < len(sieve._list): + return sieve._list[l + i - 1] + n = sieve._list[-1] + i += l - len(sieve._list) + nn = 6*(n//6) + if nn == n: + n += 1 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + elif n - nn == 5: + n += 2 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + else: + n = nn + 5 + while 1: + if isprime(n): + i -= 1 + if not i: + return n + n += 2 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + + +def prevprime(n): + """ Return the largest prime smaller than n. + + Notes + ===== + + Potential primes are located at 6*j +/- 1. This + property is used during searching. + + >>> from sympy import prevprime + >>> [(i, prevprime(i)) for i in range(10, 15)] + [(10, 7), (11, 7), (12, 11), (13, 11), (14, 13)] + + See Also + ======== + + nextprime : Return the ith prime greater than n + primerange : Generates all primes in a given range + """ + n = _as_int_ceiling(n) + if n < 3: + raise ValueError("no preceding primes") + if n < 8: + return {3: 2, 4: 3, 5: 3, 6: 5, 7: 5}[n] + if n <= sieve._list[-1]: + l, u = sieve.search(n) + if l == u: + return sieve[l-1] + else: + return sieve[l] + nn = 6*(n//6) + if n - nn <= 1: + n = nn - 1 + if isprime(n): + return n + n -= 4 + else: + n = nn + 1 + while 1: + if isprime(n): + return n + n -= 2 + if isprime(n): + return n + n -= 4 + + +def primerange(a, b=None): + """ Generate a list of all prime numbers in the range [2, a), + or [a, b). + + If the range exists in the default sieve, the values will + be returned from there; otherwise values will be returned + but will not modify the sieve. + + Examples + ======== + + >>> from sympy import primerange, prime + + All primes less than 19: + + >>> list(primerange(19)) + [2, 3, 5, 7, 11, 13, 17] + + All primes greater than or equal to 7 and less than 19: + + >>> list(primerange(7, 19)) + [7, 11, 13, 17] + + All primes through the 10th prime + + >>> list(primerange(prime(10) + 1)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + The Sieve method, primerange, is generally faster but it will + occupy more memory as the sieve stores values. The default + instance of Sieve, named sieve, can be used: + + >>> from sympy import sieve + >>> list(sieve.primerange(1, 30)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + Notes + ===== + + Some famous conjectures about the occurrence of primes in a given + range are [1]: + + - Twin primes: though often not, the following will give 2 primes + an infinite number of times: + primerange(6*n - 1, 6*n + 2) + - Legendre's: the following always yields at least one prime + primerange(n**2, (n+1)**2+1) + - Bertrand's (proven): there is always a prime in the range + primerange(n, 2*n) + - Brocard's: there are at least four primes in the range + primerange(prime(n)**2, prime(n+1)**2) + + The average gap between primes is log(n) [2]; the gap between + primes can be arbitrarily large since sequences of composite + numbers are arbitrarily large, e.g. the numbers in the sequence + n! + 2, n! + 3 ... n! + n are all composite. + + See Also + ======== + + prime : Return the nth prime + nextprime : Return the ith prime greater than n + prevprime : Return the largest prime smaller than n + randprime : Returns a random prime in a given range + primorial : Returns the product of primes based on condition + Sieve.primerange : return range from already computed primes + or extend the sieve to contain the requested + range. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Prime_number + .. [2] https://primes.utm.edu/notes/gaps.html + """ + if b is None: + a, b = 2, a + if a >= b: + return + # If we already have the range, return it. + largest_known_prime = sieve._list[-1] + if b <= largest_known_prime: + yield from sieve.primerange(a, b) + return + # If we know some of it, return it. + if a <= largest_known_prime: + yield from sieve._list[bisect_left(sieve._list, a):] + a = largest_known_prime + 1 + elif a % 2: + a -= 1 + tail = min(b, (largest_known_prime)**2) + if a < tail: + yield from sieve._primerange(a, tail) + a = tail + if b <= a: + return + # otherwise compute, without storing, the desired range. + while 1: + a = nextprime(a) + if a < b: + yield a + else: + return + + +def randprime(a, b): + """ Return a random prime number in the range [a, b). + + Bertrand's postulate assures that + randprime(a, 2*a) will always succeed for a > 1. + + Note that due to implementation difficulties, + the prime numbers chosen are not uniformly random. + For example, there are two primes in the range [112, 128), + ``113`` and ``127``, but ``randprime(112, 128)`` returns ``127`` + with a probability of 15/17. + + Examples + ======== + + >>> from sympy import randprime, isprime + >>> randprime(1, 30) #doctest: +SKIP + 13 + >>> isprime(randprime(1, 30)) + True + + See Also + ======== + + primerange : Generate all primes in a given range + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bertrand's_postulate + + """ + if a >= b: + return + a, b = map(int, (a, b)) + n = randint(a - 1, b) + p = nextprime(n) + if p >= b: + p = prevprime(b) + if p < a: + raise ValueError("no primes exist in the specified range") + return p + + +def primorial(n, nth=True): + """ + Returns the product of the first n primes (default) or + the primes less than or equal to n (when ``nth=False``). + + Examples + ======== + + >>> from sympy.ntheory.generate import primorial, primerange + >>> from sympy import factorint, Mul, primefactors, sqrt + >>> primorial(4) # the first 4 primes are 2, 3, 5, 7 + 210 + >>> primorial(4, nth=False) # primes <= 4 are 2 and 3 + 6 + >>> primorial(1) + 2 + >>> primorial(1, nth=False) + 1 + >>> primorial(sqrt(101), nth=False) + 210 + + One can argue that the primes are infinite since if you take + a set of primes and multiply them together (e.g. the primorial) and + then add or subtract 1, the result cannot be divided by any of the + original factors, hence either 1 or more new primes must divide this + product of primes. + + In this case, the number itself is a new prime: + + >>> factorint(primorial(4) + 1) + {211: 1} + + In this case two new primes are the factors: + + >>> factorint(primorial(4) - 1) + {11: 1, 19: 1} + + Here, some primes smaller and larger than the primes multiplied together + are obtained: + + >>> p = list(primerange(10, 20)) + >>> sorted(set(primefactors(Mul(*p) + 1)).difference(set(p))) + [2, 5, 31, 149] + + See Also + ======== + + primerange : Generate all primes in a given range + + """ + if nth: + n = as_int(n) + else: + n = int(n) + if n < 1: + raise ValueError("primorial argument must be >= 1") + p = 1 + if nth: + for i in range(1, n + 1): + p *= prime(i) + else: + for i in primerange(2, n + 1): + p *= i + return p + + +def cycle_length(f, x0, nmax=None, values=False): + """For a given iterated sequence, return a generator that gives + the length of the iterated cycle (lambda) and the length of terms + before the cycle begins (mu); if ``values`` is True then the + terms of the sequence will be returned instead. The sequence is + started with value ``x0``. + + Note: more than the first lambda + mu terms may be returned and this + is the cost of cycle detection with Brent's method; there are, however, + generally less terms calculated than would have been calculated if the + proper ending point were determined, e.g. by using Floyd's method. + + >>> from sympy.ntheory.generate import cycle_length + + This will yield successive values of i <-- func(i): + + >>> def gen(func, i): + ... while 1: + ... yield i + ... i = func(i) + ... + + A function is defined: + + >>> func = lambda i: (i**2 + 1) % 51 + + and given a seed of 4 and the mu and lambda terms calculated: + + >>> next(cycle_length(func, 4)) + (6, 3) + + We can see what is meant by looking at the output: + + >>> iter = cycle_length(func, 4, values=True) + >>> list(iter) + [4, 17, 35, 2, 5, 26, 14, 44, 50, 2, 5, 26, 14] + + There are 6 repeating values after the first 3. + + If a sequence is suspected of being longer than you might wish, ``nmax`` + can be used to exit early (and mu will be returned as None): + + >>> next(cycle_length(func, 4, nmax = 4)) + (4, None) + >>> list(cycle_length(func, 4, nmax = 4, values=True)) + [4, 17, 35, 2] + + Code modified from: + https://en.wikipedia.org/wiki/Cycle_detection. + """ + + nmax = int(nmax or 0) + + # main phase: search successive powers of two + power = lam = 1 + tortoise, hare = x0, f(x0) # f(x0) is the element/node next to x0. + i = 1 + if values: + yield tortoise + while tortoise != hare and (not nmax or i < nmax): + i += 1 + if power == lam: # time to start a new power of two? + tortoise = hare + power *= 2 + lam = 0 + if values: + yield hare + hare = f(hare) + lam += 1 + if nmax and i == nmax: + if values: + return + else: + yield nmax, None + return + if not values: + # Find the position of the first repetition of length lambda + mu = 0 + tortoise = hare = x0 + for i in range(lam): + hare = f(hare) + while tortoise != hare: + tortoise = f(tortoise) + hare = f(hare) + mu += 1 + yield lam, mu + + +def composite(nth): + """ Return the nth composite number, with the composite numbers indexed as + composite(1) = 4, composite(2) = 6, etc.... + + Examples + ======== + + >>> from sympy import composite + >>> composite(36) + 52 + >>> composite(1) + 4 + >>> composite(17737) + 20000 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + primepi : Return the number of primes less than or equal to n + prime : Return the nth prime + compositepi : Return the number of positive composite numbers less than or equal to n + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; composite(1) == 4") + composite_arr = [4, 6, 8, 9, 10, 12, 14, 15, 16, 18] + if n <= 10: + return composite_arr[n - 1] + + a, b = 4, sieve._list[-1] + if n <= b - _primepi(b) - 1: + while a < b - 1: + mid = (a + b) >> 1 + if mid - _primepi(mid) - 1 > n: + b = mid + else: + a = mid + if isprime(a): + a -= 1 + return a + + from sympy.functions.elementary.exponential import log + from sympy.functions.special.error_functions import li + a = 4 # Lower bound for binary search + b = int(n*(log(n) + log(log(n)))) # Upper bound for the search. + + while a < b: + mid = (a + b) >> 1 + if mid - li(mid) - 1 > n: + b = mid + else: + a = mid + 1 + + n_composites = a - _primepi(a) - 1 + while n_composites > n: + if not isprime(a): + n_composites -= 1 + a -= 1 + if isprime(a): + a -= 1 + return a + + +def compositepi(n): + """ Return the number of positive composite numbers less than or equal to n. + The first positive composite is 4, i.e. compositepi(4) = 1. + + Examples + ======== + + >>> from sympy import compositepi + >>> compositepi(25) + 15 + >>> compositepi(1000) + 831 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + prime : Return the nth prime + primepi : Return the number of primes less than or equal to n + composite : Return the nth composite number + """ + n = int(n) + if n < 4: + return 0 + return n - _primepi(n) - 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/modular.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/modular.py new file mode 100644 index 0000000000000000000000000000000000000000..628a3d8c5a7fb4b6c51ad337df66d74f90282496 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/modular.py @@ -0,0 +1,291 @@ +from math import prod + +from sympy.external.gmpy import gcd, gcdext +from sympy.ntheory.primetest import isprime +from sympy.polys.domains import ZZ +from sympy.polys.galoistools import gf_crt, gf_crt1, gf_crt2 +from sympy.utilities.misc import as_int + + +def symmetric_residue(a, m): + """Return the residual mod m such that it is within half of the modulus. + + >>> from sympy.ntheory.modular import symmetric_residue + >>> symmetric_residue(1, 6) + 1 + >>> symmetric_residue(4, 6) + -2 + """ + if a <= m // 2: + return a + return a - m + + +def crt(m, v, symmetric=False, check=True): + r"""Chinese Remainder Theorem. + + The moduli in m are assumed to be pairwise coprime. The output + is then an integer f, such that f = v_i mod m_i for each pair out + of v and m. If ``symmetric`` is False a positive integer will be + returned, else \|f\| will be less than or equal to the LCM of the + moduli, and thus f may be negative. + + If the moduli are not co-prime the correct result will be returned + if/when the test of the result is found to be incorrect. This result + will be None if there is no solution. + + The keyword ``check`` can be set to False if it is known that the moduli + are coprime. + + Examples + ======== + + As an example consider a set of residues ``U = [49, 76, 65]`` + and a set of moduli ``M = [99, 97, 95]``. Then we have:: + + >>> from sympy.ntheory.modular import crt + + >>> crt([99, 97, 95], [49, 76, 65]) + (639985, 912285) + + This is the correct result because:: + + >>> [639985 % m for m in [99, 97, 95]] + [49, 76, 65] + + If the moduli are not co-prime, you may receive an incorrect result + if you use ``check=False``: + + >>> crt([12, 6, 17], [3, 4, 2], check=False) + (954, 1224) + >>> [954 % m for m in [12, 6, 17]] + [6, 0, 2] + >>> crt([12, 6, 17], [3, 4, 2]) is None + True + >>> crt([3, 6], [2, 5]) + (5, 6) + + Note: the order of gf_crt's arguments is reversed relative to crt, + and that solve_congruence takes residue, modulus pairs. + + Programmer's note: rather than checking that all pairs of moduli share + no GCD (an O(n**2) test) and rather than factoring all moduli and seeing + that there is no factor in common, a check that the result gives the + indicated residuals is performed -- an O(n) operation. + + See Also + ======== + + solve_congruence + sympy.polys.galoistools.gf_crt : low level crt routine used by this routine + """ + if check: + m = list(map(as_int, m)) + v = list(map(as_int, v)) + + result = gf_crt(v, m, ZZ) + mm = prod(m) + + if check: + if not all(v % m == result % m for v, m in zip(v, m)): + result = solve_congruence(*list(zip(v, m)), + check=False, symmetric=symmetric) + if result is None: + return result + result, mm = result + + if symmetric: + return int(symmetric_residue(result, mm)), int(mm) + return int(result), int(mm) + + +def crt1(m): + """First part of Chinese Remainder Theorem, for multiple application. + + Examples + ======== + + >>> from sympy.ntheory.modular import crt, crt1, crt2 + >>> m = [99, 97, 95] + >>> v = [49, 76, 65] + + The following two codes have the same result. + + >>> crt(m, v) + (639985, 912285) + + >>> mm, e, s = crt1(m) + >>> crt2(m, v, mm, e, s) + (639985, 912285) + + However, it is faster when we want to fix ``m`` and + compute for multiple ``v``, i.e. the following cases: + + >>> mm, e, s = crt1(m) + >>> vs = [[52, 21, 37], [19, 46, 76]] + >>> for v in vs: + ... print(crt2(m, v, mm, e, s)) + (397042, 912285) + (803206, 912285) + + See Also + ======== + + sympy.polys.galoistools.gf_crt1 : low level crt routine used by this routine + sympy.ntheory.modular.crt + sympy.ntheory.modular.crt2 + + """ + + return gf_crt1(m, ZZ) + + +def crt2(m, v, mm, e, s, symmetric=False): + """Second part of Chinese Remainder Theorem, for multiple application. + + See ``crt1`` for usage. + + Examples + ======== + + >>> from sympy.ntheory.modular import crt1, crt2 + >>> mm, e, s = crt1([18, 42, 6]) + >>> crt2([18, 42, 6], [0, 0, 0], mm, e, s) + (0, 4536) + + See Also + ======== + + sympy.polys.galoistools.gf_crt2 : low level crt routine used by this routine + sympy.ntheory.modular.crt + sympy.ntheory.modular.crt1 + + """ + + result = gf_crt2(v, m, mm, e, s, ZZ) + + if symmetric: + return int(symmetric_residue(result, mm)), int(mm) + return int(result), int(mm) + + +def solve_congruence(*remainder_modulus_pairs, **hint): + """Compute the integer ``n`` that has the residual ``ai`` when it is + divided by ``mi`` where the ``ai`` and ``mi`` are given as pairs to + this function: ((a1, m1), (a2, m2), ...). If there is no solution, + return None. Otherwise return ``n`` and its modulus. + + The ``mi`` values need not be co-prime. If it is known that the moduli are + not co-prime then the hint ``check`` can be set to False (default=True) and + the check for a quicker solution via crt() (valid when the moduli are + co-prime) will be skipped. + + If the hint ``symmetric`` is True (default is False), the value of ``n`` + will be within 1/2 of the modulus, possibly negative. + + Examples + ======== + + >>> from sympy.ntheory.modular import solve_congruence + + What number is 2 mod 3, 3 mod 5 and 2 mod 7? + + >>> solve_congruence((2, 3), (3, 5), (2, 7)) + (23, 105) + >>> [23 % m for m in [3, 5, 7]] + [2, 3, 2] + + If you prefer to work with all remainder in one list and + all moduli in another, send the arguments like this: + + >>> solve_congruence(*zip((2, 3, 2), (3, 5, 7))) + (23, 105) + + The moduli need not be co-prime; in this case there may or + may not be a solution: + + >>> solve_congruence((2, 3), (4, 6)) is None + True + + >>> solve_congruence((2, 3), (5, 6)) + (5, 6) + + The symmetric flag will make the result be within 1/2 of the modulus: + + >>> solve_congruence((2, 3), (5, 6), symmetric=True) + (-1, 6) + + See Also + ======== + + crt : high level routine implementing the Chinese Remainder Theorem + + """ + def combine(c1, c2): + """Return the tuple (a, m) which satisfies the requirement + that n = a + i*m satisfy n = a1 + j*m1 and n = a2 = k*m2. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Method_of_successive_substitution + """ + a1, m1 = c1 + a2, m2 = c2 + a, b, c = m1, a2 - a1, m2 + g = gcd(a, b, c) + a, b, c = [i//g for i in [a, b, c]] + if a != 1: + g, inv_a, _ = gcdext(a, c) + if g != 1: + return None + b *= inv_a + a, m = a1 + m1*b, m1*c + return a, m + + rm = remainder_modulus_pairs + symmetric = hint.get('symmetric', False) + + if hint.get('check', True): + rm = [(as_int(r), as_int(m)) for r, m in rm] + + # ignore redundant pairs but raise an error otherwise; also + # make sure that a unique set of bases is sent to gf_crt if + # they are all prime. + # + # The routine will work out less-trivial violations and + # return None, e.g. for the pairs (1,3) and (14,42) there + # is no answer because 14 mod 42 (having a gcd of 14) implies + # (14/2) mod (42/2), (14/7) mod (42/7) and (14/14) mod (42/14) + # which, being 0 mod 3, is inconsistent with 1 mod 3. But to + # preprocess the input beyond checking of another pair with 42 + # or 3 as the modulus (for this example) is not necessary. + uniq = {} + for r, m in rm: + r %= m + if m in uniq: + if r != uniq[m]: + return None + continue + uniq[m] = r + rm = [(r, m) for m, r in uniq.items()] + del uniq + + # if the moduli are co-prime, the crt will be significantly faster; + # checking all pairs for being co-prime gets to be slow but a prime + # test is a good trade-off + if all(isprime(m) for r, m in rm): + r, m = list(zip(*rm)) + return crt(m, r, symmetric=symmetric, check=False) + + rv = (0, 1) + for rmi in rm: + rv = combine(rv, rmi) + if rv is None: + break + n, m = rv + n = n % m + else: + if symmetric: + return symmetric_residue(n, m), m + return n, m diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/multinomial.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec50fdb533be547b9a8e60dc47568965bf89436 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/multinomial.py @@ -0,0 +1,188 @@ +from sympy.utilities.misc import as_int + + +def binomial_coefficients(n): + """Return a dictionary containing pairs :math:`{(k1,k2) : C_kn}` where + :math:`C_kn` are binomial coefficients and :math:`n=k1+k2`. + + Examples + ======== + + >>> from sympy.ntheory import binomial_coefficients + >>> binomial_coefficients(9) + {(0, 9): 1, (1, 8): 9, (2, 7): 36, (3, 6): 84, + (4, 5): 126, (5, 4): 126, (6, 3): 84, (7, 2): 36, (8, 1): 9, (9, 0): 1} + + See Also + ======== + + binomial_coefficients_list, multinomial_coefficients + """ + n = as_int(n) + d = {(0, n): 1, (n, 0): 1} + a = 1 + for k in range(1, n//2 + 1): + a = (a * (n - k + 1))//k + d[k, n - k] = d[n - k, k] = a + return d + + +def binomial_coefficients_list(n): + """ Return a list of binomial coefficients as rows of the Pascal's + triangle. + + Examples + ======== + + >>> from sympy.ntheory import binomial_coefficients_list + >>> binomial_coefficients_list(9) + [1, 9, 36, 84, 126, 126, 84, 36, 9, 1] + + See Also + ======== + + binomial_coefficients, multinomial_coefficients + """ + n = as_int(n) + d = [1] * (n + 1) + a = 1 + for k in range(1, n//2 + 1): + a = (a * (n - k + 1))//k + d[k] = d[n - k] = a + return d + + +def multinomial_coefficients(m, n): + r"""Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}`` + where ``C_kn`` are multinomial coefficients such that + ``n=k1+k2+..+km``. + + Examples + ======== + + >>> from sympy.ntheory import multinomial_coefficients + >>> multinomial_coefficients(2, 5) # indirect doctest + {(0, 5): 1, (1, 4): 5, (2, 3): 10, (3, 2): 10, (4, 1): 5, (5, 0): 1} + + Notes + ===== + + The algorithm is based on the following result: + + .. math:: + \binom{n}{k_1, \ldots, k_m} = + \frac{k_1 + 1}{n - k_1} \sum_{i=2}^m \binom{n}{k_1 + 1, \ldots, k_i - 1, \ldots} + + Code contributed to Sage by Yann Laigle-Chapuy, copied with permission + of the author. + + See Also + ======== + + binomial_coefficients_list, binomial_coefficients + """ + m = as_int(m) + n = as_int(n) + if not m: + if n: + return {} + return {(): 1} + if m == 2: + return binomial_coefficients(n) + if m >= 2*n and n > 1: + return dict(multinomial_coefficients_iterator(m, n)) + t = [n] + [0] * (m - 1) + r = {tuple(t): 1} + if n: + j = 0 # j will be the leftmost nonzero position + else: + j = m + # enumerate tuples in co-lex order + while j < m - 1: + # compute next tuple + tj = t[j] + if j: + t[j] = 0 + t[0] = tj + if tj > 1: + t[j + 1] += 1 + j = 0 + start = 1 + v = 0 + else: + j += 1 + start = j + 1 + v = r[tuple(t)] + t[j] += 1 + # compute the value + # NB: the initialization of v was done above + for k in range(start, m): + if t[k]: + t[k] -= 1 + v += r[tuple(t)] + t[k] += 1 + t[0] -= 1 + r[tuple(t)] = (v * tj) // (n - t[0]) + return r + + +def multinomial_coefficients_iterator(m, n, _tuple=tuple): + """multinomial coefficient iterator + + This routine has been optimized for `m` large with respect to `n` by taking + advantage of the fact that when the monomial tuples `t` are stripped of + zeros, their coefficient is the same as that of the monomial tuples from + ``multinomial_coefficients(n, n)``. Therefore, the latter coefficients are + precomputed to save memory and time. + + >>> from sympy.ntheory.multinomial import multinomial_coefficients + >>> m53, m33 = multinomial_coefficients(5,3), multinomial_coefficients(3,3) + >>> m53[(0,0,0,1,2)] == m53[(0,0,1,0,2)] == m53[(1,0,2,0,0)] == m33[(0,1,2)] + True + + Examples + ======== + + >>> from sympy.ntheory.multinomial import multinomial_coefficients_iterator + >>> it = multinomial_coefficients_iterator(20,3) + >>> next(it) + ((3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), 1) + """ + m = as_int(m) + n = as_int(n) + if m < 2*n or n == 1: + mc = multinomial_coefficients(m, n) + yield from mc.items() + else: + mc = multinomial_coefficients(n, n) + mc1 = {} + for k, v in mc.items(): + mc1[_tuple(filter(None, k))] = v + mc = mc1 + + t = [n] + [0] * (m - 1) + t1 = _tuple(t) + b = _tuple(filter(None, t1)) + yield (t1, mc[b]) + if n: + j = 0 # j will be the leftmost nonzero position + else: + j = m + # enumerate tuples in co-lex order + while j < m - 1: + # compute next tuple + tj = t[j] + if j: + t[j] = 0 + t[0] = tj + if tj > 1: + t[j + 1] += 1 + j = 0 + else: + j += 1 + t[j] += 1 + + t[0] -= 1 + t1 = _tuple(t) + b = _tuple(filter(None, t1)) + yield (t1, mc[b]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/partitions_.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/partitions_.py new file mode 100644 index 0000000000000000000000000000000000000000..953fa9e2fef146b0d3a9baad0ec5e1353ad6f237 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/partitions_.py @@ -0,0 +1,277 @@ +from mpmath.libmp import (fzero, from_int, from_rational, + fone, fhalf, bitcount, to_int, mpf_mul, mpf_div, mpf_sub, + mpf_add, mpf_sqrt, mpf_pi, mpf_cosh_sinh, mpf_cos, mpf_sin) +from .residue_ntheory import _sqrt_mod_prime_power, is_quad_residue +from sympy.utilities.decorator import deprecated +from sympy.utilities.memoization import recurrence_memo + +import math +from itertools import count + +def _pre(): + maxn = 10**5 + global _factor, _totient + _factor = [0]*maxn + _totient = [1]*maxn + lim = int(maxn**0.5) + 5 + for i in range(2, lim): + if _factor[i] == 0: + for j in range(i*i, maxn, i): + if _factor[j] == 0: + _factor[j] = i + for i in range(2, maxn): + if _factor[i] == 0: + _factor[i] = i + _totient[i] = i-1 + continue + x = _factor[i] + y = i//x + if y % x == 0: + _totient[i] = _totient[y]*x + else: + _totient[i] = _totient[y]*(x - 1) + +def _a(n, k, prec): + """ Compute the inner sum in HRR formula [1]_ + + References + ========== + + .. [1] https://msp.org/pjm/1956/6-1/pjm-v6-n1-p18-p.pdf + + """ + if k == 1: + return fone + + k1 = k + e = 0 + p = _factor[k] + while k1 % p == 0: + k1 //= p + e += 1 + k2 = k//k1 # k2 = p^e + v = 1 - 24*n + pi = mpf_pi(prec) + + if k1 == 1: + # k = p^e + if p == 2: + mod = 8*k + v = mod + v % mod + v = (v*pow(9, k - 1, mod)) % mod + m = _sqrt_mod_prime_power(v, 2, e + 3)[0] + arg = mpf_div(mpf_mul( + from_int(4*m), pi, prec), from_int(mod), prec) + return mpf_mul(mpf_mul( + from_int((-1)**e*(2 - (m % 4))), + mpf_sqrt(from_int(k), prec), prec), + mpf_sin(arg, prec), prec) + if p == 3: + mod = 3*k + v = mod + v % mod + if e > 1: + v = (v*pow(64, k//3 - 1, mod)) % mod + m = _sqrt_mod_prime_power(v, 3, e + 1)[0] + arg = mpf_div(mpf_mul(from_int(4*m), pi, prec), + from_int(mod), prec) + return mpf_mul(mpf_mul( + from_int(2*(-1)**(e + 1)*(3 - 2*(m % 3))), + mpf_sqrt(from_int(k//3), prec), prec), + mpf_sin(arg, prec), prec) + v = k + v % k + jacobi3 = -1 if k % 12 in [5, 7] else 1 + if v % p == 0: + if e == 1: + return mpf_mul( + from_int(jacobi3), + mpf_sqrt(from_int(k), prec), prec) + return fzero + if not is_quad_residue(v, p): + return fzero + _phi = p**(e - 1)*(p - 1) + v = (v*pow(576, _phi - 1, k)) + m = _sqrt_mod_prime_power(v, p, e)[0] + arg = mpf_div( + mpf_mul(from_int(4*m), pi, prec), + from_int(k), prec) + return mpf_mul(mpf_mul( + from_int(2*jacobi3), + mpf_sqrt(from_int(k), prec), prec), + mpf_cos(arg, prec), prec) + + if p != 2 or e >= 3: + d1, d2 = math.gcd(k1, 24), math.gcd(k2, 24) + e = 24//(d1*d2) + n1 = ((d2*e*n + (k2**2 - 1)//d1)* + pow(e*k2*k2*d2, _totient[k1] - 1, k1)) % k1 + n2 = ((d1*e*n + (k1**2 - 1)//d2)* + pow(e*k1*k1*d1, _totient[k2] - 1, k2)) % k2 + return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec) + if e == 2: + n1 = ((8*n + 5)*pow(128, _totient[k1] - 1, k1)) % k1 + n2 = (4 + ((n - 2 - (k1**2 - 1)//8)*(k1**2)) % 4) % 4 + return mpf_mul(mpf_mul( + from_int(-1), + _a(n1, k1, prec), prec), + _a(n2, k2, prec)) + n1 = ((8*n + 1)*pow(32, _totient[k1] - 1, k1)) % k1 + n2 = (2 + (n - (k1**2 - 1)//8) % 2) % 2 + return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec) + +def _d(n, j, prec, sq23pi, sqrt8): + """ + Compute the sinh term in the outer sum of the HRR formula. + The constants sqrt(2/3*pi) and sqrt(8) must be precomputed. + """ + j = from_int(j) + pi = mpf_pi(prec) + a = mpf_div(sq23pi, j, prec) + b = mpf_sub(from_int(n), from_rational(1, 24, prec), prec) + c = mpf_sqrt(b, prec) + ch, sh = mpf_cosh_sinh(mpf_mul(a, c), prec) + D = mpf_div( + mpf_sqrt(j, prec), + mpf_mul(mpf_mul(sqrt8, b), pi), prec) + E = mpf_sub(mpf_mul(a, ch), mpf_div(sh, c, prec), prec) + return mpf_mul(D, E) + + +@recurrence_memo([1, 1]) +def _partition_rec(n: int, prev) -> int: + """ Calculate the partition function P(n) + + Parameters + ========== + + n : int + nonnegative integer + + """ + v = 0 + penta = 0 # pentagonal number: 1, 5, 12, ... + for i in count(): + penta += 3*i + 1 + np = n - penta + if np < 0: + break + s = prev[np] + np -= i + 1 + # np = n - gp where gp = generalized pentagonal: 2, 7, 15, ... + if 0 <= np: + s += prev[np] + v += -s if i % 2 else s + return v + + +def _partition(n: int) -> int: + """ Calculate the partition function P(n) + + Parameters + ========== + + n : int + + """ + if n < 0: + return 0 + if (n <= 200_000 and n - _partition_rec.cache_length() < 70 or + _partition_rec.cache_length() == 2 and n < 14_400): + # There will be 2*10**5 elements created here + # and n elements created by partition, so in case we + # are going to be working with small n, we just + # use partition to calculate (and cache) the values + # since lookup is used there while summation, using + # _factor and _totient, will be used below. But we + # only do so if n is relatively close to the length + # of the cache since doing 1 calculation here is about + # the same as adding 70 elements to the cache. In addition, + # the startup here costs about the same as calculating the first + # 14,400 values via partition, so we delay startup here unless n + # is smaller than that. + return _partition_rec(n) + if '_factor' not in globals(): + _pre() + # Estimate number of bits in p(n). This formula could be tidied + pbits = int(( + math.pi*(2*n/3.)**0.5 - + math.log(4*n))/math.log(10) + 1) * \ + math.log2(10) + prec = p = int(pbits*1.1 + 100) + + # find the number of terms needed so rounded sum will be accurate + # using Rademacher's bound M(n, N) for the remainder after a partial + # sum of N terms (https://arxiv.org/pdf/1205.5991.pdf, (1.8)) + c1 = 44*math.pi**2/(225*math.sqrt(3)) + c2 = math.pi*math.sqrt(2)/75 + c3 = math.pi*math.sqrt(2/3) + def _M(n, N): + sqrt = math.sqrt + return c1/sqrt(N) + c2*sqrt(N/(n - 1))*math.sinh(c3*sqrt(n)/N) + big = max(9, math.ceil(n**0.5)) # should be too large (for n > 65, ceil should work) + assert _M(n, big) < 0.5 # else double big until too large + while big > 40 and _M(n, big) < 0.5: + big //= 2 + small = big + big = small*2 + while big - small > 1: + N = (big + small)//2 + if (er := _M(n, N)) < 0.5: + big = N + elif er >= 0.5: + small = N + M = big # done with function M; now have value + + # sanity check for expected size of answer + if M > 10**5: # i.e. M > maxn + raise ValueError("Input too big") # i.e. n > 149832547102 + + # calculate it + s = fzero + sq23pi = mpf_mul(mpf_sqrt(from_rational(2, 3, p), p), mpf_pi(p), p) + sqrt8 = mpf_sqrt(from_int(8), p) + for q in range(1, M): + a = _a(n, q, p) + d = _d(n, q, p, sq23pi, sqrt8) + s = mpf_add(s, mpf_mul(a, d), prec) + # On average, the terms decrease rapidly in magnitude. + # Dynamically reducing the precision greatly improves + # performance. + p = bitcount(abs(to_int(d))) + 50 + return int(to_int(mpf_add(s, fhalf, prec))) + + +@deprecated("""\ +The `sympy.ntheory.partitions_.npartitions` has been moved to `sympy.functions.combinatorial.numbers.partition`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def npartitions(n, verbose=False): + """ + Calculate the partition function P(n), i.e. the number of ways that + n can be written as a sum of positive integers. + + .. deprecated:: 1.13 + + The ``npartitions`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.partition` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + P(n) is computed using the Hardy-Ramanujan-Rademacher formula [1]_. + + + The correctness of this implementation has been tested through $10^{10}$. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import partition + >>> partition(25) + 1958 + + References + ========== + + .. [1] https://mathworld.wolfram.com/PartitionFunctionP.html + + """ + from sympy.functions.combinatorial.numbers import partition as func_partition + return func_partition(n) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/primetest.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/primetest.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3cb82cc51bf57ca345a7d72ee715c861f62e2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/primetest.py @@ -0,0 +1,830 @@ +""" +Primality testing + +""" + +from itertools import count + +from sympy.core.sympify import sympify +from sympy.external.gmpy import (gmpy as _gmpy, gcd, jacobi, + is_square as gmpy_is_square, + bit_scan1, is_fermat_prp, is_euler_prp, + is_selfridge_prp, is_strong_selfridge_prp, + is_strong_bpsw_prp) +from sympy.external.ntheory import _lucas_sequence +from sympy.utilities.misc import as_int, filldedent + +# Note: This list should be updated whenever new Mersenne primes are found. +# Refer: https://www.mersenne.org/ +MERSENNE_PRIME_EXPONENTS = (2, 3, 5, 7, 13, 17, 19, 31, 61, 89, 107, 127, 521, 607, 1279, 2203, + 2281, 3217, 4253, 4423, 9689, 9941, 11213, 19937, 21701, 23209, 44497, 86243, 110503, 132049, + 216091, 756839, 859433, 1257787, 1398269, 2976221, 3021377, 6972593, 13466917, 20996011, 24036583, + 25964951, 30402457, 32582657, 37156667, 42643801, 43112609, 57885161, 74207281, 77232917, 82589933, + 136279841) + + +def is_fermat_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{n-1} \equiv 1 \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Fermat pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_fermat_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_fermat_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 341 + 561 + 645 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fermat_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a == 1: + return n == 2 or bool(n % 2) + return is_fermat_prp(n, a) + + +def is_euler_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{(n-1)/2} \equiv \pm 1 \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Euler pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_euler_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_euler_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 341 + 561 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a < 1: + raise ValueError("a should be an integer greater than 0") + if n < 1: + raise ValueError("n should be an integer greater than 0") + if n == 1: + return False + if a == 1: + return n == 2 or bool(n % 2) # (prime or odd composite) + if n % 2 == 0: + return n == 2 + if gcd(n, a) != 1: + raise ValueError("The two numbers should be relatively prime") + return pow(a, (n - 1) // 2, n) in [1, n - 1] + + +def is_euler_jacobi_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{(n-1)/2} \equiv \left(\frac{a}{n}\right) \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Euler-Jacobi pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_euler_jacobi_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_euler_jacobi_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 561 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%E2%80%93Jacobi_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a == 1: + return n == 2 or bool(n % 2) + return is_euler_prp(n, a) + + +def is_square(n, prep=True): + """Return True if n == a * a for some integer a, else False. + If n is suspected of *not* being a square then this is a + quick method of confirming that it is not. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_square + >>> is_square(25) + True + >>> is_square(2) + False + + References + ========== + + .. [1] https://mersenneforum.org/showpost.php?p=110896 + + See Also + ======== + sympy.core.intfunc.isqrt + """ + if prep: + n = as_int(n) + if n < 0: + return False + if n in (0, 1): + return True + return gmpy_is_square(n) + + +def _test(n, base, s, t): + """Miller-Rabin strong pseudoprime test for one base. + Return False if n is definitely composite, True if n is + probably prime, with a probability greater than 3/4. + + """ + # do the Fermat test + b = pow(base, t, n) + if b == 1 or b == n - 1: + return True + for _ in range(s - 1): + b = pow(b, 2, n) + if b == n - 1: + return True + # see I. Niven et al. "An Introduction to Theory of Numbers", page 78 + if b == 1: + return False + return False + + +def mr(n, bases): + """Perform a Miller-Rabin strong pseudoprime test on n using a + given list of bases/witnesses. + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 135-138 + + A list of thresholds and the bases they require are here: + https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Deterministic_variants + + Examples + ======== + + >>> from sympy.ntheory.primetest import mr + >>> mr(1373651, [2, 3]) + False + >>> mr(479001599, [31, 73]) + True + + """ + from sympy.polys.domains import ZZ + + n = as_int(n) + if n < 2 or (n > 2 and n % 2 == 0): + return False + # remove powers of 2 from n-1 (= t * 2**s) + s = bit_scan1(n - 1) + t = n >> s + for base in bases: + # Bases >= n are wrapped, bases < 2 are invalid + if base >= n: + base %= n + if base >= 2: + base = ZZ(base) + if not _test(n, base, s, t): + return False + return True + + +def _lucas_extrastrong_params(n): + """Calculates the "extra strong" parameters (D, P, Q) for n. + + Parameters + ========== + + n : int + positive odd integer + + Returns + ======= + + D, P, Q: "extra strong" parameters. + ``(0, 0, 0)`` if we find a nontrivial divisor of ``n``. + + Examples + ======== + + >>> from sympy.ntheory.primetest import _lucas_extrastrong_params + >>> _lucas_extrastrong_params(101) + (12, 4, 1) + >>> _lucas_extrastrong_params(15) + (0, 0, 0) + + References + ========== + .. [1] OEIS A217719: Extra Strong Lucas Pseudoprimes + https://oeis.org/A217719 + .. [2] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + """ + for P in count(3): + D = P**2 - 4 + j = jacobi(D, n) + if j == -1: + return (D, P, 1) + elif j == 0 and D % n: + return (0, 0, 0) + + +def is_lucas_prp(n): + """Standard Lucas compositeness test with Selfridge parameters. Returns + False if n is definitely composite, and True if n is a Lucas probable + prime. + + This is typically used in combination with the Miller-Rabin test. + + References + ========== + .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [2] OEIS A217120: Lucas Pseudoprimes + https://oeis.org/A217120 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_lucas_prp + >>> for i in range(10000): + ... if is_lucas_prp(i) and not isprime(i): + ... print(i) + 323 + 377 + 1159 + 1829 + 3827 + 5459 + 5777 + 9071 + 9179 + """ + n = as_int(n) + if n < 2: + return False + return is_selfridge_prp(n) + + +def is_strong_lucas_prp(n): + """Strong Lucas compositeness test with Selfridge parameters. Returns + False if n is definitely composite, and True if n is a strong Lucas + probable prime. + + This is often used in combination with the Miller-Rabin test, and + in particular, when combined with M-R base 2 creates the strong BPSW test. + + References + ========== + .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [2] OEIS A217255: Strong Lucas Pseudoprimes + https://oeis.org/A217255 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + .. [4] https://en.wikipedia.org/wiki/Baillie-PSW_primality_test + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_strong_lucas_prp + >>> for i in range(20000): + ... if is_strong_lucas_prp(i) and not isprime(i): + ... print(i) + 5459 + 5777 + 10877 + 16109 + 18971 + """ + n = as_int(n) + if n < 2: + return False + return is_strong_selfridge_prp(n) + + +def is_extra_strong_lucas_prp(n): + """Extra Strong Lucas compositeness test. Returns False if n is + definitely composite, and True if n is an "extra strong" Lucas probable + prime. + + The parameters are selected using P = 3, Q = 1, then incrementing P until + (D|n) == -1. The test itself is as defined in [1]_, from the + Mo and Jones preprint. The parameter selection and test are the same as + used in OEIS A217719, Perl's Math::Prime::Util, and the Lucas pseudoprime + page on Wikipedia. + + It is 20-50% faster than the strong test. + + Because of the different parameters selected, there is no relationship + between the strong Lucas pseudoprimes and extra strong Lucas pseudoprimes. + In particular, one is not a subset of the other. + + References + ========== + .. [1] Jon Grantham, Frobenius Pseudoprimes, + Math. Comp. Vol 70, Number 234 (2001), pp. 873-891, + https://doi.org/10.1090%2FS0025-5718-00-01197-2 + .. [2] OEIS A217719: Extra Strong Lucas Pseudoprimes + https://oeis.org/A217719 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_extra_strong_lucas_prp + >>> for i in range(20000): + ... if is_extra_strong_lucas_prp(i) and not isprime(i): + ... print(i) + 989 + 3239 + 5777 + 10877 + """ + # Implementation notes: + # 1) the parameters differ from Thomas R. Nicely's. His parameter + # selection leads to pseudoprimes that overlap M-R tests, and + # contradict Baillie and Wagstaff's suggestion of (D|n) = -1. + # 2) The MathWorld page as of June 2013 specifies Q=-1. The Lucas + # sequence must have Q=1. See Grantham theorem 2.3, any of the + # references on the MathWorld page, or run it and see Q=-1 is wrong. + n = as_int(n) + if n == 2: + return True + if n < 2 or (n % 2) == 0: + return False + if gmpy_is_square(n): + return False + + D, P, Q = _lucas_extrastrong_params(n) + if D == 0: + return False + + # remove powers of 2 from n+1 (= k * 2**s) + s = bit_scan1(n + 1) + k = (n + 1) >> s + + U, V, _ = _lucas_sequence(n, P, Q, k) + + if U == 0 and (V == 2 or V == n - 2): + return True + for _ in range(1, s): + if V == 0: + return True + V = (V*V - 2) % n + return False + + +def proth_test(n): + r""" Test if the Proth number `n = k2^m + 1` is prime. where k is a positive odd number and `2^m > k`. + + Parameters + ========== + + n : Integer + ``n`` is Proth number + + Returns + ======= + + bool : If ``True``, then ``n`` is the Proth prime + + Raises + ====== + + ValueError + If ``n`` is not Proth number. + + Examples + ======== + + >>> from sympy.ntheory.primetest import proth_test + >>> proth_test(41) + True + >>> proth_test(57) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Proth_prime + + """ + n = as_int(n) + if n < 3: + raise ValueError("n is not Proth number") + m = bit_scan1(n - 1) + k = n >> m + if m < k.bit_length(): + raise ValueError("n is not Proth number") + if n % 3 == 0: + return n == 3 + if k % 3: # n % 12 == 5 + return pow(3, n >> 1, n) == n - 1 + # If `n` is a square number, then `jacobi(a, n) = 1` for any `a` + if gmpy_is_square(n): + return False + # `a` may be chosen at random. + # In any case, we want to find `a` such that `jacobi(a, n) = -1`. + for a in range(5, n): + j = jacobi(a, n) + if j == -1: + return pow(a, n >> 1, n) == n - 1 + if j == 0: + return False + + +def _lucas_lehmer_primality_test(p): + r""" Test if the Mersenne number `M_p = 2^p-1` is prime. + + Parameters + ========== + + p : int + ``p`` is an odd prime number + + Returns + ======= + + bool : If ``True``, then `M_p` is the Mersenne prime + + Examples + ======== + + >>> from sympy.ntheory.primetest import _lucas_lehmer_primality_test + >>> _lucas_lehmer_primality_test(5) # 2**5 - 1 = 31 is prime + True + >>> _lucas_lehmer_primality_test(11) # 2**11 - 1 = 2047 is not prime + False + + See Also + ======== + + is_mersenne_prime + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lucas%E2%80%93Lehmer_primality_test + + """ + v = 4 + m = 2**p - 1 + for _ in range(p - 2): + v = pow(v, 2, m) - 2 + return v == 0 + + +def is_mersenne_prime(n): + """Returns True if ``n`` is a Mersenne prime, else False. + + A Mersenne prime is a prime number having the form `2^i - 1`. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_mersenne_prime + >>> is_mersenne_prime(6) + False + >>> is_mersenne_prime(127) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/MersennePrime.html + + """ + n = as_int(n) + if n < 1: + return False + if n & (n + 1): + # n is not Mersenne number + return False + p = n.bit_length() + if p in MERSENNE_PRIME_EXPONENTS: + return True + if p < 65_000_000 or not isprime(p): + # According to GIMPS, verification was completed on September 19, 2023 for p less than 65 million. + # https://www.mersenne.org/report_milestones/ + # If p is composite number, then n=2**p-1 is composite number. + return False + result = _lucas_lehmer_primality_test(p) + if result: + raise ValueError(filldedent(''' + This Mersenne Prime, 2^%s - 1, should + be added to SymPy's known values.''' % p)) + return result + + +_MR_BASES_32 = [15591, 2018, 166, 7429, 8064, 16045, 10503, 4399, 1949, 1295, + 2776, 3620, 560, 3128, 5212, 2657, 2300, 2021, 4652, 1471, + 9336, 4018, 2398, 20462, 10277, 8028, 2213, 6219, 620, 3763, + 4852, 5012, 3185, 1333, 6227,5298, 1074, 2391, 5113, 7061, + 803, 1269, 3875, 422, 751, 580, 4729, 10239, 746, 2951, 556, + 2206, 3778, 481, 1522, 3476, 481, 2487, 3266, 5633, 488, 3373, + 6441, 3344, 17, 15105, 1490, 4154, 2036, 1882, 1813, 467, + 3307, 14042, 6371, 658, 1005, 903, 737, 1887, 7447, 1888, + 2848, 1784, 7559, 3400, 951, 13969, 4304, 177, 41, 19875, + 3110, 13221, 8726, 571, 7043, 6943, 1199, 352, 6435, 165, + 1169, 3315, 978, 233, 3003, 2562, 2994, 10587, 10030, 2377, + 1902, 5354, 4447, 1555, 263, 27027, 2283, 305, 669, 1912, 601, + 6186, 429, 1930, 14873, 1784, 1661, 524, 3577, 236, 2360, + 6146, 2850, 55637, 1753, 4178, 8466, 222, 2579, 2743, 2031, + 2226, 2276, 374, 2132, 813, 23788, 1610, 4422, 5159, 1725, + 3597, 3366, 14336, 579, 165, 1375, 10018, 12616, 9816, 1371, + 536, 1867, 10864, 857, 2206, 5788, 434, 8085, 17618, 727, + 3639, 1595, 4944, 2129, 2029, 8195, 8344, 6232, 9183, 8126, + 1870, 3296, 7455, 8947, 25017, 541, 19115, 368, 566, 5674, + 411, 522, 1027, 8215, 2050, 6544, 10049, 614, 774, 2333, 3007, + 35201, 4706, 1152, 1785, 1028, 1540, 3743, 493, 4474, 2521, + 26845, 8354, 864, 18915, 5465, 2447, 42, 4511, 1660, 166, + 1249, 6259, 2553, 304, 272, 7286, 73, 6554, 899, 2816, 5197, + 13330, 7054, 2818, 3199, 811, 922, 350, 7514, 4452, 3449, + 2663, 4708, 418, 1621, 1171, 3471, 88, 11345, 412, 1559, 194] + + +def isprime(n): + """ + Test if n is a prime number (True) or not (False). For n < 2^64 the + answer is definitive; larger n values have a small probability of actually + being pseudoprimes. + + Negative numbers (e.g. -2) are not considered prime. + + The first step is looking for trivial factors, which if found enables + a quick return. Next, if the sieve is large enough, use bisection search + on the sieve. For small numbers, a set of deterministic Miller-Rabin + tests are performed with bases that are known to have no counterexamples + in their range. Finally if the number is larger than 2^64, a strong + BPSW test is performed. While this is a probable prime test and we + believe counterexamples exist, there are no known counterexamples. + + Examples + ======== + + >>> from sympy.ntheory import isprime + >>> isprime(13) + True + >>> isprime(15) + False + + Notes + ===== + + This routine is intended only for integer input, not numerical + expressions which may represent numbers. Floats are also + rejected as input because they represent numbers of limited + precision. While it is tempting to permit 7.0 to represent an + integer there are errors that may "pass silently" if this is + allowed: + + >>> from sympy import Float, S + >>> int(1e3) == 1e3 == 10**3 + True + >>> int(1e23) == 1e23 + True + >>> int(1e23) == 10**23 + False + + >>> near_int = 1 + S(1)/10**19 + >>> near_int == int(near_int) + False + >>> n = Float(near_int, 10) # truncated by precision + >>> n % 1 == 0 + True + >>> n = Float(near_int, 20) + >>> n % 1 == 0 + False + + See Also + ======== + + sympy.ntheory.generate.primerange : Generates all primes in a given range + sympy.functions.combinatorial.numbers.primepi : Return the number of primes less than or equal to n + sympy.ntheory.generate.prime : Return the nth prime + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Strong_pseudoprime + .. [2] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [3] https://en.wikipedia.org/wiki/Baillie-PSW_primality_test + """ + n = as_int(n) + + # Step 1, do quick composite testing via trial division. The individual + # modulo tests benchmark faster than one or two primorial igcds for me. + # The point here is just to speedily handle small numbers and many + # composites. Step 2 only requires that n <= 2 get handled here. + if n in [2, 3, 5]: + return True + if n < 2 or (n % 2) == 0 or (n % 3) == 0 or (n % 5) == 0: + return False + if n < 49: + return True + if (n % 7) == 0 or (n % 11) == 0 or (n % 13) == 0 or (n % 17) == 0 or \ + (n % 19) == 0 or (n % 23) == 0 or (n % 29) == 0 or (n % 31) == 0 or \ + (n % 37) == 0 or (n % 41) == 0 or (n % 43) == 0 or (n % 47) == 0: + return False + if n < 2809: + return True + if n < 65077: + # There are only five Euler pseudoprimes with a least prime factor greater than 47 + return pow(2, n >> 1, n) in [1, n - 1] and n not in [8321, 31621, 42799, 49141, 49981] + + # bisection search on the sieve if the sieve is large enough + from sympy.ntheory.generate import sieve as s + if n <= s._list[-1]: + l, u = s.search(n) + return l == u + from sympy.ntheory.factor_ import factor_cache + if (ret := factor_cache.get(n)) is not None: + return ret == n + + # If we have GMPY2, skip straight to step 3 and do a strong BPSW test. + # This should be a bit faster than our step 2, and for large values will + # be a lot faster than our step 3 (C+GMP vs. Python). + if _gmpy is not None: + return is_strong_bpsw_prp(n) + + + # Step 2: deterministic Miller-Rabin testing for numbers < 2^64. See: + # https://miller-rabin.appspot.com/ + # for lists. We have made sure the M-R routine will successfully handle + # bases larger than n, so we can use the minimal set. + # In September 2015 deterministic numbers were extended to over 2^81. + # https://arxiv.org/pdf/1509.00864.pdf + # https://oeis.org/A014233 + if n < 341531: + return mr(n, [9345883071009581737]) + if n < 4296595241: + # Michal Forisek and Jakub Jancina, + # Fast Primality Testing for Integers That Fit into a Machine Word + # https://ceur-ws.org/Vol-1326/020-Forisek.pdf + h = ((n >> 16) ^ n) * 0x45d9f3b + h = ((h >> 16) ^ h) * 0x45d9f3b + h = ((h >> 16) ^ h) & 255 + return mr(n, [_MR_BASES_32[h]]) + if n < 350269456337: + return mr(n, [4230279247111683200, 14694767155120705706, 16641139526367750375]) + if n < 55245642489451: + return mr(n, [2, 141889084524735, 1199124725622454117, 11096072698276303650]) + if n < 7999252175582851: + return mr(n, [2, 4130806001517, 149795463772692060, 186635894390467037, 3967304179347715805]) + if n < 585226005592931977: + return mr(n, [2, 123635709730000, 9233062284813009, 43835965440333360, 761179012939631437, 1263739024124850375]) + if n < 18446744073709551616: + return mr(n, [2, 325, 9375, 28178, 450775, 9780504, 1795265022]) + if n < 318665857834031151167461: + return mr(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]) + if n < 3317044064679887385961981: + return mr(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41]) + + # We could do this instead at any point: + #if n < 18446744073709551616: + # return mr(n, [2]) and is_extra_strong_lucas_prp(n) + + # Here are tests that are safe for MR routines that don't understand + # large bases. + #if n < 9080191: + # return mr(n, [31, 73]) + #if n < 19471033: + # return mr(n, [2, 299417]) + #if n < 38010307: + # return mr(n, [2, 9332593]) + #if n < 316349281: + # return mr(n, [11000544, 31481107]) + #if n < 4759123141: + # return mr(n, [2, 7, 61]) + #if n < 105936894253: + # return mr(n, [2, 1005905886, 1340600841]) + #if n < 31858317218647: + # return mr(n, [2, 642735, 553174392, 3046413974]) + #if n < 3071837692357849: + # return mr(n, [2, 75088, 642735, 203659041, 3613982119]) + #if n < 18446744073709551616: + # return mr(n, [2, 325, 9375, 28178, 450775, 9780504, 1795265022]) + + # Step 3: BPSW. + # + # Time for isprime(10**2000 + 4561), no gmpy or gmpy2 installed + # 44.0s old isprime using 46 bases + # 5.3s strong BPSW + one random base + # 4.3s extra strong BPSW + one random base + # 4.1s strong BPSW + # 3.2s extra strong BPSW + + # Classic BPSW from page 1401 of the paper. See alternate ideas below. + return is_strong_bpsw_prp(n) + + # Using extra strong test, which is somewhat faster + #return mr(n, [2]) and is_extra_strong_lucas_prp(n) + + # Add a random M-R base + #import random + #return mr(n, [2, random.randint(3, n-1)]) and is_strong_lucas_prp(n) + + +def is_gaussian_prime(num): + r"""Test if num is a Gaussian prime number. + + References + ========== + + .. [1] https://oeis.org/wiki/Gaussian_primes + """ + + num = sympify(num) + a, b = num.as_real_imag() + a = as_int(a, strict=False) + b = as_int(b, strict=False) + if a == 0: + b = abs(b) + return isprime(b) and b % 4 == 3 + elif b == 0: + a = abs(a) + return isprime(a) and a % 4 == 3 + return isprime(a**2 + b**2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/qs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/qs.py new file mode 100644 index 0000000000000000000000000000000000000000..acc9a7b6e0151695538a99a738ef397166497ba5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/qs.py @@ -0,0 +1,451 @@ +from math import exp, log +from sympy.core.random import _randint +from sympy.external.gmpy import bit_scan1, gcd, invert, sqrt as isqrt +from sympy.ntheory.factor_ import _perfect_power +from sympy.ntheory.primetest import isprime +from sympy.ntheory.residue_ntheory import _sqrt_mod_prime_power + + +class SievePolynomial: + def __init__(self, a, b, N): + """This class denotes the sieve polynomial. + Provide methods to compute `(a*x + b)**2 - N` and + `a*x + b` when given `x`. + + Parameters + ========== + + a : parameter of the sieve polynomial + b : parameter of the sieve polynomial + N : number to be factored + + """ + self.a = a + self.b = b + self.a2 = a**2 + self.ab = 2*a*b + self.b2 = b**2 - N + + def eval_u(self, x): + return self.a*x + self.b + + def eval_v(self, x): + return (self.a2*x + self.ab)*x + self.b2 + + +class FactorBaseElem: + """This class stores an element of the `factor_base`. + """ + def __init__(self, prime, tmem_p, log_p): + """ + Initialization of factor_base_elem. + + Parameters + ========== + + prime : prime number of the factor_base + tmem_p : Integer square root of x**2 = n mod prime + log_p : Compute Natural Logarithm of the prime + """ + self.prime = prime + self.tmem_p = tmem_p + self.log_p = log_p + # `soln1` and `soln2` are solutions to + # the equation `(a*x + b)**2 - N = 0 (mod p)`. + self.soln1 = None + self.soln2 = None + self.b_ainv = None + + +def _generate_factor_base(prime_bound, n): + """Generate `factor_base` for Quadratic Sieve. The `factor_base` + consists of all the points whose ``legendre_symbol(n, p) == 1`` + and ``p < num_primes``. Along with the prime `factor_base` also stores + natural logarithm of prime and the residue n modulo p. + It also returns the of primes numbers in the `factor_base` which are + close to 1000 and 5000. + + Parameters + ========== + + prime_bound : upper prime bound of the factor_base + n : integer to be factored + """ + from sympy.ntheory.generate import sieve + factor_base = [] + idx_1000, idx_5000 = None, None + for prime in sieve.primerange(1, prime_bound): + if pow(n, (prime - 1) // 2, prime) == 1: + if prime > 1000 and idx_1000 is None: + idx_1000 = len(factor_base) - 1 + if prime > 5000 and idx_5000 is None: + idx_5000 = len(factor_base) - 1 + residue = _sqrt_mod_prime_power(n, prime, 1)[0] + log_p = round(log(prime)*2**10) + factor_base.append(FactorBaseElem(prime, residue, log_p)) + return idx_1000, idx_5000, factor_base + + +def _generate_polynomial(N, M, factor_base, idx_1000, idx_5000, randint): + """ Generate sieve polynomials indefinitely. + Information such as `soln1` in the `factor_base` associated with + the polynomial is modified in place. + + Parameters + ========== + + N : Number to be factored + M : sieve interval + factor_base : factor_base primes + idx_1000 : index of prime number in the factor_base near 1000 + idx_5000 : index of prime number in the factor_base near to 5000 + randint : A callable that takes two integers (a, b) and returns a random integer + n such that a <= n <= b, similar to `random.randint`. + """ + approx_val = log(2*N)/2 - log(M) + start = idx_1000 or 0 + end = idx_5000 or (len(factor_base) - 1) + while True: + # Choose `a` that is close to `sqrt(2*N) / M` + best_a, best_q, best_ratio = None, None, None + for _ in range(50): + a = 1 + q = [] + while log(a) < approx_val: + rand_p = 0 + while(rand_p == 0 or rand_p in q): + rand_p = randint(start, end) + p = factor_base[rand_p].prime + a *= p + q.append(rand_p) + ratio = exp(log(a) - approx_val) + if best_ratio is None or abs(ratio - 1) < abs(best_ratio - 1): + best_q = q + best_a = a + best_ratio = ratio + + # Set `b` using the Chinese remainder theorem + a = best_a + q = best_q + B = [] + for val in q: + q_l = factor_base[val].prime + gamma = factor_base[val].tmem_p * invert(a // q_l, q_l) % q_l + if 2*gamma > q_l: + gamma = q_l - gamma + B.append(a//q_l*gamma) + b = sum(B) + g = SievePolynomial(a, b, N) + for fb in factor_base: + if a % fb.prime == 0: + fb.soln1 = None + continue + a_inv = invert(a, fb.prime) + fb.b_ainv = [2*b_elem*a_inv % fb.prime for b_elem in B] + fb.soln1 = (a_inv*(fb.tmem_p - b)) % fb.prime + fb.soln2 = (a_inv*(-fb.tmem_p - b)) % fb.prime + yield g + + # Update `b` with Gray code + for i in range(1, 2**(len(B)-1)): + v = bit_scan1(i) + neg_pow = 2*((i >> (v + 1)) % 2) - 1 + b = g.b + 2*neg_pow*B[v] + a = g.a + g = SievePolynomial(a, b, N) + for fb in factor_base: + if fb.soln1 is None: + continue + fb.soln1 = (fb.soln1 - neg_pow*fb.b_ainv[v]) % fb.prime + fb.soln2 = (fb.soln2 - neg_pow*fb.b_ainv[v]) % fb.prime + yield g + + +def _gen_sieve_array(M, factor_base): + """Sieve Stage of the Quadratic Sieve. For every prime in the factor_base + that does not divide the coefficient `a` we add log_p over the sieve_array + such that ``-M <= soln1 + i*p <= M`` and ``-M <= soln2 + i*p <= M`` where `i` + is an integer. When p = 2 then log_p is only added using + ``-M <= soln1 + i*p <= M``. + + Parameters + ========== + + M : sieve interval + factor_base : factor_base primes + """ + sieve_array = [0]*(2*M + 1) + for factor in factor_base: + if factor.soln1 is None: #The prime does not divides a + continue + for idx in range((M + factor.soln1) % factor.prime, 2*M, factor.prime): + sieve_array[idx] += factor.log_p + if factor.prime == 2: + continue + #if prime is 2 then sieve only with soln_1_p + for idx in range((M + factor.soln2) % factor.prime, 2*M, factor.prime): + sieve_array[idx] += factor.log_p + return sieve_array + + +def _check_smoothness(num, factor_base): + r""" Check if `num` is smooth with respect to the given `factor_base` + and compute its factorization vector. + + Parameters + ========== + + num : integer whose smootheness is to be checked + factor_base : factor_base primes + """ + if num < 0: + num *= -1 + vec = 1 + else: + vec = 0 + for i, fb in enumerate(factor_base, 1): + if num % fb.prime: + continue + e = 1 + num //= fb.prime + while num % fb.prime == 0: + e += 1 + num //= fb.prime + if e % 2: + vec += 1 << i + return vec, num + + +def _trial_division_stage(N, M, factor_base, sieve_array, sieve_poly, partial_relations, ERROR_TERM): + """Trial division stage. Here we trial divide the values generetated + by sieve_poly in the sieve interval and if it is a smooth number then + it is stored in `smooth_relations`. Moreover, if we find two partial relations + with same large prime then they are combined to form a smooth relation. + First we iterate over sieve array and look for values which are greater + than accumulated_val, as these values have a high chance of being smooth + number. Then using these values we find smooth relations. + In general, let ``t**2 = u*p modN`` and ``r**2 = v*p modN`` be two partial relations + with the same large prime p. Then they can be combined ``(t*r/p)**2 = u*v modN`` + to form a smooth relation. + + Parameters + ========== + + N : Number to be factored + M : sieve interval + factor_base : factor_base primes + sieve_array : stores log_p values + sieve_poly : polynomial from which we find smooth relations + partial_relations : stores partial relations with one large prime + ERROR_TERM : error term for accumulated_val + """ + accumulated_val = (log(M) + log(N)/2 - ERROR_TERM) * 2**10 + smooth_relations = [] + proper_factor = set() + partial_relation_upper_bound = 128*factor_base[-1].prime + for x, val in enumerate(sieve_array, -M): + if val < accumulated_val: + continue + v = sieve_poly.eval_v(x) + vec, num = _check_smoothness(v, factor_base) + if num == 1: + smooth_relations.append((sieve_poly.eval_u(x), v, vec)) + elif num < partial_relation_upper_bound and isprime(num): + if N % num == 0: + proper_factor.add(num) + continue + u = sieve_poly.eval_u(x) + if num in partial_relations: + u_prev, v_prev, vec_prev = partial_relations.pop(num) + u = u*u_prev*invert(num, N) % N + v = v*v_prev // num**2 + vec ^= vec_prev + smooth_relations.append((u, v, vec)) + else: + partial_relations[num] = (u, v, vec) + return smooth_relations, proper_factor + + +def _find_factor(N, smooth_relations, col): + """ Finds proper factor of N using fast gaussian reduction for modulo 2 matrix. + + Parameters + ========== + + N : Number to be factored + smooth_relations : Smooth relations vectors matrix + col : Number of columns in the matrix + + Reference + ========== + + .. [1] A fast algorithm for gaussian elimination over GF(2) and + its implementation on the GAPP. Cetin K.Koc, Sarath N.Arachchige + """ + matrix = [s_relation[2] for s_relation in smooth_relations] + row = len(matrix) + mark = [False] * row + for pos in range(col): + m = 1 << pos + for i in range(row): + if p := matrix[i] & m: + add_col = p ^ matrix[i] + matrix[i] = m + mark[i] = True + for j in range(i + 1, row): + if matrix[j] & m: + matrix[j] ^= add_col + break + + for m, mat, rel in zip(mark, matrix, smooth_relations): + if m: + continue + u, v = rel[0], rel[1] + for m1, mat1, rel1 in zip(mark, matrix, smooth_relations): + if m1 and mat & mat1: + u *= rel1[0] + v *= rel1[1] + # assert is_square(v) + v = isqrt(v) + if 1 < (g := gcd(u - v, N)) < N: + yield g + + +def qs(N, prime_bound, M, ERROR_TERM=25, seed=1234): + """Performs factorization using Self-Initializing Quadratic Sieve. + In SIQS, let N be a number to be factored, and this N should not be a + perfect power. If we find two integers such that ``X**2 = Y**2 modN`` and + ``X != +-Y modN``, then `gcd(X + Y, N)` will reveal a proper factor of N. + In order to find these integers X and Y we try to find relations of form + t**2 = u modN where u is a product of small primes. If we have enough of + these relations then we can form ``(t1*t2...ti)**2 = u1*u2...ui modN`` such that + the right hand side is a square, thus we found a relation of ``X**2 = Y**2 modN``. + + Here, several optimizations are done like using multiple polynomials for + sieving, fast changing between polynomials and using partial relations. + The use of partial relations can speeds up the factoring by 2 times. + + Parameters + ========== + + N : Number to be Factored + prime_bound : upper bound for primes in the factor base + M : Sieve Interval + ERROR_TERM : Error term for checking smoothness + seed : seed of random number generator + + Returns + ======= + + set(int) : A set of factors of N without considering multiplicity. + Returns ``{N}`` if factorization fails. + + Examples + ======== + + >>> from sympy.ntheory import qs + >>> qs(25645121643901801, 2000, 10000) + {5394769, 4753701529} + >>> qs(9804659461513846513, 2000, 10000) + {4641991, 2112166839943} + + See Also + ======== + + qs_factor + + References + ========== + + .. [1] https://pdfs.semanticscholar.org/5c52/8a975c1405bd35c65993abf5a4edb667c1db.pdf + .. [2] https://www.rieselprime.de/ziki/Self-initializing_quadratic_sieve + """ + return set(qs_factor(N, prime_bound, M, ERROR_TERM, seed)) + + +def qs_factor(N, prime_bound, M, ERROR_TERM=25, seed=1234): + """ Performs factorization using Self-Initializing Quadratic Sieve. + + Parameters + ========== + + N : Number to be Factored + prime_bound : upper bound for primes in the factor base + M : Sieve Interval + ERROR_TERM : Error term for checking smoothness + seed : seed of random number generator + + Returns + ======= + + dict[int, int] : Factors of N. + Returns ``{N: 1}`` if factorization fails. + Note that the key is not always a prime number. + + Examples + ======== + + >>> from sympy.ntheory import qs_factor + >>> qs_factor(1009 * 100003, 2000, 10000) + {1009: 1, 100003: 1} + + See Also + ======== + + qs + + """ + if N < 2: + raise ValueError("N should be greater than 1") + factors = {} + smooth_relations = [] + partial_relations = {} + # Eliminate the possibility of even numbers, + # prime numbers, and perfect powers. + if N % 2 == 0: + e = 1 + N //= 2 + while N % 2 == 0: + N //= 2 + e += 1 + factors[2] = e + if isprime(N): + factors[N] = 1 + return factors + if result := _perfect_power(N, 3): + n, e = result + factors[n] = e + return factors + N_copy = N + randint = _randint(seed) + idx_1000, idx_5000, factor_base = _generate_factor_base(prime_bound, N) + threshold = len(factor_base) * 105//100 + for g in _generate_polynomial(N, M, factor_base, idx_1000, idx_5000, randint): + sieve_array = _gen_sieve_array(M, factor_base) + s_rel, p_f = _trial_division_stage(N, M, factor_base, sieve_array, g, partial_relations, ERROR_TERM) + smooth_relations += s_rel + for p in p_f: + if N_copy % p: + continue + e = 1 + N_copy //= p + while N_copy % p == 0: + N_copy //= p + e += 1 + factors[p] = e + if threshold <= len(smooth_relations): + break + + for factor in _find_factor(N, smooth_relations, len(factor_base) + 1): + if N_copy % factor == 0: + e = 1 + N_copy //= factor + while N_copy % factor == 0: + N_copy //= factor + e += 1 + factors[factor] = e + if N_copy == 1 or isprime(N_copy): + break + if N_copy != 1: + factors[N_copy] = 1 + return factors diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/residue_ntheory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/residue_ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..eba024161194605aabebd10ee30bf09acb90270b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/residue_ntheory.py @@ -0,0 +1,1963 @@ +from __future__ import annotations + +from sympy.external.gmpy import (gcd, lcm, invert, sqrt, jacobi, + bit_scan1, remove) +from sympy.polys import Poly +from sympy.polys.domains import ZZ +from sympy.polys.galoistools import gf_crt1, gf_crt2, linear_congruence, gf_csolve +from .primetest import isprime +from .generate import primerange +from .factor_ import factorint, _perfect_power +from .modular import crt +from sympy.utilities.decorator import deprecated +from sympy.utilities.memoization import recurrence_memo +from sympy.utilities.misc import as_int +from sympy.utilities.iterables import iproduct +from sympy.core.random import _randint, randint + +from itertools import product + + +def n_order(a, n): + r""" Returns the order of ``a`` modulo ``n``. + + Explanation + =========== + + The order of ``a`` modulo ``n`` is the smallest integer + ``k`` such that `a^k` leaves a remainder of 1 with ``n``. + + Parameters + ========== + + a : integer + n : integer, n > 1. a and n should be relatively prime + + Returns + ======= + + int : the order of ``a`` modulo ``n`` + + Raises + ====== + + ValueError + If `n \le 1` or `\gcd(a, n) \neq 1`. + If ``a`` or ``n`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory import n_order + >>> n_order(3, 7) + 6 + >>> n_order(4, 7) + 3 + + See Also + ======== + + is_primitive_root + We say that ``a`` is a primitive root of ``n`` + when the order of ``a`` modulo ``n`` equals ``totient(n)`` + + """ + a, n = as_int(a), as_int(n) + if n <= 1: + raise ValueError("n should be an integer greater than 1") + a = a % n + # Trivial + if a == 1: + return 1 + if gcd(a, n) != 1: + raise ValueError("The two numbers should be relatively prime") + a_order = 1 + for p, e in factorint(n).items(): + pe = p**e + pe_order = (p - 1) * p**(e - 1) + factors = factorint(p - 1) + if e > 1: + factors[p] = e - 1 + order = 1 + for px, ex in factors.items(): + x = pow(a, pe_order // px**ex, pe) + while x != 1: + x = pow(x, px, pe) + order *= px + a_order = lcm(a_order, order) + return int(a_order) + + +def _primitive_root_prime_iter(p): + r""" Generates the primitive roots for a prime ``p``. + + Explanation + =========== + + The primitive roots generated are not necessarily sorted. + However, the first one is the smallest primitive root. + + Find the element whose order is ``p-1`` from the smaller one. + If we can find the first primitive root ``g``, we can use the following theorem. + + .. math :: + \operatorname{ord}(g^k) = \frac{\operatorname{ord}(g)}{\gcd(\operatorname{ord}(g), k)} + + From the assumption that `\operatorname{ord}(g)=p-1`, + it is a necessary and sufficient condition for + `\operatorname{ord}(g^k)=p-1` that `\gcd(p-1, k)=1`. + + Parameters + ========== + + p : odd prime + + Yields + ====== + + int + the primitive roots of ``p`` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_iter + >>> sorted(_primitive_root_prime_iter(19)) + [2, 3, 10, 13, 14, 15] + + References + ========== + + .. [1] W. Stein "Elementary Number Theory" (2011), page 44 + + """ + if p == 3: + yield 2 + return + # Let p = +-1 (mod 4a). Legendre symbol (a/p) = 1, so `a` is not the primitive root. + # Corollary : If p = +-1 (mod 8), then 2 is not the primitive root of p. + g_min = 3 if p % 8 in [1, 7] else 2 + if p < 41: + # small case + g = 5 if p == 23 else g_min + else: + v = [(p - 1) // i for i in factorint(p - 1).keys()] + for g in range(g_min, p): + if all(pow(g, pw, p) != 1 for pw in v): + break + yield g + # g**k is the primitive root of p iff gcd(p - 1, k) = 1 + for k in range(3, p, 2): + if gcd(p - 1, k) == 1: + yield pow(g, k, p) + + +def _primitive_root_prime_power_iter(p, e): + r""" Generates the primitive roots of `p^e`. + + Explanation + =========== + + Let ``g`` be the primitive root of ``p``. + If `g^{p-1} \not\equiv 1 \pmod{p^2}`, then ``g`` is primitive root of `p^e`. + Thus, if we find a primitive root ``g`` of ``p``, + then `g, g+p, g+2p, \ldots, g+(p-1)p` are primitive roots of `p^2` except one. + That one satisfies `\hat{g}^{p-1} \equiv 1 \pmod{p^2}`. + If ``h`` is the primitive root of `p^2`, + then `h, h+p^2, h+2p^2, \ldots, h+(p^{e-2}-1)p^e` are primitive roots of `p^e`. + + Parameters + ========== + + p : odd prime + e : positive integer + + Yields + ====== + + int + the primitive roots of `p^e` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_power_iter + >>> sorted(_primitive_root_prime_power_iter(5, 2)) + [2, 3, 8, 12, 13, 17, 22, 23] + + """ + if e == 1: + yield from _primitive_root_prime_iter(p) + else: + p2 = p**2 + for g in _primitive_root_prime_iter(p): + t = (g - pow(g, 2 - p, p2)) % p2 + for k in range(0, p2, p): + if k != t: + yield from (g + k + m for m in range(0, p**e, p2)) + + +def _primitive_root_prime_power2_iter(p, e): + r""" Generates the primitive roots of `2p^e`. + + Explanation + =========== + + If ``g`` is the primitive root of ``p**e``, + then the odd one of ``g`` and ``g+p**e`` is the primitive root of ``2*p**e``. + + Parameters + ========== + + p : odd prime + e : positive integer + + Yields + ====== + + int + the primitive roots of `2p^e` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_power2_iter + >>> sorted(_primitive_root_prime_power2_iter(5, 2)) + [3, 13, 17, 23, 27, 33, 37, 47] + + """ + for g in _primitive_root_prime_power_iter(p, e): + if g % 2 == 1: + yield g + else: + yield g + p**e + + +def primitive_root(p, smallest=True): + r""" Returns a primitive root of ``p`` or None. + + Explanation + =========== + + For the definition of primitive root, + see the explanation of ``is_primitive_root``. + + The primitive root of ``p`` exist only for + `p = 2, 4, q^e, 2q^e` (``q`` is an odd prime). + Now, if we know the primitive root of ``q``, + we can calculate the primitive root of `q^e`, + and if we know the primitive root of `q^e`, + we can calculate the primitive root of `2q^e`. + When there is no need to find the smallest primitive root, + this property can be used to obtain a fast primitive root. + On the other hand, when we want the smallest primitive root, + we naively determine whether it is a primitive root or not. + + Parameters + ========== + + p : integer, p > 1 + smallest : if True the smallest primitive root is returned or None + + Returns + ======= + + int | None : + If the primitive root exists, return the primitive root of ``p``. + If not, return None. + + Raises + ====== + + ValueError + If `p \le 1` or ``p`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import primitive_root + >>> primitive_root(19) + 2 + >>> primitive_root(21) is None + True + >>> primitive_root(50, smallest=False) + 27 + + See Also + ======== + + is_primitive_root + + References + ========== + + .. [1] W. Stein "Elementary Number Theory" (2011), page 44 + .. [2] P. Hackman "Elementary Number Theory" (2009), Chapter C + + """ + p = as_int(p) + if p <= 1: + raise ValueError("p should be an integer greater than 1") + if p <= 4: + return p - 1 + p_even = p % 2 == 0 + if not p_even: + q = p # p is odd + elif p % 4: + q = p//2 # p had 1 factor of 2 + else: + return None # p had more than one factor of 2 + if isprime(q): + e = 1 + else: + m = _perfect_power(q, 3) + if not m: + return None + q, e = m + if not isprime(q): + return None + if not smallest: + if p_even: + return next(_primitive_root_prime_power2_iter(q, e)) + return next(_primitive_root_prime_power_iter(q, e)) + if p_even: + for i in range(3, p, 2): + if i % q and is_primitive_root(i, p): + return i + g = next(_primitive_root_prime_iter(q)) + if e == 1 or pow(g, q - 1, q**2) != 1: + return g + for i in range(g + 1, p): + if i % q and is_primitive_root(i, p): + return i + + +def is_primitive_root(a, p): + r""" Returns True if ``a`` is a primitive root of ``p``. + + Explanation + =========== + + ``a`` is said to be the primitive root of ``p`` if `\gcd(a, p) = 1` and + `\phi(p)` is the smallest positive number s.t. + + `a^{\phi(p)} \equiv 1 \pmod{p}`. + + where `\phi(p)` is Euler's totient function. + + The primitive root of ``p`` exist only for + `p = 2, 4, q^e, 2q^e` (``q`` is an odd prime). + Hence, if it is not such a ``p``, it returns False. + To determine the primitive root, we need to know + the prime factorization of ``q-1``. + The hardness of the determination depends on this complexity. + + Parameters + ========== + + a : integer + p : integer, ``p`` > 1. ``a`` and ``p`` should be relatively prime + + Returns + ======= + + bool : If True, ``a`` is the primitive root of ``p``. + + Raises + ====== + + ValueError + If `p \le 1` or `\gcd(a, p) \neq 1`. + If ``a`` or ``p`` is not an integer. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import totient + >>> from sympy.ntheory import is_primitive_root, n_order + >>> is_primitive_root(3, 10) + True + >>> is_primitive_root(9, 10) + False + >>> n_order(3, 10) == totient(10) + True + >>> n_order(9, 10) == totient(10) + False + + See Also + ======== + + primitive_root + + """ + a, p = as_int(a), as_int(p) + if p <= 1: + raise ValueError("p should be an integer greater than 1") + a = a % p + if gcd(a, p) != 1: + raise ValueError("The two numbers should be relatively prime") + # Primitive root of p exist only for + # p = 2, 4, q**e, 2*q**e (q is odd prime) + if p <= 4: + # The primitive root is only p-1. + return a == p - 1 + if p % 2: + q = p # p is odd + elif p % 4: + q = p//2 # p had 1 factor of 2 + else: + return False # p had more than one factor of 2 + if isprime(q): + group_order = q - 1 + factors = factorint(q - 1).keys() + else: + m = _perfect_power(q, 3) + if not m: + return False + q, e = m + if not isprime(q): + return False + group_order = q**(e - 1)*(q - 1) + factors = set(factorint(q - 1).keys()) + factors.add(q) + return all(pow(a, group_order // prime, p) != 1 for prime in factors) + + +def _sqrt_mod_tonelli_shanks(a, p): + """ + Returns the square root in the case of ``p`` prime with ``p == 1 (mod 8)`` + + Assume that the root exists. + + Parameters + ========== + + a : int + p : int + prime number. should be ``p % 8 == 1`` + + Returns + ======= + + int : Generally, there are two roots, but only one is returned. + Which one is returned is random. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _sqrt_mod_tonelli_shanks + >>> _sqrt_mod_tonelli_shanks(2, 17) in [6, 11] + True + + References + ========== + + .. [1] Carl Pomerance, Richard Crandall, Prime Numbers: A Computational Perspective, + 2nd Edition (2005), page 101, ISBN:978-0387252827 + + """ + s = bit_scan1(p - 1) + t = p >> s + # find a non-quadratic residue + if p % 12 == 5: + # Legendre symbol (3/p) == -1 if p % 12 in [5, 7] + d = 3 + elif p % 5 in [2, 3]: + # Legendre symbol (5/p) == -1 if p % 5 in [2, 3] + d = 5 + else: + while 1: + d = randint(6, p - 1) + if jacobi(d, p) == -1: + break + #assert legendre_symbol(d, p) == -1 + A = pow(a, t, p) + D = pow(d, t, p) + m = 0 + for i in range(s): + adm = A*pow(D, m, p) % p + adm = pow(adm, 2**(s - 1 - i), p) + if adm % p == p - 1: + m += 2**i + #assert A*pow(D, m, p) % p == 1 + x = pow(a, (t + 1)//2, p)*pow(D, m//2, p) % p + return x + + +def sqrt_mod(a, p, all_roots=False): + """ + Find a root of ``x**2 = a mod p``. + + Parameters + ========== + + a : integer + p : positive integer + all_roots : if True the list of roots is returned or None + + Notes + ===== + + If there is no root it is returned None; else the returned root + is less or equal to ``p // 2``; in general is not the smallest one. + It is returned ``p // 2`` only if it is the only root. + + Use ``all_roots`` only when it is expected that all the roots fit + in memory; otherwise use ``sqrt_mod_iter``. + + Examples + ======== + + >>> from sympy.ntheory import sqrt_mod + >>> sqrt_mod(11, 43) + 21 + >>> sqrt_mod(17, 32, True) + [7, 9, 23, 25] + """ + if all_roots: + return sorted(sqrt_mod_iter(a, p)) + p = abs(as_int(p)) + halfp = p // 2 + x = None + for r in sqrt_mod_iter(a, p): + if r < halfp: + return r + elif r > halfp: + return p - r + else: + x = r + return x + + +def sqrt_mod_iter(a, p, domain=int): + """ + Iterate over solutions to ``x**2 = a mod p``. + + Parameters + ========== + + a : integer + p : positive integer + domain : integer domain, ``int``, ``ZZ`` or ``Integer`` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import sqrt_mod_iter + >>> list(sqrt_mod_iter(11, 43)) + [21, 22] + + See Also + ======== + + sqrt_mod : Same functionality, but you want a sorted list or only one solution. + + """ + a, p = as_int(a), abs(as_int(p)) + v = [] + pv = [] + _product = product + for px, ex in factorint(p).items(): + if a % px: + # `len(rx)` is at most 4 + rx = _sqrt_mod_prime_power(a, px, ex) + else: + # `len(list(rx))` can be assumed to be large. + # The `itertools.product` is disadvantageous in terms of memory usage. + # It is also inferior to iproduct in speed if not all Cartesian products are needed. + rx = _sqrt_mod1(a, px, ex) + _product = iproduct + if not rx: + return + v.append(rx) + pv.append(px**ex) + if len(v) == 1: + yield from map(domain, v[0]) + else: + mm, e, s = gf_crt1(pv, ZZ) + for vx in _product(*v): + yield domain(gf_crt2(vx, pv, mm, e, s, ZZ)) + + +def _sqrt_mod_prime_power(a, p, k): + """ + Find the solutions to ``x**2 = a mod p**k`` when ``a % p != 0``. + If no solution exists, return ``None``. + Solutions are returned in an ascending list. + + Parameters + ========== + + a : integer + p : prime number + k : positive integer + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _sqrt_mod_prime_power + >>> _sqrt_mod_prime_power(11, 43, 1) + [21, 22] + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 160 + .. [2] http://www.numbertheory.org/php/squareroot.html + .. [3] [Gathen99]_ + """ + pk = p**k + a = a % pk + + if p == 2: + # see Ref.[2] + if a % 8 != 1: + return None + # Trivial + if k <= 3: + return list(range(1, pk, 2)) + r = 1 + # r is one of the solutions to x**2 - a = 0 (mod 2**3). + # Hensel lift them to solutions of x**2 - a = 0 (mod 2**k) + # if r**2 - a = 0 mod 2**nx but not mod 2**(nx+1) + # then r + 2**(nx - 1) is a root mod 2**(nx+1) + for nx in range(3, k): + if ((r**2 - a) >> nx) % 2: + r += 1 << (nx - 1) + # r is a solution of x**2 - a = 0 (mod 2**k), and + # there exist other solutions -r, r+h, -(r+h), and these are all solutions. + h = 1 << (k - 1) + return sorted([r, pk - r, (r + h) % pk, -(r + h) % pk]) + + # If the Legendre symbol (a/p) is not 1, no solution exists. + if jacobi(a, p) != 1: + return None + if p % 4 == 3: + res = pow(a, (p + 1) // 4, p) + elif p % 8 == 5: + res = pow(a, (p + 3) // 8, p) + if pow(res, 2, p) != a % p: + res = res * pow(2, (p - 1) // 4, p) % p + else: + res = _sqrt_mod_tonelli_shanks(a, p) + if k > 1: + # Hensel lifting with Newton iteration, see Ref.[3] chapter 9 + # with f(x) = x**2 - a; one has f'(a) != 0 (mod p) for p != 2 + px = p + for _ in range(k.bit_length() - 1): + px = px**2 + frinv = invert(2*res, px) + res = (res - (res**2 - a)*frinv) % px + if k & (k - 1): # If k is not a power of 2 + frinv = invert(2*res, pk) + res = (res - (res**2 - a)*frinv) % pk + return sorted([res, pk - res]) + + +def _sqrt_mod1(a, p, n): + """ + Find solution to ``x**2 == a mod p**n`` when ``a % p == 0``. + If no solution exists, return ``None``. + + Parameters + ========== + + a : integer + p : prime number, p must divide a + n : positive integer + + References + ========== + + .. [1] http://www.numbertheory.org/php/squareroot.html + """ + pn = p**n + a = a % pn + if a == 0: + # case gcd(a, p**k) = p**n + return range(0, pn, p**((n + 1) // 2)) + # case gcd(a, p**k) = p**r, r < n + a, r = remove(a, p) + if r % 2 == 1: + return None + res = _sqrt_mod_prime_power(a, p, n - r) + if res is None: + return None + m = r // 2 + return (x for rx in res for x in range(rx*p**m, pn, p**(n - m))) + + +def is_quad_residue(a, p): + """ + Returns True if ``a`` (mod ``p``) is in the set of squares mod ``p``, + i.e a % p in set([i**2 % p for i in range(p)]). + + Parameters + ========== + + a : integer + p : positive integer + + Returns + ======= + + bool : If True, ``x**2 == a (mod p)`` has solution. + + Raises + ====== + + ValueError + If ``a``, ``p`` is not integer. + If ``p`` is not positive. + + Examples + ======== + + >>> from sympy.ntheory import is_quad_residue + >>> is_quad_residue(21, 100) + True + + Indeed, ``pow(39, 2, 100)`` would be 21. + + >>> is_quad_residue(21, 120) + False + + That is, for any integer ``x``, ``pow(x, 2, 120)`` is not 21. + + If ``p`` is an odd + prime, an iterative method is used to make the determination: + + >>> from sympy.ntheory import is_quad_residue + >>> sorted(set([i**2 % 7 for i in range(7)])) + [0, 1, 2, 4] + >>> [j for j in range(7) if is_quad_residue(j, 7)] + [0, 1, 2, 4] + + See Also + ======== + + legendre_symbol, jacobi_symbol, sqrt_mod + """ + a, p = as_int(a), as_int(p) + if p < 1: + raise ValueError('p must be > 0') + a %= p + if a < 2 or p < 3: + return True + # Since we want to compute the Jacobi symbol, + # we separate p into the odd part and the rest. + t = bit_scan1(p) + if t: + # The existence of a solution to a power of 2 is determined + # using the logic of `p==2` in `_sqrt_mod_prime_power` and `_sqrt_mod1`. + a_ = a % (1 << t) + if a_: + r = bit_scan1(a_) + if r % 2 or (a_ >> r) & 6: + return False + p >>= t + a %= p + if a < 2 or p < 3: + return True + # If Jacobi symbol is -1 or p is prime, can be determined by Jacobi symbol only + j = jacobi(a, p) + if j == -1 or isprime(p): + return j == 1 + # Checks if `x**2 = a (mod p)` has a solution + for px, ex in factorint(p).items(): + if a % px: + if jacobi(a, px) != 1: + return False + else: + a_ = a % px**ex + if a_ == 0: + continue + a_, r = remove(a_, px) + if r % 2 or jacobi(a_, px) != 1: + return False + return True + + +def is_nthpow_residue(a, n, m): + """ + Returns True if ``x**n == a (mod m)`` has solutions. + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 76 + + """ + a = a % m + a, n, m = as_int(a), as_int(n), as_int(m) + if m <= 0: + raise ValueError('m must be > 0') + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + if m == 1: + return False + return a == 1 + if a == 0: + return True + if n == 1: + return True + if n == 2: + return is_quad_residue(a, m) + return all(_is_nthpow_residue_bign_prime_power(a, n, p, e) + for p, e in factorint(m).items()) + + +def _is_nthpow_residue_bign_prime_power(a, n, p, k): + r""" + Returns True if `x^n = a \pmod{p^k}` has solutions for `n > 2`. + + Parameters + ========== + + a : positive integer + n : integer, n > 2 + p : prime number + k : positive integer + + """ + while a % p == 0: + a %= pow(p, k) + if not a: + return True + a, mu = remove(a, p) + if mu % n: + return False + k -= mu + if p != 2: + f = p**(k - 1)*(p - 1) # f = totient(p**k) + return pow(a, f // gcd(f, n), pow(p, k)) == 1 + if n & 1: + return True + c = min(bit_scan1(n) + 2, k) + return a % pow(2, c) == 1 + + +def _nthroot_mod1(s, q, p, all_roots): + """ + Root of ``x**q = s mod p``, ``p`` prime and ``q`` divides ``p - 1``. + Assume that the root exists. + + Parameters + ========== + + s : integer + q : integer, n > 2. ``q`` divides ``p - 1``. + p : prime number + all_roots : if False returns the smallest root, else the list of roots + + Returns + ======= + + list[int] | int : + Root of ``x**q = s mod p``. If ``all_roots == True``, + returned ascending list. otherwise, returned an int. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _nthroot_mod1 + >>> _nthroot_mod1(5, 3, 13, False) + 7 + >>> _nthroot_mod1(13, 4, 17, True) + [3, 5, 12, 14] + + References + ========== + + .. [1] A. M. Johnston, A Generalized qth Root Algorithm, + ACM-SIAM Symposium on Discrete Algorithms (1999), pp. 929-930 + + """ + g = next(_primitive_root_prime_iter(p)) + r = s + for qx, ex in factorint(q).items(): + f = (p - 1) // qx**ex + while f % qx == 0: + f //= qx + z = f*invert(-f, qx) + x = (1 + z) // qx + t = discrete_log(p, pow(r, f, p), pow(g, f*qx, p)) + for _ in range(ex): + # assert t == discrete_log(p, pow(r, f, p), pow(g, f*qx, p)) + r = pow(r, x, p)*pow(g, -z*t % (p - 1), p) % p + t //= qx + res = [r] + h = pow(g, (p - 1) // q, p) + #assert pow(h, q, p) == 1 + hx = r + for _ in range(q - 1): + hx = (hx*h) % p + res.append(hx) + if all_roots: + res.sort() + return res + return min(res) + + +def _nthroot_mod_prime_power(a, n, p, k): + """ Root of ``x**n = a mod p**k``. + + Parameters + ========== + + a : integer + n : integer, n > 2 + p : prime number + k : positive integer + + Returns + ======= + + list[int] : + Ascending list of roots of ``x**n = a mod p**k``. + If no solution exists, return ``[]``. + + """ + if not _is_nthpow_residue_bign_prime_power(a, n, p, k): + return [] + a_mod_p = a % p + if a_mod_p == 0: + base_roots = [0] + elif (p - 1) % n == 0: + base_roots = _nthroot_mod1(a_mod_p, n, p, all_roots=True) + else: + # The roots of ``x**n - a = 0 (mod p)`` are roots of + # ``gcd(x**n - a, x**(p - 1) - 1) = 0 (mod p)`` + pa = n + pb = p - 1 + b = 1 + if pa < pb: + a_mod_p, pa, b, pb = b, pb, a_mod_p, pa + # gcd(x**pa - a, x**pb - b) = gcd(x**pb - b, x**pc - c) + # where pc = pa % pb; c = b**-q * a mod p + while pb: + q, pc = divmod(pa, pb) + c = pow(b, -q, p) * a_mod_p % p + pa, pb = pb, pc + a_mod_p, b = b, c + if pa == 1: + base_roots = [a_mod_p] + elif pa == 2: + base_roots = sqrt_mod(a_mod_p, p, all_roots=True) + else: + base_roots = _nthroot_mod1(a_mod_p, pa, p, all_roots=True) + if k == 1: + return base_roots + a %= p**k + tot_roots = set() + for root in base_roots: + diff = pow(root, n - 1, p)*n % p + new_base = p + if diff != 0: + m_inv = invert(diff, p) + for _ in range(k - 1): + new_base *= p + tmp = pow(root, n, new_base) - a + tmp *= m_inv + root = (root - tmp) % new_base + tot_roots.add(root) + else: + roots_in_base = {root} + for _ in range(k - 1): + new_base *= p + new_roots = set() + for k_ in roots_in_base: + if pow(k_, n, new_base) != a % new_base: + continue + while k_ not in new_roots: + new_roots.add(k_) + k_ = (k_ + (new_base // p)) % new_base + roots_in_base = new_roots + tot_roots = tot_roots | roots_in_base + return sorted(tot_roots) + + +def nthroot_mod(a, n, p, all_roots=False): + """ + Find the solutions to ``x**n = a mod p``. + + Parameters + ========== + + a : integer + n : positive integer + p : positive integer + all_roots : if False returns the smallest root, else the list of roots + + Returns + ======= + + list[int] | int | None : + solutions to ``x**n = a mod p``. + The table of the output type is: + + ========== ========== ========== + all_roots has roots Returns + ========== ========== ========== + True Yes list[int] + True No [] + False Yes int + False No None + ========== ========== ========== + + Raises + ====== + + ValueError + If ``a``, ``n`` or ``p`` is not integer. + If ``n`` or ``p`` is not positive. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import nthroot_mod + >>> nthroot_mod(11, 4, 19) + 8 + >>> nthroot_mod(11, 4, 19, True) + [8, 11] + >>> nthroot_mod(68, 3, 109) + 23 + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 76 + + """ + a = a % p + a, n, p = as_int(a), as_int(n), as_int(p) + + if n < 1: + raise ValueError("n should be positive") + if p < 1: + raise ValueError("p should be positive") + if n == 1: + return [a] if all_roots else a + if n == 2: + return sqrt_mod(a, p, all_roots) + base = [] + prime_power = [] + for q, e in factorint(p).items(): + tot_roots = _nthroot_mod_prime_power(a, n, q, e) + if not tot_roots: + return [] if all_roots else None + prime_power.append(q**e) + base.append(sorted(tot_roots)) + P, E, S = gf_crt1(prime_power, ZZ) + ret = sorted(map(int, {gf_crt2(c, prime_power, P, E, S, ZZ) + for c in product(*base)})) + if all_roots: + return ret + if ret: + return ret[0] + + +def quadratic_residues(p) -> list[int]: + """ + Returns the list of quadratic residues. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import quadratic_residues + >>> quadratic_residues(7) + [0, 1, 2, 4] + """ + p = as_int(p) + r = {pow(i, 2, p) for i in range(p // 2 + 1)} + return sorted(r) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.legendre_symbol` has been moved to `sympy.functions.combinatorial.numbers.legendre_symbol`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def legendre_symbol(a, p): + r""" + Returns the Legendre symbol `(a / p)`. + + .. deprecated:: 1.13 + + The ``legendre_symbol`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.legendre_symbol` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + For an integer ``a`` and an odd prime ``p``, the Legendre symbol is + defined as + + .. math :: + \genfrac(){}{}{a}{p} = \begin{cases} + 0 & \text{if } p \text{ divides } a\\ + 1 & \text{if } a \text{ is a quadratic residue modulo } p\\ + -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p + \end{cases} + + Parameters + ========== + + a : integer + p : odd prime + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import legendre_symbol + >>> [legendre_symbol(i, 7) for i in range(7)] + [0, 1, 1, -1, 1, -1, -1] + >>> sorted(set([i**2 % 7 for i in range(7)])) + [0, 1, 2, 4] + + See Also + ======== + + is_quad_residue, jacobi_symbol + + """ + from sympy.functions.combinatorial.numbers import legendre_symbol as _legendre_symbol + return _legendre_symbol(a, p) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.jacobi_symbol` has been moved to `sympy.functions.combinatorial.numbers.jacobi_symbol`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def jacobi_symbol(m, n): + r""" + Returns the Jacobi symbol `(m / n)`. + + .. deprecated:: 1.13 + + The ``jacobi_symbol`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.jacobi_symbol` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + For any integer ``m`` and any positive odd integer ``n`` the Jacobi symbol + is defined as the product of the Legendre symbols corresponding to the + prime factors of ``n``: + + .. math :: + \genfrac(){}{}{m}{n} = + \genfrac(){}{}{m}{p^{1}}^{\alpha_1} + \genfrac(){}{}{m}{p^{2}}^{\alpha_2} + ... + \genfrac(){}{}{m}{p^{k}}^{\alpha_k} + \text{ where } n = + p_1^{\alpha_1} + p_2^{\alpha_2} + ... + p_k^{\alpha_k} + + Like the Legendre symbol, if the Jacobi symbol `\genfrac(){}{}{m}{n} = -1` + then ``m`` is a quadratic nonresidue modulo ``n``. + + But, unlike the Legendre symbol, if the Jacobi symbol + `\genfrac(){}{}{m}{n} = 1` then ``m`` may or may not be a quadratic residue + modulo ``n``. + + Parameters + ========== + + m : integer + n : odd positive integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import jacobi_symbol, legendre_symbol + >>> from sympy import S + >>> jacobi_symbol(45, 77) + -1 + >>> jacobi_symbol(60, 121) + 1 + + The relationship between the ``jacobi_symbol`` and ``legendre_symbol`` can + be demonstrated as follows: + + >>> L = legendre_symbol + >>> S(45).factors() + {3: 2, 5: 1} + >>> jacobi_symbol(7, 45) == L(7, 3)**2 * L(7, 5)**1 + True + + See Also + ======== + + is_quad_residue, legendre_symbol + """ + from sympy.functions.combinatorial.numbers import jacobi_symbol as _jacobi_symbol + return _jacobi_symbol(m, n) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.mobius` has been moved to `sympy.functions.combinatorial.numbers.mobius`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def mobius(n): + """ + Mobius function maps natural number to {-1, 0, 1} + + .. deprecated:: 1.13 + + The ``mobius`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.mobius` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + It is defined as follows: + 1) `1` if `n = 1`. + 2) `0` if `n` has a squared prime factor. + 3) `(-1)^k` if `n` is a square-free positive integer with `k` + number of prime factors. + + It is an important multiplicative function in number theory + and combinatorics. It has applications in mathematical series, + algebraic number theory and also physics (Fermion operator has very + concrete realization with Mobius Function model). + + Parameters + ========== + + n : positive integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import mobius + >>> mobius(13*7) + 1 + >>> mobius(1) + 1 + >>> mobius(13*7*5) + -1 + >>> mobius(13**2) + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/M%C3%B6bius_function + .. [2] Thomas Koshy "Elementary Number Theory with Applications" + + """ + from sympy.functions.combinatorial.numbers import mobius as _mobius + return _mobius(n) + + +def _discrete_log_trial_mul(n, a, b, order=None): + """ + Trial multiplication algorithm for computing the discrete logarithm of + ``a`` to the base ``b`` modulo ``n``. + + The algorithm finds the discrete logarithm using exhaustive search. This + naive method is used as fallback algorithm of ``discrete_log`` when the + group order is very small. The value ``n`` must be greater than 1. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_trial_mul + >>> _discrete_log_trial_mul(41, 15, 7) + 3 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + if order is None: + order = n + x = 1 + for i in range(order): + if x == a: + return i + x = x * b % n + raise ValueError("Log does not exist") + + +def _discrete_log_shanks_steps(n, a, b, order=None): + """ + Baby-step giant-step algorithm for computing the discrete logarithm of + ``a`` to the base ``b`` modulo ``n``. + + The algorithm is a time-memory trade-off of the method of exhaustive + search. It uses `O(sqrt(m))` memory, where `m` is the group order. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_shanks_steps + >>> _discrete_log_shanks_steps(41, 15, 7) + 3 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + if order is None: + order = n_order(b, n) + m = sqrt(order) + 1 + T = {} + x = 1 + for i in range(m): + T[x] = i + x = x * b % n + z = pow(b, -m, n) + x = a + for i in range(m): + if x in T: + return i * m + T[x] + x = x * z % n + raise ValueError("Log does not exist") + + +def _discrete_log_pollard_rho(n, a, b, order=None, retries=10, rseed=None): + """ + Pollard's Rho algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + It is a randomized algorithm with the same expected running time as + ``_discrete_log_shanks_steps``, but requires a negligible amount of memory. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_pollard_rho + >>> _discrete_log_pollard_rho(227, 3**7, 3) + 7 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + + if order is None: + order = n_order(b, n) + randint = _randint(rseed) + + for i in range(retries): + aa = randint(1, order - 1) + ba = randint(1, order - 1) + xa = pow(b, aa, n) * pow(a, ba, n) % n + + c = xa % 3 + if c == 0: + xb = a * xa % n + ab = aa + bb = (ba + 1) % order + elif c == 1: + xb = xa * xa % n + ab = (aa + aa) % order + bb = (ba + ba) % order + else: + xb = b * xa % n + ab = (aa + 1) % order + bb = ba + + for j in range(order): + c = xa % 3 + if c == 0: + xa = a * xa % n + ba = (ba + 1) % order + elif c == 1: + xa = xa * xa % n + aa = (aa + aa) % order + ba = (ba + ba) % order + else: + xa = b * xa % n + aa = (aa + 1) % order + + c = xb % 3 + if c == 0: + xb = a * xb % n + bb = (bb + 1) % order + elif c == 1: + xb = xb * xb % n + ab = (ab + ab) % order + bb = (bb + bb) % order + else: + xb = b * xb % n + ab = (ab + 1) % order + + c = xb % 3 + if c == 0: + xb = a * xb % n + bb = (bb + 1) % order + elif c == 1: + xb = xb * xb % n + ab = (ab + ab) % order + bb = (bb + bb) % order + else: + xb = b * xb % n + ab = (ab + 1) % order + + if xa == xb: + r = (ba - bb) % order + try: + e = invert(r, order) * (ab - aa) % order + if (pow(b, e, n) - a) % n == 0: + return e + except ZeroDivisionError: + pass + break + raise ValueError("Pollard's Rho failed to find logarithm") + + +def _discrete_log_is_smooth(n: int, factorbase: list): + """Try to factor n with respect to a given factorbase. + Upon success a list of exponents with respect to the factorbase is returned. + Otherwise None.""" + factors = [0]*len(factorbase) + for i, p in enumerate(factorbase): + while n % p == 0: # divide by p as many times as possible + factors[i] += 1 + n = n // p + if n != 1: + return None # the number factors if at the end nothing is left + return factors + + +def _discrete_log_index_calculus(n, a, b, order, rseed=None): + """ + Index Calculus algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + The group order must be given and prime. It is not suitable for small orders + and the algorithm might fail to find a solution in such situations. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_index_calculus + >>> _discrete_log_index_calculus(24570203447, 23859756228, 2, 12285101723) + 4519867240 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + randint = _randint(rseed) + from math import sqrt, exp, log + a %= n + b %= n + # assert isprime(order), "The order of the base must be prime." + # First choose a heuristic the bound B for the factorbase. + # We have added an extra term to the asymptotic value which + # is closer to the theoretical optimum for n up to 2^70. + B = int(exp(0.5 * sqrt( log(n) * log(log(n)) )*( 1 + 1/log(log(n)) ))) + max = 5 * B * B # expected number of tries to find a relation + factorbase = list(primerange(B)) # compute the factorbase + lf = len(factorbase) # length of the factorbase + ordermo = order-1 + abx = a + for x in range(order): + if abx == 1: + return (order - x) % order + relationa = _discrete_log_is_smooth(abx, factorbase) + if relationa: + relationa = [r % order for r in relationa] + [x] + break + abx = abx * b % n # abx = a*pow(b, x, n) % n + + else: + raise ValueError("Index Calculus failed") + + relations = [None] * lf + k = 1 # number of relations found + kk = 0 + while k < 3 * lf and kk < max: # find relations for all primes in our factor base + x = randint(1,ordermo) + relation = _discrete_log_is_smooth(pow(b,x,n), factorbase) + if relation is None: + kk += 1 + continue + k += 1 + kk = 0 + relation += [ x ] + index = lf # determine the index of the first nonzero entry + for i in range(lf): + ri = relation[i] % order + if ri> 0 and relations[i] is not None: # make this entry zero if we can + for j in range(lf+1): + relation[j] = (relation[j] - ri*relations[i][j]) % order + else: + relation[i] = ri + if relation[i] > 0 and index == lf: # is this the index of the first nonzero entry? + index = i + if index == lf or relations[index] is not None: # the relation contains no new information + continue + # the relation contains new information + rinv = pow(relation[index],-1,order) # normalize the first nonzero entry + for j in range(index,lf+1): + relation[j] = rinv * relation[j] % order + relations[index] = relation + for i in range(lf): # subtract the new relation from the one for a + if relationa[i] > 0 and relations[i] is not None: + rbi = relationa[i] + for j in range(lf+1): + relationa[j] = (relationa[j] - rbi*relations[i][j]) % order + if relationa[i] > 0: # the index of the first nonzero entry + break # we do not need to reduce further at this point + else: # all unknowns are gone + #print(f"Success after {k} relations out of {lf}") + x = (order -relationa[lf]) % order + if pow(b,x,n) == a: + return x + raise ValueError("Index Calculus failed") + raise ValueError("Index Calculus failed") + + +def _discrete_log_pohlig_hellman(n, a, b, order=None, order_factors=None): + """ + Pohlig-Hellman algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + In order to compute the discrete logarithm, the algorithm takes advantage + of the factorization of the group order. It is more efficient when the + group order factors into many small primes. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_pohlig_hellman + >>> _discrete_log_pohlig_hellman(251, 210, 71) + 197 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + from .modular import crt + a %= n + b %= n + + if order is None: + order = n_order(b, n) + if order_factors is None: + order_factors = factorint(order) + l = [0] * len(order_factors) + + for i, (pi, ri) in enumerate(order_factors.items()): + for j in range(ri): + aj = pow(a * pow(b, -l[i], n), order // pi**(j + 1), n) + bj = pow(b, order // pi, n) + cj = discrete_log(n, aj, bj, pi, True) + l[i] += cj * pi**j + + d, _ = crt([pi**ri for pi, ri in order_factors.items()], l) + return d + + +def discrete_log(n, a, b, order=None, prime_order=None): + """ + Compute the discrete logarithm of ``a`` to the base ``b`` modulo ``n``. + + This is a recursive function to reduce the discrete logarithm problem in + cyclic groups of composite order to the problem in cyclic groups of prime + order. + + It employs different algorithms depending on the problem (subgroup order + size, prime order or not): + + * Trial multiplication + * Baby-step giant-step + * Pollard's Rho + * Index Calculus + * Pohlig-Hellman + + Examples + ======== + + >>> from sympy.ntheory import discrete_log + >>> discrete_log(41, 15, 7) + 3 + + References + ========== + + .. [1] https://mathworld.wolfram.com/DiscreteLogarithm.html + .. [2] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + + """ + from math import sqrt, log + n, a, b = as_int(n), as_int(a), as_int(b) + + if n < 1: + raise ValueError("n should be positive") + if n == 1: + return 0 + + if order is None: + # Compute the order and its factoring in one pass + # order = totient(n), factors = factorint(order) + factors = {} + for px, kx in factorint(n).items(): + if kx > 1: + if px in factors: + factors[px] += kx - 1 + else: + factors[px] = kx - 1 + for py, ky in factorint(px - 1).items(): + if py in factors: + factors[py] += ky + else: + factors[py] = ky + order = 1 + for px, kx in factors.items(): + order *= px**kx + # Now the `order` is the order of the group and factors = factorint(order) + # The order of `b` divides the order of the group. + order_factors = {} + for p, e in factors.items(): + i = 0 + for _ in range(e): + if pow(b, order // p, n) == 1: + order //= p + i += 1 + else: + break + if i < e: + order_factors[p] = e - i + + if prime_order is None: + prime_order = isprime(order) + + if order < 1000: + return _discrete_log_trial_mul(n, a, b, order) + elif prime_order: + # Shanks and Pollard rho are O(sqrt(order)) while index calculus is O(exp(2*sqrt(log(n)log(log(n))))) + # we compare the expected running times to determine the algorithm which is expected to be faster + if 4*sqrt(log(n)*log(log(n))) < log(order) - 10: # the number 10 was determined experimental + return _discrete_log_index_calculus(n, a, b, order) + elif order < 1000000000000: + # Shanks seems typically faster, but uses O(sqrt(order)) memory + return _discrete_log_shanks_steps(n, a, b, order) + return _discrete_log_pollard_rho(n, a, b, order) + + return _discrete_log_pohlig_hellman(n, a, b, order, order_factors) + + + +def quadratic_congruence(a, b, c, n): + r""" + Find the solutions to `a x^2 + b x + c \equiv 0 \pmod{n}`. + + Parameters + ========== + + a : int + b : int + c : int + n : int + A positive integer. + + Returns + ======= + + list[int] : + A sorted list of solutions. If no solution exists, ``[]``. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import quadratic_congruence + >>> quadratic_congruence(2, 5, 3, 7) # 2x^2 + 5x + 3 = 0 (mod 7) + [2, 6] + >>> quadratic_congruence(8, 6, 4, 15) # No solution + [] + + See Also + ======== + + polynomial_congruence : Solve the polynomial congruence + + """ + a = as_int(a) + b = as_int(b) + c = as_int(c) + n = as_int(n) + if n <= 1: + raise ValueError("n should be an integer greater than 1") + a %= n + b %= n + c %= n + + if a == 0: + return linear_congruence(b, -c, n) + if n == 2: + # assert a == 1 + roots = [] + if c == 0: + roots.append(0) + if (b + c) % 2: + roots.append(1) + return roots + if gcd(2*a, n) == 1: + inv_a = invert(a, n) + b *= inv_a + c *= inv_a + if b % 2: + b += n + b >>= 1 + return sorted((i - b) % n for i in sqrt_mod_iter(b**2 - c, n)) + res = set() + for i in sqrt_mod_iter(b**2 - 4*a*c, 4*a*n): + q, rem = divmod(i - b, 2*a) + if rem == 0: + res.add(q % n) + + return sorted(res) + + +def _valid_expr(expr): + """ + return coefficients of expr if it is a univariate polynomial + with integer coefficients else raise a ValueError. + """ + + if not expr.is_polynomial(): + raise ValueError("The expression should be a polynomial") + polynomial = Poly(expr) + if not polynomial.is_univariate: + raise ValueError("The expression should be univariate") + if not polynomial.domain == ZZ: + raise ValueError("The expression should should have integer coefficients") + return polynomial.all_coeffs() + + +def polynomial_congruence(expr, m): + """ + Find the solutions to a polynomial congruence equation modulo m. + + Parameters + ========== + + expr : integer coefficient polynomial + m : positive integer + + Examples + ======== + + >>> from sympy.ntheory import polynomial_congruence + >>> from sympy.abc import x + >>> expr = x**6 - 2*x**5 -35 + >>> polynomial_congruence(expr, 6125) + [3257] + + See Also + ======== + + sympy.polys.galoistools.gf_csolve : low level solving routine used by this routine + + """ + coefficients = _valid_expr(expr) + coefficients = [num % m for num in coefficients] + rank = len(coefficients) + if rank == 3: + return quadratic_congruence(*coefficients, m) + if rank == 2: + return quadratic_congruence(0, *coefficients, m) + if coefficients[0] == 1 and 1 + coefficients[-1] == sum(coefficients): + return nthroot_mod(-coefficients[-1], rank - 1, m, True) + return gf_csolve(coefficients, m) + + +def binomial_mod(n, m, k): + """Compute ``binomial(n, m) % k``. + + Explanation + =========== + + Returns ``binomial(n, m) % k`` using a generalization of Lucas' + Theorem for prime powers given by Granville [1]_, in conjunction with + the Chinese Remainder Theorem. The residue for each prime power + is calculated in time O(log^2(n) + q^4*log(n)log(p) + q^4*p*log^3(p)). + + Parameters + ========== + + n : an integer + m : an integer + k : a positive integer + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import binomial_mod + >>> binomial_mod(10, 2, 6) # binomial(10, 2) = 45 + 3 + >>> binomial_mod(17, 9, 10) # binomial(17, 9) = 24310 + 0 + + References + ========== + + .. [1] Binomial coefficients modulo prime powers, Andrew Granville, + Available: https://web.archive.org/web/20170202003812/http://www.dms.umontreal.ca/~andrew/PDF/BinCoeff.pdf + """ + if k < 1: raise ValueError('k is required to be positive') + # We decompose q into a product of prime powers and apply + # the generalization of Lucas' Theorem given by Granville + # to obtain binomial(n, k) mod p^e, and then use the Chinese + # Remainder Theorem to obtain the result mod q + if n < 0 or m < 0 or m > n: return 0 + factorisation = factorint(k) + residues = [_binomial_mod_prime_power(n, m, p, e) for p, e in factorisation.items()] + return crt([p**pw for p, pw in factorisation.items()], residues, check=False)[0] + + +def _binomial_mod_prime_power(n, m, p, q): + """Compute ``binomial(n, m) % p**q`` for a prime ``p``. + + Parameters + ========== + + n : positive integer + m : a nonnegative integer + p : a prime + q : a positive integer (the prime exponent) + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _binomial_mod_prime_power + >>> _binomial_mod_prime_power(10, 2, 3, 2) # binomial(10, 2) = 45 + 0 + >>> _binomial_mod_prime_power(17, 9, 2, 4) # binomial(17, 9) = 24310 + 6 + + References + ========== + + .. [1] Binomial coefficients modulo prime powers, Andrew Granville, + Available: https://web.archive.org/web/20170202003812/http://www.dms.umontreal.ca/~andrew/PDF/BinCoeff.pdf + """ + # Function/variable naming within this function follows Ref.[1] + # n!_p will be used to denote the product of integers <= n not divisible by + # p, with binomial(n, m)_p the same as binomial(n, m), but defined using + # n!_p in place of n! + modulo = pow(p, q) + + def up_factorial(u): + """Compute (u*p)!_p modulo p^q.""" + r = q // 2 + fac = prod = 1 + if r == 1 and p == 2 or 2*r + 1 in (p, p*p): + if q % 2 == 1: r += 1 + modulo, div = pow(p, 2*r), pow(p, 2*r - q) + else: + modulo, div = pow(p, 2*r + 1), pow(p, (2*r + 1) - q) + for j in range(1, r + 1): + for mul in range((j - 1)*p + 1, j*p): # ignore jp itself + fac *= mul + fac %= modulo + bj_ = bj(u, j, r) + prod *= pow(fac, bj_, modulo) + prod %= modulo + if p == 2: + sm = u // 2 + for j in range(1, r + 1): sm += j//2 * bj(u, j, r) + if sm % 2 == 1: prod *= -1 + prod %= modulo//div + return prod % modulo + + def bj(u, j, r): + """Compute the exponent of (j*p)!_p in the calculation of (u*p)!_p.""" + prod = u + for i in range(1, r + 1): + if i != j: prod *= u*u - i*i + for i in range(1, r + 1): + if i != j: prod //= j*j - i*i + return prod // j + + def up_plus_v_binom(u, v): + """Compute binomial(u*p + v, v)_p modulo p^q.""" + prod = 1 + div = invert(factorial(v), modulo) + for j in range(1, q): + b = div + for v_ in range(j*p + 1, j*p + v + 1): + b *= v_ + b %= modulo + aj = u + for i in range(1, q): + if i != j: aj *= u - i + for i in range(1, q): + if i != j: aj //= j - i + aj //= j + prod *= pow(b, aj, modulo) + prod %= modulo + return prod + + @recurrence_memo([1]) + def factorial(v, prev): + """Compute v! modulo p^q.""" + return v*prev[-1] % modulo + + def factorial_p(n): + """Compute n!_p modulo p^q.""" + u, v = divmod(n, p) + return (factorial(v) * up_factorial(u) * up_plus_v_binom(u, v)) % modulo + + prod = 1 + Nj, Mj, Rj = n, m, n - m + # e0 will be the p-adic valuation of binomial(n, m) at p + e0 = carry = eq_1 = j = 0 + while Nj: + numerator = factorial_p(Nj % modulo) + denominator = factorial_p(Mj % modulo) * factorial_p(Rj % modulo) % modulo + Nj, (Mj, mj), (Rj, rj) = Nj//p, divmod(Mj, p), divmod(Rj, p) + carry = (mj + rj + carry) // p + e0 += carry + if j >= q - 1: eq_1 += carry + prod *= numerator * invert(denominator, modulo) + prod %= modulo + j += 1 + + mul = pow(1 if p == 2 and q >= 3 else -1, eq_1, modulo) + return (pow(p, e0, modulo) * mul * prod) % modulo diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b370dc5b3ca5feb06ffc04d1e69867cac1b2bf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb2353676162524b73ad2ac6bf15518fb2c594c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e3705d49dc760d7f308b0c86753adf3e143ef59 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ceceebf0adce615763722d310c9327398a3fed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d97d0ef5532b797570fac4cc54097bdff6011adc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e85e18c2acec9c2bdc1bb61f21182af246b470 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bd2f248b2a259c9931311f0b756666e8641c9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ee9cf04187030f7558cdbc25ecc851320b84bca Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b37de078f4a0c59ebd76ac96549b586e9f3da68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ef0af7f6797750431b92e78c6acc9d102ef0cad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b1ecea0a548070c7e87966ef23fd8736d577227 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d1b727df390f6de0647d58f93939a522003cb7e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e16abf821fd7a0eb849e242e5b5bb0671c9303d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df133704669bda96098c195938d04d7212a48996 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..665e19a86715a8ca25fa9470710238a8a6008381 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7120d701f057534a1d262d78c059c01db81bd0ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_bbp_pi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_bbp_pi.py new file mode 100644 index 0000000000000000000000000000000000000000..69c24970239cc45eef4140bf19dfd7d4f6a7e150 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_bbp_pi.py @@ -0,0 +1,134 @@ +from sympy.core.random import randint + +from sympy.ntheory.bbp_pi import pi_hex_digits +from sympy.testing.pytest import raises + + +# http://www.herongyang.com/Cryptography/Blowfish-First-8366-Hex-Digits-of-PI.html +# There are actually 8336 listed there; with the prepended 3 there are 8337 +# below +dig=''.join(''' +3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e6c89452821e638d013 +77be5466cf34e90c6cc0ac29b7c97c50dd3f84d5b5b54709179216d5d98979fb1bd1310ba698dfb5 +ac2ffd72dbd01adfb7b8e1afed6a267e96ba7c9045f12c7f9924a19947b3916cf70801f2e2858efc +16636920d871574e69a458fea3f4933d7e0d95748f728eb658718bcd5882154aee7b54a41dc25a59 +b59c30d5392af26013c5d1b023286085f0ca417918b8db38ef8e79dcb0603a180e6c9e0e8bb01e8a +3ed71577c1bd314b2778af2fda55605c60e65525f3aa55ab945748986263e8144055ca396a2aab10 +b6b4cc5c341141e8cea15486af7c72e993b3ee1411636fbc2a2ba9c55d741831f6ce5c3e169b8793 +1eafd6ba336c24cf5c7a325381289586773b8f48986b4bb9afc4bfe81b6628219361d809ccfb21a9 +91487cac605dec8032ef845d5de98575b1dc262302eb651b8823893e81d396acc50f6d6ff383f442 +392e0b4482a484200469c8f04a9e1f9b5e21c66842f6e96c9a670c9c61abd388f06a51a0d2d8542f +68960fa728ab5133a36eef0b6c137a3be4ba3bf0507efb2a98a1f1651d39af017666ca593e82430e +888cee8619456f9fb47d84a5c33b8b5ebee06f75d885c12073401a449f56c16aa64ed3aa62363f77 +061bfedf72429b023d37d0d724d00a1248db0fead349f1c09b075372c980991b7b25d479d8f6e8de +f7e3fe501ab6794c3b976ce0bd04c006bac1a94fb6409f60c45e5c9ec2196a246368fb6faf3e6c53 +b51339b2eb3b52ec6f6dfc511f9b30952ccc814544af5ebd09bee3d004de334afd660f2807192e4b +b3c0cba85745c8740fd20b5f39b9d3fbdb5579c0bd1a60320ad6a100c6402c7279679f25fefb1fa3 +cc8ea5e9f8db3222f83c7516dffd616b152f501ec8ad0552ab323db5fafd23876053317b483e00df +829e5c57bbca6f8ca01a87562edf1769dbd542a8f6287effc3ac6732c68c4f5573695b27b0bbca58 +c8e1ffa35db8f011a010fa3d98fd2183b84afcb56c2dd1d35b9a53e479b6f84565d28e49bc4bfb97 +90e1ddf2daa4cb7e3362fb1341cee4c6e8ef20cada36774c01d07e9efe2bf11fb495dbda4dae9091 +98eaad8e716b93d5a0d08ed1d0afc725e08e3c5b2f8e7594b78ff6e2fbf2122b648888b812900df0 +1c4fad5ea0688fc31cd1cff191b3a8c1ad2f2f2218be0e1777ea752dfe8b021fa1e5a0cc0fb56f74 +e818acf3d6ce89e299b4a84fe0fd13e0b77cc43b81d2ada8d9165fa2668095770593cc7314211a14 +77e6ad206577b5fa86c75442f5fb9d35cfebcdaf0c7b3e89a0d6411bd3ae1e7e4900250e2d2071b3 +5e226800bb57b8e0af2464369bf009b91e5563911d59dfa6aa78c14389d95a537f207d5ba202e5b9 +c5832603766295cfa911c819684e734a41b3472dca7b14a94a1b5100529a532915d60f573fbc9bc6 +e42b60a47681e6740008ba6fb5571be91ff296ec6b2a0dd915b6636521e7b9f9b6ff34052ec58556 +6453b02d5da99f8fa108ba47996e85076a4b7a70e9b5b32944db75092ec4192623ad6ea6b049a7df +7d9cee60b88fedb266ecaa8c71699a17ff5664526cc2b19ee1193602a575094c29a0591340e4183a +3e3f54989a5b429d656b8fe4d699f73fd6a1d29c07efe830f54d2d38e6f0255dc14cdd20868470eb +266382e9c6021ecc5e09686b3f3ebaefc93c9718146b6a70a1687f358452a0e286b79c5305aa5007 +373e07841c7fdeae5c8e7d44ec5716f2b8b03ada37f0500c0df01c1f040200b3ffae0cf51a3cb574 +b225837a58dc0921bdd19113f97ca92ff69432477322f547013ae5e58137c2dadcc8b576349af3dd +a7a94461460fd0030eecc8c73ea4751e41e238cd993bea0e2f3280bba1183eb3314e548b384f6db9 +086f420d03f60a04bf2cb8129024977c795679b072bcaf89afde9a771fd9930810b38bae12dccf3f +2e5512721f2e6b7124501adde69f84cd877a5847187408da17bc9f9abce94b7d8cec7aec3adb851d +fa63094366c464c3d2ef1c18473215d908dd433b3724c2ba1612a14d432a65c45150940002133ae4 +dd71dff89e10314e5581ac77d65f11199b043556f1d7a3c76b3c11183b5924a509f28fe6ed97f1fb +fa9ebabf2c1e153c6e86e34570eae96fb1860e5e0a5a3e2ab3771fe71c4e3d06fa2965dcb999e71d +0f803e89d65266c8252e4cc9789c10b36ac6150eba94e2ea78a5fc3c531e0a2df4f2f74ea7361d2b +3d1939260f19c279605223a708f71312b6ebadfe6eeac31f66e3bc4595a67bc883b17f37d1018cff +28c332ddefbe6c5aa56558218568ab9802eecea50fdb2f953b2aef7dad5b6e2f841521b628290761 +70ecdd4775619f151013cca830eb61bd960334fe1eaa0363cfb5735c904c70a239d59e9e0bcbaade +14eecc86bc60622ca79cab5cabb2f3846e648b1eaf19bdf0caa02369b9655abb5040685a323c2ab4 +b3319ee9d5c021b8f79b540b19875fa09995f7997e623d7da8f837889a97e32d7711ed935f166812 +810e358829c7e61fd696dedfa17858ba9957f584a51b2272639b83c3ff1ac24696cdb30aeb532e30 +548fd948e46dbc312858ebf2ef34c6ffeafe28ed61ee7c3c735d4a14d9e864b7e342105d14203e13 +e045eee2b6a3aaabeadb6c4f15facb4fd0c742f442ef6abbb5654f3b1d41cd2105d81e799e86854d +c7e44b476a3d816250cf62a1f25b8d2646fc8883a0c1c7b6a37f1524c369cb749247848a0b5692b2 +85095bbf00ad19489d1462b17423820e0058428d2a0c55f5ea1dadf43e233f70613372f0928d937e +41d65fecf16c223bdb7cde3759cbee74604085f2a7ce77326ea607808419f8509ee8efd85561d997 +35a969a7aac50c06c25a04abfc800bcadc9e447a2ec3453484fdd567050e1e9ec9db73dbd3105588 +cd675fda79e3674340c5c43465713e38d83d28f89ef16dff20153e21e78fb03d4ae6e39f2bdb83ad +f7e93d5a68948140f7f64c261c94692934411520f77602d4f7bcf46b2ed4a20068d40824713320f4 +6a43b7d4b7500061af1e39f62e9724454614214f74bf8b88404d95fc1d96b591af70f4ddd366a02f +45bfbc09ec03bd97857fac6dd031cb850496eb27b355fd3941da2547e6abca0a9a28507825530429 +f40a2c86dae9b66dfb68dc1462d7486900680ec0a427a18dee4f3ffea2e887ad8cb58ce0067af4d6 +b6aace1e7cd3375fecce78a399406b2a4220fe9e35d9f385b9ee39d7ab3b124e8b1dc9faf74b6d18 +5626a36631eae397b23a6efa74dd5b43326841e7f7ca7820fbfb0af54ed8feb397454056acba4895 +2755533a3a20838d87fe6ba9b7d096954b55a867bca1159a58cca9296399e1db33a62a4a563f3125 +f95ef47e1c9029317cfdf8e80204272f7080bb155c05282ce395c11548e4c66d2248c1133fc70f86 +dc07f9c9ee41041f0f404779a45d886e17325f51ebd59bc0d1f2bcc18f41113564257b7834602a9c +60dff8e8a31f636c1b0e12b4c202e1329eaf664fd1cad181156b2395e0333e92e13b240b62eebeb9 +2285b2a20ee6ba0d99de720c8c2da2f728d012784595b794fd647d0862e7ccf5f05449a36f877d48 +fac39dfd27f33e8d1e0a476341992eff743a6f6eabf4f8fd37a812dc60a1ebddf8991be14cdb6e6b +0dc67b55106d672c372765d43bdcd0e804f1290dc7cc00ffa3b5390f92690fed0b667b9ffbcedb7d +9ca091cf0bd9155ea3bb132f88515bad247b9479bf763bd6eb37392eb3cc1159798026e297f42e31 +2d6842ada7c66a2b3b12754ccc782ef11c6a124237b79251e706a1bbe64bfb63501a6b101811caed +fa3d25bdd8e2e1c3c9444216590a121386d90cec6ed5abea2a64af674eda86a85fbebfe98864e4c3 +fe9dbc8057f0f7c08660787bf86003604dd1fd8346f6381fb07745ae04d736fccc83426b33f01eab +71b08041873c005e5f77a057bebde8ae2455464299bf582e614e58f48ff2ddfda2f474ef388789bd +c25366f9c3c8b38e74b475f25546fcd9b97aeb26618b1ddf84846a0e79915f95e2466e598e20b457 +708cd55591c902de4cb90bace1bb8205d011a862487574a99eb77f19b6e0a9dc09662d09a1c43246 +33e85a1f0209f0be8c4a99a0251d6efe101ab93d1d0ba5a4dfa186f20f2868f169dcb7da83573906 +fea1e2ce9b4fcd7f5250115e01a70683faa002b5c40de6d0279af88c27773f8641c3604c0661a806 +b5f0177a28c0f586e0006058aa30dc7d6211e69ed72338ea6353c2dd94c2c21634bbcbee5690bcb6 +deebfc7da1ce591d766f05e4094b7c018839720a3d7c927c2486e3725f724d9db91ac15bb4d39eb8 +fced54557808fca5b5d83d7cd34dad0fc41e50ef5eb161e6f8a28514d96c51133c6fd5c7e756e14e +c4362abfceddc6c837d79a323492638212670efa8e406000e03a39ce37d3faf5cfabc277375ac52d +1b5cb0679e4fa33742d382274099bc9bbed5118e9dbf0f7315d62d1c7ec700c47bb78c1b6b21a190 +45b26eb1be6a366eb45748ab2fbc946e79c6a376d26549c2c8530ff8ee468dde7dd5730a1d4cd04d +c62939bbdba9ba4650ac9526e8be5ee304a1fad5f06a2d519a63ef8ce29a86ee22c089c2b843242e +f6a51e03aa9cf2d0a483c061ba9be96a4d8fe51550ba645bd62826a2f9a73a3ae14ba99586ef5562 +e9c72fefd3f752f7da3f046f6977fa0a5980e4a91587b086019b09e6ad3b3ee593e990fd5a9e34d7 +972cf0b7d9022b8b5196d5ac3a017da67dd1cf3ed67c7d2d281f9f25cfadf2b89b5ad6b4725a88f5 +4ce029ac71e019a5e647b0acfded93fa9be8d3c48d283b57ccf8d5662979132e28785f0191ed7560 +55f7960e44e3d35e8c15056dd488f46dba03a161250564f0bdc3eb9e153c9057a297271aeca93a07 +2a1b3f6d9b1e6321f5f59c66fb26dcf3197533d928b155fdf5035634828aba3cbb28517711c20ad9 +f8abcc5167ccad925f4de817513830dc8e379d58629320f991ea7a90c2fb3e7bce5121ce64774fbe +32a8b6e37ec3293d4648de53696413e680a2ae0810dd6db22469852dfd09072166b39a460a6445c0 +dd586cdecf1c20c8ae5bbef7dd1b588d40ccd2017f6bb4e3bbdda26a7e3a59ff453e350a44bcb4cd +d572eacea8fa6484bb8d6612aebf3c6f47d29be463542f5d9eaec2771bf64e6370740e0d8de75b13 +57f8721671af537d5d4040cb084eb4e2cc34d2466a0115af84e1b0042895983a1d06b89fb4ce6ea0 +486f3f3b823520ab82011a1d4b277227f8611560b1e7933fdcbb3a792b344525bda08839e151ce79 +4b2f32c9b7a01fbac9e01cc87ebcc7d1f6cf0111c3a1e8aac71a908749d44fbd9ad0dadecbd50ada +380339c32ac69136678df9317ce0b12b4ff79e59b743f5bb3af2d519ff27d9459cbf97222c15e6fc +2a0f91fc719b941525fae59361ceb69cebc2a8645912baa8d1b6c1075ee3056a0c10d25065cb03a4 +42e0ec6e0e1698db3b4c98a0be3278e9649f1f9532e0d392dfd3a0342b8971f21e1b0a74414ba334 +8cc5be7120c37632d8df359f8d9b992f2ee60b6f470fe3f11de54cda541edad891ce6279cfcd3e7e +6f1618b166fd2c1d05848fd2c5f6fb2299f523f357a632762393a8353156cccd02acf081625a75eb +b56e16369788d273ccde96629281b949d04c50901b71c65614e6c6c7bd327a140a45e1d006c3f27b +9ac9aa53fd62a80f00bb25bfe235bdd2f671126905b2040222b6cbcf7ccd769c2b53113ec01640e3 +d338abbd602547adf0ba38209cf746ce7677afa1c52075606085cbfe4e8ae88dd87aaaf9b04cf9aa +7e1948c25c02fb8a8c01c36ae4d6ebe1f990d4f869a65cdea03f09252dc208e69fb74e6132ce77e2 +5b578fdfe33ac372e6'''.split()) + + +def test_hex_pi_nth_digits(): + assert pi_hex_digits(0) == '3243f6a8885a30' + assert pi_hex_digits(1) == '243f6a8885a308' + assert pi_hex_digits(10000) == '68ac8fcfb8016c' + assert pi_hex_digits(13) == '08d313198a2e03' + assert pi_hex_digits(0, 3) == '324' + assert pi_hex_digits(0, 0) == '' + raises(ValueError, lambda: pi_hex_digits(-1)) + raises(ValueError, lambda: pi_hex_digits(0, -1)) + raises(ValueError, lambda: pi_hex_digits(3.14)) + + # this will pick a random segment to compute every time + # it is run. If it ever fails, there is an error in the + # computation. + n = randint(0, len(dig)) + prec = randint(0, len(dig) - n) + assert pi_hex_digits(n, prec) == dig[n: n + prec] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_continued_fraction.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_continued_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca6088507f1d112e9146cd5249b1143f375c2cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_continued_fraction.py @@ -0,0 +1,77 @@ +import itertools +from sympy.core import GoldenRatio as phi +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.continued_fraction import \ + (continued_fraction_periodic as cf_p, + continued_fraction_iterator as cf_i, + continued_fraction_convergents as cf_c, + continued_fraction_reduce as cf_r, + continued_fraction as cf) +from sympy.testing.pytest import raises + + +def test_continued_fraction(): + assert cf_p(1, 1, 10, 0) == cf_p(1, 1, 0, 1) + assert cf_p(1, -1, 10, 1) == cf_p(-1, 1, 10, -1) + t = sqrt(2) + assert cf((1 + t)*(1 - t)) == cf(-1) + for n in [0, 2, Rational(2, 3), sqrt(2), 3*sqrt(2), 1 + 2*sqrt(3)/5, + (2 - 3*sqrt(5))/7, 1 + sqrt(2), (-5 + sqrt(17))/4]: + assert (cf_r(cf(n)) - n).expand() == 0 + assert (cf_r(cf(-n)) + n).expand() == 0 + raises(ValueError, lambda: cf(sqrt(2 + sqrt(3)))) + raises(ValueError, lambda: cf(sqrt(2) + sqrt(3))) + raises(ValueError, lambda: cf(pi)) + raises(ValueError, lambda: cf(.1)) + + raises(ValueError, lambda: cf_p(1, 0, 0)) + raises(ValueError, lambda: cf_p(1, 1, -1)) + assert cf_p(4, 3, 0) == [1, 3] + assert cf_p(0, 3, 5) == [0, 1, [2, 1, 12, 1, 2, 2]] + assert cf_p(1, 1, 0) == [1] + assert cf_p(3, 4, 0) == [0, 1, 3] + assert cf_p(4, 5, 0) == [0, 1, 4] + assert cf_p(5, 6, 0) == [0, 1, 5] + assert cf_p(11, 13, 0) == [0, 1, 5, 2] + assert cf_p(16, 19, 0) == [0, 1, 5, 3] + assert cf_p(27, 32, 0) == [0, 1, 5, 2, 2] + assert cf_p(1, 2, 5) == [[1]] + assert cf_p(0, 1, 2) == [1, [2]] + assert cf_p(6, 7, 49) == [1, 1, 6] + assert cf_p(3796, 1387, 0) == [2, 1, 2, 1, 4] + assert cf_p(3245, 10000) == [0, 3, 12, 4, 13] + assert cf_p(1932, 2568) == [0, 1, 3, 26, 2] + assert cf_p(6589, 2569) == [2, 1, 1, 3, 2, 1, 3, 1, 23] + + def take(iterator, n=7): + return list(itertools.islice(iterator, n)) + + assert take(cf_i(phi)) == [1, 1, 1, 1, 1, 1, 1] + assert take(cf_i(pi)) == [3, 7, 15, 1, 292, 1, 1] + + assert list(cf_i(Rational(17, 12))) == [1, 2, 2, 2] + assert list(cf_i(Rational(-17, 12))) == [-2, 1, 1, 2, 2] + + assert list(cf_c([1, 6, 1, 8])) == [S.One, Rational(7, 6), Rational(8, 7), Rational(71, 62)] + assert list(cf_c([2])) == [S(2)] + assert list(cf_c([1, 1, 1, 1, 1, 1, 1])) == [S.One, S(2), Rational(3, 2), Rational(5, 3), + Rational(8, 5), Rational(13, 8), Rational(21, 13)] + assert list(cf_c([1, 6, Rational(-1, 2), 4])) == [S.One, Rational(7, 6), Rational(5, 4), Rational(3, 2)] + assert take(cf_c([[1]])) == [S.One, S(2), Rational(3, 2), Rational(5, 3), Rational(8, 5), + Rational(13, 8), Rational(21, 13)] + assert take(cf_c([1, [1, 2]])) == [S.One, S(2), Rational(5, 3), Rational(7, 4), Rational(19, 11), + Rational(26, 15), Rational(71, 41)] + + cf_iter_e = (2 if i == 1 else i // 3 * 2 if i % 3 == 0 else 1 for i in itertools.count(1)) + assert take(cf_c(cf_iter_e)) == [S(2), S(3), Rational(8, 3), Rational(11, 4), Rational(19, 7), + Rational(87, 32), Rational(106, 39)] + + assert cf_r([1, 6, 1, 8]) == Rational(71, 62) + assert cf_r([3]) == S(3) + assert cf_r([-1, 5, 1, 4]) == Rational(-24, 29) + assert (cf_r([0, 1, 1, 7, [24, 8]]) - (sqrt(3) + 2)/7).expand() == 0 + assert cf_r([1, 5, 9]) == Rational(55, 46) + assert (cf_r([[1]]) - (sqrt(5) + 1)/2).expand() == 0 + assert cf_r([-3, 1, 1, [2]]) == -1 - sqrt(2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_digits.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_digits.py new file mode 100644 index 0000000000000000000000000000000000000000..4284805f4ffe5b9095eacb2e83f2cd8076db3ee4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_digits.py @@ -0,0 +1,55 @@ +from sympy.ntheory import count_digits, digits, is_palindromic +from sympy.core.intfunc import num_digits + +from sympy.testing.pytest import raises + + +def test_num_digits(): + # depending on whether one rounds up or down or uses log or log10, + # one or more of these will fail if you don't check for the off-by + # one condition + assert num_digits(2, 2) == 2 + assert num_digits(2**48 - 1, 2) == 48 + assert num_digits(1000, 10) == 4 + assert num_digits(125, 5) == 4 + assert num_digits(100, 16) == 2 + assert num_digits(-1000, 10) == 4 + # if changes are made to the function, this structured test over + # this range will expose problems + for base in range(2, 100): + for e in range(1, 100): + n = base**e + assert num_digits(n, base) == e + 1 + assert num_digits(n + 1, base) == e + 1 + assert num_digits(n - 1, base) == e + + +def test_digits(): + assert all(digits(n, 2)[1:] == [int(d) for d in format(n, 'b')] + for n in range(20)) + assert all(digits(n, 8)[1:] == [int(d) for d in format(n, 'o')] + for n in range(20)) + assert all(digits(n, 16)[1:] == [int(d, 16) for d in format(n, 'x')] + for n in range(20)) + assert digits(2345, 34) == [34, 2, 0, 33] + assert digits(384753, 71) == [71, 1, 5, 23, 4] + assert digits(93409, 10) == [10, 9, 3, 4, 0, 9] + assert digits(-92838, 11) == [-11, 6, 3, 8, 2, 9] + assert digits(35, 10) == [10, 3, 5] + assert digits(35, 10, 3) == [10, 0, 3, 5] + assert digits(-35, 10, 4) == [-10, 0, 0, 3, 5] + raises(ValueError, lambda: digits(2, 2, 1)) + + +def test_count_digits(): + assert count_digits(55, 2) == {1: 5, 0: 1} + assert count_digits(55, 10) == {5: 2} + n = count_digits(123) + assert n[4] == 0 and type(n[4]) is int + + +def test_is_palindromic(): + assert is_palindromic(-11) + assert is_palindromic(11) + assert is_palindromic(0o121, 8) + assert not is_palindromic(123) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_ecm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_ecm.py new file mode 100644 index 0000000000000000000000000000000000000000..7f134e4e1cf68231e9f89242d2b8476b9edeabb8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_ecm.py @@ -0,0 +1,63 @@ +from sympy.external.gmpy import invert +from sympy.ntheory.ecm import ecm, Point +from sympy.testing.pytest import slow + +@slow +def test_ecm(): + assert ecm(3146531246531241245132451321) == {3, 100327907731, 10454157497791297} + assert ecm(46167045131415113) == {43, 2634823, 407485517} + assert ecm(631211032315670776841) == {9312934919, 67777885039} + assert ecm(398883434337287) == {99476569, 4009823} + assert ecm(64211816600515193) == {281719, 359641, 633767} + assert ecm(4269021180054189416198169786894227) == {184039, 241603, 333331, 477973, 618619, 974123} + assert ecm(4516511326451341281684513) == {3, 39869, 131743543, 95542348571} + assert ecm(4132846513818654136451) == {47, 160343, 2802377, 195692803} + assert ecm(168541512131094651323) == {79, 113, 11011069, 1714635721} + #This takes ~10secs while factorint is not able to factorize this even in ~10mins + assert ecm(7060005655815754299976961394452809, B1=100000, B2=1000000) == {6988699669998001, 1010203040506070809} + + +def test_Point(): + #The curve is of the form y**2 = x**3 + a*x**2 + x + mod = 101 + a = 10 + a_24 = (a + 2)*invert(4, mod) + p1 = Point(10, 17, a_24, mod) + p2 = p1.double() + assert p2 == Point(68, 56, a_24, mod) + p4 = p2.double() + assert p4 == Point(22, 64, a_24, mod) + p8 = p4.double() + assert p8 == Point(71, 95, a_24, mod) + p16 = p8.double() + assert p16 == Point(5, 16, a_24, mod) + p32 = p16.double() + assert p32 == Point(33, 96, a_24, mod) + + # p3 = p2 + p1 + p3 = p2.add(p1, p1) + assert p3 == Point(1, 61, a_24, mod) + # p5 = p3 + p2 or p4 + p1 + p5 = p3.add(p2, p1) + assert p5 == Point(49, 90, a_24, mod) + assert p5 == p4.add(p1, p3) + # p6 = 2*p3 + p6 = p3.double() + assert p6 == Point(87, 43, a_24, mod) + assert p6 == p4.add(p2, p2) + # p7 = p5 + p2 + p7 = p5.add(p2, p3) + assert p7 == Point(69, 23, a_24, mod) + assert p7 == p4.add(p3, p1) + assert p7 == p6.add(p1, p5) + # p9 = p5 + p4 + p9 = p5.add(p4, p1) + assert p9 == Point(56, 99, a_24, mod) + assert p9 == p6.add(p3, p3) + assert p9 == p7.add(p2, p5) + assert p9 == p8.add(p1, p7) + + assert p5 == p1.mont_ladder(5) + assert p9 == p1.mont_ladder(9) + assert p16 == p1.mont_ladder(16) + assert p9 == p3.mont_ladder(3) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9fac578d93a88a648bdcf8dc34550cf4a7573 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py @@ -0,0 +1,49 @@ +from sympy.core.numbers import Rational +from sympy.ntheory.egyptian_fraction import egyptian_fraction +from sympy.core.add import Add +from sympy.testing.pytest import raises +from sympy.core.random import random_complex_number + + +def test_egyptian_fraction(): + def test_equality(r, alg="Greedy"): + return r == Add(*[Rational(1, i) for i in egyptian_fraction(r, alg)]) + + r = random_complex_number(a=0, c=1, b=0, d=0, rational=True) + assert test_equality(r) + + assert egyptian_fraction(Rational(4, 17)) == [5, 29, 1233, 3039345] + assert egyptian_fraction(Rational(7, 13), "Greedy") == [2, 26] + assert egyptian_fraction(Rational(23, 101), "Greedy") == \ + [5, 37, 1438, 2985448, 40108045937720] + assert egyptian_fraction(Rational(18, 23), "Takenouchi") == \ + [2, 6, 12, 35, 276, 2415] + assert egyptian_fraction(Rational(5, 6), "Graham Jewett") == \ + [6, 7, 8, 9, 10, 42, 43, 44, 45, 56, 57, 58, 72, 73, 90, 1806, 1807, + 1808, 1892, 1893, 1980, 3192, 3193, 3306, 5256, 3263442, 3263443, + 3267056, 3581556, 10192056, 10650056950806] + assert egyptian_fraction(Rational(5, 6), "Golomb") == [2, 6, 12, 20, 30] + assert egyptian_fraction(Rational(5, 121), "Golomb") == [25, 1225, 3577, 7081, 11737] + raises(ValueError, lambda: egyptian_fraction(Rational(-4, 9))) + assert egyptian_fraction(Rational(8, 3), "Golomb") == [1, 2, 3, 4, 5, 6, 7, + 14, 574, 2788, 6460, + 11590, 33062, 113820] + assert egyptian_fraction(Rational(355, 113)) == [1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 27, 744, 893588, + 1251493536607, + 20361068938197002344405230] + + +def test_input(): + r = (2,3), Rational(2, 3), (Rational(2), Rational(3)) + for m in ["Greedy", "Graham Jewett", "Takenouchi", "Golomb"]: + for i in r: + d = egyptian_fraction(i, m) + assert all(i.is_Integer for i in d) + if m == "Graham Jewett": + assert d == [3, 4, 12] + else: + assert d == [2, 6] + # check prefix + d = egyptian_fraction(Rational(5, 3)) + assert d == [1, 2, 6] and all(i.is_Integer for i in d) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_elliptic_curve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_elliptic_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..7d49d8eac72cc622fb92dfca8c54e5cc6c8dfb8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_elliptic_curve.py @@ -0,0 +1,20 @@ +from sympy.ntheory.elliptic_curve import EllipticCurve + + +def test_elliptic_curve(): + # Point addition and multiplication + e3 = EllipticCurve(-1, 9) + p = e3(0, 3) + q = e3(-1, 3) + r = p + q + assert r.x == 1 and r.y == -3 + r = 2*p + q + assert r.x == 35 and r.y == 207 + r = -p + q + assert r.x == 37 and r.y == 225 + # Verify result in http://www.lmfdb.org/EllipticCurve/Q + # Discriminant + assert EllipticCurve(-1, 9).discriminant == -34928 + assert EllipticCurve(-2731, -55146, 1, 0, 1).discriminant == 25088 + # Torsion points + assert len(EllipticCurve(0, 1).torsion_points()) == 6 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_factor_.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_factor_.py new file mode 100644 index 0000000000000000000000000000000000000000..5174b842c49ef0e14c1ad38d2d9ad550c2a2a388 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_factor_.py @@ -0,0 +1,702 @@ +from sympy.core.containers import Dict +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.combinatorial.factorials import factorial as fac +from sympy.core.numbers import Integer, Rational +from sympy.external.gmpy import gcd + +from sympy.ntheory import (totient, + factorint, primefactors, divisors, nextprime, + pollard_rho, perfect_power, multiplicity, multiplicity_in_factorial, + divisor_count, primorial, pollard_pm1, divisor_sigma, + factorrat, reduced_totient) +from sympy.ntheory.factor_ import (smoothness, smoothness_p, proper_divisors, + antidivisors, antidivisor_count, _divisor_sigma, core, udivisors, udivisor_sigma, + udivisor_count, proper_divisor_count, primenu, primeomega, + mersenne_prime_exponent, is_perfect, is_abundant, + is_deficient, is_amicable, is_carmichael, find_carmichael_numbers_in_range, + find_first_n_carmichaels, dra, drm, _perfect_power, factor_cache) + +from sympy.testing.pytest import raises, slow + +from sympy.utilities.iterables import capture + + +def fac_multiplicity(n, p): + """Return the power of the prime number p in the + factorization of n!""" + if p > n: + return 0 + if p > n//2: + return 1 + q, m = n, 0 + while q >= p: + q //= p + m += q + return m + + +def multiproduct(seq=(), start=1): + """ + Return the product of a sequence of factors with multiplicities, + times the value of the parameter ``start``. The input may be a + sequence of (factor, exponent) pairs or a dict of such pairs. + + >>> multiproduct({3:7, 2:5}, 4) # = 3**7 * 2**5 * 4 + 279936 + + """ + if not seq: + return start + if isinstance(seq, dict): + seq = iter(seq.items()) + units = start + multi = [] + for base, exp in seq: + if not exp: + continue + elif exp == 1: + units *= base + else: + if exp % 2: + units *= base + multi.append((base, exp//2)) + return units * multiproduct(multi)**2 + + +def test_multiplicity(): + for b in range(2, 20): + for i in range(100): + assert multiplicity(b, b**i) == i + assert multiplicity(b, (b**i) * 23) == i + assert multiplicity(b, (b**i) * 1000249) == i + # Should be fast + assert multiplicity(10, 10**10023) == 10023 + # Should exit quickly + assert multiplicity(10**10, 10**10) == 1 + # Should raise errors for bad input + raises(ValueError, lambda: multiplicity(1, 1)) + raises(ValueError, lambda: multiplicity(1, 2)) + raises(ValueError, lambda: multiplicity(1.3, 2)) + raises(ValueError, lambda: multiplicity(2, 0)) + raises(ValueError, lambda: multiplicity(1.3, 0)) + + # handles Rationals + assert multiplicity(10, Rational(30, 7)) == 1 + assert multiplicity(Rational(2, 7), Rational(4, 7)) == 1 + assert multiplicity(Rational(1, 7), Rational(3, 49)) == 2 + assert multiplicity(Rational(2, 7), Rational(7, 2)) == -1 + assert multiplicity(3, Rational(1, 9)) == -2 + + +def test_multiplicity_in_factorial(): + n = fac(1000) + for i in (2, 4, 6, 12, 30, 36, 48, 60, 72, 96): + assert multiplicity(i, n) == multiplicity_in_factorial(i, 1000) + + +def test_private_perfect_power(): + assert _perfect_power(0) is False + assert _perfect_power(1) is False + assert _perfect_power(2) is False + assert _perfect_power(3) is False + for x in [2, 3, 5, 6, 7, 12, 15, 105, 100003]: + for y in range(2, 100): + assert _perfect_power(x**y) == (x, y) + if x & 1: + assert _perfect_power(x**y, next_p=3) == (x, y) + if x == 100003: + assert _perfect_power(x**y, next_p=100003) == (x, y) + assert _perfect_power(101*x**y) == False + # Catalan's conjecture + if x**y not in [8, 9]: + assert _perfect_power(x**y + 1) == False + assert _perfect_power(x**y - 1) == False + for x in range(1, 10): + for y in range(1, 10): + g = gcd(x, y) + if g == 1: + assert _perfect_power(5**x * 101**y) == False + else: + assert _perfect_power(5**x * 101**y) == (5**(x//g) * 101**(y//g), g) + + +def test_perfect_power(): + raises(ValueError, lambda: perfect_power(0.1)) + assert perfect_power(0) is False + assert perfect_power(1) is False + assert perfect_power(2) is False + assert perfect_power(3) is False + assert perfect_power(4) == (2, 2) + assert perfect_power(14) is False + assert perfect_power(25) == (5, 2) + assert perfect_power(22) is False + assert perfect_power(22, [2]) is False + assert perfect_power(137**(3*5*13)) == (137, 3*5*13) + assert perfect_power(137**(3*5*13) + 1) is False + assert perfect_power(137**(3*5*13) - 1) is False + assert perfect_power(103005006004**7) == (103005006004, 7) + assert perfect_power(103005006004**7 + 1) is False + assert perfect_power(103005006004**7 - 1) is False + assert perfect_power(103005006004**12) == (103005006004, 12) + assert perfect_power(103005006004**12 + 1) is False + assert perfect_power(103005006004**12 - 1) is False + assert perfect_power(2**10007) == (2, 10007) + assert perfect_power(2**10007 + 1) is False + assert perfect_power(2**10007 - 1) is False + assert perfect_power((9**99 + 1)**60) == (9**99 + 1, 60) + assert perfect_power((9**99 + 1)**60 + 1) is False + assert perfect_power((9**99 + 1)**60 - 1) is False + assert perfect_power((10**40000)**2, big=False) == (10**40000, 2) + assert perfect_power(10**100000) == (10, 100000) + assert perfect_power(10**100001) == (10, 100001) + assert perfect_power(13**4, [3, 5]) is False + assert perfect_power(3**4, [3, 10], factor=0) is False + assert perfect_power(3**3*5**3) == (15, 3) + assert perfect_power(2**3*5**5) is False + assert perfect_power(2*13**4) is False + assert perfect_power(2**5*3**3) is False + t = 2**24 + for d in divisors(24): + m = perfect_power(t*3**d) + assert m and m[1] == d or d == 1 + m = perfect_power(t*3**d, big=False) + assert m and m[1] == 2 or d == 1 or d == 3, (d, m) + + # negatives and non-integer rationals + assert perfect_power(-4) is False + assert perfect_power(-8) == (-2, 3) + assert perfect_power(-S(1)/8) == (-S(1)/2, 3) + assert perfect_power(S(1)/3) == False + assert perfect_power(-5**15) == (-5, 15) + assert perfect_power(-5**15, big=False) == (-3125, 3) + assert perfect_power(-5**15, [15]) == (-5, 15) + + n = -3 ** 60 + assert perfect_power(n) == (-81, 15) + assert perfect_power(n, big=False) == (-3486784401, 3) + assert perfect_power(n, [3, 5], big=True) == (-531441, 5) + assert perfect_power(n, [3, 5], big=False) == (-3486784401, 3) + assert perfect_power(n, [2]) == False + assert perfect_power(n, [2, 15]) == (-81, 15) + assert perfect_power(n, [2, 13]) == False + assert perfect_power(n, [17]) == False + assert perfect_power(n, [3]) == (-3486784401, 3) + assert perfect_power(n + 1) == False + + r = S(2) ** (2 * 5 * 7) / S(3) ** (2 * 7) + assert perfect_power(r) == (S(32) / 3, 14) + assert perfect_power(-r) == (-S(1024) / 9, 7) + assert perfect_power(r, big=False) == (S(34359738368) / 2187, 2) + assert perfect_power(r, [2, 5]) == (S(34359738368) / 2187, 2) + assert perfect_power(r, [5, 7]) == (S(1024) / 9, 7) + assert perfect_power(r, [5, 7], big=False) == (S(1024) / 9, 7) + assert perfect_power(r, [2, 5, 7], big=False) == (S(34359738368) / 2187, 2) + assert perfect_power(-r, [5, 7], big=False) == (-S(1024) / 9, 7) + + assert perfect_power(-S(1) / 8) == (-S(1) / 2, 3) + + assert perfect_power((-3)**60) == (3, 60) + assert perfect_power((-3)**61) == (-3, 61) + + assert perfect_power(S(2 ** 9) / 3 ** 12) == (S(8)/81, 3) + assert perfect_power(Rational(1, 2)**3) == (S.Half, 3) + assert perfect_power(Rational(-3, 2)**3) == (-3*S.Half, 3) + + +def test_factor_cache(): + factor_cache.cache_clear() + raises(ValueError, lambda: factor_cache.__setitem__(1, 5)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 1)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 10)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 3)) + raises(ValueError, lambda: factor_cache.__setitem__(20, 4)) + factor_cache.maxsize = 3 + for i in range(2, 10): + factor_cache[5*i] = 5 + assert len(factor_cache) == 3 + factor_cache.maxsize = 5 + for i in range(2, 10): + factor_cache[5*i] = 5 + assert len(factor_cache) == 5 + factor_cache.maxsize = 2 + assert len(factor_cache) == 2 + factor_cache.maxsize =1000 + + factor_cache.cache_clear() + factor_cache[40] = 5 + assert factor_cache.get(40) == 5 + assert factor_cache.get(20) is None + assert factor_cache[40] == 5 + raises(KeyError, lambda: factor_cache[10]) + del factor_cache[40] + assert len(factor_cache) == 0 + raises(KeyError, lambda: factor_cache.__delitem__(40)) + factor_cache.add(100, [5, 2]) + assert len(factor_cache) == 2 + assert factor_cache[100] == 5 + + for n in [1000000007, 10000019*20000003]: + factorint(n) + assert n in factor_cache + + # Restore the initial state + factor_cache.cache_clear() + factor_cache.maxsize = 1000 + + +@slow +def test_factorint(): + assert primefactors(123456) == [2, 3, 643] + assert factorint(0) == {0: 1} + assert factorint(1) == {} + assert factorint(-1) == {-1: 1} + assert factorint(-2) == {-1: 1, 2: 1} + assert factorint(-16) == {-1: 1, 2: 4} + assert factorint(2) == {2: 1} + assert factorint(126) == {2: 1, 3: 2, 7: 1} + assert factorint(123456) == {2: 6, 3: 1, 643: 1} + assert factorint(5951757) == {3: 1, 7: 1, 29: 2, 337: 1} + assert factorint(64015937) == {7993: 1, 8009: 1} + assert factorint(2**(2**6) + 1) == {274177: 1, 67280421310721: 1} + #issue 19683 + assert factorint(10**38 - 1) == {3: 2, 11: 1, 909090909090909091: 1, 1111111111111111111: 1} + #issue 17676 + assert factorint(28300421052393658575) == {3: 1, 5: 2, 11: 2, 43: 1, 2063: 2, 4127: 1, 4129: 1} + assert factorint(2063**2 * 4127**1 * 4129**1) == {2063: 2, 4127: 1, 4129: 1} + assert factorint(2347**2 * 7039**1 * 7043**1) == {2347: 2, 7039: 1, 7043: 1} + + assert factorint(0, multiple=True) == [0] + assert factorint(1, multiple=True) == [] + assert factorint(-1, multiple=True) == [-1] + assert factorint(-2, multiple=True) == [-1, 2] + assert factorint(-16, multiple=True) == [-1, 2, 2, 2, 2] + assert factorint(2, multiple=True) == [2] + assert factorint(24, multiple=True) == [2, 2, 2, 3] + assert factorint(126, multiple=True) == [2, 3, 3, 7] + assert factorint(123456, multiple=True) == [2, 2, 2, 2, 2, 2, 3, 643] + assert factorint(5951757, multiple=True) == [3, 7, 29, 29, 337] + assert factorint(64015937, multiple=True) == [7993, 8009] + assert factorint(2**(2**6) + 1, multiple=True) == [274177, 67280421310721] + + assert factorint(fac(1, evaluate=False)) == {} + assert factorint(fac(7, evaluate=False)) == {2: 4, 3: 2, 5: 1, 7: 1} + assert factorint(fac(15, evaluate=False)) == \ + {2: 11, 3: 6, 5: 3, 7: 2, 11: 1, 13: 1} + assert factorint(fac(20, evaluate=False)) == \ + {2: 18, 3: 8, 5: 4, 7: 2, 11: 1, 13: 1, 17: 1, 19: 1} + assert factorint(fac(23, evaluate=False)) == \ + {2: 19, 3: 9, 5: 4, 7: 3, 11: 2, 13: 1, 17: 1, 19: 1, 23: 1} + + assert multiproduct(factorint(fac(200))) == fac(200) + assert multiproduct(factorint(fac(200, evaluate=False))) == fac(200) + for b, e in factorint(fac(150)).items(): + assert e == fac_multiplicity(150, b) + for b, e in factorint(fac(150, evaluate=False)).items(): + assert e == fac_multiplicity(150, b) + assert factorint(103005006059**7) == {103005006059: 7} + assert factorint(31337**191) == {31337: 191} + assert factorint(2**1000 * 3**500 * 257**127 * 383**60) == \ + {2: 1000, 3: 500, 257: 127, 383: 60} + assert len(factorint(fac(10000))) == 1229 + assert len(factorint(fac(10000, evaluate=False))) == 1229 + assert factorint(12932983746293756928584532764589230) == \ + {2: 1, 5: 1, 73: 1, 727719592270351: 1, 63564265087747: 1, 383: 1} + assert factorint(727719592270351) == {727719592270351: 1} + assert factorint(2**64 + 1, use_trial=False) == factorint(2**64 + 1) + for n in range(60000): + assert multiproduct(factorint(n)) == n + assert pollard_rho(2**64 + 1, seed=1) == 274177 + assert pollard_rho(19, seed=1) is None + assert factorint(3, limit=2) == {3: 1} + assert factorint(12345) == {3: 1, 5: 1, 823: 1} + assert factorint( + 12345, limit=3) == {4115: 1, 3: 1} # the 5 is greater than the limit + assert factorint(1, limit=1) == {} + assert factorint(0, 3) == {0: 1} + assert factorint(12, limit=1) == {12: 1} + assert factorint(30, limit=2) == {2: 1, 15: 1} + assert factorint(16, limit=2) == {2: 4} + assert factorint(124, limit=3) == {2: 2, 31: 1} + assert factorint(4*31**2, limit=3) == {2: 2, 31: 2} + p1 = nextprime(2**32) + p2 = nextprime(2**16) + p3 = nextprime(p2) + assert factorint(p1*p2*p3) == {p1: 1, p2: 1, p3: 1} + assert factorint(13*17*19, limit=15) == {13: 1, 17*19: 1} + assert factorint(1951*15013*15053, limit=2000) == {225990689: 1, 1951: 1} + assert factorint(primorial(17) + 1, use_pm1=0) == \ + {int(19026377261): 1, 3467: 1, 277: 1, 105229: 1} + # when prime b is closer than approx sqrt(8*p) to prime p then they are + # "close" and have a trivial factorization + a = nextprime(2**2**8) # 78 digits + b = nextprime(a + 2**2**4) + assert 'Fermat' in capture(lambda: factorint(a*b, verbose=1)) + + raises(ValueError, lambda: pollard_rho(4)) + raises(ValueError, lambda: pollard_pm1(3)) + raises(ValueError, lambda: pollard_pm1(10, B=2)) + # verbose coverage + n = nextprime(2**16)*nextprime(2**17)*nextprime(1901) + assert 'with primes' in capture(lambda: factorint(n, verbose=1)) + capture(lambda: factorint(nextprime(2**16)*1012, verbose=1)) + + n = nextprime(2**17) + capture(lambda: factorint(n**3, verbose=1)) # perfect power termination + capture(lambda: factorint(2*n, verbose=1)) # factoring complete msg + + # exceed 1st + n = nextprime(2**17) + n *= nextprime(n) + assert '1000' in capture(lambda: factorint(n, limit=1000, verbose=1)) + n *= nextprime(n) + assert len(factorint(n)) == 3 + assert len(factorint(n, limit=p1)) == 3 + n *= nextprime(2*n) + # exceed 2nd + assert '2001' in capture(lambda: factorint(n, limit=2000, verbose=1)) + assert capture( + lambda: factorint(n, limit=4000, verbose=1)).count('Pollard') == 2 + # non-prime pm1 result + n = nextprime(8069) + n *= nextprime(2*n)*nextprime(2*n, 2) + capture(lambda: factorint(n, verbose=1)) # non-prime pm1 result + # factor fermat composite + p1 = nextprime(2**17) + p2 = nextprime(2*p1) + assert factorint((p1*p2**2)**3) == {p1: 3, p2: 6} + # Test for non integer input + raises(ValueError, lambda: factorint(4.5)) + # test dict/Dict input + sans = '2**10*3**3' + n = {4: 2, 12: 3} + assert str(factorint(n)) == sans + assert str(factorint(Dict(n))) == sans + + +def test_divisors_and_divisor_count(): + assert divisors(-1) == [1] + assert divisors(0) == [] + assert divisors(1) == [1] + assert divisors(2) == [1, 2] + assert divisors(3) == [1, 3] + assert divisors(17) == [1, 17] + assert divisors(10) == [1, 2, 5, 10] + assert divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50, 100] + assert divisors(101) == [1, 101] + assert type(divisors(2, generator=True)) is not list + + assert divisor_count(0) == 0 + assert divisor_count(-1) == 1 + assert divisor_count(1) == 1 + assert divisor_count(6) == 4 + assert divisor_count(12) == 6 + + assert divisor_count(180, 3) == divisor_count(180//3) + assert divisor_count(2*3*5, 7) == 0 + + +def test_proper_divisors_and_proper_divisor_count(): + assert proper_divisors(-1) == [] + assert proper_divisors(0) == [] + assert proper_divisors(1) == [] + assert proper_divisors(2) == [1] + assert proper_divisors(3) == [1] + assert proper_divisors(17) == [1] + assert proper_divisors(10) == [1, 2, 5] + assert proper_divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50] + assert proper_divisors(1000000007) == [1] + assert type(proper_divisors(2, generator=True)) is not list + + assert proper_divisor_count(0) == 0 + assert proper_divisor_count(-1) == 0 + assert proper_divisor_count(1) == 0 + assert proper_divisor_count(36) == 8 + assert proper_divisor_count(2*3*5) == 7 + + +def test_udivisors_and_udivisor_count(): + assert udivisors(-1) == [1] + assert udivisors(0) == [] + assert udivisors(1) == [1] + assert udivisors(2) == [1, 2] + assert udivisors(3) == [1, 3] + assert udivisors(17) == [1, 17] + assert udivisors(10) == [1, 2, 5, 10] + assert udivisors(100) == [1, 4, 25, 100] + assert udivisors(101) == [1, 101] + assert udivisors(1000) == [1, 8, 125, 1000] + assert type(udivisors(2, generator=True)) is not list + + assert udivisor_count(0) == 0 + assert udivisor_count(-1) == 1 + assert udivisor_count(1) == 1 + assert udivisor_count(6) == 4 + assert udivisor_count(12) == 4 + + assert udivisor_count(180) == 8 + assert udivisor_count(2*3*5*7) == 16 + + +def test_issue_6981(): + S = set(divisors(4)).union(set(divisors(Integer(2)))) + assert S == {1,2,4} + + +def test_issue_4356(): + assert factorint(1030903) == {53: 2, 367: 1} + + +def test_divisors(): + assert divisors(28) == [1, 2, 4, 7, 14, 28] + assert list(divisors(3*5*7, 1)) == [1, 3, 5, 15, 7, 21, 35, 105] + assert divisors(0) == [] + + +def test_divisor_count(): + assert divisor_count(0) == 0 + assert divisor_count(6) == 4 + + +def test_proper_divisors(): + assert proper_divisors(-1) == [] + assert proper_divisors(28) == [1, 2, 4, 7, 14] + assert list(proper_divisors(3*5*7, True)) == [1, 3, 5, 15, 7, 21, 35] + + +def test_proper_divisor_count(): + assert proper_divisor_count(6) == 3 + assert proper_divisor_count(108) == 11 + + +def test_antidivisors(): + assert antidivisors(-1) == [] + assert antidivisors(-3) == [2] + assert antidivisors(14) == [3, 4, 9] + assert antidivisors(237) == [2, 5, 6, 11, 19, 25, 43, 95, 158] + assert antidivisors(12345) == [2, 6, 7, 10, 30, 1646, 3527, 4938, 8230] + assert antidivisors(393216) == [262144] + assert sorted(x for x in antidivisors(3*5*7, 1)) == \ + [2, 6, 10, 11, 14, 19, 30, 42, 70] + assert antidivisors(1) == [] + assert type(antidivisors(2, generator=True)) is not list + +def test_antidivisor_count(): + assert antidivisor_count(0) == 0 + assert antidivisor_count(-1) == 0 + assert antidivisor_count(-4) == 1 + assert antidivisor_count(20) == 3 + assert antidivisor_count(25) == 5 + assert antidivisor_count(38) == 7 + assert antidivisor_count(180) == 6 + assert antidivisor_count(2*3*5) == 3 + + +def test_smoothness_and_smoothness_p(): + assert smoothness(1) == (1, 1) + assert smoothness(2**4*3**2) == (3, 16) + + assert smoothness_p(10431, m=1) == \ + (1, [(3, (2, 2, 4)), (19, (1, 5, 5)), (61, (1, 31, 31))]) + assert smoothness_p(10431) == \ + (-1, [(3, (2, 2, 2)), (19, (1, 3, 9)), (61, (1, 5, 5))]) + assert smoothness_p(10431, power=1) == \ + (-1, [(3, (2, 2, 2)), (61, (1, 5, 5)), (19, (1, 3, 9))]) + assert smoothness_p(21477639576571, visual=1) == \ + 'p**i=4410317**1 has p-1 B=1787, B-pow=1787\n' + \ + 'p**i=4869863**1 has p-1 B=2434931, B-pow=2434931' + + +def test_visual_factorint(): + assert factorint(1, visual=1) == 1 + forty2 = factorint(42, visual=True) + assert type(forty2) == Mul + assert str(forty2) == '2**1*3**1*7**1' + assert factorint(1, visual=True) is S.One + no = {"evaluate": False} + assert factorint(42**2, visual=True) == Mul(Pow(2, 2, **no), + Pow(3, 2, **no), + Pow(7, 2, **no), **no) + assert -1 in factorint(-42, visual=True).args + + +def test_factorrat(): + assert str(factorrat(S(12)/1, visual=True)) == '2**2*3**1' + assert str(factorrat(Rational(1, 1), visual=True)) == '1' + assert str(factorrat(S(25)/14, visual=True)) == '5**2/(2*7)' + assert str(factorrat(Rational(25, 14), visual=True)) == '5**2/(2*7)' + assert str(factorrat(S(-25)/14/9, visual=True)) == '-1*5**2/(2*3**2*7)' + + assert factorrat(S(12)/1, multiple=True) == [2, 2, 3] + assert factorrat(Rational(1, 1), multiple=True) == [] + assert factorrat(S(25)/14, multiple=True) == [Rational(1, 7), S.Half, 5, 5] + assert factorrat(Rational(25, 14), multiple=True) == [Rational(1, 7), S.Half, 5, 5] + assert factorrat(Rational(12, 1), multiple=True) == [2, 2, 3] + assert factorrat(S(-25)/14/9, multiple=True) == \ + [-1, Rational(1, 7), Rational(1, 3), Rational(1, 3), S.Half, 5, 5] + + +def test_visual_io(): + sm = smoothness_p + fi = factorint + # with smoothness_p + n = 124 + d = fi(n) + m = fi(d, visual=True) + t = sm(n) + s = sm(t) + for th in [d, s, t, n, m]: + assert sm(th, visual=True) == s + assert sm(th, visual=1) == s + for th in [d, s, t, n, m]: + assert sm(th, visual=False) == t + assert [sm(th, visual=None) for th in [d, s, t, n, m]] == [s, d, s, t, t] + assert [sm(th, visual=2) for th in [d, s, t, n, m]] == [s, d, s, t, t] + + # with factorint + for th in [d, m, n]: + assert fi(th, visual=True) == m + assert fi(th, visual=1) == m + for th in [d, m, n]: + assert fi(th, visual=False) == d + assert [fi(th, visual=None) for th in [d, m, n]] == [m, d, d] + assert [fi(th, visual=0) for th in [d, m, n]] == [m, d, d] + + # test reevaluation + no = {"evaluate": False} + assert sm({4: 2}, visual=False) == sm(16) + assert sm(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no), + visual=False) == sm(2**10) + + assert fi({4: 2}, visual=False) == fi(16) + assert fi(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no), + visual=False) == fi(2**10) + + +def test_core(): + assert core(35**13, 10) == 42875 + assert core(210**2) == 1 + assert core(7776, 3) == 36 + assert core(10**27, 22) == 10**5 + assert core(537824) == 14 + assert core(1, 6) == 1 + + +def test__divisor_sigma(): + assert _divisor_sigma(23450) == 50592 + assert _divisor_sigma(23450, 0) == 24 + assert _divisor_sigma(23450, 1) == 50592 + assert _divisor_sigma(23450, 2) == 730747500 + assert _divisor_sigma(23450, 3) == 14666785333344 + A000005 = [1, 2, 2, 3, 2, 4, 2, 4, 3, 4, 2, 6, 2, 4, 4, 5, 2, 6, 2, 6, 4, + 4, 2, 8, 3, 4, 4, 6, 2, 8, 2, 6, 4, 4, 4, 9, 2, 4, 4, 8, 2, 8] + for n, val in enumerate(A000005, 1): + assert _divisor_sigma(n, 0) == val + A000203 = [1, 3, 4, 7, 6, 12, 8, 15, 13, 18, 12, 28, 14, 24, 24, 31, 18, + 39, 20, 42, 32, 36, 24, 60, 31, 42, 40, 56, 30, 72, 32, 63, 48] + for n, val in enumerate(A000203, 1): + assert _divisor_sigma(n, 1) == val + A001157 = [1, 5, 10, 21, 26, 50, 50, 85, 91, 130, 122, 210, 170, 250, 260, + 341, 290, 455, 362, 546, 500, 610, 530, 850, 651, 850, 820, 1050] + for n, val in enumerate(A001157, 1): + assert _divisor_sigma(n, 2) == val + + +def test_mersenne_prime_exponent(): + assert mersenne_prime_exponent(1) == 2 + assert mersenne_prime_exponent(4) == 7 + assert mersenne_prime_exponent(10) == 89 + assert mersenne_prime_exponent(25) == 21701 + raises(ValueError, lambda: mersenne_prime_exponent(52)) + raises(ValueError, lambda: mersenne_prime_exponent(0)) + + +def test_is_perfect(): + assert is_perfect(-6) is False + assert is_perfect(6) is True + assert is_perfect(15) is False + assert is_perfect(28) is True + assert is_perfect(400) is False + assert is_perfect(496) is True + assert is_perfect(8128) is True + assert is_perfect(10000) is False + + +def test_is_abundant(): + assert is_abundant(10) is False + assert is_abundant(12) is True + assert is_abundant(18) is True + assert is_abundant(21) is False + assert is_abundant(945) is True + + +def test_is_deficient(): + assert is_deficient(10) is True + assert is_deficient(22) is True + assert is_deficient(56) is False + assert is_deficient(20) is False + assert is_deficient(36) is False + + +def test_is_amicable(): + assert is_amicable(173, 129) is False + assert is_amicable(220, 284) is True + assert is_amicable(8756, 8756) is False + + +def test_is_carmichael(): + A002997 = [561, 1105, 1729, 2465, 2821, 6601, 8911, 10585, 15841, + 29341, 41041, 46657, 52633, 62745, 63973, 75361, 101101] + for n in range(1, 5000): + assert is_carmichael(n) == (n in A002997) + for n in A002997: + assert is_carmichael(n) + + +def test_find_carmichael_numbers_in_range(): + assert find_carmichael_numbers_in_range(0, 561) == [] + assert find_carmichael_numbers_in_range(561, 562) == [561] + assert find_carmichael_numbers_in_range(561, 1105) == find_carmichael_numbers_in_range(561, 562) + raises(ValueError, lambda: find_carmichael_numbers_in_range(-2, 2)) + raises(ValueError, lambda: find_carmichael_numbers_in_range(22, 2)) + + +def test_find_first_n_carmichaels(): + assert find_first_n_carmichaels(0) == [] + assert find_first_n_carmichaels(1) == [561] + assert find_first_n_carmichaels(2) == [561, 1105] + + +def test_dra(): + assert dra(19, 12) == 8 + assert dra(2718, 10) == 9 + assert dra(0, 22) == 0 + assert dra(23456789, 10) == 8 + raises(ValueError, lambda: dra(24, -2)) + raises(ValueError, lambda: dra(24.2, 5)) + +def test_drm(): + assert drm(19, 12) == 7 + assert drm(2718, 10) == 2 + assert drm(0, 15) == 0 + assert drm(234161, 10) == 6 + raises(ValueError, lambda: drm(24, -2)) + raises(ValueError, lambda: drm(11.6, 9)) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert primenu(3) == 1 + with warns_deprecated_sympy(): + assert primeomega(3) == 1 + with warns_deprecated_sympy(): + assert totient(3) == 2 + with warns_deprecated_sympy(): + assert reduced_totient(3) == 2 + with warns_deprecated_sympy(): + assert divisor_sigma(3) == 4 + with warns_deprecated_sympy(): + assert udivisor_sigma(3) == 4 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_generate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e5918ffefede2e86f3be2b07d6c3a01c02e6e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_generate.py @@ -0,0 +1,285 @@ +from bisect import bisect, bisect_left + +from sympy.functions.combinatorial.numbers import mobius, totient +from sympy.ntheory.generate import (sieve, Sieve) + +from sympy.ntheory import isprime, randprime, nextprime, prevprime, \ + primerange, primepi, prime, primorial, composite, compositepi +from sympy.ntheory.generate import cycle_length, _primepi +from sympy.ntheory.primetest import mr +from sympy.testing.pytest import raises + +def test_prime(): + assert prime(1) == 2 + assert prime(2) == 3 + assert prime(5) == 11 + assert prime(11) == 31 + assert prime(57) == 269 + assert prime(296) == 1949 + assert prime(559) == 4051 + assert prime(3000) == 27449 + assert prime(4096) == 38873 + assert prime(9096) == 94321 + assert prime(25023) == 287341 + assert prime(10000000) == 179424673 # issue #20951 + assert prime(99999999) == 2038074739 + raises(ValueError, lambda: prime(0)) + sieve.extend(3000) + assert prime(401) == 2749 + raises(ValueError, lambda: prime(-1)) + + +def test__primepi(): + assert _primepi(-1) == 0 + assert _primepi(1) == 0 + assert _primepi(2) == 1 + assert _primepi(5) == 3 + assert _primepi(11) == 5 + assert _primepi(57) == 16 + assert _primepi(296) == 62 + assert _primepi(559) == 102 + assert _primepi(3000) == 430 + assert _primepi(4096) == 564 + assert _primepi(9096) == 1128 + assert _primepi(25023) == 2763 + assert _primepi(10**8) == 5761455 + assert _primepi(253425253) == 13856396 + assert _primepi(8769575643) == 401464322 + sieve.extend(3000) + assert _primepi(2000) == 303 + + +def test_composite(): + from sympy.ntheory.generate import sieve + sieve._reset() + assert composite(1) == 4 + assert composite(2) == 6 + assert composite(5) == 10 + assert composite(11) == 20 + assert composite(41) == 58 + assert composite(57) == 80 + assert composite(296) == 370 + assert composite(559) == 684 + assert composite(3000) == 3488 + assert composite(4096) == 4736 + assert composite(9096) == 10368 + assert composite(25023) == 28088 + sieve.extend(3000) + assert composite(1957) == 2300 + assert composite(2568) == 2998 + raises(ValueError, lambda: composite(0)) + + +def test_compositepi(): + assert compositepi(1) == 0 + assert compositepi(2) == 0 + assert compositepi(5) == 1 + assert compositepi(11) == 5 + assert compositepi(57) == 40 + assert compositepi(296) == 233 + assert compositepi(559) == 456 + assert compositepi(3000) == 2569 + assert compositepi(4096) == 3531 + assert compositepi(9096) == 7967 + assert compositepi(25023) == 22259 + assert compositepi(10**8) == 94238544 + assert compositepi(253425253) == 239568856 + assert compositepi(8769575643) == 8368111320 + sieve.extend(3000) + assert compositepi(2321) == 1976 + + +def test_generate(): + from sympy.ntheory.generate import sieve + sieve._reset() + assert nextprime(-4) == 2 + assert nextprime(2) == 3 + assert nextprime(5) == 7 + assert nextprime(12) == 13 + assert prevprime(3) == 2 + assert prevprime(7) == 5 + assert prevprime(13) == 11 + assert prevprime(19) == 17 + assert prevprime(20) == 19 + + sieve.extend_to_no(9) + assert sieve._list[-1] == 23 + + assert sieve._list[-1] < 31 + assert 31 in sieve + + assert nextprime(90) == 97 + assert nextprime(10**40) == (10**40 + 121) + primelist = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, + 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, + 79, 83, 89, 97, 101, 103, 107, 109, 113, + 127, 131, 137, 139, 149, 151, 157, 163, + 167, 173, 179, 181, 191, 193, 197, 199, + 211, 223, 227, 229, 233, 239, 241, 251, + 257, 263, 269, 271, 277, 281, 283, 293] + for i in range(len(primelist) - 2): + for j in range(2, len(primelist) - i): + assert nextprime(primelist[i], j) == primelist[i + j] + if 3 < i: + assert nextprime(primelist[i] - 1, j) == primelist[i + j - 1] + raises(ValueError, lambda: nextprime(2, 0)) + raises(ValueError, lambda: nextprime(2, -1)) + assert prevprime(97) == 89 + assert prevprime(10**40) == (10**40 - 17) + + raises(ValueError, lambda: Sieve(0)) + raises(ValueError, lambda: Sieve(-1)) + for sieve_interval in [1, 10, 11, 1_000_000]: + s = Sieve(sieve_interval=sieve_interval) + for head in range(s._list[-1] + 1, (s._list[-1] + 1)**2, 2): + for tail in range(head + 1, (s._list[-1] + 1)**2): + A = list(s._primerange(head, tail)) + B = primelist[bisect(primelist, head):bisect_left(primelist, tail)] + assert A == B + for k in range(s._list[-1], primelist[-1] - 1, 2): + s = Sieve(sieve_interval=sieve_interval) + s.extend(k) + assert list(s._list) == primelist[:bisect(primelist, k)] + s.extend(primelist[-1]) + assert list(s._list) == primelist + + assert list(sieve.primerange(10, 1)) == [] + assert list(sieve.primerange(5, 9)) == [5, 7] + sieve._reset(prime=True) + assert list(sieve.primerange(2, 13)) == [2, 3, 5, 7, 11] + assert list(sieve.primerange(13)) == [2, 3, 5, 7, 11] + assert list(sieve.primerange(8)) == [2, 3, 5, 7] + assert list(sieve.primerange(-2)) == [] + assert list(sieve.primerange(29)) == [2, 3, 5, 7, 11, 13, 17, 19, 23] + assert list(sieve.primerange(34)) == [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + + assert list(sieve.totientrange(5, 15)) == [4, 2, 6, 4, 6, 4, 10, 4, 12, 6] + sieve._reset(totient=True) + assert list(sieve.totientrange(3, 13)) == [2, 2, 4, 2, 6, 4, 6, 4, 10, 4] + assert list(sieve.totientrange(900, 1000)) == [totient(x) for x in range(900, 1000)] + assert list(sieve.totientrange(0, 1)) == [] + assert list(sieve.totientrange(1, 2)) == [1] + + assert list(sieve.mobiusrange(5, 15)) == [-1, 1, -1, 0, 0, 1, -1, 0, -1, 1] + sieve._reset(mobius=True) + assert list(sieve.mobiusrange(3, 13)) == [-1, 0, -1, 1, -1, 0, 0, 1, -1, 0] + assert list(sieve.mobiusrange(1050, 1100)) == [mobius(x) for x in range(1050, 1100)] + assert list(sieve.mobiusrange(0, 1)) == [] + assert list(sieve.mobiusrange(1, 2)) == [1] + + assert list(primerange(10, 1)) == [] + assert list(primerange(2, 7)) == [2, 3, 5] + assert list(primerange(2, 10)) == [2, 3, 5, 7] + assert list(primerange(1050, 1100)) == [1051, 1061, + 1063, 1069, 1087, 1091, 1093, 1097] + s = Sieve() + for i in range(30, 2350, 376): + for j in range(2, 5096, 1139): + A = list(s.primerange(i, i + j)) + B = list(primerange(i, i + j)) + assert A == B + s = Sieve() + sieve._reset(prime=True) + sieve.extend(13) + for i in range(200): + for j in range(i, 200): + A = list(s.primerange(i, j)) + B = list(primerange(i, j)) + assert A == B + sieve.extend(1000) + for a, b in [(901, 1103), # a < 1000 < b < 1000**2 + (806, 1002007), # a < 1000 < 1000**2 < b + (2000, 30001), # 1000 < a < b < 1000**2 + (100005, 1010001), # 1000 < a < 1000**2 < b + (1003003, 1005000), # 1000**2 < a < b + ]: + assert list(primerange(a, b)) == list(s.primerange(a, b)) + sieve._reset(prime=True) + sieve.extend(100000) + assert len(sieve._list) == len(set(sieve._list)) + s = Sieve() + assert s[10] == 29 + + assert nextprime(2, 2) == 5 + + raises(ValueError, lambda: totient(0)) + + raises(ValueError, lambda: primorial(0)) + + assert mr(1, [2]) is False + + func = lambda i: (i**2 + 1) % 51 + assert next(cycle_length(func, 4)) == (6, 3) + assert list(cycle_length(func, 4, values=True)) == \ + [4, 17, 35, 2, 5, 26, 14, 44, 50, 2, 5, 26, 14] + assert next(cycle_length(func, 4, nmax=5)) == (5, None) + assert list(cycle_length(func, 4, nmax=5, values=True)) == \ + [4, 17, 35, 2, 5] + sieve.extend(3000) + assert nextprime(2968) == 2969 + assert prevprime(2930) == 2927 + raises(ValueError, lambda: prevprime(1)) + raises(ValueError, lambda: prevprime(-4)) + + +def test_randprime(): + assert randprime(10, 1) is None + assert randprime(3, -3) is None + assert randprime(2, 3) == 2 + assert randprime(1, 3) == 2 + assert randprime(3, 5) == 3 + raises(ValueError, lambda: randprime(-12, -2)) + raises(ValueError, lambda: randprime(-10, 0)) + raises(ValueError, lambda: randprime(20, 22)) + raises(ValueError, lambda: randprime(0, 2)) + raises(ValueError, lambda: randprime(1, 2)) + for a in [100, 300, 500, 250000]: + for b in [100, 300, 500, 250000]: + p = randprime(a, a + b) + assert a <= p < (a + b) and isprime(p) + + +def test_primorial(): + assert primorial(1) == 2 + assert primorial(1, nth=0) == 1 + assert primorial(2) == 6 + assert primorial(2, nth=0) == 2 + assert primorial(4, nth=0) == 6 + + +def test_search(): + assert 2 in sieve + assert 2.1 not in sieve + assert 1 not in sieve + assert 2**1000 not in sieve + raises(ValueError, lambda: sieve.search(1)) + + +def test_sieve_slice(): + assert sieve[5] == 11 + assert list(sieve[5:10]) == [sieve[x] for x in range(5, 10)] + assert list(sieve[5:10:2]) == [sieve[x] for x in range(5, 10, 2)] + assert list(sieve[1:5]) == [2, 3, 5, 7] + raises(IndexError, lambda: sieve[:5]) + raises(IndexError, lambda: sieve[0]) + raises(IndexError, lambda: sieve[0:5]) + +def test_sieve_iter(): + values = [] + for value in sieve: + if value > 7: + break + values.append(value) + assert values == list(sieve[1:5]) + + +def test_sieve_repr(): + assert "sieve" in repr(sieve) + assert "prime" in repr(sieve) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert primepi(0) == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_hypothesis.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f4cbecdbb7a6b15b0e323700cda11039c968fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_hypothesis.py @@ -0,0 +1,24 @@ +from hypothesis import given +from hypothesis import strategies as st +from sympy import divisors +from sympy.functions.combinatorial.numbers import divisor_sigma, totient +from sympy.ntheory.primetest import is_square + + +@given(n=st.integers(1, 10**10)) +def test_tau_hypothesis(n): + div = divisors(n) + tau_n = len(div) + assert is_square(n) == (tau_n % 2 == 1) + sigmas = [divisor_sigma(i) for i in div] + totients = [totient(n // i) for i in div] + mul = [a * b for a, b in zip(sigmas, totients)] + assert n * tau_n == sum(mul) + + +@given(n=st.integers(1, 10**10)) +def test_totient_hypothesis(n): + assert totient(n) <= n + div = divisors(n) + totients = [totient(i) for i in div] + assert n == sum(totients) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_modular.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_modular.py new file mode 100644 index 0000000000000000000000000000000000000000..10ebb1d3d3bdf5f736a6229579ae4c42a805745e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_modular.py @@ -0,0 +1,34 @@ +from sympy.ntheory.modular import crt, crt1, crt2, solve_congruence +from sympy.testing.pytest import raises + + +def test_crt(): + def mcrt(m, v, r, symmetric=False): + assert crt(m, v, symmetric)[0] == r + mm, e, s = crt1(m) + assert crt2(m, v, mm, e, s, symmetric) == (r, mm) + + mcrt([2, 3, 5], [0, 0, 0], 0) + mcrt([2, 3, 5], [1, 1, 1], 1) + + mcrt([2, 3, 5], [-1, -1, -1], -1, True) + mcrt([2, 3, 5], [-1, -1, -1], 2*3*5 - 1, False) + + assert crt([656, 350], [811, 133], symmetric=True) == (-56917, 114800) + + +def test_modular(): + assert solve_congruence(*list(zip([3, 4, 2], [12, 35, 17]))) == (1719, 7140) + assert solve_congruence(*list(zip([3, 4, 2], [12, 6, 17]))) is None + assert solve_congruence(*list(zip([3, 4, 2], [13, 7, 17]))) == (172, 1547) + assert solve_congruence(*list(zip([-10, -3, -15], [13, 7, 17]))) == (172, 1547) + assert solve_congruence(*list(zip([-10, -3, 1, -15], [13, 7, 7, 17]))) is None + assert solve_congruence( + *list(zip([-10, -5, 2, -15], [13, 7, 7, 17]))) == (835, 1547) + assert solve_congruence( + *list(zip([-10, -5, 2, -15], [13, 7, 14, 17]))) == (2382, 3094) + assert solve_congruence( + *list(zip([-10, 2, 2, -15], [13, 7, 14, 17]))) == (2382, 3094) + assert solve_congruence(*list(zip((1, 1, 2), (3, 2, 4)))) is None + raises( + ValueError, lambda: solve_congruence(*list(zip([3, 4, 2], [12.1, 35, 17])))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_multinomial.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..b455c5cc979b9ba9756c9da88c1694471115cd5d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_multinomial.py @@ -0,0 +1,48 @@ +from sympy.ntheory.multinomial import (binomial_coefficients, binomial_coefficients_list, multinomial_coefficients) +from sympy.ntheory.multinomial import multinomial_coefficients_iterator + + +def test_binomial_coefficients_list(): + assert binomial_coefficients_list(0) == [1] + assert binomial_coefficients_list(1) == [1, 1] + assert binomial_coefficients_list(2) == [1, 2, 1] + assert binomial_coefficients_list(3) == [1, 3, 3, 1] + assert binomial_coefficients_list(4) == [1, 4, 6, 4, 1] + assert binomial_coefficients_list(5) == [1, 5, 10, 10, 5, 1] + assert binomial_coefficients_list(6) == [1, 6, 15, 20, 15, 6, 1] + + +def test_binomial_coefficients(): + for n in range(15): + c = binomial_coefficients(n) + l = [c[k] for k in sorted(c)] + assert l == binomial_coefficients_list(n) + + +def test_multinomial_coefficients(): + assert multinomial_coefficients(1, 1) == {(1,): 1} + assert multinomial_coefficients(1, 2) == {(2,): 1} + assert multinomial_coefficients(1, 3) == {(3,): 1} + assert multinomial_coefficients(2, 0) == {(0, 0): 1} + assert multinomial_coefficients(2, 1) == {(0, 1): 1, (1, 0): 1} + assert multinomial_coefficients(2, 2) == {(2, 0): 1, (0, 2): 1, (1, 1): 2} + assert multinomial_coefficients(2, 3) == {(3, 0): 1, (1, 2): 3, (0, 3): 1, + (2, 1): 3} + assert multinomial_coefficients(3, 1) == {(1, 0, 0): 1, (0, 1, 0): 1, + (0, 0, 1): 1} + assert multinomial_coefficients(3, 2) == {(0, 1, 1): 2, (0, 0, 2): 1, + (1, 1, 0): 2, (0, 2, 0): 1, (1, 0, 1): 2, (2, 0, 0): 1} + mc = multinomial_coefficients(3, 3) + assert mc == {(2, 1, 0): 3, (0, 3, 0): 1, + (1, 0, 2): 3, (0, 2, 1): 3, (0, 1, 2): 3, (3, 0, 0): 1, + (2, 0, 1): 3, (1, 2, 0): 3, (1, 1, 1): 6, (0, 0, 3): 1} + assert dict(multinomial_coefficients_iterator(2, 0)) == {(0, 0): 1} + assert dict( + multinomial_coefficients_iterator(2, 1)) == {(0, 1): 1, (1, 0): 1} + assert dict(multinomial_coefficients_iterator(2, 2)) == \ + {(2, 0): 1, (0, 2): 1, (1, 1): 2} + assert dict(multinomial_coefficients_iterator(3, 3)) == mc + it = multinomial_coefficients_iterator(7, 2) + assert [next(it) for i in range(4)] == \ + [((2, 0, 0, 0, 0, 0, 0), 1), ((1, 1, 0, 0, 0, 0, 0), 2), + ((0, 2, 0, 0, 0, 0, 0), 1), ((1, 0, 1, 0, 0, 0, 0), 2)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_partitions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_partitions.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb7fad3445068ae7ae4033c76c808e3c87347b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_partitions.py @@ -0,0 +1,28 @@ +from sympy.ntheory.partitions_ import npartitions, _partition_rec, _partition + + +def test__partition_rec(): + A000041 = [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, + 176, 231, 297, 385, 490, 627, 792, 1002, 1255, 1575] + for n, val in enumerate(A000041): + assert _partition_rec(n) == val + + +def test__partition(): + assert [_partition(k) for k in range(13)] == \ + [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77] + assert _partition(100) == 190569292 + assert _partition(200) == 3972999029388 + assert _partition(1000) == 24061467864032622473692149727991 + assert _partition(1001) == 25032297938763929621013218349796 + assert _partition(2000) == 4720819175619413888601432406799959512200344166 + assert _partition(10000) % 10**10 == 6916435144 + assert _partition(100000) % 10**10 == 9421098519 + assert _partition(10000000) % 10**10 == 7677288980 + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert npartitions(0) == 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_primetest.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_primetest.py new file mode 100644 index 0000000000000000000000000000000000000000..8a56332941d9421bda4d6acc1e4b3406617cee2b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_primetest.py @@ -0,0 +1,235 @@ +from math import gcd + +from sympy.ntheory.generate import Sieve, sieve +from sympy.ntheory.primetest import (mr, _lucas_extrastrong_params, is_lucas_prp, is_square, + is_strong_lucas_prp, is_extra_strong_lucas_prp, + proth_test, isprime, is_euler_pseudoprime, + is_gaussian_prime, is_fermat_pseudoprime, is_euler_jacobi_pseudoprime, + MERSENNE_PRIME_EXPONENTS, _lucas_lehmer_primality_test, + is_mersenne_prime) + +from sympy.testing.pytest import slow, raises +from sympy.core.numbers import I, Float + + +def test_is_fermat_pseudoprime(): + assert is_fermat_pseudoprime(5, 1) + assert is_fermat_pseudoprime(9, 1) + + +def test_euler_pseudoprimes(): + assert is_euler_pseudoprime(13, 1) + assert is_euler_pseudoprime(15, 1) + assert is_euler_pseudoprime(17, 6) + assert is_euler_pseudoprime(101, 7) + assert is_euler_pseudoprime(1009, 10) + assert is_euler_pseudoprime(11287, 41) + + raises(ValueError, lambda: is_euler_pseudoprime(0, 4)) + raises(ValueError, lambda: is_euler_pseudoprime(3, 0)) + raises(ValueError, lambda: is_euler_pseudoprime(15, 6)) + + # A006970 + euler_prp = [341, 561, 1105, 1729, 1905, 2047, 2465, 3277, + 4033, 4681, 5461, 6601, 8321, 8481, 10261, 10585] + for p in euler_prp: + assert is_euler_pseudoprime(p, 2) + + # A048950 + euler_prp = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401, 8911, 10585, + 12403, 15457, 15841, 16531, 18721, 19345, 23521, 24661, 28009] + for p in euler_prp: + assert is_euler_pseudoprime(p, 3) + + # A033181 + absolute_euler_prp = [1729, 2465, 15841, 41041, 46657, 75361, + 162401, 172081, 399001, 449065, 488881] + for p in absolute_euler_prp: + for a in range(2, p): + if gcd(a, p) != 1: + continue + assert is_euler_pseudoprime(p, a) + + +def test_is_euler_jacobi_pseudoprime(): + assert is_euler_jacobi_pseudoprime(11, 1) + assert is_euler_jacobi_pseudoprime(15, 1) + + +def test_lucas_extrastrong_params(): + assert _lucas_extrastrong_params(3) == (5, 3, 1) + assert _lucas_extrastrong_params(5) == (12, 4, 1) + assert _lucas_extrastrong_params(7) == (5, 3, 1) + assert _lucas_extrastrong_params(9) == (0, 0, 0) + assert _lucas_extrastrong_params(11) == (21, 5, 1) + assert _lucas_extrastrong_params(59) == (32, 6, 1) + assert _lucas_extrastrong_params(479) == (117, 11, 1) + + +def test_is_extra_strong_lucas_prp(): + assert is_extra_strong_lucas_prp(4) == False + assert is_extra_strong_lucas_prp(989) == True + assert is_extra_strong_lucas_prp(10877) == True + assert is_extra_strong_lucas_prp(9) == False + assert is_extra_strong_lucas_prp(16) == False + assert is_extra_strong_lucas_prp(169) == False + +@slow +def test_prps(): + oddcomposites = [n for n in range(1, 10**5) if + n % 2 and not isprime(n)] + # A checksum would be better. + assert sum(oddcomposites) == 2045603465 + assert [n for n in oddcomposites if mr(n, [2])] == [ + 2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, + 52633, 65281, 74665, 80581, 85489, 88357, 90751] + assert [n for n in oddcomposites if mr(n, [3])] == [ + 121, 703, 1891, 3281, 8401, 8911, 10585, 12403, 16531, + 18721, 19345, 23521, 31621, 44287, 47197, 55969, 63139, + 74593, 79003, 82513, 87913, 88573, 97567] + assert [n for n in oddcomposites if mr(n, [325])] == [ + 9, 25, 27, 49, 65, 81, 325, 341, 343, 697, 1141, 2059, + 2149, 3097, 3537, 4033, 4681, 4941, 5833, 6517, 7987, 8911, + 12403, 12913, 15043, 16021, 20017, 22261, 23221, 24649, + 24929, 31841, 35371, 38503, 43213, 44173, 47197, 50041, + 55909, 56033, 58969, 59089, 61337, 65441, 68823, 72641, + 76793, 78409, 85879] + assert not any(mr(n, [9345883071009581737]) for n in oddcomposites) + assert [n for n in oddcomposites if is_lucas_prp(n)] == [ + 323, 377, 1159, 1829, 3827, 5459, 5777, 9071, 9179, 10877, + 11419, 11663, 13919, 14839, 16109, 16211, 18407, 18971, + 19043, 22499, 23407, 24569, 25199, 25877, 26069, 27323, + 32759, 34943, 35207, 39059, 39203, 39689, 40309, 44099, + 46979, 47879, 50183, 51983, 53663, 56279, 58519, 60377, + 63881, 69509, 72389, 73919, 75077, 77219, 79547, 79799, + 82983, 84419, 86063, 90287, 94667, 97019, 97439] + assert [n for n in oddcomposites if is_strong_lucas_prp(n)] == [ + 5459, 5777, 10877, 16109, 18971, 22499, 24569, 25199, 40309, + 58519, 75077, 97439] + assert [n for n in oddcomposites if is_extra_strong_lucas_prp(n) + ] == [ + 989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, + 72389, 73919, 75077] + + +def test_proth_test(): + # Proth number + A080075 = [3, 5, 9, 13, 17, 25, 33, 41, 49, 57, 65, + 81, 97, 113, 129, 145, 161, 177, 193] + # Proth prime + A080076 = [3, 5, 13, 17, 41, 97, 113, 193] + + for n in range(200): + if n in A080075: + assert proth_test(n) == (n in A080076) + else: + raises(ValueError, lambda: proth_test(n)) + + +def test_lucas_lehmer_primality_test(): + for p in sieve.primerange(3, 100): + assert _lucas_lehmer_primality_test(p) == (p in MERSENNE_PRIME_EXPONENTS) + + +def test_is_mersenne_prime(): + assert is_mersenne_prime(-3) is False + assert is_mersenne_prime(3) is True + assert is_mersenne_prime(10) is False + assert is_mersenne_prime(127) is True + assert is_mersenne_prime(511) is False + assert is_mersenne_prime(131071) is True + assert is_mersenne_prime(2147483647) is True + + +def test_isprime(): + s = Sieve() + s.extend(100000) + ps = set(s.primerange(2, 100001)) + for n in range(100001): + # if (n in ps) != isprime(n): print n + assert (n in ps) == isprime(n) + assert isprime(179424673) + assert isprime(20678048681) + assert isprime(1968188556461) + assert isprime(2614941710599) + assert isprime(65635624165761929287) + assert isprime(1162566711635022452267983) + assert isprime(77123077103005189615466924501) + assert isprime(3991617775553178702574451996736229) + assert isprime(273952953553395851092382714516720001799) + assert isprime(int(''' +531137992816767098689588206552468627329593117727031923199444138200403\ +559860852242739162502265229285668889329486246501015346579337652707239\ +409519978766587351943831270835393219031728127''')) + + # Some Mersenne primes + assert isprime(2**61 - 1) + assert isprime(2**89 - 1) + assert isprime(2**607 - 1) + # (but not all Mersenne's are primes + assert not isprime(2**601 - 1) + + # pseudoprimes + #------------- + # to some small bases + assert not isprime(2152302898747) + assert not isprime(3474749660383) + assert not isprime(341550071728321) + assert not isprime(3825123056546413051) + # passes the base set [2, 3, 7, 61, 24251] + assert not isprime(9188353522314541) + # large examples + assert not isprime(877777777777777777777777) + # conjectured psi_12 given at http://mathworld.wolfram.com/StrongPseudoprime.html + assert not isprime(318665857834031151167461) + # conjectured psi_17 given at http://mathworld.wolfram.com/StrongPseudoprime.html + assert not isprime(564132928021909221014087501701) + # Arnault's 1993 number; a factor of it is + # 400958216639499605418306452084546853005188166041132508774506\ + # 204738003217070119624271622319159721973358216316508535816696\ + # 9145233813917169287527980445796800452592031836601 + assert not isprime(int(''' +803837457453639491257079614341942108138837688287558145837488917522297\ +427376533365218650233616396004545791504202360320876656996676098728404\ +396540823292873879185086916685732826776177102938969773947016708230428\ +687109997439976544144845341155872450633409279022275296229414984230688\ +1685404326457534018329786111298960644845216191652872597534901''')) + # Arnault's 1995 number; can be factored as + # p1*(313*(p1 - 1) + 1)*(353*(p1 - 1) + 1) where p1 is + # 296744956686855105501541746429053327307719917998530433509950\ + # 755312768387531717701995942385964281211880336647542183455624\ + # 93168782883 + assert not isprime(int(''' +288714823805077121267142959713039399197760945927972270092651602419743\ +230379915273311632898314463922594197780311092934965557841894944174093\ +380561511397999942154241693397290542371100275104208013496673175515285\ +922696291677532547504444585610194940420003990443211677661994962953925\ +045269871932907037356403227370127845389912612030924484149472897688540\ +6024976768122077071687938121709811322297802059565867''')) + sieve.extend(3000) + assert isprime(2819) + assert not isprime(2931) + raises(ValueError, lambda: isprime(2.0)) + raises(ValueError, lambda: isprime(Float(2))) + + +def test_is_square(): + assert [i for i in range(25) if is_square(i)] == [0, 1, 4, 9, 16] + + # issue #17044 + assert not is_square(60 ** 3) + assert not is_square(60 ** 5) + assert not is_square(84 ** 7) + assert not is_square(105 ** 9) + assert not is_square(120 ** 3) + +def test_is_gaussianprime(): + assert is_gaussian_prime(7*I) + assert is_gaussian_prime(7) + assert is_gaussian_prime(2 + 3*I) + assert not is_gaussian_prime(2 + 2*I) + + +def test_issue_27145(): + #https://github.com/sympy/sympy/issues/27145 + assert [mr(i,[2,3,5,7]) for i in (1, 2, 6)] == [False, True, False] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_qs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_qs.py new file mode 100644 index 0000000000000000000000000000000000000000..16932dd61badf4a467e67fa52e0f473594fd057b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_qs.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import math +from sympy.core.random import _randint +from sympy.ntheory import qs, qs_factor +from sympy.ntheory.qs import SievePolynomial, _generate_factor_base, \ + _generate_polynomial, \ + _gen_sieve_array, _check_smoothness, _trial_division_stage, _find_factor +from sympy.testing.pytest import slow + + +@slow +def test_qs_1(): + assert qs(10009202107, 100, 10000) == {100043, 100049} + assert qs(211107295182713951054568361, 1000, 10000) == \ + {13791315212531, 15307263442931} + assert qs(980835832582657*990377764891511, 2000, 10000) == \ + {980835832582657, 990377764891511} + assert qs(18640889198609*20991129234731, 1000, 50000) == \ + {18640889198609, 20991129234731} + + +def test_qs_2() -> None: + n = 10009202107 + M = 50 + sieve_poly = SievePolynomial(10, 80, n) + assert sieve_poly.eval_v(10) == sieve_poly.eval_u(10)**2 - n == -10009169707 + assert sieve_poly.eval_v(5) == sieve_poly.eval_u(5)**2 - n == -10009185207 + + idx_1000, idx_5000, factor_base = _generate_factor_base(2000, n) + assert idx_1000 == 82 + assert [factor_base[i].prime for i in range(15)] == \ + [2, 3, 7, 11, 17, 19, 29, 31, 43, 59, 61, 67, 71, 73, 79] + assert [factor_base[i].tmem_p for i in range(15)] == \ + [1, 1, 3, 5, 3, 6, 6, 14, 1, 16, 24, 22, 18, 22, 15] + assert [factor_base[i].log_p for i in range(5)] == \ + [710, 1125, 1993, 2455, 2901] + + it = _generate_polynomial( + n, M, factor_base, idx_1000, idx_5000, _randint(0)) + g = next(it) + assert g.a == 1133107 + assert g.b == 682543 + assert [factor_base[i].soln1 for i in range(15)] == \ + [0, 0, 3, 7, 13, 0, 8, 19, 9, 43, 27, 25, 63, 29, 19] + assert [factor_base[i].soln2 for i in range(15)] == \ + [0, 1, 1, 3, 12, 16, 15, 6, 15, 1, 56, 55, 61, 58, 16] + assert [factor_base[i].b_ainv for i in range(5)] == \ + [[0, 0], [0, 2], [3, 0], [3, 9], [13, 13]] + + g_1 = next(it) + assert g_1.a == 1133107 + assert g_1.b == 136765 + + sieve_array = _gen_sieve_array(M, factor_base) + assert sieve_array[0:5] == [8424, 13603, 1835, 5335, 710] + + assert _check_smoothness(9645, factor_base) == (36028797018963972, 5) + assert _check_smoothness(210313, factor_base) == (20992, 1) + + partial_relations: dict[int, tuple[int, int]] = {} + smooth_relation, proper_factor = _trial_division_stage( + n, M, factor_base, sieve_array, sieve_poly, partial_relations, + ERROR_TERM=25*2**10) + + assert partial_relations == { + 8699: (440, -10009008507, 75557863761098695507973), + 166741: (490, -10008962007, 524341), + 131449: (530, -10008921207, 664613997892457936451903530140172325), + 6653: (550, -10008899607, 19342813113834066795307021) + } + assert [smooth_relation[i][0] for i in range(5)] == [ + -250, 1064469, 72819, 231957, 44167] + assert [smooth_relation[i][1] for i in range(5)] == [ + -10009139607, 1133094251961, 5302606761, 53804049849, 1950723889] + assert smooth_relation[0][2] == 89213869829863962596973701078031812362502145 + assert proper_factor == set() + + +def test_qs_3(): + N = 1817 + smooth_relations = [ + (2455024, 637, 8), + (-27993000, 81536, 10), + (11461840, 12544, 0), + (149, 20384, 10), + (-31138074, 19208, 2) + ] + assert next(_find_factor(N, smooth_relations, 4)) == 23 + + +def test_qs_4(): + N = 10007**2 * 10009 * 10037**3 * 10039 + for factor in qs(N, 1000, 2000): + assert N % factor == 0 + N //= factor + + +def test_qs_factor(): + assert qs_factor(1009 * 100003, 2000, 10000) == {1009: 1, 100003: 1} + n = 1009**2 * 2003**2*30011*400009 + factors = qs_factor(n, 2000, 10000) + assert len(factors) > 1 + assert math.prod(p**e for p, e in factors.items()) == n + + +def test_issue_27616(): + #https://github.com/sympy/sympy/issues/27616 + N = 9804659461513846513 + 1 + assert qs(N, 5000, 20000) is not None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_residue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_residue.py new file mode 100644 index 0000000000000000000000000000000000000000..4d530905f39b88d8d7cc0e861ac6eadb2fa6f98a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/ntheory/tests/test_residue.py @@ -0,0 +1,349 @@ +from collections import defaultdict +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.numbers import totient +from sympy.ntheory import n_order, is_primitive_root, is_quad_residue, \ + legendre_symbol, jacobi_symbol, primerange, sqrt_mod, \ + primitive_root, quadratic_residues, is_nthpow_residue, nthroot_mod, \ + sqrt_mod_iter, mobius, discrete_log, quadratic_congruence, \ + polynomial_congruence, sieve +from sympy.ntheory.residue_ntheory import _primitive_root_prime_iter, \ + _primitive_root_prime_power_iter, _primitive_root_prime_power2_iter, \ + _nthroot_mod_prime_power, _discrete_log_trial_mul, _discrete_log_shanks_steps, \ + _discrete_log_pollard_rho, _discrete_log_index_calculus, _discrete_log_pohlig_hellman, \ + _binomial_mod_prime_power, binomial_mod +from sympy.polys.domains import ZZ +from sympy.testing.pytest import raises +from sympy.core.random import randint, choice + + +def test_residue(): + assert n_order(2, 13) == 12 + assert [n_order(a, 7) for a in range(1, 7)] == \ + [1, 3, 6, 3, 6, 2] + assert n_order(5, 17) == 16 + assert n_order(17, 11) == n_order(6, 11) + assert n_order(101, 119) == 6 + assert n_order(11, (10**50 + 151)**2) == 10000000000000000000000000000000000000000000000030100000000000000000000000000000000000000000000022650 + raises(ValueError, lambda: n_order(6, 9)) + + assert is_primitive_root(2, 7) is False + assert is_primitive_root(3, 8) is False + assert is_primitive_root(11, 14) is False + assert is_primitive_root(12, 17) == is_primitive_root(29, 17) + raises(ValueError, lambda: is_primitive_root(3, 6)) + + for p in primerange(3, 100): + li = list(_primitive_root_prime_iter(p)) + assert li[0] == min(li) + for g in li: + assert n_order(g, p) == p - 1 + assert len(li) == totient(totient(p)) + for e in range(1, 4): + li_power = list(_primitive_root_prime_power_iter(p, e)) + li_power2 = list(_primitive_root_prime_power2_iter(p, e)) + assert len(li_power) == len(li_power2) == totient(totient(p**e)) + assert primitive_root(97) == 5 + assert n_order(primitive_root(97, False), 97) == totient(97) + assert primitive_root(97**2) == 5 + assert n_order(primitive_root(97**2, False), 97**2) == totient(97**2) + assert primitive_root(40487) == 5 + assert n_order(primitive_root(40487, False), 40487) == totient(40487) + # note that primitive_root(40487) + 40487 = 40492 is a primitive root + # of 40487**2, but it is not the smallest + assert primitive_root(40487**2) == 10 + assert n_order(primitive_root(40487**2, False), 40487**2) == totient(40487**2) + assert primitive_root(82) == 7 + assert n_order(primitive_root(82, False), 82) == totient(82) + p = 10**50 + 151 + assert primitive_root(p) == 11 + assert n_order(primitive_root(p, False), p) == totient(p) + assert primitive_root(2*p) == 11 + assert n_order(primitive_root(2*p, False), 2*p) == totient(2*p) + assert primitive_root(p**2) == 11 + assert n_order(primitive_root(p**2, False), p**2) == totient(p**2) + assert primitive_root(4 * 11) is None and primitive_root(4 * 11, False) is None + assert primitive_root(15) is None and primitive_root(15, False) is None + raises(ValueError, lambda: primitive_root(-3)) + + assert is_quad_residue(3, 7) is False + assert is_quad_residue(10, 13) is True + assert is_quad_residue(12364, 139) == is_quad_residue(12364 % 139, 139) + assert is_quad_residue(207, 251) is True + assert is_quad_residue(0, 1) is True + assert is_quad_residue(1, 1) is True + assert is_quad_residue(0, 2) == is_quad_residue(1, 2) is True + assert is_quad_residue(1, 4) is True + assert is_quad_residue(2, 27) is False + assert is_quad_residue(13122380800, 13604889600) is True + assert [j for j in range(14) if is_quad_residue(j, 14)] == \ + [0, 1, 2, 4, 7, 8, 9, 11] + raises(ValueError, lambda: is_quad_residue(1.1, 2)) + raises(ValueError, lambda: is_quad_residue(2, 0)) + + assert quadratic_residues(S.One) == [0] + assert quadratic_residues(1) == [0] + assert quadratic_residues(12) == [0, 1, 4, 9] + assert quadratic_residues(13) == [0, 1, 3, 4, 9, 10, 12] + assert [len(quadratic_residues(i)) for i in range(1, 20)] == \ + [1, 2, 2, 2, 3, 4, 4, 3, 4, 6, 6, 4, 7, 8, 6, 4, 9, 8, 10] + + assert list(sqrt_mod_iter(6, 2)) == [0] + assert sqrt_mod(3, 13) == 4 + assert sqrt_mod(3, -13) == 4 + assert sqrt_mod(6, 23) == 11 + assert sqrt_mod(345, 690) == 345 + assert sqrt_mod(67, 101) == None + assert sqrt_mod(1020, 104729) == None + + for p in range(3, 100): + d = defaultdict(list) + for i in range(p): + d[pow(i, 2, p)].append(i) + for i in range(1, p): + it = sqrt_mod_iter(i, p) + v = sqrt_mod(i, p, True) + if v: + v = sorted(v) + assert d[i] == v + else: + assert not d[i] + + assert sqrt_mod(9, 27, True) == [3, 6, 12, 15, 21, 24] + assert sqrt_mod(9, 81, True) == [3, 24, 30, 51, 57, 78] + assert sqrt_mod(9, 3**5, True) == [3, 78, 84, 159, 165, 240] + assert sqrt_mod(81, 3**4, True) == [0, 9, 18, 27, 36, 45, 54, 63, 72] + assert sqrt_mod(81, 3**5, True) == [9, 18, 36, 45, 63, 72, 90, 99, 117,\ + 126, 144, 153, 171, 180, 198, 207, 225, 234] + assert sqrt_mod(81, 3**6, True) == [9, 72, 90, 153, 171, 234, 252, 315,\ + 333, 396, 414, 477, 495, 558, 576, 639, 657, 720] + assert sqrt_mod(81, 3**7, True) == [9, 234, 252, 477, 495, 720, 738, 963,\ + 981, 1206, 1224, 1449, 1467, 1692, 1710, 1935, 1953, 2178] + + for a, p in [(26214400, 32768000000), (26214400, 16384000000), + (262144, 1048576), (87169610025, 163443018796875), + (22315420166400, 167365651248000000)]: + assert pow(sqrt_mod(a, p), 2, p) == a + + n = 70 + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+2) + it = sqrt_mod_iter(a, p) + for i in range(10): + assert pow(next(it), 2, p) == a + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+3) + it = sqrt_mod_iter(a, p) + for i in range(2): + assert pow(next(it), 2, p) == a + n = 100 + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+1) + it = sqrt_mod_iter(a, p) + for i in range(2): + assert pow(next(it), 2, p) == a + + assert type(next(sqrt_mod_iter(9, 27))) is int + assert type(next(sqrt_mod_iter(9, 27, ZZ))) is type(ZZ(1)) + assert type(next(sqrt_mod_iter(1, 7, ZZ))) is type(ZZ(1)) + + assert is_nthpow_residue(2, 1, 5) + + #issue 10816 + assert is_nthpow_residue(1, 0, 1) is False + assert is_nthpow_residue(1, 0, 2) is True + assert is_nthpow_residue(3, 0, 2) is True + assert is_nthpow_residue(0, 1, 8) is True + assert is_nthpow_residue(2, 3, 2) is True + assert is_nthpow_residue(2, 3, 9) is False + assert is_nthpow_residue(3, 5, 30) is True + assert is_nthpow_residue(21, 11, 20) is True + assert is_nthpow_residue(7, 10, 20) is False + assert is_nthpow_residue(5, 10, 20) is True + assert is_nthpow_residue(3, 10, 48) is False + assert is_nthpow_residue(1, 10, 40) is True + assert is_nthpow_residue(3, 10, 24) is False + assert is_nthpow_residue(1, 10, 24) is True + assert is_nthpow_residue(3, 10, 24) is False + assert is_nthpow_residue(2, 10, 48) is False + assert is_nthpow_residue(81, 3, 972) is False + assert is_nthpow_residue(243, 5, 5103) is True + assert is_nthpow_residue(243, 3, 1240029) is False + assert is_nthpow_residue(36010, 8, 87382) is True + assert is_nthpow_residue(28552, 6, 2218) is True + assert is_nthpow_residue(92712, 9, 50026) is True + x = {pow(i, 56, 1024) for i in range(1024)} + assert {a for a in range(1024) if is_nthpow_residue(a, 56, 1024)} == x + x = { pow(i, 256, 2048) for i in range(2048)} + assert {a for a in range(2048) if is_nthpow_residue(a, 256, 2048)} == x + x = { pow(i, 11, 324000) for i in range(1000)} + assert [ is_nthpow_residue(a, 11, 324000) for a in x] + x = { pow(i, 17, 22217575536) for i in range(1000)} + assert [ is_nthpow_residue(a, 17, 22217575536) for a in x] + assert is_nthpow_residue(676, 3, 5364) + assert is_nthpow_residue(9, 12, 36) + assert is_nthpow_residue(32, 10, 41) + assert is_nthpow_residue(4, 2, 64) + assert is_nthpow_residue(31, 4, 41) + assert not is_nthpow_residue(2, 2, 5) + assert is_nthpow_residue(8547, 12, 10007) + assert is_nthpow_residue(Dummy(even=True) + 3, 3, 2) == True + # _nthroot_mod_prime_power + for p in primerange(2, 10): + for a in range(3): + for n in range(3, 5): + ans = _nthroot_mod_prime_power(a, n, p, 1) + assert isinstance(ans, list) + if len(ans) == 0: + for b in range(p): + assert pow(b, n, p) != a % p + for k in range(2, 10): + assert _nthroot_mod_prime_power(a, n, p, k) == [] + else: + for b in range(p): + pred = pow(b, n, p) == a % p + assert not(pred ^ (b in ans)) + for k in range(2, 10): + ans = _nthroot_mod_prime_power(a, n, p, k) + if not ans: + break + for b in ans: + assert pow(b, n , p**k) == a + + assert nthroot_mod(Dummy(odd=True), 3, 2) == 1 + assert nthroot_mod(29, 31, 74) == 45 + assert nthroot_mod(1801, 11, 2663) == 44 + for a, q, p in [(51922, 2, 203017), (43, 3, 109), (1801, 11, 2663), + (26118163, 1303, 33333347), (1499, 7, 2663), (595, 6, 2663), + (1714, 12, 2663), (28477, 9, 33343)]: + r = nthroot_mod(a, q, p) + assert pow(r, q, p) == a + assert nthroot_mod(11, 3, 109) is None + assert nthroot_mod(16, 5, 36, True) == [4, 22] + assert nthroot_mod(9, 16, 36, True) == [3, 9, 15, 21, 27, 33] + assert nthroot_mod(4, 3, 3249000) is None + assert nthroot_mod(36010, 8, 87382, True) == [40208, 47174] + assert nthroot_mod(0, 12, 37, True) == [0] + assert nthroot_mod(0, 7, 100, True) == [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + assert nthroot_mod(4, 4, 27, True) == [5, 22] + assert nthroot_mod(4, 4, 121, True) == [19, 102] + assert nthroot_mod(2, 3, 7, True) == [] + for p in range(1, 20): + for a in range(p): + for n in range(1, p): + ans = nthroot_mod(a, n, p, True) + assert isinstance(ans, list) + for b in range(p): + pred = pow(b, n, p) == a + assert not(pred ^ (b in ans)) + ans2 = nthroot_mod(a, n, p, False) + if ans2 is None: + assert ans == [] + else: + assert ans2 in ans + + x = Symbol('x', positive=True) + i = Symbol('i', integer=True) + assert _discrete_log_trial_mul(587, 2**7, 2) == 7 + assert _discrete_log_trial_mul(941, 7**18, 7) == 18 + assert _discrete_log_trial_mul(389, 3**81, 3) == 81 + assert _discrete_log_trial_mul(191, 19**123, 19) == 123 + assert _discrete_log_shanks_steps(442879, 7**2, 7) == 2 + assert _discrete_log_shanks_steps(874323, 5**19, 5) == 19 + assert _discrete_log_shanks_steps(6876342, 7**71, 7) == 71 + assert _discrete_log_shanks_steps(2456747, 3**321, 3) == 321 + assert _discrete_log_pollard_rho(6013199, 2**6, 2, rseed=0) == 6 + assert _discrete_log_pollard_rho(6138719, 2**19, 2, rseed=0) == 19 + assert _discrete_log_pollard_rho(36721943, 2**40, 2, rseed=0) == 40 + assert _discrete_log_pollard_rho(24567899, 3**333, 3, rseed=0) == 333 + raises(ValueError, lambda: _discrete_log_pollard_rho(11, 7, 31, rseed=0)) + raises(ValueError, lambda: _discrete_log_pollard_rho(227, 3**7, 5, rseed=0)) + assert _discrete_log_index_calculus(983, 948, 2, 491) == 183 + assert _discrete_log_index_calculus(633383, 21794, 2, 316691) == 68048 + assert _discrete_log_index_calculus(941762639, 68822582, 2, 470881319) == 338029275 + assert _discrete_log_index_calculus(999231337607, 888188918786, 2, 499615668803) == 142811376514 + assert _discrete_log_index_calculus(47747730623, 19410045286, 43425105668, 645239603) == 590504662 + assert _discrete_log_pohlig_hellman(98376431, 11**9, 11) == 9 + assert _discrete_log_pohlig_hellman(78723213, 11**31, 11) == 31 + assert _discrete_log_pohlig_hellman(32942478, 11**98, 11) == 98 + assert _discrete_log_pohlig_hellman(14789363, 11**444, 11) == 444 + assert discrete_log(1, 0, 2) == 0 + raises(ValueError, lambda: discrete_log(-4, 1, 3)) + raises(ValueError, lambda: discrete_log(10, 3, 2)) + assert discrete_log(587, 2**9, 2) == 9 + assert discrete_log(2456747, 3**51, 3) == 51 + assert discrete_log(32942478, 11**127, 11) == 127 + assert discrete_log(432751500361, 7**324, 7) == 324 + assert discrete_log(265390227570863,184500076053622, 2) == 17835221372061 + assert discrete_log(22708823198678103974314518195029102158525052496759285596453269189798311427475159776411276642277139650833937, + 17463946429475485293747680247507700244427944625055089103624311227422110546803452417458985046168310373075327, + 123456) == 2068031853682195777930683306640554533145512201725884603914601918777510185469769997054750835368413389728895 + args = 5779, 3528, 6215 + assert discrete_log(*args) == 687 + assert discrete_log(*Tuple(*args)) == 687 + assert quadratic_congruence(400, 85, 125, 1600) == [295, 615, 935, 1255, 1575] + assert quadratic_congruence(3, 6, 5, 25) == [3, 20] + assert quadratic_congruence(120, 80, 175, 500) == [] + assert quadratic_congruence(15, 14, 7, 2) == [1] + assert quadratic_congruence(8, 15, 7, 29) == [10, 28] + assert quadratic_congruence(160, 200, 300, 461) == [144, 431] + assert quadratic_congruence(100000, 123456, 7415263, 48112959837082048697) == [30417843635344493501, 36001135160550533083] + assert quadratic_congruence(65, 121, 72, 277) == [249, 252] + assert quadratic_congruence(5, 10, 14, 2) == [0] + assert quadratic_congruence(10, 17, 19, 2) == [1] + assert quadratic_congruence(10, 14, 20, 2) == [0, 1] + assert quadratic_congruence(2**48-7, 2**48-1, 4, 2**48) == [8249717183797, 31960993774868] + assert polynomial_congruence(6*x**5 + 10*x**4 + 5*x**3 + x**2 + x + 1, + 972000) == [220999, 242999, 463999, 485999, 706999, 728999, 949999, 971999] + + assert polynomial_congruence(x**3 - 10*x**2 + 12*x - 82, 33075) == [30287] + assert polynomial_congruence(x**2 + x + 47, 2401) == [785, 1615] + assert polynomial_congruence(10*x**2 + 14*x + 20, 2) == [0, 1] + assert polynomial_congruence(x**3 + 3, 16) == [5] + assert polynomial_congruence(65*x**2 + 121*x + 72, 277) == [249, 252] + assert polynomial_congruence(x**4 - 4, 27) == [5, 22] + assert polynomial_congruence(35*x**3 - 6*x**2 - 567*x + 2308, 148225) == [86957, + 111157, 122531, 146731] + assert polynomial_congruence(x**16 - 9, 36) == [3, 9, 15, 21, 27, 33] + assert polynomial_congruence(x**6 - 2*x**5 - 35, 6125) == [3257] + raises(ValueError, lambda: polynomial_congruence(x**x, 6125)) + raises(ValueError, lambda: polynomial_congruence(x**i, 6125)) + raises(ValueError, lambda: polynomial_congruence(0.1*x**2 + 6, 100)) + + assert binomial_mod(-1, 1, 10) == 0 + assert binomial_mod(1, -1, 10) == 0 + raises(ValueError, lambda: binomial_mod(2, 1, -1)) + assert binomial_mod(51, 10, 10) == 0 + assert binomial_mod(10**3, 500, 3**6) == 567 + assert binomial_mod(10**18 - 1, 123456789, 4) == 0 + assert binomial_mod(10**18, 10**12, (10**5 + 3)**2) == 3744312326 + + +def test_binomial_p_pow(): + n, binomials, binomial = 1000, [1], 1 + for i in range(1, n + 1): + binomial *= n - i + 1 + binomial //= i + binomials.append(binomial) + + # Test powers of two, which the algorithm treats slightly differently + trials_2 = 100 + for _ in range(trials_2): + m, power = randint(0, n), randint(1, 20) + assert _binomial_mod_prime_power(n, m, 2, power) == binomials[m] % 2**power + + # Test against other prime powers + primes = list(sieve.primerange(2*n)) + trials = 1000 + for _ in range(trials): + m, prime, power = randint(0, n), choice(primes), randint(1, 10) + assert _binomial_mod_prime_power(n, m, prime, power) == binomials[m] % prime**power + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert mobius(3) == -1 + with warns_deprecated_sympy(): + assert legendre_symbol(2, 3) == -1 + with warns_deprecated_sympy(): + assert jacobi_symbol(2, 3) == -1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b39d031bca26bc599eb9eb0e12dfe48f7e6db174 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__init__.py @@ -0,0 +1,4 @@ +"""Used for translating a string into a SymPy expression. """ +__all__ = ['parse_expr'] + +from .sympy_parser import parse_expr diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c609eb39ce777e1514cde282a8c0a24464ff3fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ad8fc81c69d21da27fdc68e3e570125b731c690 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/mathematica.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/mathematica.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec7a37c18dba61c21ea7e1413e63a701c9aa0eee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/mathematica.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/maxima.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/maxima.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a66b7138cdf3b4c6665069ad8c50aeebf52a6e1c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/maxima.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..780f8470ebb995ebd3ba08c2db52905d1efc1462 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..419b36dc5ce7ca32381716d0ce574d413f25a01a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/ast_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..95a773d5bec6e130810b7b7925fdff57270aec17 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/ast_parser.py @@ -0,0 +1,79 @@ +""" +This module implements the functionality to take any Python expression as a +string and fix all numbers and other things before evaluating it, +thus + +1/2 + +returns + +Integer(1)/Integer(2) + +We use the ast module for this. It is well documented at docs.python.org. + +Some tips to understand how this works: use dump() to get a nice +representation of any node. Then write a string of what you want to get, +e.g. "Integer(1)", parse it, dump it and you'll see that you need to do +"Call(Name('Integer', Load()), [node], [], None, None)". You do not need +to bother with lineno and col_offset, just call fix_missing_locations() +before returning the node. +""" + +from sympy.core.basic import Basic +from sympy.core.sympify import SympifyError + +from ast import parse, NodeTransformer, Call, Name, Load, \ + fix_missing_locations, Constant, Tuple + +class Transform(NodeTransformer): + + def __init__(self, local_dict, global_dict): + NodeTransformer.__init__(self) + self.local_dict = local_dict + self.global_dict = global_dict + + def visit_Constant(self, node): + if isinstance(node.value, int): + return fix_missing_locations(Call(func=Name('Integer', Load()), + args=[node], keywords=[])) + elif isinstance(node.value, float): + return fix_missing_locations(Call(func=Name('Float', Load()), + args=[node], keywords=[])) + return node + + def visit_Name(self, node): + if node.id in self.local_dict: + return node + elif node.id in self.global_dict: + name_obj = self.global_dict[node.id] + + if isinstance(name_obj, (Basic, type)) or callable(name_obj): + return node + elif node.id in ['True', 'False']: + return node + return fix_missing_locations(Call(func=Name('Symbol', Load()), + args=[Constant(node.id)], keywords=[])) + + def visit_Lambda(self, node): + args = [self.visit(arg) for arg in node.args.args] + body = self.visit(node.body) + n = Call(func=Name('Lambda', Load()), + args=[Tuple(args, Load()), body], keywords=[]) + return fix_missing_locations(n) + +def parse_expr(s, local_dict): + """ + Converts the string "s" to a SymPy expression, in local_dict. + + It converts all numbers to Integers before feeding it to Python and + automatically creates Symbols. + """ + global_dict = {} + exec('from sympy import *', global_dict) + try: + a = parse(s.strip(), mode="eval") + except SyntaxError: + raise SympifyError("Cannot parse %s." % repr(s)) + a = Transform(local_dict, global_dict).visit(a) + e = compile(a, "", "eval") + return eval(e, global_dict, local_dict) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/Autolev.g4 b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/Autolev.g4 new file mode 100644 index 0000000000000000000000000000000000000000..94feea5fa4f49e9d1054eca2cd60c996aebff7c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/Autolev.g4 @@ -0,0 +1,118 @@ +grammar Autolev; + +options { + language = Python3; +} + +prog: stat+; + +stat: varDecl + | functionCall + | codeCommands + | massDecl + | inertiaDecl + | assignment + | settings + ; + +assignment: vec equals expr #vecAssign + | ID '[' index ']' equals expr #indexAssign + | ID diff? equals expr #regularAssign; + +equals: ('='|'+='|'-='|':='|'*='|'/='|'^='); + +index: expr (',' expr)* ; + +diff: ('\'')+; + +functionCall: ID '(' (expr (',' expr)*)? ')' + | (Mass|Inertia) '(' (ID (',' ID)*)? ')'; + +varDecl: varType varDecl2 (',' varDecl2)*; + +varType: Newtonian|Frames|Bodies|Particles|Points|Constants + | Specifieds|Imaginary|Variables ('\'')*|MotionVariables ('\'')*; + +varDecl2: ID ('{' INT ',' INT '}')? (('{' INT ':' INT (',' INT ':' INT)* '}'))? ('{' INT '}')? ('+'|'-')? ('\'')* ('=' expr)?; + +ranges: ('{' INT ':' INT (',' INT ':' INT)* '}'); + +massDecl: Mass massDecl2 (',' massDecl2)*; + +massDecl2: ID '=' expr; + +inertiaDecl: Inertia ID ('(' ID ')')? (',' expr)+; + +matrix: '[' expr ((','|';') expr)* ']'; +matrixInOutput: (ID (ID '=' (FLOAT|INT)?))|FLOAT|INT; + +codeCommands: units + | inputs + | outputs + | codegen + | commands; + +settings: ID (EXP|ID|FLOAT|INT)?; + +units: UnitSystem ID (',' ID)*; +inputs: Input inputs2 (',' inputs2)*; +id_diff: ID diff?; +inputs2: id_diff '=' expr expr?; +outputs: Output outputs2 (',' outputs2)*; +outputs2: expr expr?; +codegen: ID functionCall ('['matrixInOutput (',' matrixInOutput)*']')? ID'.'ID; + +commands: Save ID'.'ID + | Encode ID (',' ID)*; + +vec: ID ('>')+ + | '0>' + | '1>>'; + +expr: expr '^' expr # Exponent + | expr ('*'|'/') expr # MulDiv + | expr ('+'|'-') expr # AddSub + | EXP # exp + | '-' expr # negativeOne + | FLOAT # float + | INT # int + | ID('\'')* # id + | vec # VectorOrDyadic + | ID '['expr (',' expr)* ']' # Indexing + | functionCall # function + | matrix # matrices + | '(' expr ')' # parens + | expr '=' expr # idEqualsExpr + | expr ':' expr # colon + | ID? ranges ('\'')* # rangess + ; + +// These are to take care of the case insensitivity of Autolev. +Mass: ('M'|'m')('A'|'a')('S'|'s')('S'|'s'); +Inertia: ('I'|'i')('N'|'n')('E'|'e')('R'|'r')('T'|'t')('I'|'i')('A'|'a'); +Input: ('I'|'i')('N'|'n')('P'|'p')('U'|'u')('T'|'t')('S'|'s')?; +Output: ('O'|'o')('U'|'u')('T'|'t')('P'|'p')('U'|'u')('T'|'t'); +Save: ('S'|'s')('A'|'a')('V'|'v')('E'|'e'); +UnitSystem: ('U'|'u')('N'|'n')('I'|'i')('T'|'t')('S'|'s')('Y'|'y')('S'|'s')('T'|'t')('E'|'e')('M'|'m'); +Encode: ('E'|'e')('N'|'n')('C'|'c')('O'|'o')('D'|'d')('E'|'e'); +Newtonian: ('N'|'n')('E'|'e')('W'|'w')('T'|'t')('O'|'o')('N'|'n')('I'|'i')('A'|'a')('N'|'n'); +Frames: ('F'|'f')('R'|'r')('A'|'a')('M'|'m')('E'|'e')('S'|'s')?; +Bodies: ('B'|'b')('O'|'o')('D'|'d')('I'|'i')('E'|'e')('S'|'s')?; +Particles: ('P'|'p')('A'|'a')('R'|'r')('T'|'t')('I'|'i')('C'|'c')('L'|'l')('E'|'e')('S'|'s')?; +Points: ('P'|'p')('O'|'o')('I'|'i')('N'|'n')('T'|'t')('S'|'s')?; +Constants: ('C'|'c')('O'|'o')('N'|'n')('S'|'s')('T'|'t')('A'|'a')('N'|'n')('T'|'t')('S'|'s')?; +Specifieds: ('S'|'s')('P'|'p')('E'|'e')('C'|'c')('I'|'i')('F'|'f')('I'|'i')('E'|'e')('D'|'d')('S'|'s')?; +Imaginary: ('I'|'i')('M'|'m')('A'|'a')('G'|'g')('I'|'i')('N'|'n')('A'|'a')('R'|'r')('Y'|'y'); +Variables: ('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; +MotionVariables: ('M'|'m')('O'|'o')('T'|'t')('I'|'i')('O'|'o')('N'|'n')('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; + +fragment DIFF: ('\'')*; +fragment DIGIT: [0-9]; +INT: [0-9]+ ; // match integers +FLOAT: DIGIT+ '.' DIGIT* + | '.' DIGIT+; +EXP: FLOAT 'E' INT +| FLOAT 'E' '-' INT; +LINE_COMMENT : '%' .*? '\r'? '\n' -> skip ; +ID: [a-zA-Z][a-zA-Z0-9_]*; +WS: [ \t\r\n&]+ -> skip ; // toss out whitespace diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81bb83325d68e1c11b43a1df5ec56846367e9f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__init__.py @@ -0,0 +1,97 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('antlr4',)) +def parse_autolev(autolev_code, include_numeric=False): + """Parses Autolev code (version 4.1) to SymPy code. + + Parameters + ========= + autolev_code : Can be an str or any object with a readlines() method (such as a file handle or StringIO). + include_numeric : boolean, optional + If True NumPy, PyDy, or other numeric code is included for numeric evaluation lines in the Autolev code. + + Returns + ======= + sympy_code : str + Equivalent SymPy and/or numpy/pydy code as the input code. + + + Example (Double Pendulum) + ========================= + >>> my_al_text = ("MOTIONVARIABLES' Q{2}', U{2}'", + ... "CONSTANTS L,M,G", + ... "NEWTONIAN N", + ... "FRAMES A,B", + ... "SIMPROT(N, A, 3, Q1)", + ... "SIMPROT(N, B, 3, Q2)", + ... "W_A_N>=U1*N3>", + ... "W_B_N>=U2*N3>", + ... "POINT O", + ... "PARTICLES P,R", + ... "P_O_P> = L*A1>", + ... "P_P_R> = L*B1>", + ... "V_O_N> = 0>", + ... "V2PTS(N, A, O, P)", + ... "V2PTS(N, B, P, R)", + ... "MASS P=M, R=M", + ... "Q1' = U1", + ... "Q2' = U2", + ... "GRAVITY(G*N1>)", + ... "ZERO = FR() + FRSTAR()", + ... "KANE()", + ... "INPUT M=1,G=9.81,L=1", + ... "INPUT Q1=.1,Q2=.2,U1=0,U2=0", + ... "INPUT TFINAL=10, INTEGSTP=.01", + ... "CODE DYNAMICS() some_filename.c") + >>> my_al_text = '\\n'.join(my_al_text) + >>> from sympy.parsing.autolev import parse_autolev + >>> print(parse_autolev(my_al_text, include_numeric=True)) + import sympy.physics.mechanics as _me + import sympy as _sm + import math as m + import numpy as _np + + q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') + q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) + l, m, g = _sm.symbols('l m g', real=True) + frame_n = _me.ReferenceFrame('n') + frame_a = _me.ReferenceFrame('a') + frame_b = _me.ReferenceFrame('b') + frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) + frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) + frame_a.set_ang_vel(frame_n, u1*frame_n.z) + frame_b.set_ang_vel(frame_n, u2*frame_n.z) + point_o = _me.Point('o') + particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) + particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) + particle_p.point.set_pos(point_o, l*frame_a.x) + particle_r.point.set_pos(particle_p.point, l*frame_b.x) + point_o.set_vel(frame_n, 0) + particle_p.point.v2pt_theory(point_o,frame_n,frame_a) + particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) + particle_p.mass = m + particle_r.mass = m + force_p = particle_p.mass*(g*frame_n.x) + force_r = particle_r.mass*(g*frame_n.x) + kd_eqs = [q1_d - u1, q2_d - u2] + forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] + kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) + fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) + zero = fr+frstar + from pydy.system import System + sys = System(kane, constants = {l:1, m:1, g:9.81}, + specifieds={}, + initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, + times = _np.linspace(0.0, 10, 10/.01)) + + y=sys.integrate() + + """ + + _autolev = import_module( + 'sympy.parsing.autolev._parse_autolev_antlr', + import_kwargs={'fromlist': ['X']}) + + if _autolev is not None: + return _autolev.parse_autolev(autolev_code, include_numeric) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3a7bd5d49e504db8e9722d31d965f8a41021817 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acd359792c985eb87ab6d1bbf1f12f9fe5c1f7a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9d28e21d274460ee8743024211a4bb097d8a71 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b71e9f51fd455558a9eb42dc840604c6c96e4b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__init__.py @@ -0,0 +1,5 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c0dd63dad5b307819bcc890d1f8e74271ef1bc2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6473a04c28e09ac791247d331d06e5664a4bfd7d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d5b1108eca091a1a35de7ed6efd84bdd41b73a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b3b1d27ade809a63d9fd328a1572c17625443e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py @@ -0,0 +1,253 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,49,393,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,1,0,1,0,1,1, + 1,1,1,2,1,2,1,3,1,3,1,3,1,4,1,4,1,4,1,5,1,5,1,5,1,6,1,6,1,6,1,7, + 1,7,1,7,1,8,1,8,1,8,1,9,1,9,1,10,1,10,1,11,1,11,1,12,1,12,1,13,1, + 13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18,1,18,1,19,1,19,1, + 20,1,20,1,21,1,21,1,21,1,22,1,22,1,22,1,22,1,23,1,23,1,24,1,24,1, + 25,1,25,1,26,1,26,1,26,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1, + 27,1,27,1,28,1,28,1,28,1,28,1,28,1,28,3,28,184,8,28,1,29,1,29,1, + 29,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,31,1,31,1,31,1, + 31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,34,1, + 34,1,34,1,34,1,34,1,34,3,34,232,8,34,1,35,1,35,1,35,1,35,1,35,1, + 35,3,35,240,8,35,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,3, + 36,251,8,36,1,37,1,37,1,37,1,37,1,37,1,37,3,37,259,8,37,1,38,1,38, + 1,38,1,38,1,38,1,38,1,38,1,38,1,38,3,38,270,8,38,1,39,1,39,1,39, + 1,39,1,39,1,39,1,39,1,39,1,39,1,39,3,39,282,8,39,1,40,1,40,1,40, + 1,40,1,40,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41,1,41,1,41, + 1,41,1,41,1,41,3,41,303,8,41,1,42,1,42,1,42,1,42,1,42,1,42,1,42, + 1,42,1,42,1,42,1,42,1,42,1,42,1,42,1,42,3,42,320,8,42,1,43,5,43, + 323,8,43,10,43,12,43,326,9,43,1,44,1,44,1,45,4,45,331,8,45,11,45, + 12,45,332,1,46,4,46,336,8,46,11,46,12,46,337,1,46,1,46,5,46,342, + 8,46,10,46,12,46,345,9,46,1,46,1,46,4,46,349,8,46,11,46,12,46,350, + 3,46,353,8,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,3,47, + 364,8,47,1,48,1,48,5,48,368,8,48,10,48,12,48,371,9,48,1,48,3,48, + 374,8,48,1,48,1,48,1,48,1,48,1,49,1,49,5,49,382,8,49,10,49,12,49, + 385,9,49,1,50,4,50,388,8,50,11,50,12,50,389,1,50,1,50,1,369,0,51, + 1,1,3,2,5,3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13, + 27,14,29,15,31,16,33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24, + 49,25,51,26,53,27,55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35, + 71,36,73,37,75,38,77,39,79,40,81,41,83,42,85,43,87,0,89,0,91,44, + 93,45,95,46,97,47,99,48,101,49,1,0,24,2,0,77,77,109,109,2,0,65,65, + 97,97,2,0,83,83,115,115,2,0,73,73,105,105,2,0,78,78,110,110,2,0, + 69,69,101,101,2,0,82,82,114,114,2,0,84,84,116,116,2,0,80,80,112, + 112,2,0,85,85,117,117,2,0,79,79,111,111,2,0,86,86,118,118,2,0,89, + 89,121,121,2,0,67,67,99,99,2,0,68,68,100,100,2,0,87,87,119,119,2, + 0,70,70,102,102,2,0,66,66,98,98,2,0,76,76,108,108,2,0,71,71,103, + 103,1,0,48,57,2,0,65,90,97,122,4,0,48,57,65,90,95,95,97,122,4,0, + 9,10,13,13,32,32,38,38,410,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0, + 7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17, + 1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27, + 1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37, + 1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0,47, + 1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0,57, + 1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0,67, + 1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0,77, + 1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0,91, + 1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0,97,1,0,0,0,0,99,1,0,0,0,0,101, + 1,0,0,0,1,103,1,0,0,0,3,105,1,0,0,0,5,107,1,0,0,0,7,109,1,0,0,0, + 9,112,1,0,0,0,11,115,1,0,0,0,13,118,1,0,0,0,15,121,1,0,0,0,17,124, + 1,0,0,0,19,127,1,0,0,0,21,129,1,0,0,0,23,131,1,0,0,0,25,133,1,0, + 0,0,27,135,1,0,0,0,29,137,1,0,0,0,31,139,1,0,0,0,33,141,1,0,0,0, + 35,143,1,0,0,0,37,145,1,0,0,0,39,147,1,0,0,0,41,149,1,0,0,0,43,151, + 1,0,0,0,45,154,1,0,0,0,47,158,1,0,0,0,49,160,1,0,0,0,51,162,1,0, + 0,0,53,164,1,0,0,0,55,169,1,0,0,0,57,177,1,0,0,0,59,185,1,0,0,0, + 61,192,1,0,0,0,63,197,1,0,0,0,65,208,1,0,0,0,67,215,1,0,0,0,69,225, + 1,0,0,0,71,233,1,0,0,0,73,241,1,0,0,0,75,252,1,0,0,0,77,260,1,0, + 0,0,79,271,1,0,0,0,81,283,1,0,0,0,83,293,1,0,0,0,85,304,1,0,0,0, + 87,324,1,0,0,0,89,327,1,0,0,0,91,330,1,0,0,0,93,352,1,0,0,0,95,363, + 1,0,0,0,97,365,1,0,0,0,99,379,1,0,0,0,101,387,1,0,0,0,103,104,5, + 91,0,0,104,2,1,0,0,0,105,106,5,93,0,0,106,4,1,0,0,0,107,108,5,61, + 0,0,108,6,1,0,0,0,109,110,5,43,0,0,110,111,5,61,0,0,111,8,1,0,0, + 0,112,113,5,45,0,0,113,114,5,61,0,0,114,10,1,0,0,0,115,116,5,58, + 0,0,116,117,5,61,0,0,117,12,1,0,0,0,118,119,5,42,0,0,119,120,5,61, + 0,0,120,14,1,0,0,0,121,122,5,47,0,0,122,123,5,61,0,0,123,16,1,0, + 0,0,124,125,5,94,0,0,125,126,5,61,0,0,126,18,1,0,0,0,127,128,5,44, + 0,0,128,20,1,0,0,0,129,130,5,39,0,0,130,22,1,0,0,0,131,132,5,40, + 0,0,132,24,1,0,0,0,133,134,5,41,0,0,134,26,1,0,0,0,135,136,5,123, + 0,0,136,28,1,0,0,0,137,138,5,125,0,0,138,30,1,0,0,0,139,140,5,58, + 0,0,140,32,1,0,0,0,141,142,5,43,0,0,142,34,1,0,0,0,143,144,5,45, + 0,0,144,36,1,0,0,0,145,146,5,59,0,0,146,38,1,0,0,0,147,148,5,46, + 0,0,148,40,1,0,0,0,149,150,5,62,0,0,150,42,1,0,0,0,151,152,5,48, + 0,0,152,153,5,62,0,0,153,44,1,0,0,0,154,155,5,49,0,0,155,156,5,62, + 0,0,156,157,5,62,0,0,157,46,1,0,0,0,158,159,5,94,0,0,159,48,1,0, + 0,0,160,161,5,42,0,0,161,50,1,0,0,0,162,163,5,47,0,0,163,52,1,0, + 0,0,164,165,7,0,0,0,165,166,7,1,0,0,166,167,7,2,0,0,167,168,7,2, + 0,0,168,54,1,0,0,0,169,170,7,3,0,0,170,171,7,4,0,0,171,172,7,5,0, + 0,172,173,7,6,0,0,173,174,7,7,0,0,174,175,7,3,0,0,175,176,7,1,0, + 0,176,56,1,0,0,0,177,178,7,3,0,0,178,179,7,4,0,0,179,180,7,8,0,0, + 180,181,7,9,0,0,181,183,7,7,0,0,182,184,7,2,0,0,183,182,1,0,0,0, + 183,184,1,0,0,0,184,58,1,0,0,0,185,186,7,10,0,0,186,187,7,9,0,0, + 187,188,7,7,0,0,188,189,7,8,0,0,189,190,7,9,0,0,190,191,7,7,0,0, + 191,60,1,0,0,0,192,193,7,2,0,0,193,194,7,1,0,0,194,195,7,11,0,0, + 195,196,7,5,0,0,196,62,1,0,0,0,197,198,7,9,0,0,198,199,7,4,0,0,199, + 200,7,3,0,0,200,201,7,7,0,0,201,202,7,2,0,0,202,203,7,12,0,0,203, + 204,7,2,0,0,204,205,7,7,0,0,205,206,7,5,0,0,206,207,7,0,0,0,207, + 64,1,0,0,0,208,209,7,5,0,0,209,210,7,4,0,0,210,211,7,13,0,0,211, + 212,7,10,0,0,212,213,7,14,0,0,213,214,7,5,0,0,214,66,1,0,0,0,215, + 216,7,4,0,0,216,217,7,5,0,0,217,218,7,15,0,0,218,219,7,7,0,0,219, + 220,7,10,0,0,220,221,7,4,0,0,221,222,7,3,0,0,222,223,7,1,0,0,223, + 224,7,4,0,0,224,68,1,0,0,0,225,226,7,16,0,0,226,227,7,6,0,0,227, + 228,7,1,0,0,228,229,7,0,0,0,229,231,7,5,0,0,230,232,7,2,0,0,231, + 230,1,0,0,0,231,232,1,0,0,0,232,70,1,0,0,0,233,234,7,17,0,0,234, + 235,7,10,0,0,235,236,7,14,0,0,236,237,7,3,0,0,237,239,7,5,0,0,238, + 240,7,2,0,0,239,238,1,0,0,0,239,240,1,0,0,0,240,72,1,0,0,0,241,242, + 7,8,0,0,242,243,7,1,0,0,243,244,7,6,0,0,244,245,7,7,0,0,245,246, + 7,3,0,0,246,247,7,13,0,0,247,248,7,18,0,0,248,250,7,5,0,0,249,251, + 7,2,0,0,250,249,1,0,0,0,250,251,1,0,0,0,251,74,1,0,0,0,252,253,7, + 8,0,0,253,254,7,10,0,0,254,255,7,3,0,0,255,256,7,4,0,0,256,258,7, + 7,0,0,257,259,7,2,0,0,258,257,1,0,0,0,258,259,1,0,0,0,259,76,1,0, + 0,0,260,261,7,13,0,0,261,262,7,10,0,0,262,263,7,4,0,0,263,264,7, + 2,0,0,264,265,7,7,0,0,265,266,7,1,0,0,266,267,7,4,0,0,267,269,7, + 7,0,0,268,270,7,2,0,0,269,268,1,0,0,0,269,270,1,0,0,0,270,78,1,0, + 0,0,271,272,7,2,0,0,272,273,7,8,0,0,273,274,7,5,0,0,274,275,7,13, + 0,0,275,276,7,3,0,0,276,277,7,16,0,0,277,278,7,3,0,0,278,279,7,5, + 0,0,279,281,7,14,0,0,280,282,7,2,0,0,281,280,1,0,0,0,281,282,1,0, + 0,0,282,80,1,0,0,0,283,284,7,3,0,0,284,285,7,0,0,0,285,286,7,1,0, + 0,286,287,7,19,0,0,287,288,7,3,0,0,288,289,7,4,0,0,289,290,7,1,0, + 0,290,291,7,6,0,0,291,292,7,12,0,0,292,82,1,0,0,0,293,294,7,11,0, + 0,294,295,7,1,0,0,295,296,7,6,0,0,296,297,7,3,0,0,297,298,7,1,0, + 0,298,299,7,17,0,0,299,300,7,18,0,0,300,302,7,5,0,0,301,303,7,2, + 0,0,302,301,1,0,0,0,302,303,1,0,0,0,303,84,1,0,0,0,304,305,7,0,0, + 0,305,306,7,10,0,0,306,307,7,7,0,0,307,308,7,3,0,0,308,309,7,10, + 0,0,309,310,7,4,0,0,310,311,7,11,0,0,311,312,7,1,0,0,312,313,7,6, + 0,0,313,314,7,3,0,0,314,315,7,1,0,0,315,316,7,17,0,0,316,317,7,18, + 0,0,317,319,7,5,0,0,318,320,7,2,0,0,319,318,1,0,0,0,319,320,1,0, + 0,0,320,86,1,0,0,0,321,323,5,39,0,0,322,321,1,0,0,0,323,326,1,0, + 0,0,324,322,1,0,0,0,324,325,1,0,0,0,325,88,1,0,0,0,326,324,1,0,0, + 0,327,328,7,20,0,0,328,90,1,0,0,0,329,331,7,20,0,0,330,329,1,0,0, + 0,331,332,1,0,0,0,332,330,1,0,0,0,332,333,1,0,0,0,333,92,1,0,0,0, + 334,336,3,89,44,0,335,334,1,0,0,0,336,337,1,0,0,0,337,335,1,0,0, + 0,337,338,1,0,0,0,338,339,1,0,0,0,339,343,5,46,0,0,340,342,3,89, + 44,0,341,340,1,0,0,0,342,345,1,0,0,0,343,341,1,0,0,0,343,344,1,0, + 0,0,344,353,1,0,0,0,345,343,1,0,0,0,346,348,5,46,0,0,347,349,3,89, + 44,0,348,347,1,0,0,0,349,350,1,0,0,0,350,348,1,0,0,0,350,351,1,0, + 0,0,351,353,1,0,0,0,352,335,1,0,0,0,352,346,1,0,0,0,353,94,1,0,0, + 0,354,355,3,93,46,0,355,356,5,69,0,0,356,357,3,91,45,0,357,364,1, + 0,0,0,358,359,3,93,46,0,359,360,5,69,0,0,360,361,5,45,0,0,361,362, + 3,91,45,0,362,364,1,0,0,0,363,354,1,0,0,0,363,358,1,0,0,0,364,96, + 1,0,0,0,365,369,5,37,0,0,366,368,9,0,0,0,367,366,1,0,0,0,368,371, + 1,0,0,0,369,370,1,0,0,0,369,367,1,0,0,0,370,373,1,0,0,0,371,369, + 1,0,0,0,372,374,5,13,0,0,373,372,1,0,0,0,373,374,1,0,0,0,374,375, + 1,0,0,0,375,376,5,10,0,0,376,377,1,0,0,0,377,378,6,48,0,0,378,98, + 1,0,0,0,379,383,7,21,0,0,380,382,7,22,0,0,381,380,1,0,0,0,382,385, + 1,0,0,0,383,381,1,0,0,0,383,384,1,0,0,0,384,100,1,0,0,0,385,383, + 1,0,0,0,386,388,7,23,0,0,387,386,1,0,0,0,388,389,1,0,0,0,389,387, + 1,0,0,0,389,390,1,0,0,0,390,391,1,0,0,0,391,392,6,50,0,0,392,102, + 1,0,0,0,21,0,183,231,239,250,258,269,281,302,319,324,332,337,343, + 350,352,363,369,373,383,389,1,6,0,0 + ] + +class AutolevLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + T__18 = 19 + T__19 = 20 + T__20 = 21 + T__21 = 22 + T__22 = 23 + T__23 = 24 + T__24 = 25 + T__25 = 26 + Mass = 27 + Inertia = 28 + Input = 29 + Output = 30 + Save = 31 + UnitSystem = 32 + Encode = 33 + Newtonian = 34 + Frames = 35 + Bodies = 36 + Particles = 37 + Points = 38 + Constants = 39 + Specifieds = 40 + Imaginary = 41 + Variables = 42 + MotionVariables = 43 + INT = 44 + FLOAT = 45 + EXP = 46 + LINE_COMMENT = 47 + ID = 48 + WS = 49 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "'['", "']'", "'='", "'+='", "'-='", "':='", "'*='", "'/='", + "'^='", "','", "'''", "'('", "')'", "'{'", "'}'", "':'", "'+'", + "'-'", "';'", "'.'", "'>'", "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", "MotionVariables", + "INT", "FLOAT", "EXP", "LINE_COMMENT", "ID", "WS" ] + + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", + "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", + "Points", "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "DIFF", "DIGIT", "INT", "FLOAT", "EXP", + "LINE_COMMENT", "ID", "WS" ] + + grammarFileName = "Autolev.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py new file mode 100644 index 0000000000000000000000000000000000000000..6f391a298a71ecf2d04cf921a919cbb68b181fab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py @@ -0,0 +1,421 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +if __name__ is not None and "." in __name__: + from .autolevparser import AutolevParser +else: + from autolevparser import AutolevParser + +# This class defines a complete listener for a parse tree produced by AutolevParser. +class AutolevListener(ParseTreeListener): + + # Enter a parse tree produced by AutolevParser#prog. + def enterProg(self, ctx:AutolevParser.ProgContext): + pass + + # Exit a parse tree produced by AutolevParser#prog. + def exitProg(self, ctx:AutolevParser.ProgContext): + pass + + + # Enter a parse tree produced by AutolevParser#stat. + def enterStat(self, ctx:AutolevParser.StatContext): + pass + + # Exit a parse tree produced by AutolevParser#stat. + def exitStat(self, ctx:AutolevParser.StatContext): + pass + + + # Enter a parse tree produced by AutolevParser#vecAssign. + def enterVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#vecAssign. + def exitVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#indexAssign. + def enterIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#indexAssign. + def exitIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#regularAssign. + def enterRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#regularAssign. + def exitRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#equals. + def enterEquals(self, ctx:AutolevParser.EqualsContext): + pass + + # Exit a parse tree produced by AutolevParser#equals. + def exitEquals(self, ctx:AutolevParser.EqualsContext): + pass + + + # Enter a parse tree produced by AutolevParser#index. + def enterIndex(self, ctx:AutolevParser.IndexContext): + pass + + # Exit a parse tree produced by AutolevParser#index. + def exitIndex(self, ctx:AutolevParser.IndexContext): + pass + + + # Enter a parse tree produced by AutolevParser#diff. + def enterDiff(self, ctx:AutolevParser.DiffContext): + pass + + # Exit a parse tree produced by AutolevParser#diff. + def exitDiff(self, ctx:AutolevParser.DiffContext): + pass + + + # Enter a parse tree produced by AutolevParser#functionCall. + def enterFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + # Exit a parse tree produced by AutolevParser#functionCall. + def exitFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl. + def enterVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#varDecl. + def exitVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#varType. + def enterVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + # Exit a parse tree produced by AutolevParser#varType. + def exitVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl2. + def enterVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#varDecl2. + def exitVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#ranges. + def enterRanges(self, ctx:AutolevParser.RangesContext): + pass + + # Exit a parse tree produced by AutolevParser#ranges. + def exitRanges(self, ctx:AutolevParser.RangesContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl. + def enterMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#massDecl. + def exitMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl2. + def enterMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#massDecl2. + def exitMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#inertiaDecl. + def enterInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#inertiaDecl. + def exitInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrix. + def enterMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + # Exit a parse tree produced by AutolevParser#matrix. + def exitMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrixInOutput. + def enterMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + # Exit a parse tree produced by AutolevParser#matrixInOutput. + def exitMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + + # Enter a parse tree produced by AutolevParser#codeCommands. + def enterCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#codeCommands. + def exitCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#settings. + def enterSettings(self, ctx:AutolevParser.SettingsContext): + pass + + # Exit a parse tree produced by AutolevParser#settings. + def exitSettings(self, ctx:AutolevParser.SettingsContext): + pass + + + # Enter a parse tree produced by AutolevParser#units. + def enterUnits(self, ctx:AutolevParser.UnitsContext): + pass + + # Exit a parse tree produced by AutolevParser#units. + def exitUnits(self, ctx:AutolevParser.UnitsContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs. + def enterInputs(self, ctx:AutolevParser.InputsContext): + pass + + # Exit a parse tree produced by AutolevParser#inputs. + def exitInputs(self, ctx:AutolevParser.InputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#id_diff. + def enterId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + # Exit a parse tree produced by AutolevParser#id_diff. + def exitId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs2. + def enterInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#inputs2. + def exitInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#outputs. + def enterOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + # Exit a parse tree produced by AutolevParser#outputs. + def exitOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#outputs2. + def enterOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#outputs2. + def exitOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#codegen. + def enterCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + # Exit a parse tree produced by AutolevParser#codegen. + def exitCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + + # Enter a parse tree produced by AutolevParser#commands. + def enterCommands(self, ctx:AutolevParser.CommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#commands. + def exitCommands(self, ctx:AutolevParser.CommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#vec. + def enterVec(self, ctx:AutolevParser.VecContext): + pass + + # Exit a parse tree produced by AutolevParser#vec. + def exitVec(self, ctx:AutolevParser.VecContext): + pass + + + # Enter a parse tree produced by AutolevParser#parens. + def enterParens(self, ctx:AutolevParser.ParensContext): + pass + + # Exit a parse tree produced by AutolevParser#parens. + def exitParens(self, ctx:AutolevParser.ParensContext): + pass + + + # Enter a parse tree produced by AutolevParser#VectorOrDyadic. + def enterVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + # Exit a parse tree produced by AutolevParser#VectorOrDyadic. + def exitVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + + # Enter a parse tree produced by AutolevParser#Exponent. + def enterExponent(self, ctx:AutolevParser.ExponentContext): + pass + + # Exit a parse tree produced by AutolevParser#Exponent. + def exitExponent(self, ctx:AutolevParser.ExponentContext): + pass + + + # Enter a parse tree produced by AutolevParser#MulDiv. + def enterMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + # Exit a parse tree produced by AutolevParser#MulDiv. + def exitMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + + # Enter a parse tree produced by AutolevParser#AddSub. + def enterAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + # Exit a parse tree produced by AutolevParser#AddSub. + def exitAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + + # Enter a parse tree produced by AutolevParser#float. + def enterFloat(self, ctx:AutolevParser.FloatContext): + pass + + # Exit a parse tree produced by AutolevParser#float. + def exitFloat(self, ctx:AutolevParser.FloatContext): + pass + + + # Enter a parse tree produced by AutolevParser#int. + def enterInt(self, ctx:AutolevParser.IntContext): + pass + + # Exit a parse tree produced by AutolevParser#int. + def exitInt(self, ctx:AutolevParser.IntContext): + pass + + + # Enter a parse tree produced by AutolevParser#idEqualsExpr. + def enterIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + # Exit a parse tree produced by AutolevParser#idEqualsExpr. + def exitIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + + # Enter a parse tree produced by AutolevParser#negativeOne. + def enterNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + # Exit a parse tree produced by AutolevParser#negativeOne. + def exitNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + + # Enter a parse tree produced by AutolevParser#function. + def enterFunction(self, ctx:AutolevParser.FunctionContext): + pass + + # Exit a parse tree produced by AutolevParser#function. + def exitFunction(self, ctx:AutolevParser.FunctionContext): + pass + + + # Enter a parse tree produced by AutolevParser#rangess. + def enterRangess(self, ctx:AutolevParser.RangessContext): + pass + + # Exit a parse tree produced by AutolevParser#rangess. + def exitRangess(self, ctx:AutolevParser.RangessContext): + pass + + + # Enter a parse tree produced by AutolevParser#colon. + def enterColon(self, ctx:AutolevParser.ColonContext): + pass + + # Exit a parse tree produced by AutolevParser#colon. + def exitColon(self, ctx:AutolevParser.ColonContext): + pass + + + # Enter a parse tree produced by AutolevParser#id. + def enterId(self, ctx:AutolevParser.IdContext): + pass + + # Exit a parse tree produced by AutolevParser#id. + def exitId(self, ctx:AutolevParser.IdContext): + pass + + + # Enter a parse tree produced by AutolevParser#exp. + def enterExp(self, ctx:AutolevParser.ExpContext): + pass + + # Exit a parse tree produced by AutolevParser#exp. + def exitExp(self, ctx:AutolevParser.ExpContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrices. + def enterMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + # Exit a parse tree produced by AutolevParser#matrices. + def exitMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + + # Enter a parse tree produced by AutolevParser#Indexing. + def enterIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + # Exit a parse tree produced by AutolevParser#Indexing. + def exitIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + + +del AutolevParser diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py new file mode 100644 index 0000000000000000000000000000000000000000..e63ef1c110812580d06291ee7c7ec40b6a076cea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py @@ -0,0 +1,3063 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,49,431,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,1,0,4,0,58,8,0,11,0,12,0,59,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,3,1,69,8,1,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2, + 3,2,84,8,2,1,2,1,2,1,2,3,2,89,8,2,1,3,1,3,1,4,1,4,1,4,5,4,96,8,4, + 10,4,12,4,99,9,4,1,5,4,5,102,8,5,11,5,12,5,103,1,6,1,6,1,6,1,6,1, + 6,5,6,111,8,6,10,6,12,6,114,9,6,3,6,116,8,6,1,6,1,6,1,6,1,6,1,6, + 1,6,5,6,124,8,6,10,6,12,6,127,9,6,3,6,129,8,6,1,6,3,6,132,8,6,1, + 7,1,7,1,7,1,7,5,7,138,8,7,10,7,12,7,141,9,7,1,8,1,8,1,8,1,8,1,8, + 1,8,1,8,1,8,1,8,1,8,5,8,153,8,8,10,8,12,8,156,9,8,1,8,1,8,5,8,160, + 8,8,10,8,12,8,163,9,8,3,8,165,8,8,1,9,1,9,1,9,1,9,1,9,1,9,3,9,173, + 8,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,5,9,183,8,9,10,9,12,9,186,9, + 9,1,9,3,9,189,8,9,1,9,1,9,1,9,3,9,194,8,9,1,9,3,9,197,8,9,1,9,5, + 9,200,8,9,10,9,12,9,203,9,9,1,9,1,9,3,9,207,8,9,1,10,1,10,1,10,1, + 10,1,10,1,10,1,10,1,10,5,10,217,8,10,10,10,12,10,220,9,10,1,10,1, + 10,1,11,1,11,1,11,1,11,5,11,228,8,11,10,11,12,11,231,9,11,1,12,1, + 12,1,12,1,12,1,13,1,13,1,13,1,13,1,13,3,13,242,8,13,1,13,1,13,4, + 13,246,8,13,11,13,12,13,247,1,14,1,14,1,14,1,14,5,14,254,8,14,10, + 14,12,14,257,9,14,1,14,1,14,1,15,1,15,1,15,1,15,3,15,265,8,15,1, + 15,1,15,3,15,269,8,15,1,16,1,16,1,16,1,16,1,16,3,16,276,8,16,1,17, + 1,17,3,17,280,8,17,1,18,1,18,1,18,1,18,5,18,286,8,18,10,18,12,18, + 289,9,18,1,19,1,19,1,19,1,19,5,19,295,8,19,10,19,12,19,298,9,19, + 1,20,1,20,3,20,302,8,20,1,21,1,21,1,21,1,21,3,21,308,8,21,1,22,1, + 22,1,22,1,22,5,22,314,8,22,10,22,12,22,317,9,22,1,23,1,23,3,23,321, + 8,23,1,24,1,24,1,24,1,24,1,24,1,24,5,24,329,8,24,10,24,12,24,332, + 9,24,1,24,1,24,3,24,336,8,24,1,24,1,24,1,24,1,24,1,25,1,25,1,25, + 1,25,1,25,1,25,1,25,1,25,5,25,350,8,25,10,25,12,25,353,9,25,3,25, + 355,8,25,1,26,1,26,4,26,359,8,26,11,26,12,26,360,1,26,1,26,3,26, + 365,8,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,375,8,27,10, + 27,12,27,378,9,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,386,8,27,10, + 27,12,27,389,9,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,3, + 27,400,8,27,1,27,1,27,5,27,404,8,27,10,27,12,27,407,9,27,3,27,409, + 8,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,27,1,27,5,27,426,8,27,10,27,12,27,429,9,27,1,27,0,1,54,28, + 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44, + 46,48,50,52,54,0,7,1,0,3,9,1,0,27,28,1,0,17,18,2,0,10,10,19,19,1, + 0,44,45,2,0,44,46,48,48,1,0,25,26,483,0,57,1,0,0,0,2,68,1,0,0,0, + 4,88,1,0,0,0,6,90,1,0,0,0,8,92,1,0,0,0,10,101,1,0,0,0,12,131,1,0, + 0,0,14,133,1,0,0,0,16,164,1,0,0,0,18,166,1,0,0,0,20,208,1,0,0,0, + 22,223,1,0,0,0,24,232,1,0,0,0,26,236,1,0,0,0,28,249,1,0,0,0,30,268, + 1,0,0,0,32,275,1,0,0,0,34,277,1,0,0,0,36,281,1,0,0,0,38,290,1,0, + 0,0,40,299,1,0,0,0,42,303,1,0,0,0,44,309,1,0,0,0,46,318,1,0,0,0, + 48,322,1,0,0,0,50,354,1,0,0,0,52,364,1,0,0,0,54,408,1,0,0,0,56,58, + 3,2,1,0,57,56,1,0,0,0,58,59,1,0,0,0,59,57,1,0,0,0,59,60,1,0,0,0, + 60,1,1,0,0,0,61,69,3,14,7,0,62,69,3,12,6,0,63,69,3,32,16,0,64,69, + 3,22,11,0,65,69,3,26,13,0,66,69,3,4,2,0,67,69,3,34,17,0,68,61,1, + 0,0,0,68,62,1,0,0,0,68,63,1,0,0,0,68,64,1,0,0,0,68,65,1,0,0,0,68, + 66,1,0,0,0,68,67,1,0,0,0,69,3,1,0,0,0,70,71,3,52,26,0,71,72,3,6, + 3,0,72,73,3,54,27,0,73,89,1,0,0,0,74,75,5,48,0,0,75,76,5,1,0,0,76, + 77,3,8,4,0,77,78,5,2,0,0,78,79,3,6,3,0,79,80,3,54,27,0,80,89,1,0, + 0,0,81,83,5,48,0,0,82,84,3,10,5,0,83,82,1,0,0,0,83,84,1,0,0,0,84, + 85,1,0,0,0,85,86,3,6,3,0,86,87,3,54,27,0,87,89,1,0,0,0,88,70,1,0, + 0,0,88,74,1,0,0,0,88,81,1,0,0,0,89,5,1,0,0,0,90,91,7,0,0,0,91,7, + 1,0,0,0,92,97,3,54,27,0,93,94,5,10,0,0,94,96,3,54,27,0,95,93,1,0, + 0,0,96,99,1,0,0,0,97,95,1,0,0,0,97,98,1,0,0,0,98,9,1,0,0,0,99,97, + 1,0,0,0,100,102,5,11,0,0,101,100,1,0,0,0,102,103,1,0,0,0,103,101, + 1,0,0,0,103,104,1,0,0,0,104,11,1,0,0,0,105,106,5,48,0,0,106,115, + 5,12,0,0,107,112,3,54,27,0,108,109,5,10,0,0,109,111,3,54,27,0,110, + 108,1,0,0,0,111,114,1,0,0,0,112,110,1,0,0,0,112,113,1,0,0,0,113, + 116,1,0,0,0,114,112,1,0,0,0,115,107,1,0,0,0,115,116,1,0,0,0,116, + 117,1,0,0,0,117,132,5,13,0,0,118,119,7,1,0,0,119,128,5,12,0,0,120, + 125,5,48,0,0,121,122,5,10,0,0,122,124,5,48,0,0,123,121,1,0,0,0,124, + 127,1,0,0,0,125,123,1,0,0,0,125,126,1,0,0,0,126,129,1,0,0,0,127, + 125,1,0,0,0,128,120,1,0,0,0,128,129,1,0,0,0,129,130,1,0,0,0,130, + 132,5,13,0,0,131,105,1,0,0,0,131,118,1,0,0,0,132,13,1,0,0,0,133, + 134,3,16,8,0,134,139,3,18,9,0,135,136,5,10,0,0,136,138,3,18,9,0, + 137,135,1,0,0,0,138,141,1,0,0,0,139,137,1,0,0,0,139,140,1,0,0,0, + 140,15,1,0,0,0,141,139,1,0,0,0,142,165,5,34,0,0,143,165,5,35,0,0, + 144,165,5,36,0,0,145,165,5,37,0,0,146,165,5,38,0,0,147,165,5,39, + 0,0,148,165,5,40,0,0,149,165,5,41,0,0,150,154,5,42,0,0,151,153,5, + 11,0,0,152,151,1,0,0,0,153,156,1,0,0,0,154,152,1,0,0,0,154,155,1, + 0,0,0,155,165,1,0,0,0,156,154,1,0,0,0,157,161,5,43,0,0,158,160,5, + 11,0,0,159,158,1,0,0,0,160,163,1,0,0,0,161,159,1,0,0,0,161,162,1, + 0,0,0,162,165,1,0,0,0,163,161,1,0,0,0,164,142,1,0,0,0,164,143,1, + 0,0,0,164,144,1,0,0,0,164,145,1,0,0,0,164,146,1,0,0,0,164,147,1, + 0,0,0,164,148,1,0,0,0,164,149,1,0,0,0,164,150,1,0,0,0,164,157,1, + 0,0,0,165,17,1,0,0,0,166,172,5,48,0,0,167,168,5,14,0,0,168,169,5, + 44,0,0,169,170,5,10,0,0,170,171,5,44,0,0,171,173,5,15,0,0,172,167, + 1,0,0,0,172,173,1,0,0,0,173,188,1,0,0,0,174,175,5,14,0,0,175,176, + 5,44,0,0,176,177,5,16,0,0,177,184,5,44,0,0,178,179,5,10,0,0,179, + 180,5,44,0,0,180,181,5,16,0,0,181,183,5,44,0,0,182,178,1,0,0,0,183, + 186,1,0,0,0,184,182,1,0,0,0,184,185,1,0,0,0,185,187,1,0,0,0,186, + 184,1,0,0,0,187,189,5,15,0,0,188,174,1,0,0,0,188,189,1,0,0,0,189, + 193,1,0,0,0,190,191,5,14,0,0,191,192,5,44,0,0,192,194,5,15,0,0,193, + 190,1,0,0,0,193,194,1,0,0,0,194,196,1,0,0,0,195,197,7,2,0,0,196, + 195,1,0,0,0,196,197,1,0,0,0,197,201,1,0,0,0,198,200,5,11,0,0,199, + 198,1,0,0,0,200,203,1,0,0,0,201,199,1,0,0,0,201,202,1,0,0,0,202, + 206,1,0,0,0,203,201,1,0,0,0,204,205,5,3,0,0,205,207,3,54,27,0,206, + 204,1,0,0,0,206,207,1,0,0,0,207,19,1,0,0,0,208,209,5,14,0,0,209, + 210,5,44,0,0,210,211,5,16,0,0,211,218,5,44,0,0,212,213,5,10,0,0, + 213,214,5,44,0,0,214,215,5,16,0,0,215,217,5,44,0,0,216,212,1,0,0, + 0,217,220,1,0,0,0,218,216,1,0,0,0,218,219,1,0,0,0,219,221,1,0,0, + 0,220,218,1,0,0,0,221,222,5,15,0,0,222,21,1,0,0,0,223,224,5,27,0, + 0,224,229,3,24,12,0,225,226,5,10,0,0,226,228,3,24,12,0,227,225,1, + 0,0,0,228,231,1,0,0,0,229,227,1,0,0,0,229,230,1,0,0,0,230,23,1,0, + 0,0,231,229,1,0,0,0,232,233,5,48,0,0,233,234,5,3,0,0,234,235,3,54, + 27,0,235,25,1,0,0,0,236,237,5,28,0,0,237,241,5,48,0,0,238,239,5, + 12,0,0,239,240,5,48,0,0,240,242,5,13,0,0,241,238,1,0,0,0,241,242, + 1,0,0,0,242,245,1,0,0,0,243,244,5,10,0,0,244,246,3,54,27,0,245,243, + 1,0,0,0,246,247,1,0,0,0,247,245,1,0,0,0,247,248,1,0,0,0,248,27,1, + 0,0,0,249,250,5,1,0,0,250,255,3,54,27,0,251,252,7,3,0,0,252,254, + 3,54,27,0,253,251,1,0,0,0,254,257,1,0,0,0,255,253,1,0,0,0,255,256, + 1,0,0,0,256,258,1,0,0,0,257,255,1,0,0,0,258,259,5,2,0,0,259,29,1, + 0,0,0,260,261,5,48,0,0,261,262,5,48,0,0,262,264,5,3,0,0,263,265, + 7,4,0,0,264,263,1,0,0,0,264,265,1,0,0,0,265,269,1,0,0,0,266,269, + 5,45,0,0,267,269,5,44,0,0,268,260,1,0,0,0,268,266,1,0,0,0,268,267, + 1,0,0,0,269,31,1,0,0,0,270,276,3,36,18,0,271,276,3,38,19,0,272,276, + 3,44,22,0,273,276,3,48,24,0,274,276,3,50,25,0,275,270,1,0,0,0,275, + 271,1,0,0,0,275,272,1,0,0,0,275,273,1,0,0,0,275,274,1,0,0,0,276, + 33,1,0,0,0,277,279,5,48,0,0,278,280,7,5,0,0,279,278,1,0,0,0,279, + 280,1,0,0,0,280,35,1,0,0,0,281,282,5,32,0,0,282,287,5,48,0,0,283, + 284,5,10,0,0,284,286,5,48,0,0,285,283,1,0,0,0,286,289,1,0,0,0,287, + 285,1,0,0,0,287,288,1,0,0,0,288,37,1,0,0,0,289,287,1,0,0,0,290,291, + 5,29,0,0,291,296,3,42,21,0,292,293,5,10,0,0,293,295,3,42,21,0,294, + 292,1,0,0,0,295,298,1,0,0,0,296,294,1,0,0,0,296,297,1,0,0,0,297, + 39,1,0,0,0,298,296,1,0,0,0,299,301,5,48,0,0,300,302,3,10,5,0,301, + 300,1,0,0,0,301,302,1,0,0,0,302,41,1,0,0,0,303,304,3,40,20,0,304, + 305,5,3,0,0,305,307,3,54,27,0,306,308,3,54,27,0,307,306,1,0,0,0, + 307,308,1,0,0,0,308,43,1,0,0,0,309,310,5,30,0,0,310,315,3,46,23, + 0,311,312,5,10,0,0,312,314,3,46,23,0,313,311,1,0,0,0,314,317,1,0, + 0,0,315,313,1,0,0,0,315,316,1,0,0,0,316,45,1,0,0,0,317,315,1,0,0, + 0,318,320,3,54,27,0,319,321,3,54,27,0,320,319,1,0,0,0,320,321,1, + 0,0,0,321,47,1,0,0,0,322,323,5,48,0,0,323,335,3,12,6,0,324,325,5, + 1,0,0,325,330,3,30,15,0,326,327,5,10,0,0,327,329,3,30,15,0,328,326, + 1,0,0,0,329,332,1,0,0,0,330,328,1,0,0,0,330,331,1,0,0,0,331,333, + 1,0,0,0,332,330,1,0,0,0,333,334,5,2,0,0,334,336,1,0,0,0,335,324, + 1,0,0,0,335,336,1,0,0,0,336,337,1,0,0,0,337,338,5,48,0,0,338,339, + 5,20,0,0,339,340,5,48,0,0,340,49,1,0,0,0,341,342,5,31,0,0,342,343, + 5,48,0,0,343,344,5,20,0,0,344,355,5,48,0,0,345,346,5,33,0,0,346, + 351,5,48,0,0,347,348,5,10,0,0,348,350,5,48,0,0,349,347,1,0,0,0,350, + 353,1,0,0,0,351,349,1,0,0,0,351,352,1,0,0,0,352,355,1,0,0,0,353, + 351,1,0,0,0,354,341,1,0,0,0,354,345,1,0,0,0,355,51,1,0,0,0,356,358, + 5,48,0,0,357,359,5,21,0,0,358,357,1,0,0,0,359,360,1,0,0,0,360,358, + 1,0,0,0,360,361,1,0,0,0,361,365,1,0,0,0,362,365,5,22,0,0,363,365, + 5,23,0,0,364,356,1,0,0,0,364,362,1,0,0,0,364,363,1,0,0,0,365,53, + 1,0,0,0,366,367,6,27,-1,0,367,409,5,46,0,0,368,369,5,18,0,0,369, + 409,3,54,27,12,370,409,5,45,0,0,371,409,5,44,0,0,372,376,5,48,0, + 0,373,375,5,11,0,0,374,373,1,0,0,0,375,378,1,0,0,0,376,374,1,0,0, + 0,376,377,1,0,0,0,377,409,1,0,0,0,378,376,1,0,0,0,379,409,3,52,26, + 0,380,381,5,48,0,0,381,382,5,1,0,0,382,387,3,54,27,0,383,384,5,10, + 0,0,384,386,3,54,27,0,385,383,1,0,0,0,386,389,1,0,0,0,387,385,1, + 0,0,0,387,388,1,0,0,0,388,390,1,0,0,0,389,387,1,0,0,0,390,391,5, + 2,0,0,391,409,1,0,0,0,392,409,3,12,6,0,393,409,3,28,14,0,394,395, + 5,12,0,0,395,396,3,54,27,0,396,397,5,13,0,0,397,409,1,0,0,0,398, + 400,5,48,0,0,399,398,1,0,0,0,399,400,1,0,0,0,400,401,1,0,0,0,401, + 405,3,20,10,0,402,404,5,11,0,0,403,402,1,0,0,0,404,407,1,0,0,0,405, + 403,1,0,0,0,405,406,1,0,0,0,406,409,1,0,0,0,407,405,1,0,0,0,408, + 366,1,0,0,0,408,368,1,0,0,0,408,370,1,0,0,0,408,371,1,0,0,0,408, + 372,1,0,0,0,408,379,1,0,0,0,408,380,1,0,0,0,408,392,1,0,0,0,408, + 393,1,0,0,0,408,394,1,0,0,0,408,399,1,0,0,0,409,427,1,0,0,0,410, + 411,10,16,0,0,411,412,5,24,0,0,412,426,3,54,27,17,413,414,10,15, + 0,0,414,415,7,6,0,0,415,426,3,54,27,16,416,417,10,14,0,0,417,418, + 7,2,0,0,418,426,3,54,27,15,419,420,10,3,0,0,420,421,5,3,0,0,421, + 426,3,54,27,4,422,423,10,2,0,0,423,424,5,16,0,0,424,426,3,54,27, + 3,425,410,1,0,0,0,425,413,1,0,0,0,425,416,1,0,0,0,425,419,1,0,0, + 0,425,422,1,0,0,0,426,429,1,0,0,0,427,425,1,0,0,0,427,428,1,0,0, + 0,428,55,1,0,0,0,429,427,1,0,0,0,50,59,68,83,88,97,103,112,115,125, + 128,131,139,154,161,164,172,184,188,193,196,201,206,218,229,241, + 247,255,264,268,275,279,287,296,301,307,315,320,330,335,351,354, + 360,364,376,387,399,405,408,425,427 + ] + +class AutolevParser ( Parser ): + + grammarFileName = "Autolev.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "'['", "']'", "'='", "'+='", "'-='", "':='", + "'*='", "'/='", "'^='", "','", "'''", "'('", "')'", + "'{'", "'}'", "':'", "'+'", "'-'", "';'", "'.'", "'>'", + "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "Mass", "Inertia", + "Input", "Output", "Save", "UnitSystem", "Encode", + "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "INT", "FLOAT", "EXP", "LINE_COMMENT", + "ID", "WS" ] + + RULE_prog = 0 + RULE_stat = 1 + RULE_assignment = 2 + RULE_equals = 3 + RULE_index = 4 + RULE_diff = 5 + RULE_functionCall = 6 + RULE_varDecl = 7 + RULE_varType = 8 + RULE_varDecl2 = 9 + RULE_ranges = 10 + RULE_massDecl = 11 + RULE_massDecl2 = 12 + RULE_inertiaDecl = 13 + RULE_matrix = 14 + RULE_matrixInOutput = 15 + RULE_codeCommands = 16 + RULE_settings = 17 + RULE_units = 18 + RULE_inputs = 19 + RULE_id_diff = 20 + RULE_inputs2 = 21 + RULE_outputs = 22 + RULE_outputs2 = 23 + RULE_codegen = 24 + RULE_commands = 25 + RULE_vec = 26 + RULE_expr = 27 + + ruleNames = [ "prog", "stat", "assignment", "equals", "index", "diff", + "functionCall", "varDecl", "varType", "varDecl2", "ranges", + "massDecl", "massDecl2", "inertiaDecl", "matrix", "matrixInOutput", + "codeCommands", "settings", "units", "inputs", "id_diff", + "inputs2", "outputs", "outputs2", "codegen", "commands", + "vec", "expr" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + T__18=19 + T__19=20 + T__20=21 + T__21=22 + T__22=23 + T__23=24 + T__24=25 + T__25=26 + Mass=27 + Inertia=28 + Input=29 + Output=30 + Save=31 + UnitSystem=32 + Encode=33 + Newtonian=34 + Frames=35 + Bodies=36 + Particles=37 + Points=38 + Constants=39 + Specifieds=40 + Imaginary=41 + Variables=42 + MotionVariables=43 + INT=44 + FLOAT=45 + EXP=46 + LINE_COMMENT=47 + ID=48 + WS=49 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class ProgContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def stat(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.StatContext) + else: + return self.getTypedRuleContext(AutolevParser.StatContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_prog + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterProg" ): + listener.enterProg(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitProg" ): + listener.exitProg(self) + + + + + def prog(self): + + localctx = AutolevParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 57 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 56 + self.stat() + self.state = 59 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (((_la) & ~0x3f) == 0 and ((1 << _la) & 299067041120256) != 0): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class StatContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varDecl(self): + return self.getTypedRuleContext(AutolevParser.VarDeclContext,0) + + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def codeCommands(self): + return self.getTypedRuleContext(AutolevParser.CodeCommandsContext,0) + + + def massDecl(self): + return self.getTypedRuleContext(AutolevParser.MassDeclContext,0) + + + def inertiaDecl(self): + return self.getTypedRuleContext(AutolevParser.InertiaDeclContext,0) + + + def assignment(self): + return self.getTypedRuleContext(AutolevParser.AssignmentContext,0) + + + def settings(self): + return self.getTypedRuleContext(AutolevParser.SettingsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_stat + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterStat" ): + listener.enterStat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitStat" ): + listener.exitStat(self) + + + + + def stat(self): + + localctx = AutolevParser.StatContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_stat) + try: + self.state = 68 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 61 + self.varDecl() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 62 + self.functionCall() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 63 + self.codeCommands() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 64 + self.massDecl() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 65 + self.inertiaDecl() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 66 + self.assignment() + pass + + elif la_ == 7: + self.enterOuterAlt(localctx, 7) + self.state = 67 + self.settings() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AssignmentContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_assignment + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class VecAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVecAssign" ): + listener.enterVecAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVecAssign" ): + listener.exitVecAssign(self) + + + class RegularAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRegularAssign" ): + listener.enterRegularAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRegularAssign" ): + listener.exitRegularAssign(self) + + + class IndexAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def index(self): + return self.getTypedRuleContext(AutolevParser.IndexContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexAssign" ): + listener.enterIndexAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexAssign" ): + listener.exitIndexAssign(self) + + + + def assignment(self): + + localctx = AutolevParser.AssignmentContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_assignment) + self._la = 0 # Token type + try: + self.state = 88 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) + if la_ == 1: + localctx = AutolevParser.VecAssignContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 70 + self.vec() + self.state = 71 + self.equals() + self.state = 72 + self.expr(0) + pass + + elif la_ == 2: + localctx = AutolevParser.IndexAssignContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 74 + self.match(AutolevParser.ID) + self.state = 75 + self.match(AutolevParser.T__0) + self.state = 76 + self.index() + self.state = 77 + self.match(AutolevParser.T__1) + self.state = 78 + self.equals() + self.state = 79 + self.expr(0) + pass + + elif la_ == 3: + localctx = AutolevParser.RegularAssignContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 81 + self.match(AutolevParser.ID) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 82 + self.diff() + + + self.state = 85 + self.equals() + self.state = 86 + self.expr(0) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class EqualsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_equals + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterEquals" ): + listener.enterEquals(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitEquals" ): + listener.exitEquals(self) + + + + + def equals(self): + + localctx = AutolevParser.EqualsContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_equals) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 90 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 1016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IndexContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_index + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndex" ): + listener.enterIndex(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndex" ): + listener.exitIndex(self) + + + + + def index(self): + + localctx = AutolevParser.IndexContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_index) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 92 + self.expr(0) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 93 + self.match(AutolevParser.T__9) + self.state = 94 + self.expr(0) + self.state = 99 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DiffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterDiff" ): + listener.enterDiff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitDiff" ): + listener.exitDiff(self) + + + + + def diff(self): + + localctx = AutolevParser.DiffContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 101 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 100 + self.match(AutolevParser.T__10) + self.state = 103 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==11): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FunctionCallContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_functionCall + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunctionCall" ): + listener.enterFunctionCall(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunctionCall" ): + listener.exitFunctionCall(self) + + + + + def functionCall(self): + + localctx = AutolevParser.FunctionCallContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_functionCall) + self._la = 0 # Token type + try: + self.state = 131 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 105 + self.match(AutolevParser.ID) + self.state = 106 + self.match(AutolevParser.T__11) + self.state = 115 + self._errHandler.sync(self) + _la = self._input.LA(1) + if ((_la) & ~0x3f) == 0 and ((1 << _la) & 404620694540290) != 0: + self.state = 107 + self.expr(0) + self.state = 112 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 108 + self.match(AutolevParser.T__9) + self.state = 109 + self.expr(0) + self.state = 114 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 117 + self.match(AutolevParser.T__12) + pass + elif token in [27, 28]: + self.enterOuterAlt(localctx, 2) + self.state = 118 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 119 + self.match(AutolevParser.T__11) + self.state = 128 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 120 + self.match(AutolevParser.ID) + self.state = 125 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 121 + self.match(AutolevParser.T__9) + self.state = 122 + self.match(AutolevParser.ID) + self.state = 127 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 130 + self.match(AutolevParser.T__12) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varType(self): + return self.getTypedRuleContext(AutolevParser.VarTypeContext,0) + + + def varDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.VarDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.VarDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl" ): + listener.enterVarDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl" ): + listener.exitVarDecl(self) + + + + + def varDecl(self): + + localctx = AutolevParser.VarDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_varDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 133 + self.varType() + self.state = 134 + self.varDecl2() + self.state = 139 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 135 + self.match(AutolevParser.T__9) + self.state = 136 + self.varDecl2() + self.state = 141 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarTypeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Newtonian(self): + return self.getToken(AutolevParser.Newtonian, 0) + + def Frames(self): + return self.getToken(AutolevParser.Frames, 0) + + def Bodies(self): + return self.getToken(AutolevParser.Bodies, 0) + + def Particles(self): + return self.getToken(AutolevParser.Particles, 0) + + def Points(self): + return self.getToken(AutolevParser.Points, 0) + + def Constants(self): + return self.getToken(AutolevParser.Constants, 0) + + def Specifieds(self): + return self.getToken(AutolevParser.Specifieds, 0) + + def Imaginary(self): + return self.getToken(AutolevParser.Imaginary, 0) + + def Variables(self): + return self.getToken(AutolevParser.Variables, 0) + + def MotionVariables(self): + return self.getToken(AutolevParser.MotionVariables, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_varType + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarType" ): + listener.enterVarType(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarType" ): + listener.exitVarType(self) + + + + + def varType(self): + + localctx = AutolevParser.VarTypeContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_varType) + self._la = 0 # Token type + try: + self.state = 164 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [34]: + self.enterOuterAlt(localctx, 1) + self.state = 142 + self.match(AutolevParser.Newtonian) + pass + elif token in [35]: + self.enterOuterAlt(localctx, 2) + self.state = 143 + self.match(AutolevParser.Frames) + pass + elif token in [36]: + self.enterOuterAlt(localctx, 3) + self.state = 144 + self.match(AutolevParser.Bodies) + pass + elif token in [37]: + self.enterOuterAlt(localctx, 4) + self.state = 145 + self.match(AutolevParser.Particles) + pass + elif token in [38]: + self.enterOuterAlt(localctx, 5) + self.state = 146 + self.match(AutolevParser.Points) + pass + elif token in [39]: + self.enterOuterAlt(localctx, 6) + self.state = 147 + self.match(AutolevParser.Constants) + pass + elif token in [40]: + self.enterOuterAlt(localctx, 7) + self.state = 148 + self.match(AutolevParser.Specifieds) + pass + elif token in [41]: + self.enterOuterAlt(localctx, 8) + self.state = 149 + self.match(AutolevParser.Imaginary) + pass + elif token in [42]: + self.enterOuterAlt(localctx, 9) + self.state = 150 + self.match(AutolevParser.Variables) + self.state = 154 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 151 + self.match(AutolevParser.T__10) + self.state = 156 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [43]: + self.enterOuterAlt(localctx, 10) + self.state = 157 + self.match(AutolevParser.MotionVariables) + self.state = 161 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 158 + self.match(AutolevParser.T__10) + self.state = 163 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl2" ): + listener.enterVarDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl2" ): + listener.exitVarDecl2(self) + + + + + def varDecl2(self): + + localctx = AutolevParser.VarDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_varDecl2) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 166 + self.match(AutolevParser.ID) + self.state = 172 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 167 + self.match(AutolevParser.T__13) + self.state = 168 + self.match(AutolevParser.INT) + self.state = 169 + self.match(AutolevParser.T__9) + self.state = 170 + self.match(AutolevParser.INT) + self.state = 171 + self.match(AutolevParser.T__14) + + + self.state = 188 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,17,self._ctx) + if la_ == 1: + self.state = 174 + self.match(AutolevParser.T__13) + self.state = 175 + self.match(AutolevParser.INT) + self.state = 176 + self.match(AutolevParser.T__15) + self.state = 177 + self.match(AutolevParser.INT) + self.state = 184 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 178 + self.match(AutolevParser.T__9) + self.state = 179 + self.match(AutolevParser.INT) + self.state = 180 + self.match(AutolevParser.T__15) + self.state = 181 + self.match(AutolevParser.INT) + self.state = 186 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 187 + self.match(AutolevParser.T__14) + + + self.state = 193 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==14: + self.state = 190 + self.match(AutolevParser.T__13) + self.state = 191 + self.match(AutolevParser.INT) + self.state = 192 + self.match(AutolevParser.T__14) + + + self.state = 196 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==17 or _la==18: + self.state = 195 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + self.state = 201 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 198 + self.match(AutolevParser.T__10) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 206 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==3: + self.state = 204 + self.match(AutolevParser.T__2) + self.state = 205 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RangesContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def getRuleIndex(self): + return AutolevParser.RULE_ranges + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRanges" ): + listener.enterRanges(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRanges" ): + listener.exitRanges(self) + + + + + def ranges(self): + + localctx = AutolevParser.RangesContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_ranges) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 208 + self.match(AutolevParser.T__13) + self.state = 209 + self.match(AutolevParser.INT) + self.state = 210 + self.match(AutolevParser.T__15) + self.state = 211 + self.match(AutolevParser.INT) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 212 + self.match(AutolevParser.T__9) + self.state = 213 + self.match(AutolevParser.INT) + self.state = 214 + self.match(AutolevParser.T__15) + self.state = 215 + self.match(AutolevParser.INT) + self.state = 220 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 221 + self.match(AutolevParser.T__14) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def massDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MassDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.MassDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl" ): + listener.enterMassDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl" ): + listener.exitMassDecl(self) + + + + + def massDecl(self): + + localctx = AutolevParser.MassDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_massDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 223 + self.match(AutolevParser.Mass) + self.state = 224 + self.massDecl2() + self.state = 229 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 225 + self.match(AutolevParser.T__9) + self.state = 226 + self.massDecl2() + self.state = 231 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl2" ): + listener.enterMassDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl2" ): + listener.exitMassDecl2(self) + + + + + def massDecl2(self): + + localctx = AutolevParser.MassDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_massDecl2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 232 + self.match(AutolevParser.ID) + self.state = 233 + self.match(AutolevParser.T__2) + self.state = 234 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InertiaDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inertiaDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInertiaDecl" ): + listener.enterInertiaDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInertiaDecl" ): + listener.exitInertiaDecl(self) + + + + + def inertiaDecl(self): + + localctx = AutolevParser.InertiaDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_inertiaDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.match(AutolevParser.Inertia) + self.state = 237 + self.match(AutolevParser.ID) + self.state = 241 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==12: + self.state = 238 + self.match(AutolevParser.T__11) + self.state = 239 + self.match(AutolevParser.ID) + self.state = 240 + self.match(AutolevParser.T__12) + + + self.state = 245 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 243 + self.match(AutolevParser.T__9) + self.state = 244 + self.expr(0) + self.state = 247 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==10): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_matrix + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrix" ): + listener.enterMatrix(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrix" ): + listener.exitMatrix(self) + + + + + def matrix(self): + + localctx = AutolevParser.MatrixContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_matrix) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 249 + self.match(AutolevParser.T__0) + self.state = 250 + self.expr(0) + self.state = 255 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10 or _la==19: + self.state = 251 + _la = self._input.LA(1) + if not(_la==10 or _la==19): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 252 + self.expr(0) + self.state = 257 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 258 + self.match(AutolevParser.T__1) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixInOutputContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_matrixInOutput + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrixInOutput" ): + listener.enterMatrixInOutput(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrixInOutput" ): + listener.exitMatrixInOutput(self) + + + + + def matrixInOutput(self): + + localctx = AutolevParser.MatrixInOutputContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_matrixInOutput) + self._la = 0 # Token type + try: + self.state = 268 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 260 + self.match(AutolevParser.ID) + + self.state = 261 + self.match(AutolevParser.ID) + self.state = 262 + self.match(AutolevParser.T__2) + self.state = 264 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==44 or _la==45: + self.state = 263 + _la = self._input.LA(1) + if not(_la==44 or _la==45): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + pass + elif token in [45]: + self.enterOuterAlt(localctx, 2) + self.state = 266 + self.match(AutolevParser.FLOAT) + pass + elif token in [44]: + self.enterOuterAlt(localctx, 3) + self.state = 267 + self.match(AutolevParser.INT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodeCommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def units(self): + return self.getTypedRuleContext(AutolevParser.UnitsContext,0) + + + def inputs(self): + return self.getTypedRuleContext(AutolevParser.InputsContext,0) + + + def outputs(self): + return self.getTypedRuleContext(AutolevParser.OutputsContext,0) + + + def codegen(self): + return self.getTypedRuleContext(AutolevParser.CodegenContext,0) + + + def commands(self): + return self.getTypedRuleContext(AutolevParser.CommandsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_codeCommands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodeCommands" ): + listener.enterCodeCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodeCommands" ): + listener.exitCodeCommands(self) + + + + + def codeCommands(self): + + localctx = AutolevParser.CodeCommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_codeCommands) + try: + self.state = 275 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [32]: + self.enterOuterAlt(localctx, 1) + self.state = 270 + self.units() + pass + elif token in [29]: + self.enterOuterAlt(localctx, 2) + self.state = 271 + self.inputs() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 3) + self.state = 272 + self.outputs() + pass + elif token in [48]: + self.enterOuterAlt(localctx, 4) + self.state = 273 + self.codegen() + pass + elif token in [31, 33]: + self.enterOuterAlt(localctx, 5) + self.state = 274 + self.commands() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SettingsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_settings + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterSettings" ): + listener.enterSettings(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitSettings" ): + listener.exitSettings(self) + + + + + def settings(self): + + localctx = AutolevParser.SettingsContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_settings) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 277 + self.match(AutolevParser.ID) + self.state = 279 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 278 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 404620279021568) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class UnitsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UnitSystem(self): + return self.getToken(AutolevParser.UnitSystem, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def getRuleIndex(self): + return AutolevParser.RULE_units + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterUnits" ): + listener.enterUnits(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitUnits" ): + listener.exitUnits(self) + + + + + def units(self): + + localctx = AutolevParser.UnitsContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_units) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 281 + self.match(AutolevParser.UnitSystem) + self.state = 282 + self.match(AutolevParser.ID) + self.state = 287 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 283 + self.match(AutolevParser.T__9) + self.state = 284 + self.match(AutolevParser.ID) + self.state = 289 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Input(self): + return self.getToken(AutolevParser.Input, 0) + + def inputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Inputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Inputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs" ): + listener.enterInputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs" ): + listener.exitInputs(self) + + + + + def inputs(self): + + localctx = AutolevParser.InputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_inputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 290 + self.match(AutolevParser.Input) + self.state = 291 + self.inputs2() + self.state = 296 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 292 + self.match(AutolevParser.T__9) + self.state = 293 + self.inputs2() + self.state = 298 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Id_diffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_id_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId_diff" ): + listener.enterId_diff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId_diff" ): + listener.exitId_diff(self) + + + + + def id_diff(self): + + localctx = AutolevParser.Id_diffContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_id_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 299 + self.match(AutolevParser.ID) + self.state = 301 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 300 + self.diff() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Inputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def id_diff(self): + return self.getTypedRuleContext(AutolevParser.Id_diffContext,0) + + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs2" ): + listener.enterInputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs2" ): + listener.exitInputs2(self) + + + + + def inputs2(self): + + localctx = AutolevParser.Inputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_inputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 303 + self.id_diff() + self.state = 304 + self.match(AutolevParser.T__2) + self.state = 305 + self.expr(0) + self.state = 307 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,34,self._ctx) + if la_ == 1: + self.state = 306 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class OutputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Output(self): + return self.getToken(AutolevParser.Output, 0) + + def outputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Outputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Outputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs" ): + listener.enterOutputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs" ): + listener.exitOutputs(self) + + + + + def outputs(self): + + localctx = AutolevParser.OutputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_outputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 309 + self.match(AutolevParser.Output) + self.state = 310 + self.outputs2() + self.state = 315 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 311 + self.match(AutolevParser.T__9) + self.state = 312 + self.outputs2() + self.state = 317 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Outputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs2" ): + listener.enterOutputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs2" ): + listener.exitOutputs2(self) + + + + + def outputs2(self): + + localctx = AutolevParser.Outputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_outputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.expr(0) + self.state = 320 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,36,self._ctx) + if la_ == 1: + self.state = 319 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodegenContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def matrixInOutput(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MatrixInOutputContext) + else: + return self.getTypedRuleContext(AutolevParser.MatrixInOutputContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_codegen + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodegen" ): + listener.enterCodegen(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodegen" ): + listener.exitCodegen(self) + + + + + def codegen(self): + + localctx = AutolevParser.CodegenContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_codegen) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 322 + self.match(AutolevParser.ID) + self.state = 323 + self.functionCall() + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==1: + self.state = 324 + self.match(AutolevParser.T__0) + self.state = 325 + self.matrixInOutput() + self.state = 330 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 326 + self.match(AutolevParser.T__9) + self.state = 327 + self.matrixInOutput() + self.state = 332 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 333 + self.match(AutolevParser.T__1) + + + self.state = 337 + self.match(AutolevParser.ID) + self.state = 338 + self.match(AutolevParser.T__19) + self.state = 339 + self.match(AutolevParser.ID) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Save(self): + return self.getToken(AutolevParser.Save, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def Encode(self): + return self.getToken(AutolevParser.Encode, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_commands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCommands" ): + listener.enterCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCommands" ): + listener.exitCommands(self) + + + + + def commands(self): + + localctx = AutolevParser.CommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_commands) + self._la = 0 # Token type + try: + self.state = 354 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [31]: + self.enterOuterAlt(localctx, 1) + self.state = 341 + self.match(AutolevParser.Save) + self.state = 342 + self.match(AutolevParser.ID) + self.state = 343 + self.match(AutolevParser.T__19) + self.state = 344 + self.match(AutolevParser.ID) + pass + elif token in [33]: + self.enterOuterAlt(localctx, 2) + self.state = 345 + self.match(AutolevParser.Encode) + self.state = 346 + self.match(AutolevParser.ID) + self.state = 351 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 347 + self.match(AutolevParser.T__9) + self.state = 348 + self.match(AutolevParser.ID) + self.state = 353 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VecContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_vec + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVec" ): + listener.enterVec(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVec" ): + listener.exitVec(self) + + + + + def vec(self): + + localctx = AutolevParser.VecContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_vec) + try: + self.state = 364 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 356 + self.match(AutolevParser.ID) + self.state = 358 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 357 + self.match(AutolevParser.T__20) + + else: + raise NoViableAltException(self) + self.state = 360 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,41,self._ctx) + + pass + elif token in [22]: + self.enterOuterAlt(localctx, 2) + self.state = 362 + self.match(AutolevParser.T__21) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 3) + self.state = 363 + self.match(AutolevParser.T__22) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_expr + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterParens" ): + listener.enterParens(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitParens" ): + listener.exitParens(self) + + + class VectorOrDyadicContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVectorOrDyadic" ): + listener.enterVectorOrDyadic(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVectorOrDyadic" ): + listener.exitVectorOrDyadic(self) + + + class ExponentContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExponent" ): + listener.enterExponent(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExponent" ): + listener.exitExponent(self) + + + class MulDivContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMulDiv" ): + listener.enterMulDiv(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMulDiv" ): + listener.exitMulDiv(self) + + + class AddSubContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterAddSub" ): + listener.enterAddSub(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitAddSub" ): + listener.exitAddSub(self) + + + class FloatContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFloat" ): + listener.enterFloat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFloat" ): + listener.exitFloat(self) + + + class IntContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInt" ): + listener.enterInt(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInt" ): + listener.exitInt(self) + + + class IdEqualsExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIdEqualsExpr" ): + listener.enterIdEqualsExpr(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIdEqualsExpr" ): + listener.exitIdEqualsExpr(self) + + + class NegativeOneContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterNegativeOne" ): + listener.enterNegativeOne(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitNegativeOne" ): + listener.exitNegativeOne(self) + + + class FunctionContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunction" ): + listener.enterFunction(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunction" ): + listener.exitFunction(self) + + + class RangessContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ranges(self): + return self.getTypedRuleContext(AutolevParser.RangesContext,0) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRangess" ): + listener.enterRangess(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRangess" ): + listener.exitRangess(self) + + + class ColonContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterColon" ): + listener.enterColon(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitColon" ): + listener.exitColon(self) + + + class IdContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId" ): + listener.enterId(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId" ): + listener.exitId(self) + + + class ExpContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExp" ): + listener.enterExp(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExp" ): + listener.exitExp(self) + + + class MatricesContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def matrix(self): + return self.getTypedRuleContext(AutolevParser.MatrixContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrices" ): + listener.enterMatrices(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrices" ): + listener.exitMatrices(self) + + + class IndexingContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexing" ): + listener.enterIndexing(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexing" ): + listener.exitIndexing(self) + + + + def expr(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = AutolevParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 54 + self.enterRecursionRule(localctx, 54, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 408 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,47,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExpContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 367 + self.match(AutolevParser.EXP) + pass + + elif la_ == 2: + localctx = AutolevParser.NegativeOneContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 368 + self.match(AutolevParser.T__17) + self.state = 369 + self.expr(12) + pass + + elif la_ == 3: + localctx = AutolevParser.FloatContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 370 + self.match(AutolevParser.FLOAT) + pass + + elif la_ == 4: + localctx = AutolevParser.IntContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 371 + self.match(AutolevParser.INT) + pass + + elif la_ == 5: + localctx = AutolevParser.IdContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 372 + self.match(AutolevParser.ID) + self.state = 376 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 373 + self.match(AutolevParser.T__10) + self.state = 378 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + + pass + + elif la_ == 6: + localctx = AutolevParser.VectorOrDyadicContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 379 + self.vec() + pass + + elif la_ == 7: + localctx = AutolevParser.IndexingContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 380 + self.match(AutolevParser.ID) + self.state = 381 + self.match(AutolevParser.T__0) + self.state = 382 + self.expr(0) + self.state = 387 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 383 + self.match(AutolevParser.T__9) + self.state = 384 + self.expr(0) + self.state = 389 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 390 + self.match(AutolevParser.T__1) + pass + + elif la_ == 8: + localctx = AutolevParser.FunctionContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 392 + self.functionCall() + pass + + elif la_ == 9: + localctx = AutolevParser.MatricesContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 393 + self.matrix() + pass + + elif la_ == 10: + localctx = AutolevParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 394 + self.match(AutolevParser.T__11) + self.state = 395 + self.expr(0) + self.state = 396 + self.match(AutolevParser.T__12) + pass + + elif la_ == 11: + localctx = AutolevParser.RangessContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 399 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 398 + self.match(AutolevParser.ID) + + + self.state = 401 + self.ranges() + self.state = 405 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 402 + self.match(AutolevParser.T__10) + self.state = 407 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 427 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 425 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExponentContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 410 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 411 + self.match(AutolevParser.T__23) + self.state = 412 + self.expr(17) + pass + + elif la_ == 2: + localctx = AutolevParser.MulDivContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 413 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 414 + _la = self._input.LA(1) + if not(_la==25 or _la==26): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 415 + self.expr(16) + pass + + elif la_ == 3: + localctx = AutolevParser.AddSubContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 416 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 417 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 418 + self.expr(15) + pass + + elif la_ == 4: + localctx = AutolevParser.IdEqualsExprContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 419 + if not self.precpred(self._ctx, 3): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 420 + self.match(AutolevParser.T__2) + self.state = 421 + self.expr(4) + pass + + elif la_ == 5: + localctx = AutolevParser.ColonContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 422 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 423 + self.match(AutolevParser.T__15) + self.state = 424 + self.expr(3) + pass + + + self.state = 429 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[27] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 3) + + + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..8314b2f546c0a18a8e281768b60d66556c852e3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py @@ -0,0 +1,86 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "Autolev.g4")) +dir_autolev_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_autolev_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + "-no-visitor", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to autolev* but Autolev*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "Autolev*.*")) or + glob.glob(os.path.join(output_dir, "autolev*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip().replace('AutolevParser import', 'autolevparser import') +'\n' + for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca2f8af88de18036b90788fd29d02707f098213 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py @@ -0,0 +1,2083 @@ +import collections +import warnings + +from sympy.external import import_module + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def strfunc(z): + if z == 0: + return "" + elif z == 1: + return "_d" + else: + return "_" + "d" * z + +def declare_phy_entities(self, ctx, phy_type, i, j=None): + if phy_type in ("frame", "newtonian"): + declare_frames(self, ctx, i, j) + elif phy_type == "particle": + declare_particles(self, ctx, i, j) + elif phy_type == "point": + declare_points(self, ctx, i, j) + elif phy_type == "bodies": + declare_bodies(self, ctx, i, j) + +def declare_frames(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + name2 = "frame_" + name1 + if self.getValue(ctx.parentCtx.varType()) == "newtonian": + self.newtonian = name2 + + self.symbol_table2.update({name1: name2}) + + self.symbol_table.update({name1 + "1>": name2 + ".x"}) + self.symbol_table.update({name1 + "2>": name2 + ".y"}) + self.symbol_table.update({name1 + "3>": name2 + ".z"}) + + self.type2.update({name1: "frame"}) + self.write(name2 + " = " + "_me.ReferenceFrame('" + name1 + "')\n") + +def declare_points(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "point_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "point"}) + self.write(name2 + " = " + "_me.Point('" + name1 + "')\n") + +def declare_particles(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "particle_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "particle"}) + self.bodies.update({name1: name2}) + self.write(name2 + " = " + "_me.Particle('" + name1 + "', " + "_me.Point('" + + name1 + "_pt" + "'), " + "_sm.Symbol('m'))\n") + +def declare_bodies(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "body_" + name1 + self.bodies.update({name1: name2}) + masscenter = name2 + "_cm" + refFrame = name2 + "_f" + + self.symbol_table2.update({name1: name2}) + self.symbol_table2.update({name1 + "o": masscenter}) + self.symbol_table.update({name1 + "1>": refFrame+".x"}) + self.symbol_table.update({name1 + "2>": refFrame+".y"}) + self.symbol_table.update({name1 + "3>": refFrame+".z"}) + + self.type2.update({name1: "bodies"}) + self.type2.update({name1+"o": "point"}) + + self.write(masscenter + " = " + "_me.Point('" + name1 + "_cm" + "')\n") + if self.newtonian: + self.write(masscenter + ".set_vel(" + self.newtonian + ", " + "0)\n") + self.write(refFrame + " = " + "_me.ReferenceFrame('" + name1 + "_f" + "')\n") + # We set a dummy mass and inertia here. + # They will be reset using the setters later in the code anyway. + self.write(name2 + " = " + "_me.RigidBody('" + name1 + "', " + masscenter + ", " + + refFrame + ", " + "_sm.symbols('m'), (_me.outer(" + refFrame + + ".x," + refFrame + ".x)," + masscenter + "))\n") + +def inertia_func(self, v1, v2, l, frame): + + if self.type2[v1] == "particle": + l.append("_me.inertia_of_point_mass(" + self.bodies[v1] + ".mass, " + self.bodies[v1] + + ".point.pos_from(" + self.symbol_table2[v2] + "), " + frame + ")") + + elif self.type2[v1] == "bodies": + # Inertia has been defined about center of mass. + if self.inertia_point[v1] == v1 + "o": + # Asking point is cm as well + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + + # Asking point is not cm + else: + l.append(self.bodies[v1] + ".inertia[0]" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + # Inertia has been defined about another point + else: + # Asking point is the defined point + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + # Asking point is cm + elif v2 == v1 + "o": + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")") + # Asking point is some other point + else: + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + +def processConstants(self, ctx): + # Process constant declarations of the type: Constants F = 3, g = 9.81 + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + self.symbol_table.update({name: name}) + # self.inputs.update({self.symbol_table[name]: self.getValue(ctx.getChild(2))}) + self.write(self.symbol_table[name] + " = " + "_sm.S(" + self.getValue(ctx.getChild(2)) + ")\n") + self.type.update({name: "constants"}) + return + + # Constants declarations of the type: Constants A, B + else: + if "{" not in ctx.getText(): + self.symbol_table[name] = name + self.type[name] = "constants" + + # Process constant declarations of the type: Constants C+, D- + if ctx.getChildCount() == 2: + # This is set for declaring nonpositive=True and nonnegative=True + if ctx.getChild(1).getText() == "+": + self.sign[name] = "+" + elif ctx.getChild(1).getText() == "-": + self.sign[name] = "-" + else: + if "{" not in ctx.getText(): + self.sign[name] = "o" + + # Process constant declarations of the type: Constants K{4}, a{1:2, 1:2}, b{1:2} + if "{" in ctx.getText(): + if ":" in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + if ":" in ctx.getText(): + if "," in ctx.getText(): + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + for i in range(num1, num2): + for j in range(num3, num4): + self.symbol_table[name + str(i) + str(j)] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + elif "," in ctx.getText(): + for i in range(1, int(ctx.INT(0).getText()) + 1): + for j in range(1, int(ctx.INT(1).getText()) + 1): + self.symbol_table[name] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + if "{" not in ctx.getText(): + self.var_list.append(name) + + +def writeConstants(self, ctx): + l1 = list(filter(lambda x: self.sign[x] == "o", self.var_list)) + l2 = list(filter(lambda x: self.sign[x] == "+", self.var_list)) + l3 = list(filter(lambda x: self.sign[x] == "-", self.var_list)) + try: + if self.settings["complex"] == "on": + real = ", real=True" + elif self.settings["complex"] == "off": + real = "" + except Exception: + real = ", real=True" + + if l1: + a = ", ".join(l1) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l1) + "'" + real + ")\n" + self.write(a) + if l2: + a = ", ".join(l2) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l2) + "'" + real + ", nonnegative=True)\n" + self.write(a) + if l3: + a = ", ".join(l3) + " = " + "_sm.symbols(" + "'" + \ + " ".join(l3) + "'" + real + ", nonpositive=True)\n" + self.write(a) + self.var_list = [] + + +def processVariables(self, ctx): + # Specified F = x*N1> + y*N2> + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + text = name + "'"*(ctx.getChildCount()-3) + self.write(text + " = " + self.getValue(ctx.expr()) + "\n") + return + + # Process variables of the type: Variables qA, qB + if ctx.getChildCount() == 1: + self.symbol_table[name] = name + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name: self.getValue(ctx.parentCtx.getChild(0))}) + + self.var_list.append(name) + self.sign[name] = 0 + + # Process variables of the type: Variables x', y'' + elif "'" in ctx.getText() and "{" not in ctx.getText(): + if ctx.getText().count("'") > self.maxDegree: + self.maxDegree = ctx.getText().count("'") + for i in range(ctx.getChildCount()): + self.sign[name + strfunc(i)] = i + self.symbol_table[name + "'"*i] = name + strfunc(i) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + "'"*i: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + strfunc(i)) + + elif "{" in ctx.getText(): + # Process variables of the type: Variables x{3}, y{2} + + if "'" in ctx.getText(): + dash_count = ctx.getText().count("'") + if dash_count > self.maxDegree: + self.maxDegree = dash_count + + if ":" in ctx.getText(): + # Variables C{1:2, 1:2} + if "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + # Variables C{1:2} + else: + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + + # Variables C{1,3} + elif "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + for i in range(num1, num2): + try: + for j in range(num3, num4): + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + str(j) + "'"*z: name + str(i) + str(j) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j) + strfunc(z)) + self.sign.update({name + str(i) + str(j) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i) + str(j): name + str(i) + str(j)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j)) + self.sign.update({name + str(i) + str(j): 0}) + except Exception: + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + "'"*z: name + str(i) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + strfunc(z)) + self.sign.update({name + str(i) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i): name + str(i)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i)) + self.sign.update({name + str(i): 0}) + +def writeVariables(self, ctx): + #print(self.sign) + #print(self.symbol_table) + if self.var_list: + for i in range(self.maxDegree+1): + if i == 0: + j = "" + t = "" + else: + j = str(i) + t = ", " + l = [] + for k in list(filter(lambda x: self.sign[x] == i, self.var_list)): + if i == 0: + l.append(k) + if i == 1: + l.append(k[:-1]) + if i > 1: + l.append(k[:-2]) + a = ", ".join(list(filter(lambda x: self.sign[x] == i, self.var_list))) + " = " +\ + "_me.dynamicsymbols(" + "'" + " ".join(l) + "'" + t + j + ")\n" + l = [] + self.write(a) + self.maxDegree = 0 + self.var_list = [] + +def processImaginary(self, ctx): + name = ctx.ID().getText().lower() + self.symbol_table[name] = name + self.type[name] = "imaginary" + self.var_list.append(name) + + +def writeImaginary(self, ctx): + a = ", ".join(self.var_list) + " = " + "_sm.symbols(" + "'" + \ + " ".join(self.var_list) + "')\n" + b = ", ".join(self.var_list) + " = " + "_sm.I\n" + self.write(a) + self.write(b) + self.var_list = [] + +if AutolevListener: + class MyListener(AutolevListener): # type: ignore + def __init__(self, include_numeric=False): + # Stores data in tree nodes(tree annotation). Especially useful for expr reconstruction. + self.tree_property = {} + + # Stores the declared variables, constants etc as they are declared in Autolev and SymPy + # {"": ""}. + self.symbol_table = collections.OrderedDict() + + # Similar to symbol_table. Used for storing Physical entities like Frames, Points, + # Particles, Bodies etc + self.symbol_table2 = collections.OrderedDict() + + # Used to store nonpositive, nonnegative etc for constants and number of "'"s (order of diff) + # in variables. + self.sign = {} + + # Simple list used as a store to pass around variables between the 'process' and 'write' + # methods. + self.var_list = [] + + # Stores the type of a declared variable (constants, variables, specifieds etc) + self.type = collections.OrderedDict() + + # Similar to self.type. Used for storing the type of Physical entities like Frames, Points, + # Particles, Bodies etc + self.type2 = collections.OrderedDict() + + # These lists are used to distinguish matrix, numeric and vector expressions. + self.matrix_expr = [] + self.numeric_expr = [] + self.vector_expr = [] + self.fr_expr = [] + + self.output_code = [] + + # Stores the variables and their rhs for substituting upon the Autolev command EXPLICIT. + self.explicit = collections.OrderedDict() + + # Write code to import common dependencies. + self.output_code.append("import sympy.physics.mechanics as _me\n") + self.output_code.append("import sympy as _sm\n") + self.output_code.append("import math as m\n") + self.output_code.append("import numpy as _np\n") + self.output_code.append("\n") + + # Just a store for the max degree variable in a line. + self.maxDegree = 0 + + # Stores the input parameters which are then used for codegen and numerical analysis. + self.inputs = collections.OrderedDict() + # Stores the variables which appear in Output Autolev commands. + self.outputs = [] + # Stores the settings specified by the user. Ex: Complex on/off, Degrees on/off + self.settings = {} + # Boolean which changes the behaviour of some expression reconstruction + # when parsing Input Autolev commands. + self.in_inputs = False + self.in_outputs = False + + # Stores for the physical entities. + self.newtonian = None + self.bodies = collections.OrderedDict() + self.constants = [] + self.forces = collections.OrderedDict() + self.q_ind = [] + self.q_dep = [] + self.u_ind = [] + self.u_dep = [] + self.kd_eqs = [] + self.dependent_variables = [] + self.kd_equivalents = collections.OrderedDict() + self.kd_equivalents2 = collections.OrderedDict() + self.kd_eqs_supplied = None + self.kane_type = "no_args" + self.inertia_point = collections.OrderedDict() + self.kane_parsed = False + self.t = False + + # PyDy ode code will be included only if this flag is set to True. + self.include_numeric = include_numeric + + def write(self, string): + self.output_code.append(string) + + def getValue(self, node): + return self.tree_property[node] + + def setValue(self, node, value): + self.tree_property[node] = value + + def getSymbolTable(self): + return self.symbol_table + + def getType(self): + return self.type + + def exitVarDecl(self, ctx): + # This event method handles variable declarations. The parse tree node varDecl contains + # one or more varDecl2 nodes. Eg varDecl for 'Constants a{1:2, 1:2}, b{1:2}' has two varDecl2 + # nodes(one for a{1:2, 1:2} and one for b{1:2}). + + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + + # determine the type of declaration + if self.getValue(ctx.varType()) == "constant": + writeConstants(self, ctx) + elif self.getValue(ctx.varType()) in\ + ("variable", "motionvariable", "motionvariable'", "specified"): + writeVariables(self, ctx) + elif self.getValue(ctx.varType()) == "imaginary": + writeImaginary(self, ctx) + + def exitVarType(self, ctx): + # Annotate the varType tree node with the type of the variable declaration. + name = ctx.getChild(0).getText().lower() + if name[-1] == "s" and name != "bodies": + self.setValue(ctx, name[:-1]) + else: + self.setValue(ctx, name) + + def exitVarDecl2(self, ctx): + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + # This is the case for constants, variables, specifieds etc. + + # This isn't the case for all types of declarations though. For instance + # Frames A, B, C, N cannot be defined on one line in SymPy. So we do not append A, B, C, N + # to a var_list or use exitVarDecl. exitVarDecl2 directly writes out to the file. + + # determine the type of declaration + if self.getValue(ctx.parentCtx.varType()) == "constant": + processConstants(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in \ + ("variable", "motionvariable", "motionvariable'", "specified"): + processVariables(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) == "imaginary": + processImaginary(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in ("frame", "newtonian", "point", "particle", "bodies"): + if "{" in ctx.getText(): + if ":" in ctx.getText() and "," not in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + elif ":" not in ctx.getText() and "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + elif ":" in ctx.getText() and "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + else: + num1 = 1 + num2 = 2 + for i in range(num1, num2): + try: + for j in range(num3, num4): + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i, j) + except Exception: + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i) + # ================== Subrules of parser rule expr (Start) ====================== # + + def exitId(self, ctx): + # Tree annotation for ID which is a labeled subrule of the parser rule expr. + # A_C + python_keywords = ["and", "as", "assert", "break", "class", "continue", "def", "del", "elif", "else", "except",\ + "exec", "finally", "for", "from", "global", "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",\ + "raise", "return", "try", "while", "with", "yield"] + + if ctx.ID().getText().lower() in python_keywords: + warnings.warn("Python keywords must not be used as identifiers. Please refer to the list of keywords at https://docs.python.org/2.5/ref/keywords.html", + SyntaxWarning) + + if "_" in ctx.ID().getText() and ctx.ID().getText().count('_') == 1: + e1, e2 = ctx.ID().getText().lower().split('_') + try: + if self.type2[e1] == "frame": + e1 = self.symbol_table2[e1] + elif self.type2[e1] == "bodies": + e1 = self.symbol_table2[e1] + "_f" + if self.type2[e2] == "frame": + e2 = self.symbol_table2[e2] + elif self.type2[e2] == "bodies": + e2 = self.symbol_table2[e2] + "_f" + + self.setValue(ctx, e1 + ".dcm(" + e2 + ")") + except Exception: + self.setValue(ctx, ctx.ID().getText().lower()) + else: + # Reserved constant Pi + if ctx.ID().getText().lower() == "pi": + self.setValue(ctx, "_sm.pi") + self.numeric_expr.append(ctx) + + # Reserved variable T (for time) + elif ctx.ID().getText().lower() == "t": + self.setValue(ctx, "_me.dynamicsymbols._t") + if not self.in_inputs and not self.in_outputs: + self.t = True + + else: + idText = ctx.ID().getText().lower() + "'"*(ctx.getChildCount() - 1) + if idText in self.type.keys() and self.type[idText] == "matrix": + self.matrix_expr.append(ctx) + if self.in_inputs: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + self.setValue(ctx, idText.lower()) + else: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + pass + + def exitInt(self, ctx): + # Tree annotation for int which is a labeled subrule of the parser rule expr. + int_text = ctx.INT().getText() + self.setValue(ctx, int_text) + self.numeric_expr.append(ctx) + + def exitFloat(self, ctx): + # Tree annotation for float which is a labeled subrule of the parser rule expr. + floatText = ctx.FLOAT().getText() + self.setValue(ctx, floatText) + self.numeric_expr.append(ctx) + + def exitAddSub(self, ctx): + # Tree annotation for AddSub which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (+|-) expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + + def exitMulDiv(self, ctx): + # Tree annotation for MulDiv which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (*|/) expr + try: + if ctx.expr(0) in self.vector_expr and ctx.expr(1) in self.vector_expr: + self.setValue(ctx, "_me.outer(" + self.getValue(ctx.expr(0)) + ", " + + self.getValue(ctx.expr(1)) + ")") + else: + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + except Exception: + pass + + def exitNegativeOne(self, ctx): + # Tree annotation for negativeOne which is a labeled subrule of the parser rule expr. + self.setValue(ctx, "-1*" + self.getValue(ctx.getChild(1))) + if ctx.getChild(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.getChild(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + + def exitParens(self, ctx): + # Tree annotation for parens which is a labeled subrule of the parser rule expr. + # The subrule is expr = '(' expr ')' + if ctx.expr() in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr() in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr() in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, "(" + self.getValue(ctx.expr()) + ")") + + def exitExponent(self, ctx): + # Tree annotation for Exponent which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr ^ expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + "**" + self.getValue(ctx.expr(1))) + + def exitExp(self, ctx): + s = ctx.EXP().getText()[ctx.EXP().getText().index('E')+1:] + if "-" in s: + s = s[0] + s[1:].lstrip("0") + else: + s = s.lstrip("0") + self.setValue(ctx, ctx.EXP().getText()[:ctx.EXP().getText().index('E')] + + "*10**(" + s + ")") + + def exitFunction(self, ctx): + # Tree annotation for function which is a labeled subrule of the parser rule expr. + + # The difference between this and FunctionCall is that this is used for non standalone functions + # appearing in expressions and assignments. + # Eg: + # When we come across a standalone function say Expand(E, n:m) then it is categorized as FunctionCall + # which is a parser rule in itself under rule stat. exitFunctionCall() takes care of it and writes to the file. + # + # On the other hand, while we come across E_diff = D(E, y), we annotate the tree node + # of the function D(E, y) with the SymPy equivalent in exitFunction(). + # In this case it is the method exitAssignment() that writes the code to the file and not exitFunction(). + + ch = ctx.getChild(0) + func_name = ch.getChild(0).getText().lower() + + # Expand(y, n:m) * + if func_name == "expand": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + # _sm.Matrix([i.expand() for i in z]).reshape(z.shape[0], z.shape[1]) + self.setValue(ctx, "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "expand()") + + # Factor(y, x) * + elif func_name == "factor": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([_sm.factor(i, " + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "_sm.factor(" + "(" + expr + ")" + + ", " + self.getValue(ch.expr(1)) + ")") + + # D(y, x) + elif func_name == "d": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.diff(" + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 8: + frame = self.symbol_table2[ch.expr(2).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + self.getValue(ch.expr(1)) + + ", " + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + + self.getValue(ch.expr(1)) + ")") + + # Dt(y) + elif func_name == "dt": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.vector_expr: + text = "dt(" + else: + text = "diff(_sm.Symbol('t')" + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i." + text + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 6: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "dt(" + + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + text + ")") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ch.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})") + else: + self.setValue(ctx, expr) + + # Taylor(y, 0:2, w=a, x=0) + # TODO: Currently only works with symbols. Make it work for dynamicsymbols. + elif func_name == "taylor": + exp = self.getValue(ch.expr(0)) + order = self.getValue(ch.expr(1).expr(1)) + x = (ch.getChildCount()-6)//2 + l = [] + for i in range(x): + index = 2 + i + child = ch.expr(index) + l.append(".series(" + self.getValue(child.getChild(0)) + + ", " + self.getValue(child.getChild(2)) + + ", " + order + ").removeO()") + self.setValue(ctx, "(" + exp + ")" + "".join(l)) + + # Evaluate(y, a=x, b=2) + elif func_name == "evaluate": + expr = self.getValue(ch.expr(0)) + l = [] + x = (ch.getChildCount()-4)//2 + for i in range(x): + index = 1 + i + child = ch.expr(index) + l.append(self.getValue(child.getChild(0)) + ":" + + self.getValue(child.getChild(2))) + + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if self.explicit: + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(explicit_list) + + "}).subs({" + ",".join(l) + "})") + else: + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(l) + "})") + + # Polynomial([a, b, c], x) + elif func_name == "polynomial": + self.setValue(ctx, "_sm.Poly(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + # Roots(Poly, x, 2) + # Roots([1; 2; 3; 4]) + elif func_name == "roots": + self.matrix_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + "_sm.Poly(" + expr + ", " + "x),x)]") + else: + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + expr + ", " + self.getValue(ch.expr(1)) + ")]") + + # Transpose(A), Inv(A) + elif func_name in ("transpose", "inv", "inverse"): + self.matrix_expr.append(ctx) + if func_name == "transpose": + e = ".T" + elif func_name in ("inv", "inverse"): + e = "**(-1)" + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e) + + # Eig(A) + elif func_name == "eig": + # "_sm.Matrix([i.evalf() for i in " + + self.setValue(ctx, "_sm.Matrix([i.evalf() for i in (" + + self.getValue(ch.expr(0)) + ").eigenvals().keys()])") + + # Diagmat(n, m, x) + # Diagmat(3, 1) + elif func_name == "diagmat": + self.matrix_expr.append(ctx) + if ch.getChildCount() == 6: + l = [] + for i in range(int(self.getValue(ch.expr(0)))): + l.append(self.getValue(ch.expr(1)) + ",") + + self.setValue(ctx, "_sm.diag(" + ("".join(l))[:-1] + ")") + + elif ch.getChildCount() == 8: + # _sm.Matrix([x if i==j else 0 for i in range(n) for j in range(m)]).reshape(n, m) + n = self.getValue(ch.expr(0)) + m = self.getValue(ch.expr(1)) + x = self.getValue(ch.expr(2)) + self.setValue(ctx, "_sm.Matrix([" + x + " if i==j else 0 for i in range(" + + n + ") for j in range(" + m + ")]).reshape(" + n + ", " + m + ")") + + # Cols(A) + # Cols(A, 1) + # Cols(A, 1, 2:4, 3) + elif func_name in ("cols", "rows"): + self.matrix_expr.append(ctx) + if func_name == "cols": + e1 = ".cols" + e2 = ".T." + else: + e1 = ".rows" + e2 = "." + if ch.getChildCount() == 4: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e1) + elif ch.getChildCount() == 6: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + e1[:-1] + "(" + str(int(self.getValue(ch.expr(1))) - 1) + ")") + else: + l = [] + for i in range(4, ch.getChildCount()): + try: + if ch.getChild(i).getChildCount() > 1 and ch.getChild(i).getChild(1).getText() == ":": + for j in range(int(ch.getChild(i).getChild(0).getText()), + int(ch.getChild(i).getChild(2).getText())+1): + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(j-1) + ")") + else: + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(int(ch.getChild(i).getText())-1) + ")") + except Exception: + pass + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "])") + + # Det(A) Trace(A) + elif func_name in ["det", "trace"]: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "." + + func_name + "()") + + # Element(A, 2, 3) + elif func_name == "element": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "[" + + str(int(self.getValue(ch.expr(1)))-1) + "," + + str(int(self.getValue(ch.expr(2)))-1) + "]") + + elif func_name in \ + ["cos", "sin", "tan", "cosh", "sinh", "tanh", "acos", "asin", "atan", + "log", "exp", "sqrt", "factorial", "floor", "sign"]: + self.setValue(ctx, "_sm." + func_name + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "ceil": + self.setValue(ctx, "_sm.ceiling" + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "sqr": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + "**2") + + elif func_name == "log10": + self.setValue(ctx, "_sm.log" + + "(" + self.getValue(ch.expr(0)) + ", 10)") + + elif func_name == "atan2": + self.setValue(ctx, "_sm.atan2" + "(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + elif func_name in ["int", "round"]: + self.setValue(ctx, func_name + + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "abs": + self.setValue(ctx, "_sm.Abs(" + self.getValue(ch.expr(0)) + ")") + + elif func_name in ["max", "min"]: + # max(x, y, z) + l = [] + for i in range(1, ch.getChildCount()): + if ch.getChild(i) in self.tree_property.keys(): + l.append(self.getValue(ch.getChild(i))) + elif ch.getChild(i).getText() in [",", "(", ")"]: + l.append(ch.getChild(i).getText()) + self.setValue(ctx, "_sm." + ch.getChild(0).getText().capitalize() + "".join(l)) + + # Coef(y, x) + elif func_name == "coef": + #A41_A53=COEF([RHS(U4);RHS(U5)],[U1,U2,U3]) + if ch.expr(0) in self.matrix_expr and ch.expr(1) in self.matrix_expr: + icount = jcount = 0 + for i in range(ch.expr(0).getChild(0).getChildCount()): + try: + ch.expr(0).getChild(0).getChild(i).getRuleIndex() + icount+=1 + except Exception: + pass + for j in range(ch.expr(1).getChild(0).getChildCount()): + try: + ch.expr(1).getChild(0).getChild(j).getRuleIndex() + jcount+=1 + except Exception: + pass + l = [] + for i in range(icount): + for j in range(jcount): + # a41_a53[i,j] = u4.expand().coeff(u1) + l.append(self.getValue(ch.expr(0).getChild(0).expr(i)) + ".expand().coeff(" + + self.getValue(ch.expr(1).getChild(0).expr(j)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ", ".join(l) + "]).reshape(" + str(icount) + ", " + str(jcount) + ")") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + ".expand().coeff(" + self.getValue(ch.expr(1)) + ")") + + # Exclude(y, x) Include(y, x) + elif func_name in ("exclude", "include"): + if func_name == "exclude": + e = "0" + else: + e = "1" + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(1)) + "])" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")" + "for i in " + expr + ")" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(1)) + ")" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")") + + # RHS(y) + elif func_name == "rhs": + self.setValue(ctx, self.explicit[self.getValue(ch.expr(0))]) + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(2)) + + ")" + "for i in " + expr + "])"+ + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(2)) + ")") + + # Replace(y, sin(x)=3) + elif func_name == "replace": + l = [] + for i in range(1, ch.getChildCount()): + try: + if ch.getChild(i).getChild(1).getText() == "=": + l.append(self.getValue(ch.getChild(i).getChild(0)) + + ":" + self.getValue(ch.getChild(i).getChild(2))) + except Exception: + pass + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + ".subs({" + ",".join(l) + "})") + + # Dot(Loop>, N1>) + elif func_name == "dot": + l = [] + num = (ch.expr(1).getChild(0).getChildCount()-1)//2 + if ch.expr(1) in self.matrix_expr: + for i in range(num): + l.append("_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1).getChild(0).expr(i)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "]).reshape(" + str(num) + ", " + "1)") + else: + self.setValue(ctx, "_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + # Cross(w_A_N>, P_NA_AB>) + elif func_name == "cross": + self.vector_expr.append(ctx) + self.setValue(ctx, "_me.cross(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + + # Mag(P_O_Q>) + elif func_name == "mag": + self.setValue(ctx, self.getValue(ch.expr(0)) + "." + "magnitude()") + + # MATRIX(A, I_R>>) + elif func_name == "matrix": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + self.setValue(ctx, "(" + self.getValue(ch.expr(1)) + ")" + ".to_matrix(" + + self.symbol_table2[ch.expr(0).getText().lower()] + text + ")") + + # VECTOR(A, ROWS(EIGVECS,1)) + elif func_name == "vector": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + v = self.getValue(ch.expr(1)) + f = self.symbol_table2[ch.expr(0).getText().lower()] + text + self.setValue(ctx, v + "[0]*" + f + ".x +" + v + "[1]*" + f + ".y +" + + v + "[2]*" + f + ".z") + + # Express(A2>, B) + # Here I am dealing with all the Inertia commands as I expect the users to use Inertia + # commands only with Express because SymPy needs the Reference frame to be specified unlike Autolev. + elif func_name == "express": + self.vector_expr.append(ctx) + if self.type2[ch.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ch.expr(1).getText().lower()] + else: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + "_f" + if ch.expr(0).getText().lower() == "1>>": + self.setValue(ctx, "_me.inertia(" + frame + ", 1, 1, 1)") + + elif '_' in ch.expr(0).getText().lower() and ch.expr(0).getText().lower().count('_') == 2\ + and ch.expr(0).getText().lower()[0] == "i" and ch.expr(0).getText().lower()[-2:] == ">>": + v1 = ch.expr(0).getText().lower()[:-2].split('_')[1] + v2 = ch.expr(0).getText().lower()[:-2].split('_')[2] + l = [] + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + elif ch.expr(0).getChild(0).getChild(0).getText().lower() == "inertia": + if ch.expr(0).getChild(0).getChildCount() == 4: + l = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for v1 in self.bodies: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + l = [] + l2 = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for i in range(1, (ch.expr(0).getChild(0).getChildCount()-2)//2): + l2.append(ch.expr(0).getChild(0).ID(i).getText().lower()) + for v1 in l2: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".express(" + + self.symbol_table2[ch.expr(1).getText().lower()] + ")") + # CM(P) + elif func_name == "cm": + if self.type2[ch.expr(0).getText().lower()] == "point": + text = "" + else: + text = ".point" + if ch.getChildCount() == 4: + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(self.bodies.values()) + ")") + else: + bodies = [] + for i in range(1, (ch.getChildCount()-1)//2): + bodies.append(self.symbol_table2[ch.expr(i).getText().lower()]) + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(bodies) + ")") + + # PARTIALS(V_P1_E>,U1) + elif func_name == "partials": + speeds = [] + for i in range(1, (ch.getChildCount()-1)//2): + if self.kd_equivalents2: + speeds.append(self.kd_equivalents2[self.symbol_table[ch.expr(i).getText().lower()]]) + else: + speeds.append(self.symbol_table[ch.expr(i).getText().lower()]) + v1, v2, v3 = ch.expr(0).getText().lower().replace(">","").split('_') + if self.type2[v2] == "point": + point = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + point = self.symbol_table2[v2] + ".point" + frame = self.symbol_table2[v3] + self.setValue(ctx, point + ".partial_velocity(" + frame + ", " + ",".join(speeds) + ")") + + # UnitVec(A1>+A2>+A3>) + elif func_name == "unitvec": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".normalize()") + + # Units(deg, rad) + elif func_name == "units": + if ch.expr(0).getText().lower() == "deg" and ch.expr(1).getText().lower() == "rad": + factor = 0.0174533 + elif ch.expr(0).getText().lower() == "rad" and ch.expr(1).getText().lower() == "deg": + factor = 57.2958 + self.setValue(ctx, str(factor)) + # Mass(A) + elif func_name == "mass": + l = [] + try: + ch.ID(0).getText().lower() + for i in range((ch.getChildCount()-1)//2): + l.append(self.symbol_table2[ch.ID(i).getText().lower()] + ".mass") + self.setValue(ctx, "+".join(l)) + except Exception: + for i in self.bodies.keys(): + l.append(self.bodies[i] + ".mass") + self.setValue(ctx, "+".join(l)) + + # Fr() FrStar() + # _me.KanesMethod(n, q_ind, u_ind, kd, velocity_constraints).kanes_equations(pl, fl)[0] + elif func_name in ["fr", "frstar"]: + if not self.kane_parsed: + if self.kd_eqs: + for i in self.kd_eqs: + self.q_ind.append(self.symbol_table[i.strip().split('-')[0].replace("'","")]) + self.u_ind.append(self.symbol_table[i.strip().split('-')[1].replace("'","")]) + + for i in range(len(self.kd_eqs)): + self.kd_eqs[i] = self.symbol_table[self.kd_eqs[i].strip().split('-')[0]] + " - " +\ + self.symbol_table[self.kd_eqs[i].strip().split('-')[1]] + + # Do all of this if kd_eqs are not specified + if not self.kd_eqs: + self.kd_eqs_supplied = False + self.matrix_expr.append(ctx) + for i in self.type.keys(): + if self.type[i] == "motionvariable": + if self.sign[self.symbol_table[i.lower()]] == 0: + self.q_ind.append(self.symbol_table[i.lower()]) + elif self.sign[self.symbol_table[i.lower()]] == 1: + name = "u_" + self.symbol_table[i.lower()] + self.symbol_table.update({name: name}) + self.write(name + " = " + "_me.dynamicsymbols('" + name + "')\n") + if self.symbol_table[i.lower()] not in self.dependent_variables: + self.u_ind.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + else: + self.u_dep.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + + for i in self.kd_equivalents.keys(): + self.kd_eqs.append(self.kd_equivalents[i] + "-" + i) + + if not self.u_ind and not self.kd_eqs: + self.u_ind = self.q_ind.copy() + self.q_ind = [] + + # deal with velocity constraints + if self.dependent_variables: + for i in self.dependent_variables: + self.u_dep.append(i) + if i in self.u_ind: + self.u_ind.remove(i) + + + self.u_dep[:] = [i for i in self.u_dep if i not in self.kd_equivalents.values()] + + force_list = [] + for i in self.forces.keys(): + force_list.append("(" + i + "," + self.forces[i] + ")") + if self.u_dep: + u_dep_text = ", u_dependent=[" + ", ".join(self.u_dep) + "]" + else: + u_dep_text = "" + if self.dependent_variables: + velocity_constraints_text = ", velocity_constraints = velocity_constraints" + else: + velocity_constraints_text = "" + if ctx.parentCtx not in self.fr_expr: + self.write("kd_eqs = [" + ", ".join(self.kd_eqs) + "]\n") + self.write("forceList = " + "[" + ", ".join(force_list) + "]\n") + self.write("kane = _me.KanesMethod(" + self.newtonian + ", " + "q_ind=[" + + ",".join(self.q_ind) + "], " + "u_ind=[" + + ", ".join(self.u_ind) + "]" + u_dep_text + ", " + + "kd_eqs = kd_eqs" + velocity_constraints_text + ")\n") + self.write("fr, frstar = kane." + "kanes_equations([" + + ", ".join(self.bodies.values()) + "], forceList)\n") + self.fr_expr.append(ctx.parentCtx) + self.kane_parsed = True + self.setValue(ctx, func_name) + + def exitMatrices(self, ctx): + # Tree annotation for Matrices which is a labeled subrule of the parser rule expr. + + # MO = [a, b; c, d] + # we generate _sm.Matrix([a, b, c, d]).reshape(2, 2) + # The reshape values are determined by counting the "," and ";" in the Autolev matrix + + # Eg: + # [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12] + # semicolon_count = 3 and rows = 3+1 = 4 + # comma_count = 8 and cols = 8/rows + 1 = 8/4 + 1 = 3 + + # TODO** Parse block matrices + self.matrix_expr.append(ctx) + l = [] + semicolon_count = 0 + comma_count = 0 + for i in range(ctx.matrix().getChildCount()): + child = ctx.matrix().getChild(i) + if child == AutolevParser.ExprContext: + l.append(self.getValue(child)) + elif child.getText() == ";": + semicolon_count += 1 + l.append(",") + elif child.getText() == ",": + comma_count += 1 + l.append(",") + else: + try: + try: + l.append(self.getValue(child)) + except Exception: + l.append(self.symbol_table[child.getText().lower()]) + except Exception: + l.append(child.getText().lower()) + num_of_rows = semicolon_count + 1 + num_of_cols = (comma_count//num_of_rows) + 1 + + self.setValue(ctx, "_sm.Matrix(" + "".join(l) + ")" + ".reshape(" + + str(num_of_rows) + ", " + str(num_of_cols) + ")") + + def exitVectorOrDyadic(self, ctx): + self.vector_expr.append(ctx) + ch = ctx.vec() + + if ch.getChild(0).getText() == "0>": + self.setValue(ctx, "0") + + elif ch.getChild(0).getText() == "1>>": + self.setValue(ctx, "1>>") + + elif "_" in ch.ID().getText() and ch.ID().getText().count('_') == 2: + vec_text = ch.getText().lower() + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + get_vec = e3 + ".pos_from(" + e2 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".ang_vel_in(" + elif v1 == "alf": + text = ".ang_acc_in(" + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + elif self.type2[v2] == "frame": + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + elif self.type2[v3] == "frame": + e3 = self.symbol_table2[v3] + get_vec = e2 + text + e3 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".vel(" + elif v1 == "a": + text = ".acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + get_vec = e2 + text + self.symbol_table2[v3] + ")" + self.setValue(ctx, get_vec) + + else: + self.setValue(ctx, vec_text.replace(">", "")) + + else: + vec_text = ch.getText().lower() + name = self.symbol_table[vec_text] + self.setValue(ctx, name) + + def exitIndexing(self, ctx): + if ctx.getChildCount() == 4: + try: + int_text = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text + "]") + elif ctx.getChildCount() == 6: + try: + int_text1 = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text1 = self.getValue(ctx.getChild(2)) + " - 1" + try: + int_text2 = str(int(self.getValue(ctx.getChild(4))) - 1) + except Exception: + int_text2 = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text1 + ", " + int_text2 + "]") + + + # ================== Subrules of parser rule expr (End) ====================== # + + def exitRegularAssign(self, ctx): + # Handle assignments of type ID = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + try: + a = ctx.ID().getText().lower() + "'"*ctx.diff().getText().count("'") + except Exception: + a = ctx.ID().getText().lower() + + if a in self.type.keys() and self.type[a] in ("motionvariable", "motionvariable'") and\ + self.type[ctx.expr().getText().lower()] in ("motionvariable", "motionvariable'"): + b = ctx.expr().getText().lower() + if "'" in b and "'" not in a: + a, b = b, a + if not self.kane_parsed: + self.kd_eqs.append(a + "-" + b) + self.kd_equivalents.update({self.symbol_table[a]: + self.symbol_table[b]}) + self.kd_equivalents2.update({self.symbol_table[b]: + self.symbol_table[a]}) + + if a in self.symbol_table.keys() and a in self.type.keys() and self.type[a] in ("variable", "motionvariable"): + self.explicit.update({self.symbol_table[a]: self.getValue(ctx.expr())}) + + else: + if ctx.expr() in self.matrix_expr: + self.type.update({a: "matrix"}) + + try: + b = self.symbol_table[a] + except KeyError: + self.symbol_table[a] = a + + if "_" in a and a.count("_") == 1: + e1, e2 = a.split('_') + if e1 in self.type2.keys() and self.type2[e1] in ("frame", "bodies")\ + and e2 in self.type2.keys() and self.type2[e2] in ("frame", "bodies"): + if self.type2[e1] == "bodies": + t1 = "_f" + else: + t1 = "" + if self.type2[e2] == "bodies": + t2 = "_f" + else: + t2 = "" + + self.write(self.symbol_table2[e2] + t2 + ".orient(" + self.symbol_table2[e1] + + t1 + ", 'DCM', " + self.getValue(ctx.expr()) + ")\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + + def exitIndexAssign(self, ctx): + # Handle assignments of type ID[index] = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + text = ctx.ID().getText().lower() + self.type.update({text: "matrix"}) + # Handle assignments of type ID[2] = expr + if ctx.index().getChildCount() == 1: + if ctx.index().getChild(0).getText() == "1": + self.type.update({text: "matrix"}) + self.symbol_table.update({text: text}) + self.write(text + " = " + "_sm.Matrix([[0]])\n") + self.write(text + "[0] = " + self.getValue(ctx.expr()) + "\n") + else: + # m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) + self.write(text + " = " + text + + ".row_insert(" + text + ".shape[0]" + ", " + "_sm.Matrix([[0]])" + ")\n") + self.write(text + "[" + text + ".shape[0]-1" + "] = " + self.getValue(ctx.expr()) + "\n") + + # Handle assignments of type ID[2, 2] = expr + elif ctx.index().getChildCount() == 3: + l = [] + try: + l.append(str(int(self.getValue(ctx.index().getChild(0)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(0)) + "-1") + l.append(",") + try: + l.append(str(int(self.getValue(ctx.index().getChild(2)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(2)) + "-1") + self.write(self.symbol_table[ctx.ID().getText().lower()] + + "[" + "".join(l) + "]" + " " + equals + " " + self.getValue(ctx.expr()) + "\n") + + def exitVecAssign(self, ctx): + # Handle assignments of the type vec = expr + ch = ctx.vec() + vec_text = ch.getText().lower() + + if "_" in ch.ID().getText(): + num = ch.ID().getText().count('_') + + if num == 2: + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + # ab.set_pos(na, la*a.x) + self.write(e3 + ".set_pos(" + e2 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".set_ang_vel(" + elif v1 == "alf": + text = ".set_ang_acc(" + # a.set_ang_vel(n, qad*a.z) + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + else: + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + else: + e3 = self.symbol_table2[v3] + self.write(e2 + text + e3 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".set_vel(" + elif v1 == "a": + text = ".set_acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.write(e2 + text + self.symbol_table2[v3] + + ", " + self.getValue(ctx.expr()) + ")\n") + elif v1 == "i": + if v2 in self.type2.keys() and self.type2[v2] == "bodies": + self.write(self.symbol_table2[v2] + ".inertia = (" + self.getValue(ctx.expr()) + + ", " + self.symbol_table2[v3] + ")\n") + self.inertia_point.update({v2: v3}) + elif v2 in self.type2.keys() and self.type2[v2] == "particle": + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + + elif num == 1: + v1, v2 = ch.ID().getText().lower().split('_') + + if v1 in ("force", "torque"): + if self.type2[v2] in ("point", "frame"): + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.symbol_table.update({vec_text: ch.ID().getText().lower()}) + + if e2 in self.forces.keys(): + self.forces[e2] = self.forces[e2] + " + " + self.getValue(ctx.expr()) + else: + self.forces.update({e2: self.getValue(ctx.expr())}) + self.write(ch.ID().getText().lower() + " = " + self.forces[e2] + "\n") + + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + + def enterInputs2(self, ctx): + self.in_inputs = True + + # Inputs + def exitInputs2(self, ctx): + # Stores numerical values given by the input command which + # are used for codegen and numerical analysis. + if ctx.getChildCount() == 3: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: self.getValue(ctx.expr(0))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): self.getValue(ctx.expr(0))}) + elif ctx.getChildCount() == 4: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + + self.in_inputs = False + + def enterOutputs(self, ctx): + self.in_outputs = True + def exitOutputs(self, ctx): + self.in_outputs = False + + def exitOutputs2(self, ctx): + try: + if "[" in ctx.expr(1).getText(): + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()] + + ctx.expr(1).getText().lower()) + else: + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()]) + + except Exception: + pass + + # Code commands + def exitCodegen(self, ctx): + # Handles the CODE() command ie the solvers and the codgen part. + # Uses linsolve for the algebraic solvers and nsolve for non linear solvers. + + if ctx.functionCall().getChild(0).getText().lower() == "algebraic": + matrix_name = self.getValue(ctx.functionCall().expr(0)) + e = [] + d = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + + for i in self.inputs.keys(): + d.append(i + ":" + self.inputs[i]) + self.write(matrix_name + "_list" + " = " + "[]\n") + self.write("for i in " + matrix_name + ": " + matrix_name + + "_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.linsolve(" + matrix_name + "_list" + ", " + ",".join(e) + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() == "nonlinear": + e = [] + d = [] + guess = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + #print(self.inputs) + for i in self.inputs.keys(): + if i in self.symbol_table.keys(): + if type(self.inputs[i]) is tuple: + j, z = self.inputs[i] + else: + j = self.inputs[i] + z = "" + if i not in e: + if z == "deg": + d.append(i + ":" + "_np.deg2rad(" + j + ")") + else: + d.append(i + ":" + j) + else: + if z == "deg": + guess.append("_np.deg2rad(" + j + ")") + else: + guess.append(j) + + self.write("matrix_list" + " = " + "[]\n") + self.write("for i in " + self.getValue(ctx.functionCall().expr(0)) + ":") + self.write("matrix_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.nsolve(matrix_list," + "(" + ",".join(e) + ")" + + ",(" + ",".join(guess) + ")" + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() in ["ode", "dynamics"] and self.include_numeric: + if self.kane_type == "no_args": + for i in self.symbol_table.keys(): + try: + if self.type[i] == "constants" or self.type[self.symbol_table[i]] == "constants": + self.constants.append(self.symbol_table[i]) + except Exception: + pass + q_add_u = self.q_ind + self.q_dep + self.u_ind + self.u_dep + x0 = [] + for i in q_add_u: + try: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + if self.inputs[i][1] == "deg": + x0.append(i + ":" + "_np.deg2rad(" + self.inputs[i][0] + ")") + else: + x0.append(i + ":" + self.inputs[i][0]) + else: + x0.append(i + ":" + self.inputs[i]) + elif self.kd_equivalents[i] in self.inputs.keys(): + if type(self.inputs[self.kd_equivalents[i]]) is tuple: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]][0]) + else: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]]) + except Exception: + pass + + # numerical constants + numerical_constants = [] + for i in self.constants: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + numerical_constants.append(self.inputs[i][0]) + else: + numerical_constants.append(self.inputs[i]) + + # t = linspace + t_final = self.inputs["tfinal"] + integ_stp = self.inputs["integstp"] + + self.write("from pydy.system import System\n") + const_list = [] + if numerical_constants: + for i in range(len(self.constants)): + const_list.append(self.constants[i] + ":" + numerical_constants[i]) + specifieds = [] + if self.t: + specifieds.append("_me.dynamicsymbols('t')" + ":" + "lambda x, t: t") + + for i in self.inputs: + if i in self.symbol_table.keys() and self.symbol_table[i] not in\ + self.constants + self.q_ind + self.q_dep + self.u_ind + self.u_dep: + specifieds.append(self.symbol_table[i] + ":" + self.inputs[i]) + + self.write("sys = System(kane, constants = {" + ", ".join(const_list) + "},\n" + + "specifieds={" + ", ".join(specifieds) + "},\n" + + "initial_conditions={" + ", ".join(x0) + "},\n" + + "times = _np.linspace(0.0, " + str(t_final) + ", " + str(t_final) + + "/" + str(integ_stp) + "))\n\ny=sys.integrate()\n") + + # For outputs other than qs and us. + other_outputs = [] + for i in self.outputs: + if i not in q_add_u: + if "[" in i: + other_outputs.append((i[:-3] + i[-2], i[:-3] + "[" + str(int(i[-2])-1) + "]")) + else: + other_outputs.append((i, i)) + + for i in other_outputs: + self.write(i[0] + "_out" + " = " + "[]\n") + if other_outputs: + self.write("for i in y:\n") + self.write(" q_u_dict = dict(zip(sys.coordinates+sys.speeds, i))\n") + for i in other_outputs: + self.write(" "*4 + i[0] + "_out" + ".append(" + i[1] + ".subs(q_u_dict)" + + ".subs(sys.constants).evalf())\n") + + # Standalone function calls (used for dual functions) + def exitFunctionCall(self, ctx): + # Basically deals with standalone function calls ie functions which are not a part of + # expressions and assignments. Autolev Dual functions can both appear in standalone + # function calls and also on the right hand side as part of expr or assignment. + + # Dual functions are indicated by a * in the comments below + + # Checks if the function is a statement on its own + if ctx.parentCtx.getRuleIndex() == AutolevParser.RULE_stat: + func_name = ctx.getChild(0).getText().lower() + # Expand(E, n:m) * + if func_name == "expand": + # If the first argument is a pre declared variable. + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(symbol + " = " + symbol + "." + "expand()\n") + + # Factor(E, x) * + elif func_name == "factor": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([_sm.factor(i," + self.getValue(ctx.expr(1)) + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(expr + " = " + "_sm.factor(" + expr + ", " + + self.getValue(ctx.expr(1)) + ")\n") + + # Solve(Zero, x, y) + elif func_name == "solve": + l = [] + l2 = [] + num = 0 + for i in range(1, ctx.getChildCount()): + if ctx.getChild(i).getText() == ",": + num+=1 + try: + l.append(self.getValue(ctx.getChild(i))) + except Exception: + l.append(ctx.getChild(i).getText()) + + if i != 2: + try: + l2.append(self.getValue(ctx.getChild(i))) + except Exception: + pass + + for i in l2: + self.explicit.update({i: "_sm.solve" + "".join(l) + "[" + i + "]"}) + + self.write("print(_sm.solve" + "".join(l) + ")\n") + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.collect(" + self.getValue(ctx.expr(2)) + + ")" + "for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(self.getValue(ctx.expr(0)) + ".collect(" + + self.getValue(ctx.expr(2)) + ")\n") + + # Eig(M, EigenValue, EigenVec) + elif func_name == "eig": + self.symbol_table.update({ctx.expr(1).getText().lower(): ctx.expr(1).getText().lower()}) + self.symbol_table.update({ctx.expr(2).getText().lower(): ctx.expr(2).getText().lower()}) + # _sm.Matrix([i.evalf() for i in (i_s_so).eigenvals().keys()]) + self.write(ctx.expr(1).getText().lower() + " = " + + "_sm.Matrix([i.evalf() for i in " + + "(" + self.getValue(ctx.expr(0)) + ")" + ".eigenvals().keys()])\n") + # _sm.Matrix([i[2][0].evalf() for i in (i_s_o).eigenvects()]).reshape(i_s_o.shape[0], i_s_o.shape[1]) + self.write(ctx.expr(2).getText().lower() + " = " + + "_sm.Matrix([i[2][0].evalf() for i in " + "(" + self.getValue(ctx.expr(0)) + ")" + + ".eigenvects()]).reshape(" + self.getValue(ctx.expr(0)) + ".shape[0], " + + self.getValue(ctx.expr(0)) + ".shape[1])\n") + + # Simprot(N, A, 3, qA) + elif func_name == "simprot": + # A.orient(N, 'Axis', qA, N.z) + if self.type2[ctx.expr(0).getText().lower()] == "frame": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + elif self.type2[ctx.expr(0).getText().lower()] == "bodies": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + "_f" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + e2 = "" + if ctx.expr(2).getText()[0] == "-": + e2 = "-1*" + if ctx.expr(2).getText() in ("1", "-1"): + e = frame1 + ".x" + elif ctx.expr(2).getText() in ("2", "-2"): + e = frame1 + ".y" + elif ctx.expr(2).getText() in ("3", "-3"): + e = frame1 + ".z" + else: + e = self.getValue(ctx.expr(2)) + e2 = "" + + if "degrees" in self.settings.keys() and self.settings["degrees"] == "off": + value = self.getValue(ctx.expr(3)) + else: + if ctx.expr(3) in self.numeric_expr: + value = "_np.deg2rad(" + self.getValue(ctx.expr(3)) + ")" + else: + value = self.getValue(ctx.expr(3)) + self.write(frame2 + ".orient(" + frame1 + + ", " + "'Axis'" + ", " + "[" + value + + ", " + e2 + e + "]" + ")\n") + + # Express(A2>, B) * + elif func_name == "express": + if self.type2[ctx.expr(1).getText().lower()] == "bodies": + f = "_f" + else: + f = "" + + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "v": + self.write(v1 + ".set_vel(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "a": + self.write(v1 + ".set_acc(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + + # Angvel(A, B) + elif func_name == "angvel": + self.write("print(" + self.symbol_table2[ctx.expr(1).getText().lower()] + + ".ang_vel_in(" + self.symbol_table2[ctx.expr(0).getText().lower()] + "))\n") + + # v2pts(N, A, O, P) + elif func_name in ("v2pts", "a2pts", "v2pt", "a1pt"): + if func_name == "v2pts": + text = ".v2pt_theory(" + elif func_name == "a2pts": + text = ".a2pt_theory(" + elif func_name == "v1pt": + text = ".v1pt_theory(" + elif func_name == "a1pt": + text = ".a1pt_theory(" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + expr_list = [] + for i in range(2, 4): + if self.type2[ctx.expr(i).getText().lower()] == "point": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()]) + elif self.type2[ctx.expr(i).getText().lower()] == "particle": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()] + ".point") + + self.write(expr_list[1] + text + expr_list[0] + + "," + self.symbol_table2[ctx.expr(0).getText().lower()] + "," + + frame + ")\n") + + # Gravity(g*N1>) + elif func_name == "gravity": + for i in self.bodies.keys(): + if self.type2[i] == "bodies": + e = self.symbol_table2[i] + ".masscenter" + elif self.type2[i] == "particle": + e = self.symbol_table2[i] + ".point" + if e in self.forces.keys(): + self.forces[e] = self.forces[e] + self.symbol_table2[i] +\ + ".mass*(" + self.getValue(ctx.expr(0)) + ")" + else: + self.forces.update({e: self.symbol_table2[i] + + ".mass*(" + self.getValue(ctx.expr(0)) + ")"}) + self.write("force_" + i + " = " + self.forces[e] + "\n") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ctx.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ctx.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "v": + self.write(v2 + ".set_vel(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "a": + self.write(v2 + ".set_acc(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + + # Force(O/Q, -k*Stretch*Uvec>) + elif func_name in ("force", "torque"): + + if "/" in ctx.expr(0).getText().lower(): + p1 = ctx.expr(0).getText().lower().split('/')[0] + p2 = ctx.expr(0).getText().lower().split('/')[1] + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if self.type2[p2] in ("point", "frame"): + pt2 = self.symbol_table2[p2] + elif self.type2[p2] == "particle": + pt2 = self.symbol_table2[p2] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + " + -1*("+self.getValue(ctx.expr(1)) + ")" + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + else: + self.forces.update({pt1: "-1*("+self.getValue(ctx.expr(1)) + ")"}) + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + if pt2 in self.forces.keys(): + self.forces[pt2] = self.forces[pt2] + "+ " + self.getValue(ctx.expr(1)) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + else: + self.forces.update({pt2: self.getValue(ctx.expr(1))}) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + + elif ctx.expr(0).getChildCount() == 1: + p1 = ctx.expr(0).getText().lower() + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + "+ -1*(" + self.getValue(ctx.expr(1)) + ")" + else: + self.forces.update({pt1: "-1*(" + self.getValue(ctx.expr(1)) + ")"}) + + # Constrain(Dependent[qB]) + elif func_name == "constrain": + if ctx.getChild(2).getChild(0).getText().lower() == "dependent": + self.write("velocity_constraints = [i for i in dependent]\n") + x = (ctx.expr(0).getChildCount()-2)//2 + for i in range(x): + self.dependent_variables.append(self.getValue(ctx.expr(0).expr(i))) + + # Kane() + elif func_name == "kane": + if ctx.getChildCount() == 3: + self.kane_type = "no_args" + + # Settings + def exitSettings(self, ctx): + # Stores settings like Complex on/off, Degrees on/off etc in self.settings. + try: + self.settings.update({ctx.getChild(0).getText().lower(): + ctx.getChild(1).getText().lower()}) + except Exception: + pass + + def exitMassDecl2(self, ctx): + # Used for declaring the masses of particles and rigidbodies. + particle = self.symbol_table2[ctx.getChild(0).getText().lower()] + if ctx.getText().count("=") == 2: + if ctx.expr().expr(1) in self.numeric_expr: + e = "_sm.S(" + self.getValue(ctx.expr().expr(1)) + ")" + else: + e = self.getValue(ctx.expr().expr(1)) + self.symbol_table.update({ctx.expr().expr(0).getText().lower(): ctx.expr().expr(0).getText().lower()}) + self.write(ctx.expr().expr(0).getText().lower() + " = " + e + "\n") + mass = ctx.expr().expr(0).getText().lower() + else: + try: + if ctx.expr() in self.numeric_expr: + mass = "_sm.S(" + self.getValue(ctx.expr()) + ")" + else: + mass = self.getValue(ctx.expr()) + except Exception: + a_text = ctx.expr().getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + mass = a_text + + self.write(particle + ".mass = " + mass + "\n") + + def exitInertiaDecl(self, ctx): + inertia_list = [] + try: + ctx.ID(1).getText() + num = 5 + except Exception: + num = 2 + for i in range((ctx.getChildCount()-num)//2): + try: + if ctx.expr(i) in self.numeric_expr: + inertia_list.append("_sm.S(" + self.getValue(ctx.expr(i)) + ")") + else: + inertia_list.append(self.getValue(ctx.expr(i))) + except Exception: + a_text = ctx.expr(i).getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + inertia_list.append(a_text) + + if len(inertia_list) < 6: + for i in range(6-len(inertia_list)): + inertia_list.append("0") + # body_a.inertia = (_me.inertia(body_a, I1, I2, I3, 0, 0, 0), body_a_cm) + try: + frame = self.symbol_table2[ctx.ID(1).getText().lower()] + point = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[1]] + body = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[0]] + self.inertia_point.update({ctx.ID(0).getText().lower().split('_')[0] + : ctx.ID(0).getText().lower().split('_')[1]}) + self.write(body + ".inertia" + " = " + "(_me.inertia(" + frame + ", " + + ", ".join(inertia_list) + "), " + point + ")\n") + + except Exception: + body_name = self.symbol_table2[ctx.ID(0).getText().lower()] + body_name_cm = body_name + "_cm" + self.inertia_point.update({ctx.ID(0).getText().lower(): ctx.ID(0).getText().lower() + "o"}) + self.write(body_name + ".inertia" + " = " + "(_me.inertia(" + body_name + "_f" + ", " + + ", ".join(inertia_list) + "), " + body_name_cm + ")\n") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..e43924aac30903ade996b31921d3960afae90284 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py @@ -0,0 +1,38 @@ +from importlib.metadata import version +from sympy.external import import_module + + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def parse_autolev(autolev_code, include_numeric): + antlr4 = import_module('antlr4') + if not antlr4 or not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("Autolev parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime)" + " conda (antlr-python-runtime), version 4.11") + try: + l = autolev_code.readlines() + input_stream = antlr4.InputStream("".join(l)) + except Exception: + input_stream = antlr4.InputStream(autolev_code) + + if AutolevListener: + from ._listener_autolev_antlr import MyListener + lexer = AutolevLexer(input_stream) + token_stream = antlr4.CommonTokenStream(lexer) + parser = AutolevParser(token_stream) + tree = parser.prog() + my_listener = MyListener(include_numeric) + walker = antlr4.ParseTreeWalker() + walker.walk(my_listener, tree) + return "".join(my_listener.output_code) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/README.txt b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..946b006bac33544fadd2dc6d24c22240c8fbc8e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/README.txt @@ -0,0 +1,9 @@ +# parsing/tests/test_autolev.py uses the .al files in this directory as inputs and checks +# the equivalence of the parser generated codes and the respective .py files. + +# By default, this directory contains tests for all rules of the parser. + +# Additional tests consisting of full physics examples shall be made available soon in +# the form of another repository. One shall be able to copy the contents of that repo +# to this folder and use those tests after uncommenting the respective code in +# parsing/tests/test_autolev.py. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92be1967ef3090eed75ad6435a245b3c6ceced32 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af1f11c3aa3ecf35b6e0d361b50f4e05b3dc5dc9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da8f99c60f2cda1c47d273ac8ac3266e16427dda Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8278cbf51d19fc531f4d3d6fc797caf7a2ea894 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b60567dd92d853c8afed5de71ca14cdeb815dd1d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ebd79b6c31ac555ad35a4e048186152a0745cbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8775f0886ee6e9e7ae0a5cf297793df56ea12216 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761ce0d311b9b55ee5b27fbc4003ece0b1baf2d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b30f75843ac124588a00c3fd2aa7ac3d048982a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3faf6ad85898e5cb8695c7116df292ef109873 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a3d67a9e177e1d3303226b5d6a7efb20874c3f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8553141862b8bdf3f8d7b760ef913ded345cdf7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26079620e81c291ded12b207640b5e52ae7fa14e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c3478c4c5df60f245f6065aca333b61c471e942 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a87099a82bfd45560001e61d346415f8b406054 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..055311a865cd646a381d92dc90dffbfa84499684 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..3bbb4d51b853bfd759df38d666a42adc1cbea190 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al @@ -0,0 +1,33 @@ +CONSTANTS G,LB,W,H +MOTIONVARIABLES' THETA'',PHI'',OMEGA',ALPHA' +NEWTONIAN N +BODIES A,B +SIMPROT(N,A,2,THETA) +SIMPROT(A,B,3,PHI) +POINT O +LA = (LB-H/2)/2 +P_O_AO> = LA*A3> +P_O_BO> = LB*A3> +OMEGA = THETA' +ALPHA = PHI' +W_A_N> = OMEGA*N2> +W_B_A> = ALPHA*A3> +V_O_N> = 0> +V2PTS(N, A, O, AO) +V2PTS(N, A, O, BO) +MASS A=MA, B=MB +IAXX = 1/12*MA*(2*LA)^2 +IAYY = IAXX +IAZZ = 0 +IBXX = 1/12*MB*H^2 +IBYY = 1/12*MB*(W^2+H^2) +IBZZ = 1/12*MB*W^2 +INERTIA A, IAXX, IAYY, IAZZ +INERTIA B, IBXX, IBYY, IBZZ +GRAVITY(G*N3>) +ZERO = FR() + FRSTAR() +KANE() +INPUT LB=0.2,H=0.1,W=0.2,MA=0.01,MB=0.1,G=9.81 +INPUT THETA = 90 DEG, PHI = 0.5 DEG, OMEGA=0, ALPHA=0 +INPUT TFINAL=10, INTEGSTP=0.02 +CODE DYNAMICS() some_filename.c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..4435635720bb38f40366f55bb3ace0f6f6899284 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +g, lb, w, h = _sm.symbols('g lb w h', real=True) +theta, phi, omega, alpha = _me.dynamicsymbols('theta phi omega alpha') +theta_d, phi_d, omega_d, alpha_d = _me.dynamicsymbols('theta_ phi_ omega_ alpha_', 1) +theta_dd, phi_dd = _me.dynamicsymbols('theta_ phi_', 2) +frame_n = _me.ReferenceFrame('n') +body_a_cm = _me.Point('a_cm') +body_a_cm.set_vel(frame_n, 0) +body_a_f = _me.ReferenceFrame('a_f') +body_a = _me.RigidBody('a', body_a_cm, body_a_f, _sm.symbols('m'), (_me.outer(body_a_f.x,body_a_f.x),body_a_cm)) +body_b_cm = _me.Point('b_cm') +body_b_cm.set_vel(frame_n, 0) +body_b_f = _me.ReferenceFrame('b_f') +body_b = _me.RigidBody('b', body_b_cm, body_b_f, _sm.symbols('m'), (_me.outer(body_b_f.x,body_b_f.x),body_b_cm)) +body_a_f.orient(frame_n, 'Axis', [theta, frame_n.y]) +body_b_f.orient(body_a_f, 'Axis', [phi, body_a_f.z]) +point_o = _me.Point('o') +la = (lb-h/2)/2 +body_a_cm.set_pos(point_o, la*body_a_f.z) +body_b_cm.set_pos(point_o, lb*body_a_f.z) +body_a_f.set_ang_vel(frame_n, omega*frame_n.y) +body_b_f.set_ang_vel(body_a_f, alpha*body_a_f.z) +point_o.set_vel(frame_n, 0) +body_a_cm.v2pt_theory(point_o,frame_n,body_a_f) +body_b_cm.v2pt_theory(point_o,frame_n,body_a_f) +ma = _sm.symbols('ma') +body_a.mass = ma +mb = _sm.symbols('mb') +body_b.mass = mb +iaxx = 1/12*ma*(2*la)**2 +iayy = iaxx +iazz = 0 +ibxx = 1/12*mb*h**2 +ibyy = 1/12*mb*(w**2+h**2) +ibzz = 1/12*mb*w**2 +body_a.inertia = (_me.inertia(body_a_f, iaxx, iayy, iazz, 0, 0, 0), body_a_cm) +body_b.inertia = (_me.inertia(body_b_f, ibxx, ibyy, ibzz, 0, 0, 0), body_b_cm) +force_a = body_a.mass*(g*frame_n.z) +force_b = body_b.mass*(g*frame_n.z) +kd_eqs = [theta_d - omega, phi_d - alpha] +forceList = [(body_a.masscenter,body_a.mass*(g*frame_n.z)), (body_b.masscenter,body_b.mass*(g*frame_n.z))] +kane = _me.KanesMethod(frame_n, q_ind=[theta,phi], u_ind=[omega, alpha], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([body_a, body_b], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {g:9.81, lb:0.2, w:0.2, h:0.1, ma:0.01, mb:0.1}, +specifieds={}, +initial_conditions={theta:_np.deg2rad(90), phi:_np.deg2rad(0.5), omega:0, alpha:0}, +times = _np.linspace(0.0, 10, 10/0.02)) + +y=sys.integrate() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..0b6d72a072e093a6cb048a0b7976041ee9c2f4f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al @@ -0,0 +1,25 @@ +MOTIONVARIABLES' Q{2}', U{2}' +CONSTANTS L,M,G +NEWTONIAN N +FRAMES A,B +SIMPROT(N, A, 3, Q1) +SIMPROT(N, B, 3, Q2) +W_A_N>=U1*N3> +W_B_N>=U2*N3> +POINT O +PARTICLES P,R +P_O_P> = L*A1> +P_P_R> = L*B1> +V_O_N> = 0> +V2PTS(N, A, O, P) +V2PTS(N, B, P, R) +MASS P=M, R=M +Q1' = U1 +Q2' = U2 +GRAVITY(G*N1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT M=1,G=9.81,L=1 +INPUT Q1=.1,Q2=.2,U1=0,U2=0 +INPUT TFINAL=10, INTEGSTP=.01 +CODE DYNAMICS() some_filename.c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..12c73c3b4b198399f4c45f5e00d556c859caff74 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py @@ -0,0 +1,39 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) +frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) +frame_a.set_ang_vel(frame_n, u1*frame_n.z) +frame_b.set_ang_vel(frame_n, u2*frame_n.z) +point_o = _me.Point('o') +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_o, l*frame_a.x) +particle_r.point.set_pos(particle_p.point, l*frame_b.x) +point_o.set_vel(frame_n, 0) +particle_p.point.v2pt_theory(point_o,frame_n,frame_a) +particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) +particle_p.mass = m +particle_r.mass = m +force_p = particle_p.mass*(g*frame_n.x) +force_r = particle_r.mass*(g*frame_n.x) +kd_eqs = [q1_d - u1, q2_d - u2] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {l:1, m:1, g:9.81}, +specifieds={}, +initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, +times = _np.linspace(0.0, 10, 10/.01)) + +y=sys.integrate() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al new file mode 100644 index 0000000000000000000000000000000000000000..4892e5ca8cb18cad6b14a2a37cbdc1f7fb8217ac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al @@ -0,0 +1,19 @@ +CONSTANTS M,K,B,G +MOTIONVARIABLES' POSITION',SPEED' +VARIABLES O +FORCE = O*SIN(T) +NEWTONIAN CEILING +POINTS ORIGIN +V_ORIGIN_CEILING> = 0> +PARTICLES BLOCK +P_ORIGIN_BLOCK> = POSITION*CEILING1> +MASS BLOCK=M +V_BLOCK_CEILING>=SPEED*CEILING1> +POSITION' = SPEED +FORCE_MAGNITUDE = M*G-K*POSITION-B*SPEED+FORCE +FORCE_BLOCK>=EXPLICIT(FORCE_MAGNITUDE*CEILING1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT TFINAL=10.0, INTEGSTP=0.01 +INPUT M=1.0, K=1.0, B=0.2, G=9.8, POSITION=0.1, SPEED=-1.0, O=2 +CODE DYNAMICS() dummy_file.c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5baab9642ff140e0ee81027a1e8f9152d7050c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py @@ -0,0 +1,31 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +m, k, b, g = _sm.symbols('m k b g', real=True) +position, speed = _me.dynamicsymbols('position speed') +position_d, speed_d = _me.dynamicsymbols('position_ speed_', 1) +o = _me.dynamicsymbols('o') +force = o*_sm.sin(_me.dynamicsymbols._t) +frame_ceiling = _me.ReferenceFrame('ceiling') +point_origin = _me.Point('origin') +point_origin.set_vel(frame_ceiling, 0) +particle_block = _me.Particle('block', _me.Point('block_pt'), _sm.Symbol('m')) +particle_block.point.set_pos(point_origin, position*frame_ceiling.x) +particle_block.mass = m +particle_block.point.set_vel(frame_ceiling, speed*frame_ceiling.x) +force_magnitude = m*g-k*position-b*speed+force +force_block = (force_magnitude*frame_ceiling.x).subs({position_d:speed}) +kd_eqs = [position_d - speed] +forceList = [(particle_block.point,(force_magnitude*frame_ceiling.x).subs({position_d:speed}))] +kane = _me.KanesMethod(frame_ceiling, q_ind=[position], u_ind=[speed], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_block], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {m:1.0, k:1.0, b:0.2, g:9.8}, +specifieds={_me.dynamicsymbols('t'):lambda x, t: t, o:2}, +initial_conditions={position:0.1, speed:-1*1.0}, +times = _np.linspace(0.0, 10.0, 10.0/0.01)) + +y=sys.integrate() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..74f5062d80926db7acd634a04759abce857087e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al @@ -0,0 +1,20 @@ +MOTIONVARIABLES' Q{2}'' +CONSTANTS L,M,G +NEWTONIAN N +POINT PN +V_PN_N> = 0> +THETA1 = ATAN(Q2/Q1) +FRAMES A +SIMPROT(N, A, 3, THETA1) +PARTICLES P +P_PN_P> = Q1*N1>+Q2*N2> +MASS P=M +V_P_N>=DT(P_P_PN>, N) +F_V = DOT(EXPRESS(V_P_N>,A), A1>) +GRAVITY(G*N1>) +DEPENDENT[1] = F_V +CONSTRAIN(DEPENDENT[Q1']) +ZERO=FR()+FRSTAR() +F_C = MAG(P_P_PN>)-L +CONFIG[1]=F_C +ZERO[2]=CONFIG[1] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..fc972ebd518e77da5e1902c149f2699979865e7f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +q1_d, q2_d = _me.dynamicsymbols('q1_ q2_', 1) +q1_dd, q2_dd = _me.dynamicsymbols('q1_ q2_', 2) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +point_pn = _me.Point('pn') +point_pn.set_vel(frame_n, 0) +theta1 = _sm.atan(q2/q1) +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [theta1, frame_n.z]) +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_pn, q1*frame_n.x+q2*frame_n.y) +particle_p.mass = m +particle_p.point.set_vel(frame_n, (point_pn.pos_from(particle_p.point)).dt(frame_n)) +f_v = _me.dot((particle_p.point.vel(frame_n)).express(frame_a), frame_a.x) +force_p = particle_p.mass*(g*frame_n.x) +dependent = _sm.Matrix([[0]]) +dependent[0] = f_v +velocity_constraints = [i for i in dependent] +u_q1_d = _me.dynamicsymbols('u_q1_d') +u_q2_d = _me.dynamicsymbols('u_q2_d') +kd_eqs = [q1_d-u_q1_d, q2_d-u_q2_d] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u_q2_d], u_dependent=[u_q1_d], kd_eqs = kd_eqs, velocity_constraints = velocity_constraints) +fr, frstar = kane.kanes_equations([particle_p], forceList) +zero = fr+frstar +f_c = point_pn.pos_from(particle_p.point).magnitude()-l +config = _sm.Matrix([[0]]) +config[0] = f_c +zero = zero.row_insert(zero.shape[0], _sm.Matrix([[0]])) +zero[zero.shape[0]-1] = config[0] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al new file mode 100644 index 0000000000000000000000000000000000000000..457e79fd646677c0decdc69f921bc05e9e0dcf51 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al @@ -0,0 +1,8 @@ +% ruletest1.al +CONSTANTS F = 3, G = 9.81 +CONSTANTS A, B +CONSTANTS S, S1, S2+, S3+, S4- +CONSTANTS K{4}, L{1:3}, P{1:2,1:3} +CONSTANTS C{2,3} +E1 = A*F + S2 - G +E2 = F^2 + K3*K2*G diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py new file mode 100644 index 0000000000000000000000000000000000000000..8466392ac930f13f2419c9c04eef9dcc2884e9bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py @@ -0,0 +1,15 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +f = _sm.S(3) +g = _sm.S(9.81) +a, b = _sm.symbols('a b', real=True) +s, s1 = _sm.symbols('s s1', real=True) +s2, s3 = _sm.symbols('s2 s3', real=True, nonnegative=True) +s4 = _sm.symbols('s4', real=True, nonpositive=True) +k1, k2, k3, k4, l1, l2, l3, p11, p12, p13, p21, p22, p23 = _sm.symbols('k1 k2 k3 k4 l1 l2 l3 p11 p12 p13 p21 p22 p23', real=True) +c11, c12, c13, c21, c22, c23 = _sm.symbols('c11 c12 c13 c21 c22 c23', real=True) +e1 = a*f+s2-g +e2 = f**2+k3*k2*g diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al new file mode 100644 index 0000000000000000000000000000000000000000..9d5f76f063c43bcb5e2a8d4f29619a6952abf9e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al @@ -0,0 +1,58 @@ +% ruletest10.al + +VARIABLES X,Y +COMPLEX ON +CONSTANTS A,B +E = A*(B*X+Y)^2 +M = [E;E] +EXPAND(E) +EXPAND(M) +FACTOR(E,X) +FACTOR(M,X) + +EQN[1] = A*X + B*Y +EQN[2] = 2*A*X - 3*B*Y +SOLVE(EQN, X, Y) +RHS_Y = RHS(Y) +E = (X+Y)^2 + 2*X^2 +ARRANGE(E, 2, X) + +CONSTANTS A,B,C +M = [A,B;C,0] +M2 = EVALUATE(M,A=1,B=2,C=3) +EIG(M2, EIGVALUE, EIGVEC) + +NEWTONIAN N +FRAMES A +SIMPROT(N, A, N1>, X) +DEGREES OFF +SIMPROT(N, A, N1>, PI/2) + +CONSTANTS C{3} +V> = C1*A1> + C2*A2> + C3*A3> +POINTS O, P +P_P_O> = C1*A1> +EXPRESS(V>,N) +EXPRESS(P_P_O>,N) +W_A_N> = C3*A3> +ANGVEL(A,N) + +V2PTS(N,A,O,P) +PARTICLES P{2} +V2PTS(N,A,P1,P2) +A2PTS(N,A,P1,P) + +BODIES B{2} +CONSTANT G +GRAVITY(G*N1>) + +VARIABLE Z +V> = X*A1> + Y*A3> +P_P_O> = X*A1> + Y*A2> +X = 2*Z +Y = Z +EXPLICIT(V>) +EXPLICIT(P_P_O>) + +FORCE(O/P1, X*Y*A1>) +FORCE(P2, X*Y*A1>) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9674e47d5f6132c5a79a33b9d8d55a131942d6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py @@ -0,0 +1,64 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b = _sm.symbols('a b', real=True) +e = a*(b*x+y)**2 +m = _sm.Matrix([e,e]).reshape(2, 1) +e = e.expand() +m = _sm.Matrix([i.expand() for i in m]).reshape((m).shape[0], (m).shape[1]) +e = _sm.factor(e, x) +m = _sm.Matrix([_sm.factor(i,x) for i in m]).reshape((m).shape[0], (m).shape[1]) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x+b*y +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = 2*a*x-3*b*y +print(_sm.solve(eqn,x,y)) +rhs_y = _sm.solve(eqn,x,y)[y] +e = (x+y)**2+2*x**2 +e.collect(x) +a, b, c = _sm.symbols('a b c', real=True) +m = _sm.Matrix([a,b,c,0]).reshape(2, 2) +m2 = _sm.Matrix([i.subs({a:1,b:2,c:3}) for i in m]).reshape((m).shape[0], (m).shape[1]) +eigvalue = _sm.Matrix([i.evalf() for i in (m2).eigenvals().keys()]) +eigvec = _sm.Matrix([i[2][0].evalf() for i in (m2).eigenvects()]).reshape(m2.shape[0], m2.shape[1]) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [x, frame_n.x]) +frame_a.orient(frame_n, 'Axis', [_sm.pi/2, frame_n.x]) +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +v = c1*frame_a.x+c2*frame_a.y+c3*frame_a.z +point_o = _me.Point('o') +point_p = _me.Point('p') +point_o.set_pos(point_p, c1*frame_a.x) +v = (v).express(frame_n) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).express(frame_n)) +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +print(frame_n.ang_vel_in(frame_a)) +point_p.v2pt_theory(point_o,frame_n,frame_a) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p2.point.v2pt_theory(particle_p1.point,frame_n,frame_a) +point_p.a2pt_theory(particle_p1.point,frame_n,frame_a) +body_b1_cm = _me.Point('b1_cm') +body_b1_cm.set_vel(frame_n, 0) +body_b1_f = _me.ReferenceFrame('b1_f') +body_b1 = _me.RigidBody('b1', body_b1_cm, body_b1_f, _sm.symbols('m'), (_me.outer(body_b1_f.x,body_b1_f.x),body_b1_cm)) +body_b2_cm = _me.Point('b2_cm') +body_b2_cm.set_vel(frame_n, 0) +body_b2_f = _me.ReferenceFrame('b2_f') +body_b2 = _me.RigidBody('b2', body_b2_cm, body_b2_f, _sm.symbols('m'), (_me.outer(body_b2_f.x,body_b2_f.x),body_b2_cm)) +g = _sm.symbols('g', real=True) +force_p1 = particle_p1.mass*(g*frame_n.x) +force_p2 = particle_p2.mass*(g*frame_n.x) +force_b1 = body_b1.mass*(g*frame_n.x) +force_b2 = body_b2.mass*(g*frame_n.x) +z = _me.dynamicsymbols('z') +v = x*frame_a.x+y*frame_a.z +point_o.set_pos(point_p, x*frame_a.x+y*frame_a.y) +v = (v).subs({x:2*z, y:z}) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).subs({x:2*z, y:z})) +force_o = -1*(x*y*frame_a.x) +force_p1 = particle_p1.mass*(g*frame_n.x)+ x*y*frame_a.x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al new file mode 100644 index 0000000000000000000000000000000000000000..60934c1ca563024828110bfe984a90d5686b89e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al @@ -0,0 +1,6 @@ +VARIABLES X, Y +CONSTANTS A{1:2, 1:2}, B{1:2} +EQN[1] = A11*x + A12*y - B1 +EQN[2] = A21*x + A22*y - B2 +INPUT A11=2, A12=5, A21=3, A22=4, B1=7, B2=6 +CODE ALGEBRAIC(EQN, X, Y) some_filename.c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec2397ea96261d7b582d1f699e3897caae88f20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a11, a12, a21, a22, b1, b2 = _sm.symbols('a11 a12 a21 a22 b1 b2', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a11*x+a12*y-b1 +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a21*x+a22*y-b2 +eqn_list = [] +for i in eqn: eqn_list.append(i.subs({a11:2, a12:5, a21:3, a22:4, b1:7, b2:6})) +print(_sm.linsolve(eqn_list, x,y)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al new file mode 100644 index 0000000000000000000000000000000000000000..f147f55afd1438436767960e0487d5d9e7161c8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al @@ -0,0 +1,7 @@ +VARIABLES X,Y +CONSTANTS A,B,R +EQN[1] = A*X^3+B*Y^2-R +EQN[2] = A*SIN(X)^2 + B*COS(2*Y) - R^2 +INPUT A=2.0, B=3.0, R=1.0 +INPUT X = 30 DEG, Y = 3.14 +CODE NONLINEAR(EQN,X,Y) some_filename.c diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7d996fa649f796a536dba20c1a36554acd8046 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b, r = _sm.symbols('a b r', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x**3+b*y**2-r +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a*_sm.sin(x)**2+b*_sm.cos(2*y)-r**2 +matrix_list = [] +for i in eqn:matrix_list.append(i.subs({a:2.0, b:3.0, r:1.0})) +print(_sm.nsolve(matrix_list,(x,y),(_np.deg2rad(30),3.14))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al new file mode 100644 index 0000000000000000000000000000000000000000..17937e58bd20a9fb82f44ccd05f0c081a1aa6c9b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al @@ -0,0 +1,12 @@ +% ruletest2.al +VARIABLES X1,X2 +SPECIFIED F1 = X1*X2 + 3*X1^2 +SPECIFIED F2=X1*T+X2*T^2 +VARIABLE X', Y'' +MOTIONVARIABLES Q{3}, U{2} +VARIABLES P{2}' +VARIABLE W{3}', R{2}'' +VARIABLES C{1:2, 1:2} +VARIABLES D{1,3} +VARIABLES J{1:2} +IMAGINARY N diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py new file mode 100644 index 0000000000000000000000000000000000000000..31c1d9974c2292466b805b91f8254bffaa94e2ac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py @@ -0,0 +1,22 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x1, x2 = _me.dynamicsymbols('x1 x2') +f1 = x1*x2+3*x1**2 +f2 = x1*_me.dynamicsymbols._t+x2*_me.dynamicsymbols._t**2 +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +y_dd = _me.dynamicsymbols('y_', 2) +q1, q2, q3, u1, u2 = _me.dynamicsymbols('q1 q2 q3 u1 u2') +p1, p2 = _me.dynamicsymbols('p1 p2') +p1_d, p2_d = _me.dynamicsymbols('p1_ p2_', 1) +w1, w2, w3, r1, r2 = _me.dynamicsymbols('w1 w2 w3 r1 r2') +w1_d, w2_d, w3_d, r1_d, r2_d = _me.dynamicsymbols('w1_ w2_ w3_ r1_ r2_', 1) +r1_dd, r2_dd = _me.dynamicsymbols('r1_ r2_', 2) +c11, c12, c21, c22 = _me.dynamicsymbols('c11 c12 c21 c22') +d11, d12, d13 = _me.dynamicsymbols('d11 d12 d13') +j1, j2 = _me.dynamicsymbols('j1 j2') +n = _sm.symbols('n') +n = _sm.I diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al new file mode 100644 index 0000000000000000000000000000000000000000..f263f1802ebca2725481dd5fdd3540bf8e9f11bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al @@ -0,0 +1,25 @@ +% ruletest3.al +FRAMES A, B +NEWTONIAN N + +VARIABLES X{3} +CONSTANTS L + +V1> = X1*A1> + X2*A2> + X3*A3> +V2> = X1*B1> + X2*B2> + X3*B3> +V3> = X1*N1> + X2*N2> + X3*N3> + +V> = V1> + V2> + V3> + +POINTS C, D +POINTS PO{3} + +PARTICLES L +PARTICLES P{3} + +BODIES S +BODIES R{2} + +V4> = X1*S1> + X2*S2> + X3*S3> + +P_C_SO> = L*N1> diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py new file mode 100644 index 0000000000000000000000000000000000000000..23f79aa571337f200b3ff4d56b5747f7704985c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py @@ -0,0 +1,37 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_n = _me.ReferenceFrame('n') +x1, x2, x3 = _me.dynamicsymbols('x1 x2 x3') +l = _sm.symbols('l', real=True) +v1 = x1*frame_a.x+x2*frame_a.y+x3*frame_a.z +v2 = x1*frame_b.x+x2*frame_b.y+x3*frame_b.z +v3 = x1*frame_n.x+x2*frame_n.y+x3*frame_n.z +v = v1+v2+v3 +point_c = _me.Point('c') +point_d = _me.Point('d') +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +point_po3 = _me.Point('po3') +particle_l = _me.Particle('l', _me.Point('l_pt'), _sm.Symbol('m')) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p3 = _me.Particle('p3', _me.Point('p3_pt'), _sm.Symbol('m')) +body_s_cm = _me.Point('s_cm') +body_s_cm.set_vel(frame_n, 0) +body_s_f = _me.ReferenceFrame('s_f') +body_s = _me.RigidBody('s', body_s_cm, body_s_f, _sm.symbols('m'), (_me.outer(body_s_f.x,body_s_f.x),body_s_cm)) +body_r1_cm = _me.Point('r1_cm') +body_r1_cm.set_vel(frame_n, 0) +body_r1_f = _me.ReferenceFrame('r1_f') +body_r1 = _me.RigidBody('r1', body_r1_cm, body_r1_f, _sm.symbols('m'), (_me.outer(body_r1_f.x,body_r1_f.x),body_r1_cm)) +body_r2_cm = _me.Point('r2_cm') +body_r2_cm.set_vel(frame_n, 0) +body_r2_f = _me.ReferenceFrame('r2_f') +body_r2 = _me.RigidBody('r2', body_r2_cm, body_r2_f, _sm.symbols('m'), (_me.outer(body_r2_f.x,body_r2_f.x),body_r2_cm)) +v4 = x1*body_s_f.x+x2*body_s_f.y+x3*body_s_f.z +body_s_cm.set_pos(point_c, l*frame_n.x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al new file mode 100644 index 0000000000000000000000000000000000000000..7302bd7724bad9b763c75fe4230faa42b5070408 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al @@ -0,0 +1,20 @@ +% ruletest4.al + +FRAMES A, B +MOTIONVARIABLES Q{3} +SIMPROT(A, B, 1, Q3) +DCM = A_B +M = DCM*3 - A_B + +VARIABLES R +CIRCLE_AREA = PI*R^2 + +VARIABLES U, A +VARIABLES X, Y +S = U*T - 1/2*A*T^2 + +EXPR1 = 2*A*0.5 - 1.25 + 0.25 +EXPR2 = -X^2 + Y^2 + 0.25*(X+Y)^2 +EXPR3 = 0.5E-10 + +DYADIC>> = A1>*A1> + A2>*A2> + A3>*A3> diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py new file mode 100644 index 0000000000000000000000000000000000000000..74b18543e04d6c9e42dd569d2152040c13ae0899 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py @@ -0,0 +1,20 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +q1, q2, q3 = _me.dynamicsymbols('q1 q2 q3') +frame_b.orient(frame_a, 'Axis', [q3, frame_a.x]) +dcm = frame_a.dcm(frame_b) +m = dcm*3-frame_a.dcm(frame_b) +r = _me.dynamicsymbols('r') +circle_area = _sm.pi*r**2 +u, a = _me.dynamicsymbols('u a') +x, y = _me.dynamicsymbols('x y') +s = u*_me.dynamicsymbols._t-1/2*a*_me.dynamicsymbols._t**2 +expr1 = 2*a*0.5-1.25+0.25 +expr2 = -1*x**2+y**2+0.25*(x+y)**2 +expr3 = 0.5*10**(-10) +dyadic = _me.outer(frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(frame_a.z, frame_a.z) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al new file mode 100644 index 0000000000000000000000000000000000000000..a859dc8bb1f0251af14809681d995c59b31377ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al @@ -0,0 +1,32 @@ +% ruletest5.al +VARIABLES X', Y' + +E1 = (X+Y)^2 + (X-Y)^3 +E2 = (X-Y)^2 +E3 = X^2 + Y^2 + 2*X*Y + +M1 = [E1;E2] +M2 = [(X+Y)^2,(X-Y)^2] +M3 = M1 + [X;Y] + +AM = EXPAND(M1) +CM = EXPAND([(X+Y)^2,(X-Y)^2]) +EM = EXPAND(M1 + [X;Y]) +F = EXPAND(E1) +G = EXPAND(E2) + +A = FACTOR(E3, X) +BM = FACTOR(M1, X) +CM = FACTOR(M1 + [X;Y], X) + +A = D(E3, X) +B = D(E3, Y) +CM = D(M2, X) +DM = D(M1 + [X;Y], X) +FRAMES A, B +A_B = [1,0,0;1,0,0;1,0,0] +V1> = X*A1> + Y*A2> + X*Y*A3> +E> = D(V1>, X, B) +FM = DT(M1) +GM = DT([(X+Y)^2,(X-Y)^2]) +H> = DT(V1>, B) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py new file mode 100644 index 0000000000000000000000000000000000000000..93684435b402f5b56e2f4a5c3c81500208556423 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py @@ -0,0 +1,33 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e1 = (x+y)**2+(x-y)**3 +e2 = (x-y)**2 +e3 = x**2+y**2+2*x*y +m1 = _sm.Matrix([e1,e2]).reshape(2, 1) +m2 = _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2) +m3 = m1+_sm.Matrix([x,y]).reshape(2, 1) +am = _sm.Matrix([i.expand() for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([i.expand() for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +em = _sm.Matrix([i.expand() for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +f = (e1).expand() +g = (e2).expand() +a = _sm.factor((e3), x) +bm = _sm.Matrix([_sm.factor(i, x) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([_sm.factor(i, x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +a = (e3).diff(x) +b = (e3).diff(y) +cm = _sm.Matrix([i.diff(x) for i in m2]).reshape((m2).shape[0], (m2).shape[1]) +dm = _sm.Matrix([i.diff(x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_b.orient(frame_a, 'DCM', _sm.Matrix([1,0,0,1,0,0,1,0,0]).reshape(3, 3)) +v1 = x*frame_a.x+y*frame_a.y+x*y*frame_a.z +e = (v1).diff(x, frame_b) +fm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +gm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +h = (v1).dt(frame_b) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al new file mode 100644 index 0000000000000000000000000000000000000000..7ec3ba61590e77772ae631237df048b932fe778c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al @@ -0,0 +1,41 @@ +% ruletest6.al +VARIABLES Q{2} +VARIABLES X,Y,Z +Q1 = X^2 + Y^2 +Q2 = X-Y +E = Q1 + Q2 +A = EXPLICIT(E) +E2 = COS(X) +E3 = COS(X*Y) +A = TAYLOR(E2, 0:2, X=0) +B = TAYLOR(E3, 0:2, X=0, Y=0) + +E = EXPAND((X+Y)^2) +A = EVALUATE(E, X=1, Y=Z) +BM = EVALUATE([E;2*E], X=1, Y=Z) + +E = Q1 + Q2 +A = EVALUATE(E, X=2, Y=Z^2) + +CONSTANTS J,K,L +P1 = POLYNOMIAL([J,K,L],X) +P2 = POLYNOMIAL(J*X+K,X,1) + +ROOT1 = ROOTS(P1, X, 2) +ROOT2 = ROOTS([1;2;3]) + +M = [1,2,3,4;5,6,7,8;9,10,11,12;13,14,15,16] + +AM = TRANSPOSE(M) + M +BM = EIG(M) +C1 = DIAGMAT(4, 1) +C2 = DIAGMAT(3, 4, 2) +DM = INV(M+C1) +E = DET(M+C1) + TRACE([1,0;0,1]) +F = ELEMENT(M, 2, 3) + +A = COLS(M) +BM = COLS(M, 1) +CM = COLS(M, 1, 2:4, 3) +DM = ROWS(M, 1) +EM = ROWS(M, 1, 2:4, 3) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1a0b49518bb0ae5766cbe91b9c24a1b8e9c20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +x, y, z = _me.dynamicsymbols('x y z') +e = q1+q2 +a = (e).subs({q1:x**2+y**2, q2:x-y}) +e2 = _sm.cos(x) +e3 = _sm.cos(x*y) +a = (e2).series(x, 0, 2).removeO() +b = (e3).series(x, 0, 2).removeO().series(y, 0, 2).removeO() +e = ((x+y)**2).expand() +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:1,y:z}) +bm = _sm.Matrix([i.subs({x:1,y:z}) for i in _sm.Matrix([e,2*e]).reshape(2, 1)]).reshape((_sm.Matrix([e,2*e]).reshape(2, 1)).shape[0], (_sm.Matrix([e,2*e]).reshape(2, 1)).shape[1]) +e = q1+q2 +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:2,y:z**2}) +j, k, l = _sm.symbols('j k l', real=True) +p1 = _sm.Poly(_sm.Matrix([j,k,l]).reshape(1, 3), x) +p2 = _sm.Poly(j*x+k, x) +root1 = [i.evalf() for i in _sm.solve(p1, x)] +root2 = [i.evalf() for i in _sm.solve(_sm.Poly(_sm.Matrix([1,2,3]).reshape(3, 1), x),x)] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]).reshape(4, 4) +am = (m).T+m +bm = _sm.Matrix([i.evalf() for i in (m).eigenvals().keys()]) +c1 = _sm.diag(1,1,1,1) +c2 = _sm.Matrix([2 if i==j else 0 for i in range(3) for j in range(4)]).reshape(3, 4) +dm = (m+c1)**(-1) +e = (m+c1).det()+(_sm.Matrix([1,0,0,1]).reshape(2, 2)).trace() +f = (m)[1,2] +a = (m).cols +bm = (m).col(0) +cm = _sm.Matrix([(m).T.row(0),(m).T.row(1),(m).T.row(2),(m).T.row(3),(m).T.row(2)]) +dm = (m).row(0) +em = _sm.Matrix([(m).row(0),(m).row(1),(m).row(2),(m).row(3),(m).row(2)]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al new file mode 100644 index 0000000000000000000000000000000000000000..2904a602f589645d22e1d3d378d077dd6a1ec27e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al @@ -0,0 +1,39 @@ +% ruletest7.al +VARIABLES X', Y' +E = COS(X) + SIN(X) + TAN(X)& ++ COSH(X) + SINH(X) + TANH(X)& ++ ACOS(X) + ASIN(X) + ATAN(X)& ++ LOG(X) + EXP(X) + SQRT(X)& ++ FACTORIAL(X) + CEIL(X) +& +FLOOR(X) + SIGN(X) + +E = SQR(X) + LOG10(X) + +A = ABS(-1) + INT(1.5) + ROUND(1.9) + +E1 = 2*X + 3*Y +E2 = X + Y + +AM = COEF([E1;E2], [X,Y]) +B = COEF(E1, X) +C = COEF(E2, Y) +D1 = EXCLUDE(E1, X) +D2 = INCLUDE(E1, X) +FM = ARRANGE([E1,E2],2,X) +F = ARRANGE(E1, 2, Y) +G = REPLACE(E1, X=2*X) +GM = REPLACE([E1;E2], X=3) + +FRAMES A, B +VARIABLES THETA +SIMPROT(A,B,3,THETA) +V1> = 2*A1> - 3*A2> + A3> +V2> = B1> + B2> + B3> +A = DOT(V1>, V2>) +BM = DOT(V1>, [V2>;2*V2>]) +C> = CROSS(V1>,V2>) +D = MAG(2*V1>) + MAG(3*V1>) +DYADIC>> = 3*A1>*A1> + A2>*A2> + 2*A3>*A3> +AM = MATRIX(B, DYADIC>>) +M = [1;2;3] +V> = VECTOR(A, M) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py new file mode 100644 index 0000000000000000000000000000000000000000..19147856dc3b0d451184a6bb539c1c331f61a6d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py @@ -0,0 +1,35 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e = _sm.cos(x)+_sm.sin(x)+_sm.tan(x)+_sm.cosh(x)+_sm.sinh(x)+_sm.tanh(x)+_sm.acos(x)+_sm.asin(x)+_sm.atan(x)+_sm.log(x)+_sm.exp(x)+_sm.sqrt(x)+_sm.factorial(x)+_sm.ceiling(x)+_sm.floor(x)+_sm.sign(x) +e = (x)**2+_sm.log(x, 10) +a = _sm.Abs(-1*1)+int(1.5)+round(1.9) +e1 = 2*x+3*y +e2 = x+y +am = _sm.Matrix([e1.expand().coeff(x), e1.expand().coeff(y), e2.expand().coeff(x), e2.expand().coeff(y)]).reshape(2, 2) +b = (e1).expand().coeff(x) +c = (e2).expand().coeff(y) +d1 = (e1).collect(x).coeff(x,0) +d2 = (e1).collect(x).coeff(x,1) +fm = _sm.Matrix([i.collect(x)for i in _sm.Matrix([e1,e2]).reshape(1, 2)]).reshape((_sm.Matrix([e1,e2]).reshape(1, 2)).shape[0], (_sm.Matrix([e1,e2]).reshape(1, 2)).shape[1]) +f = (e1).collect(y) +g = (e1).subs({x:2*x}) +gm = _sm.Matrix([i.subs({x:3}) for i in _sm.Matrix([e1,e2]).reshape(2, 1)]).reshape((_sm.Matrix([e1,e2]).reshape(2, 1)).shape[0], (_sm.Matrix([e1,e2]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +theta = _me.dynamicsymbols('theta') +frame_b.orient(frame_a, 'Axis', [theta, frame_a.z]) +v1 = 2*frame_a.x-3*frame_a.y+frame_a.z +v2 = frame_b.x+frame_b.y+frame_b.z +a = _me.dot(v1, v2) +bm = _sm.Matrix([_me.dot(v1, v2),_me.dot(v1, 2*v2)]).reshape(2, 1) +c = _me.cross(v1, v2) +d = 2*v1.magnitude()+3*v1.magnitude() +dyadic = _me.outer(3*frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(2*frame_a.z, frame_a.z) +am = (dyadic).to_matrix(frame_b) +m = _sm.Matrix([1,2,3]).reshape(3, 1) +v = m[0]*frame_a.x +m[1]*frame_a.y +m[2]*frame_a.z diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al new file mode 100644 index 0000000000000000000000000000000000000000..4b2462c51e6730f46bf60b4b21ab6cfbf1993640 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al @@ -0,0 +1,38 @@ +% ruletest8.al +FRAMES A +CONSTANTS C{3} +A>> = EXPRESS(1>>,A) +PARTICLES P1, P2 +BODIES R +R_A = [1,1,1;1,1,0;0,0,1] +POINT O +MASS P1=M1, P2=M2, R=MR +INERTIA R, I1, I2, I3 +P_P1_O> = C1*A1> +P_P2_O> = C2*A2> +P_RO_O> = C3*A3> +A>> = EXPRESS(I_P1_O>>, A) +A>> = EXPRESS(I_P2_O>>, A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(INERTIA(O), A) +A>> = EXPRESS(INERTIA(O, P1, R), A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(I_R_RO>>, A) + +P_P1_P2> = C1*A1> + C2*A2> +P_P1_RO> = C3*A1> +P_P2_RO> = C3*A2> + +B> = CM(O) +B> = CM(O, P1, R) +B> = CM(P1) + +MOTIONVARIABLES U{3} +V> = U1*A1> + U2*A2> + U3*A3> +U> = UNITVEC(V> + C1*A1>) +V_P1_A> = U1*A1> +A> = PARTIALS(V_P1_A>, U1) + +M = MASS(P1,R) +M = MASS(P2) +M = MASS() \ No newline at end of file diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py new file mode 100644 index 0000000000000000000000000000000000000000..6809c47138e40027c700536e807ca7cfa5f468d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py @@ -0,0 +1,49 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +a = _me.inertia(frame_a, 1, 1, 1) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +body_r_cm = _me.Point('r_cm') +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +frame_a.orient(body_r_f, 'DCM', _sm.Matrix([1,1,1,1,1,0,0,0,1]).reshape(3, 3)) +point_o = _me.Point('o') +m1 = _sm.symbols('m1') +particle_p1.mass = m1 +m2 = _sm.symbols('m2') +particle_p2.mass = m2 +mr = _sm.symbols('mr') +body_r.mass = mr +i1 = _sm.symbols('i1') +i2 = _sm.symbols('i2') +i3 = _sm.symbols('i3') +body_r.inertia = (_me.inertia(body_r_f, i1, i2, i3, 0, 0, 0), body_r_cm) +point_o.set_pos(particle_p1.point, c1*frame_a.x) +point_o.set_pos(particle_p2.point, c2*frame_a.y) +point_o.set_pos(body_r_cm, c3*frame_a.z) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] +particle_p2.point.set_pos(particle_p1.point, c1*frame_a.x+c2*frame_a.y) +body_r_cm.set_pos(particle_p1.point, c3*frame_a.x) +body_r_cm.set_pos(particle_p2.point, c3*frame_a.y) +b = _me.functions.center_of_mass(point_o,particle_p1, particle_p2, body_r) +b = _me.functions.center_of_mass(point_o,particle_p1, body_r) +b = _me.functions.center_of_mass(particle_p1.point,particle_p1, particle_p2, body_r) +u1, u2, u3 = _me.dynamicsymbols('u1 u2 u3') +v = u1*frame_a.x+u2*frame_a.y+u3*frame_a.z +u = (v+c1*frame_a.x).normalize() +particle_p1.point.set_vel(frame_a, u1*frame_a.x) +a = particle_p1.point.partial_velocity(frame_a, u1) +m = particle_p1.mass+body_r.mass +m = particle_p2.mass +m = particle_p1.mass+particle_p2.mass+body_r.mass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al new file mode 100644 index 0000000000000000000000000000000000000000..df5c70f05b76fc215f829672e281491b0c96c6a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al @@ -0,0 +1,54 @@ +% ruletest9.al +NEWTONIAN N +FRAMES A +A> = 0> +D>> = EXPRESS(1>>, A) + +POINTS PO{2} +PARTICLES P{2} +MOTIONVARIABLES' C{3}' +BODIES R +P_P1_PO2> = C1*A1> +V> = 2*P_P1_PO2> + C2*A2> + +W_A_N> = C3*A3> +V> = 2*W_A_N> + C2*A2> +W_R_N> = C3*A3> +V> = 2*W_R_N> + C2*A2> + +ALF_A_N> = DT(W_A_N>, A) +V> = 2*ALF_A_N> + C2*A2> + +V_P1_A> = C1*A1> + C3*A2> +A_RO_N> = C2*A2> +V_A> = CROSS(A_RO_N>, V_P1_A>) + +X_B_C> = V_A> +X_B_D> = 2*X_B_C> +A_B_C_D_E> = X_B_D>*2 + +A_B_C = 2*C1*C2*C3 +A_B_C += 2*C1 +A_B_C := 3*C1 + +MOTIONVARIABLES' Q{2}', U{2}' +Q1' = U1 +Q2' = U2 + +VARIABLES X'', Y'' +SPECIFIED YY +Y'' = X*X'^2 + 1 +YY = X*X'^2 + 1 + +M[1] = 2*X +M[2] = 2*Y +A = 2*M[1] + +M = [1,2,3;4,5,6;7,8,9] +M[1, 2] = 5 +A = M[1, 2]*2 + +FORCE_RO> = Q1*N1> +TORQUE_A> = Q2*N3> +FORCE_RO> = Q2*N2> +F> = FORCE_RO>*2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py new file mode 100644 index 0000000000000000000000000000000000000000..09d8ae4ee8385bde5c38b946458a43c8ffdaa9b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +a = 0 +d = _me.inertia(frame_a, 1, 1, 1) +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +c1, c2, c3 = _me.dynamicsymbols('c1 c2 c3') +c1_d, c2_d, c3_d = _me.dynamicsymbols('c1_ c2_ c3_', 1) +body_r_cm = _me.Point('r_cm') +body_r_cm.set_vel(frame_n, 0) +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +point_po2.set_pos(particle_p1.point, c1*frame_a.x) +v = 2*point_po2.pos_from(particle_p1.point)+c2*frame_a.y +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*frame_a.ang_vel_in(frame_n)+c2*frame_a.y +body_r_f.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*body_r_f.ang_vel_in(frame_n)+c2*frame_a.y +frame_a.set_ang_acc(frame_n, (frame_a.ang_vel_in(frame_n)).dt(frame_a)) +v = 2*frame_a.ang_acc_in(frame_n)+c2*frame_a.y +particle_p1.point.set_vel(frame_a, c1*frame_a.x+c3*frame_a.y) +body_r_cm.set_acc(frame_n, c2*frame_a.y) +v_a = _me.cross(body_r_cm.acc(frame_n), particle_p1.point.vel(frame_a)) +x_b_c = v_a +x_b_d = 2*x_b_c +a_b_c_d_e = x_b_d*2 +a_b_c = 2*c1*c2*c3 +a_b_c += 2*c1 +a_b_c = 3*c1 +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +x_dd, y_dd = _me.dynamicsymbols('x_ y_', 2) +yy = _me.dynamicsymbols('yy') +yy = x*x_d**2+1 +m = _sm.Matrix([[0]]) +m[0] = 2*x +m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) +m[m.shape[0]-1] = 2*y +a = 2*m[0] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9]).reshape(3, 3) +m[0,1] = 5 +a = m[0, 1]*2 +force_ro = q1*frame_n.x +torque_a = q2*frame_n.z +force_ro = q1*frame_n.x + q2*frame_n.y +f = force_ro*2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18d3d5301cb001c78fc4a9bc04b25aa36f282a93 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__init__.py @@ -0,0 +1 @@ +"""Used for translating C source code into a SymPy expression""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b76bf33a23ff470c4bc197a4237cc181fe73a17 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8fcef1d2f06c9492a50a90a3936ef6f1a3ad1df Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/c_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7223f8351205272e803773589649fcf1902f15 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/c/c_parser.py @@ -0,0 +1,1059 @@ +from sympy.external import import_module +import os + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +""" +This module contains all the necessary Classes and Function used to Parse C and +C++ code into SymPy expression +The module serves as a backend for SymPyExpression to parse C code +It is also dependent on Clang's AST and SymPy's Codegen AST. +The module only supports the features currently supported by the Clang and +codegen AST which will be updated as the development of codegen AST and this +module progresses. +You might find unexpected bugs and exceptions while using the module, feel free +to report them to the SymPy Issue Tracker + +Features Supported +================== + +- Variable Declarations (integers and reals) +- Assignment (using integer & floating literal and function calls) +- Function Definitions and Declaration +- Function Calls +- Compound statements, Return statements + +Notes +===== + +The module is dependent on an external dependency which needs to be installed +to use the features of this module. + +Clang: The C and C++ compiler which is used to extract an AST from the provided +C source code. + +References +========== + +.. [1] https://github.com/sympy/sympy/issues +.. [2] https://clang.llvm.org/docs/ +.. [3] https://clang.llvm.org/docs/IntroductionToTheClangAST.html + +""" + +if cin: + from sympy.codegen.ast import (Variable, Integer, Float, + FunctionPrototype, FunctionDefinition, FunctionCall, + none, Return, Assignment, intc, int8, int16, int64, + uint8, uint16, uint32, uint64, float32, float64, float80, + aug_assign, bool_, While, CodeBlock) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import Add, Mod, Mul, Pow, Rel + from sympy.logic.boolalg import And, as_Boolean, Not, Or + from sympy.core.symbol import Symbol + from sympy.core.sympify import sympify + from sympy.logic.boolalg import (false, true) + import sys + import tempfile + + class BaseParser: + """Base Class for the C parser""" + + def __init__(self): + """Initializes the Base parser creating a Clang AST index""" + self.index = cin.Index.create() + + def diagnostics(self, out): + """Diagostics function for the Clang AST""" + for diag in self.tu.diagnostics: + # tu = translation unit + print('%s %s (line %s, col %s) %s' % ( + { + 4: 'FATAL', + 3: 'ERROR', + 2: 'WARNING', + 1: 'NOTE', + 0: 'IGNORED', + }[diag.severity], + diag.location.file, + diag.location.line, + diag.location.column, + diag.spelling + ), file=out) + + class CCodeConverter(BaseParser): + """The Code Convereter for Clang AST + + The converter object takes the C source code or file as input and + converts them to SymPy Expressions. + """ + + def __init__(self): + """Initializes the code converter""" + super().__init__() + self._py_nodes = [] + self._data_types = { + "void": { + cin.TypeKind.VOID: none + }, + "bool": { + cin.TypeKind.BOOL: bool_ + }, + "int": { + cin.TypeKind.SCHAR: int8, + cin.TypeKind.SHORT: int16, + cin.TypeKind.INT: intc, + cin.TypeKind.LONG: int64, + cin.TypeKind.UCHAR: uint8, + cin.TypeKind.USHORT: uint16, + cin.TypeKind.UINT: uint32, + cin.TypeKind.ULONG: uint64 + }, + "float": { + cin.TypeKind.FLOAT: float32, + cin.TypeKind.DOUBLE: float64, + cin.TypeKind.LONGDOUBLE: float80 + } + } + + def parse(self, filename, flags): + """Function to parse a file with C source code + + It takes the filename as an attribute and creates a Clang AST + Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + filename : string + Path to the C file to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + filepath = os.path.abspath(filename) + self.tu = self.index.parse( + filepath, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def parse_str(self, source, flags): + """Function to parse a string with C source code + + It takes the source code as an attribute, stores it in a temporary + file and creates a Clang AST Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + source : string + A string containing the C source code to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + file = tempfile.NamedTemporaryFile(mode = 'w+', suffix = '.cpp') + file.write(source) + file.flush() + file.seek(0) + self.tu = self.index.parse( + file.name, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + file.close() + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def transform(self, node): + """Transformation Function for Clang AST nodes + + It determines the kind of node and calls the respective + transformation function for that node. + + Raises + ====== + + NotImplementedError : if the transformation for the provided node + is not implemented + + """ + handler = getattr(self, 'transform_%s' % node.kind.name.lower(), None) + + if handler is None: + print( + "Ignoring node of type %s (%s)" % ( + node.kind, + ' '.join( + t.spelling for t in node.get_tokens()) + ), + file=sys.stderr + ) + + return handler(node) + + def transform_var_decl(self, node): + """Transformation Function for Variable Declaration + + Used to create nodes for variable declarations and assignments with + values or function call for the respective nodes in the clang AST + + Returns + ======= + + A variable node as Declaration, with the initial value if given + + Raises + ====== + + NotImplementedError : if called for data types not currently + implemented + + Notes + ===== + + The function currently supports following data types: + + Boolean: + bool, _Bool + + Integer: + 8-bit: signed char and unsigned char + 16-bit: short, short int, signed short, + signed short int, unsigned short, unsigned short int + 32-bit: int, signed int, unsigned int + 64-bit: long, long int, signed long, + signed long int, unsigned long, unsigned long int + + Floating point: + Single Precision: float + Double Precision: double + Extended Precision: long double + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + #ignoring namespace and type details for the variable + while child.kind == cin.CursorKind.NAMESPACE_REF or child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + val = self.transform(child) + + supported_rhs = [ + cin.CursorKind.INTEGER_LITERAL, + cin.CursorKind.FLOATING_LITERAL, + cin.CursorKind.UNEXPOSED_EXPR, + cin.CursorKind.BINARY_OPERATOR, + cin.CursorKind.PAREN_EXPR, + cin.CursorKind.UNARY_OPERATOR, + cin.CursorKind.CXX_BOOL_LITERAL_EXPR + ] + + if child.kind in supported_rhs: + if isinstance(val, str): + value = Symbol(val) + elif isinstance(val, bool): + if node.type.kind in self._data_types["int"]: + value = Integer(0) if val == False else Integer(1) + elif node.type.kind in self._data_types["float"]: + value = Float(0.0) if val == False else Float(1.0) + elif node.type.kind in self._data_types["bool"]: + value = sympify(val) + elif isinstance(val, (Integer, int, Float, float)): + if node.type.kind in self._data_types["int"]: + value = Integer(val) + elif node.type.kind in self._data_types["float"]: + value = Float(val) + elif node.type.kind in self._data_types["bool"]: + value = sympify(bool(val)) + else: + value = val + + return Variable( + node.spelling + ).as_Declaration( + type = type, + value = value + ) + + elif child.kind == cin.CursorKind.CALL_EXPR: + return Variable( + node.spelling + ).as_Declaration( + value = val + ) + + else: + raise NotImplementedError("Given " + "variable declaration \"{}\" " + "is not possible to parse yet!" + .format(" ".join( + t.spelling for t in node.get_tokens() + ) + )) + + except StopIteration: + return Variable( + node.spelling + ).as_Declaration( + type = type + ) + + def transform_function_decl(self, node): + """Transformation Function For Function Declaration + + Used to create nodes for function declarations and definitions for + the respective nodes in the clang AST + + Returns + ======= + + function : Codegen AST node + - FunctionPrototype node if function body is not present + - FunctionDefinition node if the function body is present + + + """ + + if node.result_type.kind in self._data_types["int"]: + ret_type = self._data_types["int"][node.result_type.kind] + elif node.result_type.kind in self._data_types["float"]: + ret_type = self._data_types["float"][node.result_type.kind] + elif node.result_type.kind in self._data_types["bool"]: + ret_type = self._data_types["bool"][node.result_type.kind] + elif node.result_type.kind in self._data_types["void"]: + ret_type = self._data_types["void"][node.result_type.kind] + else: + raise NotImplementedError("Only void, bool, int " + "and float are supported") + body = [] + param = [] + + # Subsequent nodes will be the parameters for the function. + for child in node.get_children(): + decl = self.transform(child) + if child.kind == cin.CursorKind.PARM_DECL: + param.append(decl) + elif child.kind == cin.CursorKind.COMPOUND_STMT: + for val in decl: + body.append(val) + else: + body.append(decl) + + if body == []: + function = FunctionPrototype( + return_type = ret_type, + name = node.spelling, + parameters = param + ) + else: + function = FunctionDefinition( + return_type = ret_type, + name = node.spelling, + parameters = param, + body = body + ) + return function + + def transform_parm_decl(self, node): + """Transformation function for Parameter Declaration + + Used to create parameter nodes for the required functions for the + respective nodes in the clang AST + + Returns + ======= + + param : Codegen AST Node + Variable node with the value and type of the variable + + Raises + ====== + + ValueError if multiple children encountered in the parameter node + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + # Any namespace nodes can be ignored + while child.kind in [cin.CursorKind.NAMESPACE_REF, + cin.CursorKind.TYPE_REF, + cin.CursorKind.TEMPLATE_REF]: + child = next(children) + + # If there is a child, it is the default value of the parameter. + lit = self.transform(child) + if node.type.kind in self._data_types["int"]: + val = Integer(lit) + elif node.type.kind in self._data_types["float"]: + val = Float(lit) + elif node.type.kind in self._data_types["bool"]: + val = sympify(bool(lit)) + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + + param = Variable( + node.spelling + ).as_Declaration( + type = type, + value = val + ) + except StopIteration: + param = Variable( + node.spelling + ).as_Declaration( + type = type + ) + + try: + self.transform(next(children)) + raise ValueError("Can't handle multiple children on parameter") + except StopIteration: + pass + + return param + + def transform_integer_literal(self, node): + """Transformation function for integer literal + + Used to get the value and type of the given integer literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of the integer + value contains the value stored in the variable + + Notes + ===== + + Only Base Integer type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except StopIteration: + # No tokens + value = node.literal + return int(value) + + def transform_floating_literal(self, node): + """Transformation function for floating literal + + Used to get the value and type of the given floating literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of float + value contains the value stored in the variable + + Notes + ===== + + Only Base Float type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return float(value) + + def transform_string_literal(self, node): + #TODO: No string type in AST + #type = + #try: + # value = next(node.get_tokens()).spelling + #except (StopIteration, ValueError): + # No tokens + # value = node.literal + #val = [type, value] + #return val + pass + + def transform_character_literal(self, node): + """Transformation function for character literal + + Used to get the value of the given character literal. + + Returns + ======= + + val : int + val contains the ascii value of the character literal + + Notes + ===== + + Only for cases where character is assigned to a integer value, + since character literal is not in SymPy AST + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return ord(str(value[1])) + + def transform_cxx_bool_literal_expr(self, node): + """Transformation function for boolean literal + + Used to get the value of the given boolean literal. + + Returns + ======= + + value : bool + value contains the boolean value of the variable + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + value = node.literal + return True if value == 'true' else False + + def transform_unexposed_decl(self,node): + """Transformation function for unexposed declarations""" + pass + + def transform_unexposed_expr(self, node): + """Transformation function for unexposed expression + + Unexposed expressions are used to wrap float, double literals and + expressions + + Returns + ======= + + expr : Codegen AST Node + the result from the wrapped expression + + None : NoneType + No children are found for the node + + Raises + ====== + + ValueError if the expression contains multiple children + + """ + # Ignore unexposed nodes; pass whatever is the first + # (and should be only) child unaltered. + try: + children = node.get_children() + expr = self.transform(next(children)) + except StopIteration: + return None + + try: + next(children) + raise ValueError("Unexposed expression has > 1 children.") + except StopIteration: + pass + + return expr + + def transform_decl_ref_expr(self, node): + """Returns the name of the declaration reference""" + return node.spelling + + def transform_call_expr(self, node): + """Transformation function for a call expression + + Used to create function call nodes for the function calls present + in the C code + + Returns + ======= + + FunctionCall : Codegen AST Node + FunctionCall node with parameters if any parameters are present + + """ + param = [] + children = node.get_children() + child = next(children) + + while child.kind == cin.CursorKind.NAMESPACE_REF: + child = next(children) + while child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + first_child = self.transform(child) + try: + for child in children: + arg = self.transform(child) + if child.kind == cin.CursorKind.INTEGER_LITERAL: + param.append(Integer(arg)) + elif child.kind == cin.CursorKind.FLOATING_LITERAL: + param.append(Float(arg)) + else: + param.append(arg) + return FunctionCall(first_child, param) + + except StopIteration: + return FunctionCall(first_child) + + def transform_return_stmt(self, node): + """Returns the Return Node for a return statement""" + return Return(next(node.get_children()).spelling) + + def transform_compound_stmt(self, node): + """Transformation function for compound statements + + Returns + ======= + + expr : list + list of Nodes for the expressions present in the statement + + None : NoneType + if the compound statement is empty + + """ + expr = [] + children = node.get_children() + + for child in children: + expr.append(self.transform(child)) + return expr + + def transform_decl_stmt(self, node): + """Transformation function for declaration statements + + These statements are used to wrap different kinds of declararions + like variable or function declaration + The function calls the transformer function for the child of the + given node + + Returns + ======= + + statement : Codegen AST Node + contains the node returned by the children node for the type of + declaration + + Raises + ====== + + ValueError if multiple children present + + """ + try: + children = node.get_children() + statement = self.transform(next(children)) + except StopIteration: + pass + + try: + self.transform(next(children)) + raise ValueError("Don't know how to handle multiple statements") + except StopIteration: + pass + + return statement + + def transform_paren_expr(self, node): + """Transformation function for Parenthesized expressions + + Returns the result from its children nodes + + """ + return self.transform(next(node.get_children())) + + def transform_compound_assignment_operator(self, node): + """Transformation function for handling shorthand operators + + Returns + ======= + + augmented_assignment_expression: Codegen AST node + shorthand assignment expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If the shorthand operator for bitwise operators + (~=, ^=, &=, |=, <<=, >>=) is encountered + + """ + return self.transform_binary_operator(node) + + def transform_unary_operator(self, node): + """Transformation function for handling unary operators + + Returns + ======= + + unary_expression: Codegen AST node + simplified unary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If dereferencing operator(*), address operator(&) or + bitwise NOT operator(~) is encountered + + """ + # supported operators list + operators_list = ['+', '-', '++', '--', '!'] + tokens = list(node.get_tokens()) + + # it can be either pre increment/decrement or any other operator from the list + if tokens[0].spelling in operators_list: + child = self.transform(next(node.get_children())) + # (decl_ref) e.g.; int a = ++b; or simply ++b; + if isinstance(child, str): + if tokens[0].spelling == '+': + return Symbol(child) + if tokens[0].spelling == '-': + return Mul(Symbol(child), -1) + if tokens[0].spelling == '++': + return PreIncrement(Symbol(child)) + if tokens[0].spelling == '--': + return PreDecrement(Symbol(child)) + if tokens[0].spelling == '!': + return Not(Symbol(child)) + # e.g.; int a = -1; or int b = -(1 + 2); + else: + if tokens[0].spelling == '+': + return child + if tokens[0].spelling == '-': + return Mul(child, -1) + if tokens[0].spelling == '!': + return Not(sympify(bool(child))) + + # it can be either post increment/decrement + # since variable name is obtained in token[0].spelling + elif tokens[1].spelling in ['++', '--']: + child = self.transform(next(node.get_children())) + if tokens[1].spelling == '++': + return PostIncrement(Symbol(child)) + if tokens[1].spelling == '--': + return PostDecrement(Symbol(child)) + else: + raise NotImplementedError("Dereferencing operator, " + "Address operator and bitwise NOT operator " + "have not been implemented yet!") + + def transform_binary_operator(self, node): + """Transformation function for handling binary operators + + Returns + ======= + + binary_expression: Codegen AST node + simplified binary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If a bitwise operator or + unary operator(which is a child of any binary + operator in Clang AST) is encountered + + """ + # get all the tokens of assignment + # and store it in the tokens list + tokens = list(node.get_tokens()) + + # supported operators list + operators_list = ['+', '-', '*', '/', '%','=', + '>', '>=', '<', '<=', '==', '!=', '&&', '||', '+=', '-=', + '*=', '/=', '%='] + + # this stack will contain variable content + # and type of variable in the rhs + combined_variables_stack = [] + + # this stack will contain operators + # to be processed in the rhs + operators_stack = [] + + # iterate through every token + for token in tokens: + # token is either '(', ')' or + # any of the supported operators from the operator list + if token.kind == cin.TokenKind.PUNCTUATION: + + # push '(' to the operators stack + if token.spelling == '(': + operators_stack.append('(') + + elif token.spelling == ')': + # keep adding the expression to the + # combined variables stack unless + # '(' is found + while (operators_stack + and operators_stack[-1] != '('): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # pop '(' + operators_stack.pop() + + # token is an operator (supported) + elif token.spelling in operators_list: + while (operators_stack + and self.priority_of(token.spelling) + <= self.priority_of( + operators_stack[-1])): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # push current operator + operators_stack.append(token.spelling) + + # token is a bitwise operator + elif token.spelling in ['&', '|', '^', '<<', '>>']: + raise NotImplementedError( + "Bitwise operator has not been " + "implemented yet!") + + # token is a shorthand bitwise operator + elif token.spelling in ['&=', '|=', '^=', '<<=', + '>>=']: + raise NotImplementedError( + "Shorthand bitwise operator has not been " + "implemented yet!") + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # token is an identifier(variable) + elif token.kind == cin.TokenKind.IDENTIFIER: + combined_variables_stack.append( + [token.spelling, 'identifier']) + + # token is a literal + elif token.kind == cin.TokenKind.LITERAL: + combined_variables_stack.append( + [token.spelling, 'literal']) + + # token is a keyword, either true or false + elif (token.kind == cin.TokenKind.KEYWORD + and token.spelling in ['true', 'false']): + combined_variables_stack.append( + [token.spelling, 'boolean']) + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # process remaining operators + while operators_stack: + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation(lhs, rhs, operator)) + + return combined_variables_stack[-1][0] + + def priority_of(self, op): + """To get the priority of given operator""" + if op in ['=', '+=', '-=', '*=', '/=', '%=']: + return 1 + if op in ['&&', '||']: + return 2 + if op in ['<', '<=', '>', '>=', '==', '!=']: + return 3 + if op in ['+', '-']: + return 4 + if op in ['*', '/', '%']: + return 5 + return 0 + + def perform_operation(self, lhs, rhs, op): + """Performs operation supported by the SymPy core + + Returns + ======= + + combined_variable: list + contains variable content and type of variable + + """ + lhs_value = self.get_expr_for_operand(lhs) + rhs_value = self.get_expr_for_operand(rhs) + if op == '+': + return [Add(lhs_value, rhs_value), 'expr'] + if op == '-': + return [Add(lhs_value, -rhs_value), 'expr'] + if op == '*': + return [Mul(lhs_value, rhs_value), 'expr'] + if op == '/': + return [Mul(lhs_value, Pow(rhs_value, Integer(-1))), 'expr'] + if op == '%': + return [Mod(lhs_value, rhs_value), 'expr'] + if op in ['<', '<=', '>', '>=', '==', '!=']: + return [Rel(lhs_value, rhs_value, op), 'expr'] + if op == '&&': + return [And(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '||': + return [Or(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '=': + return [Assignment(Variable(lhs_value), rhs_value), 'expr'] + if op in ['+=', '-=', '*=', '/=', '%=']: + return [aug_assign(Variable(lhs_value), op[0], rhs_value), 'expr'] + + def get_expr_for_operand(self, combined_variable): + """Gives out SymPy Codegen AST node + + AST node returned is corresponding to + combined variable passed.Combined variable contains + variable content and type of variable + + """ + if combined_variable[1] == 'identifier': + return Symbol(combined_variable[0]) + if combined_variable[1] == 'literal': + if '.' in combined_variable[0]: + return Float(float(combined_variable[0])) + else: + return Integer(int(combined_variable[0])) + if combined_variable[1] == 'expr': + return combined_variable[0] + if combined_variable[1] == 'boolean': + return true if combined_variable[0] == 'true' else false + + def transform_null_stmt(self, node): + """Handles Null Statement and returns None""" + return none + + def transform_while_stmt(self, node): + """Transformation function for handling while statement + + Returns + ======= + + while statement : Codegen AST Node + contains the while statement node having condition and + statement block + + """ + children = node.get_children() + + condition = self.transform(next(children)) + statements = self.transform(next(children)) + + if isinstance(statements, list): + statement_block = CodeBlock(*statements) + else: + statement_block = CodeBlock(statements) + + return While(condition, statement_block) + + + +else: + class CCodeConverter(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError("Module not Installed") + + +def parse_c(source): + """Function for converting a C source code + + The function reads the source code present in the given file and parses it + to give out SymPy Expressions + + Returns + ======= + + src : list + List of Python expression strings + + """ + converter = CCodeConverter() + if os.path.exists(source): + src = converter.parse(source, flags = []) + else: + src = converter.parse_str(source, flags = []) + return src diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65e37cf3de2dddbcee0fa5c7eeac2fdc9f685db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__init__.py @@ -0,0 +1 @@ +"""Used for translating Fortran source code into a SymPy expression. """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb376fd4d7dec206066b397f9c8e7cecefe695b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b55cc74f9428026d5d1b49331143be9d9291986 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/fortran_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..504249f6119a59a90d91c5e989f893cffe20e643 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/fortran/fortran_parser.py @@ -0,0 +1,347 @@ +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment) + from sympy.core import Add, Mul, Integer, Float + from sympy.core.symbol import Symbol + + asr_mod = lfortran.asr + asr = lfortran.asr.asr + src_to_ast = lfortran.ast.src_to_ast + ast_to_asr = lfortran.semantic.ast_to_asr.ast_to_asr + + """ + This module contains all the necessary Classes and Function used to Parse + Fortran code into SymPy expression + + The module and its API are currently under development and experimental. + It is also dependent on LFortran for the ASR that is converted to SymPy syntax + which is also under development. + The module only supports the features currently supported by the LFortran ASR + which will be updated as the development of LFortran and this module progresses + + You might find unexpected bugs and exceptions while using the module, feel free + to report them to the SymPy Issue Tracker + + The API for the module might also change while in development if better and + more effective ways are discovered for the process + + Features Supported + ================== + + - Variable Declarations (integers and reals) + - Function Definitions + - Assignments and Basic Binary Operations + + + Notes + ===== + + The module depends on an external dependency + + LFortran : Required to parse Fortran source code into ASR + + + References + ========== + + .. [1] https://github.com/sympy/sympy/issues + .. [2] https://gitlab.com/lfortran/lfortran + .. [3] https://docs.lfortran.org/ + + """ + + + class ASR2PyVisitor(asr.ASTVisitor): # type: ignore + """ + Visitor Class for LFortran ASR + + It is a Visitor class derived from asr.ASRVisitor which visits all the + nodes of the LFortran ASR and creates corresponding AST node for each + ASR node + + """ + + def __init__(self): + """Initialize the Parser""" + self._py_ast = [] + + def visit_TranslationUnit(self, node): + """ + Function to visit all the elements of the Translation Unit + created by LFortran ASR + """ + for s in node.global_scope.symbols: + sym = node.global_scope.symbols[s] + self.visit(sym) + for item in node.items: + self.visit(item) + + def visit_Assignment(self, node): + """Visitor Function for Assignment + + Visits each Assignment is the LFortran ASR and creates corresponding + assignment for SymPy. + + Notes + ===== + + The function currently only supports variable assignment and binary + operation assignments of varying multitudes. Any type of numberS or + array is not supported. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments or Arrays + + """ + # TODO: Arithmetic Assignment + if isinstance(node.target, asr.Variable): + target = node.target + value = node.value + if isinstance(value, asr.Variable): + new_node = Assignment( + Variable( + target.name + ), + Variable( + value.name + ) + ) + elif (type(value) == asr.BinOp): + exp_ast = call_visitor(value) + for expr in exp_ast: + new_node = Assignment( + Variable(target.name), + expr + ) + else: + raise NotImplementedError("Numeric assignments not supported") + else: + raise NotImplementedError("Arrays not supported") + self._py_ast.append(new_node) + + def visit_BinOp(self, node): + """Visitor Function for Binary Operations + + Visits each binary operation present in the LFortran ASR like addition, + subtraction, multiplication, division and creates the corresponding + operation node in SymPy's AST + + In case of more than one binary operations, the function calls the + call_visitor() function on the child nodes of the binary operations + recursively until all the operations have been processed. + + Notes + ===== + + The function currently only supports binary operations with Variables + or other binary operations. Numerics are not supported as of yet. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments + + """ + # TODO: Integer Binary Operations + op = node.op + lhs = node.left + rhs = node.right + + if (type(lhs) == asr.Variable): + left_value = Symbol(lhs.name) + elif(type(lhs) == asr.BinOp): + l_exp_ast = call_visitor(lhs) + for exp in l_exp_ast: + left_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if (type(rhs) == asr.Variable): + right_value = Symbol(rhs.name) + elif(type(rhs) == asr.BinOp): + r_exp_ast = call_visitor(rhs) + for exp in r_exp_ast: + right_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if isinstance(op, asr.Add): + new_node = Add(left_value, right_value) + elif isinstance(op, asr.Sub): + new_node = Add(left_value, -right_value) + elif isinstance(op, asr.Div): + new_node = Mul(left_value, 1/right_value) + elif isinstance(op, asr.Mul): + new_node = Mul(left_value, right_value) + + self._py_ast.append(new_node) + + def visit_Variable(self, node): + """Visitor Function for Variable Declaration + + Visits each variable declaration present in the ASR and creates a + Symbol declaration for each variable + + Notes + ===== + + The functions currently only support declaration of integer and + real variables. Other data types are still under development. + + Raises + ====== + + NotImplementedError() when called for unsupported data types + + """ + if isinstance(node.type, asr.Integer): + var_type = IntBaseType(String('integer')) + value = Integer(0) + elif isinstance(node.type, asr.Real): + var_type = FloatBaseType(String('real')) + value = Float(0.0) + else: + raise NotImplementedError("Data type not supported") + + if not (node.intent == 'in'): + new_node = Variable( + node.name + ).as_Declaration( + type = var_type, + value = value + ) + self._py_ast.append(new_node) + + def visit_Sequence(self, seq): + """Visitor Function for code sequence + + Visits a code sequence/ block and calls the visitor function on all the + children of the code block to create corresponding code in python + + """ + if seq is not None: + for node in seq: + self._py_ast.append(call_visitor(node)) + + def visit_Num(self, node): + """Visitor Function for Numbers in ASR + + This function is currently under development and will be updated + with improvements in the LFortran ASR + + """ + # TODO:Numbers when the LFortran ASR is updated + # self._py_ast.append(Integer(node.n)) + pass + + def visit_Function(self, node): + """Visitor Function for function Definitions + + Visits each function definition present in the ASR and creates a + function definition node in the Python AST with all the elements of the + given function + + The functions declare all the variables required as SymPy symbols in + the function before the function definition + + This function also the call_visior_function to parse the contents of + the function body + + """ + # TODO: Return statement, variable declaration + fn_args = [Variable(arg_iter.name) for arg_iter in node.args] + fn_body = [] + fn_name = node.name + for i in node.body: + fn_ast = call_visitor(i) + try: + fn_body_expr = fn_ast + except UnboundLocalError: + fn_body_expr = [] + for sym in node.symtab.symbols: + decl = call_visitor(node.symtab.symbols[sym]) + for symbols in decl: + fn_body.append(symbols) + for elem in fn_body_expr: + fn_body.append(elem) + fn_body.append( + Return( + Variable( + node.return_var.name + ) + ) + ) + if isinstance(node.return_var.type, asr.Integer): + ret_type = IntBaseType(String('integer')) + elif isinstance(node.return_var.type, asr.Real): + ret_type = FloatBaseType(String('real')) + else: + raise NotImplementedError("Data type not supported") + new_node = FunctionDefinition( + return_type = ret_type, + name = fn_name, + parameters = fn_args, + body = fn_body + ) + self._py_ast.append(new_node) + + def ret_ast(self): + """Returns the AST nodes""" + return self._py_ast +else: + class ASR2PyVisitor(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError('lfortran not available') + +def call_visitor(fort_node): + """Calls the AST Visitor on the Module + + This function is used to call the AST visitor for a program or module + It imports all the required modules and calls the visit() function + on the given node + + Parameters + ========== + + fort_node : LFortran ASR object + Node for the operation for which the NodeVisitor is called + + Returns + ======= + + res_ast : list + list of SymPy AST Nodes + + """ + v = ASR2PyVisitor() + v.visit(fort_node) + res_ast = v.ret_ast() + return res_ast + + +def src_to_sympy(src): + """Wrapper function to convert the given Fortran source code to SymPy Expressions + + Parameters + ========== + + src : string + A string with the Fortran source code + + Returns + ======= + + py_src : string + A string with the Python source code compatible with SymPy + + """ + a_ast = src_to_ast(src, translation_unit=False) + a = ast_to_asr(a_ast) + py_src = call_visitor(a) + return py_src diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LICENSE.txt b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bbfda911b2afada41a568218e31a6502dc68f44 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright 2016, latex2sympy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LaTeX.g4 b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LaTeX.g4 new file mode 100644 index 0000000000000000000000000000000000000000..fc2c30f9817931e2060b549a39f98a6a4f9cb1f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/LaTeX.g4 @@ -0,0 +1,312 @@ +/* + ANTLR4 LaTeX Math Grammar + + Ported from latex2sympy by @augustt198 https://github.com/augustt198/latex2sympy See license in + LICENSE.txt + */ + +/* + After changing this file, it is necessary to run `python setup.py antlr` in the root directory of + the repository. This will regenerate the code in `sympy/parsing/latex/_antlr/*.py`. + */ + +grammar LaTeX; + +options { + language = Python3; +} + +WS: [ \t\r\n]+ -> skip; +THINSPACE: ('\\,' | '\\thinspace') -> skip; +MEDSPACE: ('\\:' | '\\medspace') -> skip; +THICKSPACE: ('\\;' | '\\thickspace') -> skip; +QUAD: '\\quad' -> skip; +QQUAD: '\\qquad' -> skip; +NEGTHINSPACE: ('\\!' | '\\negthinspace') -> skip; +NEGMEDSPACE: '\\negmedspace' -> skip; +NEGTHICKSPACE: '\\negthickspace' -> skip; +CMD_LEFT: '\\left' -> skip; +CMD_RIGHT: '\\right' -> skip; + +IGNORE: + ( + '\\vrule' + | '\\vcenter' + | '\\vbox' + | '\\vskip' + | '\\vspace' + | '\\hfil' + | '\\*' + | '\\-' + | '\\.' + | '\\/' + | '\\"' + | '\\(' + | '\\=' + ) -> skip; + +ADD: '+'; +SUB: '-'; +MUL: '*'; +DIV: '/'; + +L_PAREN: '('; +R_PAREN: ')'; +L_BRACE: '{'; +R_BRACE: '}'; +L_BRACE_LITERAL: '\\{'; +R_BRACE_LITERAL: '\\}'; +L_BRACKET: '['; +R_BRACKET: ']'; + +BAR: '|'; + +R_BAR: '\\right|'; +L_BAR: '\\left|'; + +L_ANGLE: '\\langle'; +R_ANGLE: '\\rangle'; +FUNC_LIM: '\\lim'; +LIM_APPROACH_SYM: + '\\to' + | '\\rightarrow' + | '\\Rightarrow' + | '\\longrightarrow' + | '\\Longrightarrow'; +FUNC_INT: + '\\int' + | '\\int\\limits'; +FUNC_SUM: '\\sum'; +FUNC_PROD: '\\prod'; + +FUNC_EXP: '\\exp'; +FUNC_LOG: '\\log'; +FUNC_LG: '\\lg'; +FUNC_LN: '\\ln'; +FUNC_SIN: '\\sin'; +FUNC_COS: '\\cos'; +FUNC_TAN: '\\tan'; +FUNC_CSC: '\\csc'; +FUNC_SEC: '\\sec'; +FUNC_COT: '\\cot'; + +FUNC_ARCSIN: '\\arcsin'; +FUNC_ARCCOS: '\\arccos'; +FUNC_ARCTAN: '\\arctan'; +FUNC_ARCCSC: '\\arccsc'; +FUNC_ARCSEC: '\\arcsec'; +FUNC_ARCCOT: '\\arccot'; + +FUNC_SINH: '\\sinh'; +FUNC_COSH: '\\cosh'; +FUNC_TANH: '\\tanh'; +FUNC_ARSINH: '\\arsinh'; +FUNC_ARCOSH: '\\arcosh'; +FUNC_ARTANH: '\\artanh'; + +L_FLOOR: '\\lfloor'; +R_FLOOR: '\\rfloor'; +L_CEIL: '\\lceil'; +R_CEIL: '\\rceil'; + +FUNC_SQRT: '\\sqrt'; +FUNC_OVERLINE: '\\overline'; + +CMD_TIMES: '\\times'; +CMD_CDOT: '\\cdot'; +CMD_DIV: '\\div'; +CMD_FRAC: + '\\frac' + | '\\dfrac' + | '\\tfrac'; +CMD_BINOM: '\\binom'; +CMD_DBINOM: '\\dbinom'; +CMD_TBINOM: '\\tbinom'; + +CMD_MATHIT: '\\mathit'; + +UNDERSCORE: '_'; +CARET: '^'; +COLON: ':'; + +fragment WS_CHAR: [ \t\r\n]; +DIFFERENTIAL: 'd' WS_CHAR*? ([a-zA-Z] | '\\' [a-zA-Z]+); + +LETTER: [a-zA-Z]; +DIGIT: [0-9]; + +EQUAL: (('&' WS_CHAR*?)? '=') | ('=' (WS_CHAR*? '&')?); +NEQ: '\\neq'; + +LT: '<'; +LTE: ('\\leq' | '\\le' | LTE_Q | LTE_S); +LTE_Q: '\\leqq'; +LTE_S: '\\leqslant'; + +GT: '>'; +GTE: ('\\geq' | '\\ge' | GTE_Q | GTE_S); +GTE_Q: '\\geqq'; +GTE_S: '\\geqslant'; + +BANG: '!'; + +SINGLE_QUOTES: '\''+; + +SYMBOL: '\\' [a-zA-Z]+; + +math: relation; + +relation: + relation (EQUAL | LT | LTE | GT | GTE | NEQ) relation + | expr; + +equality: expr EQUAL expr; + +expr: additive; + +additive: additive (ADD | SUB) additive | mp; + +// mult part +mp: + mp (MUL | CMD_TIMES | CMD_CDOT | DIV | CMD_DIV | COLON) mp + | unary; + +mp_nofunc: + mp_nofunc ( + MUL + | CMD_TIMES + | CMD_CDOT + | DIV + | CMD_DIV + | COLON + ) mp_nofunc + | unary_nofunc; + +unary: (ADD | SUB) unary | postfix+; + +unary_nofunc: + (ADD | SUB) unary_nofunc + | postfix postfix_nofunc*; + +postfix: exp postfix_op*; +postfix_nofunc: exp_nofunc postfix_op*; +postfix_op: BANG | eval_at; + +eval_at: + BAR (eval_at_sup | eval_at_sub | eval_at_sup eval_at_sub); + +eval_at_sub: UNDERSCORE L_BRACE (expr | equality) R_BRACE; + +eval_at_sup: CARET L_BRACE (expr | equality) R_BRACE; + +exp: exp CARET (atom | L_BRACE expr R_BRACE) subexpr? | comp; + +exp_nofunc: + exp_nofunc CARET (atom | L_BRACE expr R_BRACE) subexpr? + | comp_nofunc; + +comp: + group + | abs_group + | func + | atom + | floor + | ceil; + +comp_nofunc: + group + | abs_group + | atom + | floor + | ceil; + +group: + L_PAREN expr R_PAREN + | L_BRACKET expr R_BRACKET + | L_BRACE expr R_BRACE + | L_BRACE_LITERAL expr R_BRACE_LITERAL; + +abs_group: BAR expr BAR; + +number: DIGIT+ (',' DIGIT DIGIT DIGIT)* ('.' DIGIT+)?; + +atom: (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) + | number + | DIFFERENTIAL + | mathit + | frac + | binom + | bra + | ket; + +bra: L_ANGLE expr (R_BAR | BAR); +ket: (L_BAR | BAR) expr R_ANGLE; + +mathit: CMD_MATHIT L_BRACE mathit_text R_BRACE; +mathit_text: LETTER*; + +frac: CMD_FRAC (upperd = DIGIT | L_BRACE upper = expr R_BRACE) + (lowerd = DIGIT | L_BRACE lower = expr R_BRACE); + +binom: + (CMD_BINOM | CMD_DBINOM | CMD_TBINOM) L_BRACE n = expr R_BRACE L_BRACE k = expr R_BRACE; + +floor: L_FLOOR val = expr R_FLOOR; +ceil: L_CEIL val = expr R_CEIL; + +func_normal: + FUNC_EXP + | FUNC_LOG + | FUNC_LG + | FUNC_LN + | FUNC_SIN + | FUNC_COS + | FUNC_TAN + | FUNC_CSC + | FUNC_SEC + | FUNC_COT + | FUNC_ARCSIN + | FUNC_ARCCOS + | FUNC_ARCTAN + | FUNC_ARCCSC + | FUNC_ARCSEC + | FUNC_ARCCOT + | FUNC_SINH + | FUNC_COSH + | FUNC_TANH + | FUNC_ARSINH + | FUNC_ARCOSH + | FUNC_ARTANH; + +func: + func_normal (subexpr? supexpr? | supexpr? subexpr?) ( + L_PAREN func_arg R_PAREN + | func_arg_noparens + ) + | (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) // e.g. f(x), f_1'(x) + L_PAREN args R_PAREN + | FUNC_INT (subexpr supexpr | supexpr subexpr)? ( + additive? DIFFERENTIAL + | frac + | additive + ) + | FUNC_SQRT (L_BRACKET root = expr R_BRACKET)? L_BRACE base = expr R_BRACE + | FUNC_OVERLINE L_BRACE base = expr R_BRACE + | (FUNC_SUM | FUNC_PROD) (subeq supexpr | supexpr subeq) mp + | FUNC_LIM limit_sub mp; + +args: (expr ',' args) | expr; + +limit_sub: + UNDERSCORE L_BRACE (LETTER | SYMBOL) LIM_APPROACH_SYM expr ( + CARET ((L_BRACE (ADD | SUB) R_BRACE) | ADD | SUB) + )? R_BRACE; + +func_arg: expr | (expr ',' func_arg); +func_arg_noparens: mp_nofunc; + +subexpr: UNDERSCORE (atom | L_BRACE expr R_BRACE); +supexpr: CARET (atom | L_BRACE expr R_BRACE); + +subeq: UNDERSCORE L_BRACE equality R_BRACE; +supeq: UNDERSCORE L_BRACE equality R_BRACE; diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9466d37b8b06f1f292c73f975e44d21c96da10d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__init__.py @@ -0,0 +1,204 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on +from re import compile as rcompile + +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark # noqa + +from .errors import LaTeXParsingError # noqa + + +IGNORE_L = r"\s*[{]*\s*" +IGNORE_R = r"\s*[}]*\s*" +NO_LEFT = r"(? len(latex_str): + e = len(latex_str) + eellipsis = "" + + if x[3] in END_DELIM_REPR: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{BEGIN_DELIM_REPR[MATRIX_DELIMS_INV[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + if x[7] is None: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{END_DELIM_REPR[MATRIX_DELIMS[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + correct_end_regex = MATRIX_DELIMS[x[3]] + sellipsis = "..." if x[0] > 0 else "" + eellipsis = "..." if x[5] < len(latex_str) else "" + if x[7] != correct_end_regex: + err = ("Expected " + f"'{END_DELIM_REPR[correct_end_regex]}' " + f"to close the '{x[2]}' at index {x[0]} but " + f"found '{x[6]}' at index {x[4]} of LaTeX " + f"string instead: {sellipsis}{latex_str[x[0]:x[5]]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + +__doctest_requires__ = {('parse_latex',): ['antlr4', 'lark']} + + +@doctest_depends_on(modules=('antlr4', 'lark')) +def parse_latex(s, strict=False, backend="antlr"): + r"""Converts the input LaTeX string ``s`` to a SymPy ``Expr``. + + Parameters + ========== + + s : str + The LaTeX string to parse. In Python source containing LaTeX, + *raw strings* (denoted with ``r"``, like this one) are preferred, + as LaTeX makes liberal use of the ``\`` character, which would + trigger escaping in normal Python strings. + backend : str, optional + Currently, there are two backends supported: ANTLR, and Lark. + The default setting is to use the ANTLR backend, which can be + changed to Lark if preferred. + + Use ``backend="antlr"`` for the ANTLR-based parser, and + ``backend="lark"`` for the Lark-based parser. + + The ``backend`` option is case-sensitive, and must be in + all lowercase. + strict : bool, optional + This option is only available with the ANTLR backend. + + If True, raise an exception if the string cannot be parsed as + valid LaTeX. If False, try to recover gracefully from common + mistakes. + + Examples + ======== + + >>> from sympy.parsing.latex import parse_latex + >>> expr = parse_latex(r"\frac {1 + \sqrt {\a}} {\b}") + >>> expr + (sqrt(a) + 1)/b + >>> expr.evalf(4, subs=dict(a=5, b=2)) + 1.618 + >>> func = parse_latex(r"\int_1^\alpha \dfrac{\mathrm{d}t}{t}", backend="lark") + >>> func.evalf(subs={"alpha": 2}) + 0.693147180559945 + """ + + check_matrix_delimiters(s) + + if backend == "antlr": + _latex = import_module( + 'sympy.parsing.latex._parse_latex_antlr', + import_kwargs={'fromlist': ['X']}) + + if _latex is not None: + return _latex.parse_latex(s, strict) + elif backend == "lark": + return parse_latex_lark(s) + else: + raise NotImplementedError(f"Using the '{backend}' backend in the LaTeX" \ + " parser is not supported.") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce5f35676c21377400e74d832647dbe4a89cdb5e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b34ec3f9a4af4649193ff187d939890c2db6d13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59df8fd409bd4f90f868b7c23f13221cc32aada Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a72bd98051f730ece6e5a6ff5455bcec69decff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d690e1eb8631ee7731fc1875769d3a4704a1743 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__init__.py @@ -0,0 +1,9 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4928df7a777daf3fbe5a93c79540173c7b593358 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a30ceb581c77f321be4a20716bbe9a89bd7e050 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexlexer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..46ca959736c967782eef360b9b3268ccd0be0979 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexlexer.py @@ -0,0 +1,512 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,91,911,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,2,51,7,51,2, + 52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2,56,7,56,2,57,7,57,2,58,7, + 58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2,63,7,63,2,64,7,64,2, + 65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2,70,7,70,2,71,7, + 71,2,72,7,72,2,73,7,73,2,74,7,74,2,75,7,75,2,76,7,76,2,77,7,77,2, + 78,7,78,2,79,7,79,2,80,7,80,2,81,7,81,2,82,7,82,2,83,7,83,2,84,7, + 84,2,85,7,85,2,86,7,86,2,87,7,87,2,88,7,88,2,89,7,89,2,90,7,90,2, + 91,7,91,1,0,1,0,1,1,1,1,1,2,4,2,191,8,2,11,2,12,2,192,1,2,1,2,1, + 3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,3,3,209,8,3,1,3,1, + 3,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,3,4,224,8,4,1,4,1, + 4,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,3,5,241,8, + 5,1,5,1,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,7,1,7,1,7,1,7,1,7,1, + 7,1,7,1,7,1,7,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1, + 8,1,8,1,8,3,8,277,8,8,1,8,1,8,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1, + 9,1,9,1,9,1,9,1,9,1,9,1,9,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10, + 1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,11,1,11,1,11,1,11, + 1,11,1,11,1,11,1,11,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,3,13, + 381,8,13,1,13,1,13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18, + 1,18,1,19,1,19,1,20,1,20,1,21,1,21,1,22,1,22,1,22,1,23,1,23,1,23, + 1,24,1,24,1,25,1,25,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1,29,1,29,1,29,1,29,1,29, + 1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,31,1,31, + 1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,3,32,504,8,32,1,33,1,33,1,33,1,33, + 1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,3,33,521, + 8,33,1,34,1,34,1,34,1,34,1,34,1,35,1,35,1,35,1,35,1,35,1,35,1,36, + 1,36,1,36,1,36,1,36,1,37,1,37,1,37,1,37,1,37,1,38,1,38,1,38,1,38, + 1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41, + 1,41,1,42,1,42,1,42,1,42,1,42,1,43,1,43,1,43,1,43,1,43,1,44,1,44, + 1,44,1,44,1,44,1,45,1,45,1,45,1,45,1,45,1,46,1,46,1,46,1,46,1,46, + 1,46,1,46,1,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,48,1,48, + 1,48,1,48,1,48,1,48,1,48,1,48,1,49,1,49,1,49,1,49,1,49,1,49,1,49, + 1,49,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,51,1,51,1,51,1,51, + 1,51,1,51,1,51,1,51,1,52,1,52,1,52,1,52,1,52,1,52,1,53,1,53,1,53, + 1,53,1,53,1,53,1,54,1,54,1,54,1,54,1,54,1,54,1,55,1,55,1,55,1,55, + 1,55,1,55,1,55,1,55,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,57, + 1,57,1,57,1,57,1,57,1,57,1,57,1,57,1,58,1,58,1,58,1,58,1,58,1,58, + 1,58,1,58,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,60,1,60,1,60, + 1,60,1,60,1,60,1,60,1,61,1,61,1,61,1,61,1,61,1,61,1,61,1,62,1,62, + 1,62,1,62,1,62,1,62,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63, + 1,63,1,64,1,64,1,64,1,64,1,64,1,64,1,64,1,65,1,65,1,65,1,65,1,65, + 1,65,1,66,1,66,1,66,1,66,1,66,1,67,1,67,1,67,1,67,1,67,1,67,1,67, + 1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,3,67,753,8,67, + 1,68,1,68,1,68,1,68,1,68,1,68,1,68,1,69,1,69,1,69,1,69,1,69,1,69, + 1,69,1,69,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,71,1,71,1,71, + 1,71,1,71,1,71,1,71,1,71,1,72,1,72,1,73,1,73,1,74,1,74,1,75,1,75, + 1,76,1,76,5,76,796,8,76,10,76,12,76,799,9,76,1,76,1,76,1,76,4,76, + 804,8,76,11,76,12,76,805,3,76,808,8,76,1,77,1,77,1,78,1,78,1,79, + 1,79,5,79,816,8,79,10,79,12,79,819,9,79,3,79,821,8,79,1,79,1,79, + 1,79,5,79,826,8,79,10,79,12,79,829,9,79,1,79,3,79,832,8,79,3,79, + 834,8,79,1,80,1,80,1,80,1,80,1,80,1,81,1,81,1,82,1,82,1,82,1,82, + 1,82,1,82,1,82,1,82,1,82,3,82,852,8,82,1,83,1,83,1,83,1,83,1,83, + 1,83,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,85,1,85, + 1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,3,86,881,8,86,1,87, + 1,87,1,87,1,87,1,87,1,87,1,88,1,88,1,88,1,88,1,88,1,88,1,88,1,88, + 1,88,1,88,1,89,1,89,1,90,4,90,902,8,90,11,90,12,90,903,1,91,1,91, + 4,91,908,8,91,11,91,12,91,909,3,797,817,827,0,92,1,1,3,2,5,3,7,4, + 9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13,27,14,29,15,31,16, + 33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24,49,25,51,26,53,27, + 55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35,71,36,73,37,75,38, + 77,39,79,40,81,41,83,42,85,43,87,44,89,45,91,46,93,47,95,48,97,49, + 99,50,101,51,103,52,105,53,107,54,109,55,111,56,113,57,115,58,117, + 59,119,60,121,61,123,62,125,63,127,64,129,65,131,66,133,67,135,68, + 137,69,139,70,141,71,143,72,145,73,147,74,149,75,151,0,153,76,155, + 77,157,78,159,79,161,80,163,81,165,82,167,83,169,84,171,85,173,86, + 175,87,177,88,179,89,181,90,183,91,1,0,3,3,0,9,10,13,13,32,32,2, + 0,65,90,97,122,1,0,48,57,949,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0, + 0,7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0, + 17,1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0, + 27,1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0, + 37,1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0, + 47,1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0, + 57,1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0, + 67,1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0, + 77,1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0, + 87,1,0,0,0,0,89,1,0,0,0,0,91,1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0, + 97,1,0,0,0,0,99,1,0,0,0,0,101,1,0,0,0,0,103,1,0,0,0,0,105,1,0,0, + 0,0,107,1,0,0,0,0,109,1,0,0,0,0,111,1,0,0,0,0,113,1,0,0,0,0,115, + 1,0,0,0,0,117,1,0,0,0,0,119,1,0,0,0,0,121,1,0,0,0,0,123,1,0,0,0, + 0,125,1,0,0,0,0,127,1,0,0,0,0,129,1,0,0,0,0,131,1,0,0,0,0,133,1, + 0,0,0,0,135,1,0,0,0,0,137,1,0,0,0,0,139,1,0,0,0,0,141,1,0,0,0,0, + 143,1,0,0,0,0,145,1,0,0,0,0,147,1,0,0,0,0,149,1,0,0,0,0,153,1,0, + 0,0,0,155,1,0,0,0,0,157,1,0,0,0,0,159,1,0,0,0,0,161,1,0,0,0,0,163, + 1,0,0,0,0,165,1,0,0,0,0,167,1,0,0,0,0,169,1,0,0,0,0,171,1,0,0,0, + 0,173,1,0,0,0,0,175,1,0,0,0,0,177,1,0,0,0,0,179,1,0,0,0,0,181,1, + 0,0,0,0,183,1,0,0,0,1,185,1,0,0,0,3,187,1,0,0,0,5,190,1,0,0,0,7, + 208,1,0,0,0,9,223,1,0,0,0,11,240,1,0,0,0,13,244,1,0,0,0,15,252,1, + 0,0,0,17,276,1,0,0,0,19,280,1,0,0,0,21,295,1,0,0,0,23,312,1,0,0, + 0,25,320,1,0,0,0,27,380,1,0,0,0,29,384,1,0,0,0,31,386,1,0,0,0,33, + 388,1,0,0,0,35,390,1,0,0,0,37,392,1,0,0,0,39,394,1,0,0,0,41,396, + 1,0,0,0,43,398,1,0,0,0,45,400,1,0,0,0,47,403,1,0,0,0,49,406,1,0, + 0,0,51,408,1,0,0,0,53,410,1,0,0,0,55,412,1,0,0,0,57,420,1,0,0,0, + 59,427,1,0,0,0,61,435,1,0,0,0,63,443,1,0,0,0,65,503,1,0,0,0,67,520, + 1,0,0,0,69,522,1,0,0,0,71,527,1,0,0,0,73,533,1,0,0,0,75,538,1,0, + 0,0,77,543,1,0,0,0,79,547,1,0,0,0,81,551,1,0,0,0,83,556,1,0,0,0, + 85,561,1,0,0,0,87,566,1,0,0,0,89,571,1,0,0,0,91,576,1,0,0,0,93,581, + 1,0,0,0,95,589,1,0,0,0,97,597,1,0,0,0,99,605,1,0,0,0,101,613,1,0, + 0,0,103,621,1,0,0,0,105,629,1,0,0,0,107,635,1,0,0,0,109,641,1,0, + 0,0,111,647,1,0,0,0,113,655,1,0,0,0,115,663,1,0,0,0,117,671,1,0, + 0,0,119,679,1,0,0,0,121,687,1,0,0,0,123,694,1,0,0,0,125,701,1,0, + 0,0,127,707,1,0,0,0,129,717,1,0,0,0,131,724,1,0,0,0,133,730,1,0, + 0,0,135,752,1,0,0,0,137,754,1,0,0,0,139,761,1,0,0,0,141,769,1,0, + 0,0,143,777,1,0,0,0,145,785,1,0,0,0,147,787,1,0,0,0,149,789,1,0, + 0,0,151,791,1,0,0,0,153,793,1,0,0,0,155,809,1,0,0,0,157,811,1,0, + 0,0,159,833,1,0,0,0,161,835,1,0,0,0,163,840,1,0,0,0,165,851,1,0, + 0,0,167,853,1,0,0,0,169,859,1,0,0,0,171,869,1,0,0,0,173,880,1,0, + 0,0,175,882,1,0,0,0,177,888,1,0,0,0,179,898,1,0,0,0,181,901,1,0, + 0,0,183,905,1,0,0,0,185,186,5,44,0,0,186,2,1,0,0,0,187,188,5,46, + 0,0,188,4,1,0,0,0,189,191,7,0,0,0,190,189,1,0,0,0,191,192,1,0,0, + 0,192,190,1,0,0,0,192,193,1,0,0,0,193,194,1,0,0,0,194,195,6,2,0, + 0,195,6,1,0,0,0,196,197,5,92,0,0,197,209,5,44,0,0,198,199,5,92,0, + 0,199,200,5,116,0,0,200,201,5,104,0,0,201,202,5,105,0,0,202,203, + 5,110,0,0,203,204,5,115,0,0,204,205,5,112,0,0,205,206,5,97,0,0,206, + 207,5,99,0,0,207,209,5,101,0,0,208,196,1,0,0,0,208,198,1,0,0,0,209, + 210,1,0,0,0,210,211,6,3,0,0,211,8,1,0,0,0,212,213,5,92,0,0,213,224, + 5,58,0,0,214,215,5,92,0,0,215,216,5,109,0,0,216,217,5,101,0,0,217, + 218,5,100,0,0,218,219,5,115,0,0,219,220,5,112,0,0,220,221,5,97,0, + 0,221,222,5,99,0,0,222,224,5,101,0,0,223,212,1,0,0,0,223,214,1,0, + 0,0,224,225,1,0,0,0,225,226,6,4,0,0,226,10,1,0,0,0,227,228,5,92, + 0,0,228,241,5,59,0,0,229,230,5,92,0,0,230,231,5,116,0,0,231,232, + 5,104,0,0,232,233,5,105,0,0,233,234,5,99,0,0,234,235,5,107,0,0,235, + 236,5,115,0,0,236,237,5,112,0,0,237,238,5,97,0,0,238,239,5,99,0, + 0,239,241,5,101,0,0,240,227,1,0,0,0,240,229,1,0,0,0,241,242,1,0, + 0,0,242,243,6,5,0,0,243,12,1,0,0,0,244,245,5,92,0,0,245,246,5,113, + 0,0,246,247,5,117,0,0,247,248,5,97,0,0,248,249,5,100,0,0,249,250, + 1,0,0,0,250,251,6,6,0,0,251,14,1,0,0,0,252,253,5,92,0,0,253,254, + 5,113,0,0,254,255,5,113,0,0,255,256,5,117,0,0,256,257,5,97,0,0,257, + 258,5,100,0,0,258,259,1,0,0,0,259,260,6,7,0,0,260,16,1,0,0,0,261, + 262,5,92,0,0,262,277,5,33,0,0,263,264,5,92,0,0,264,265,5,110,0,0, + 265,266,5,101,0,0,266,267,5,103,0,0,267,268,5,116,0,0,268,269,5, + 104,0,0,269,270,5,105,0,0,270,271,5,110,0,0,271,272,5,115,0,0,272, + 273,5,112,0,0,273,274,5,97,0,0,274,275,5,99,0,0,275,277,5,101,0, + 0,276,261,1,0,0,0,276,263,1,0,0,0,277,278,1,0,0,0,278,279,6,8,0, + 0,279,18,1,0,0,0,280,281,5,92,0,0,281,282,5,110,0,0,282,283,5,101, + 0,0,283,284,5,103,0,0,284,285,5,109,0,0,285,286,5,101,0,0,286,287, + 5,100,0,0,287,288,5,115,0,0,288,289,5,112,0,0,289,290,5,97,0,0,290, + 291,5,99,0,0,291,292,5,101,0,0,292,293,1,0,0,0,293,294,6,9,0,0,294, + 20,1,0,0,0,295,296,5,92,0,0,296,297,5,110,0,0,297,298,5,101,0,0, + 298,299,5,103,0,0,299,300,5,116,0,0,300,301,5,104,0,0,301,302,5, + 105,0,0,302,303,5,99,0,0,303,304,5,107,0,0,304,305,5,115,0,0,305, + 306,5,112,0,0,306,307,5,97,0,0,307,308,5,99,0,0,308,309,5,101,0, + 0,309,310,1,0,0,0,310,311,6,10,0,0,311,22,1,0,0,0,312,313,5,92,0, + 0,313,314,5,108,0,0,314,315,5,101,0,0,315,316,5,102,0,0,316,317, + 5,116,0,0,317,318,1,0,0,0,318,319,6,11,0,0,319,24,1,0,0,0,320,321, + 5,92,0,0,321,322,5,114,0,0,322,323,5,105,0,0,323,324,5,103,0,0,324, + 325,5,104,0,0,325,326,5,116,0,0,326,327,1,0,0,0,327,328,6,12,0,0, + 328,26,1,0,0,0,329,330,5,92,0,0,330,331,5,118,0,0,331,332,5,114, + 0,0,332,333,5,117,0,0,333,334,5,108,0,0,334,381,5,101,0,0,335,336, + 5,92,0,0,336,337,5,118,0,0,337,338,5,99,0,0,338,339,5,101,0,0,339, + 340,5,110,0,0,340,341,5,116,0,0,341,342,5,101,0,0,342,381,5,114, + 0,0,343,344,5,92,0,0,344,345,5,118,0,0,345,346,5,98,0,0,346,347, + 5,111,0,0,347,381,5,120,0,0,348,349,5,92,0,0,349,350,5,118,0,0,350, + 351,5,115,0,0,351,352,5,107,0,0,352,353,5,105,0,0,353,381,5,112, + 0,0,354,355,5,92,0,0,355,356,5,118,0,0,356,357,5,115,0,0,357,358, + 5,112,0,0,358,359,5,97,0,0,359,360,5,99,0,0,360,381,5,101,0,0,361, + 362,5,92,0,0,362,363,5,104,0,0,363,364,5,102,0,0,364,365,5,105,0, + 0,365,381,5,108,0,0,366,367,5,92,0,0,367,381,5,42,0,0,368,369,5, + 92,0,0,369,381,5,45,0,0,370,371,5,92,0,0,371,381,5,46,0,0,372,373, + 5,92,0,0,373,381,5,47,0,0,374,375,5,92,0,0,375,381,5,34,0,0,376, + 377,5,92,0,0,377,381,5,40,0,0,378,379,5,92,0,0,379,381,5,61,0,0, + 380,329,1,0,0,0,380,335,1,0,0,0,380,343,1,0,0,0,380,348,1,0,0,0, + 380,354,1,0,0,0,380,361,1,0,0,0,380,366,1,0,0,0,380,368,1,0,0,0, + 380,370,1,0,0,0,380,372,1,0,0,0,380,374,1,0,0,0,380,376,1,0,0,0, + 380,378,1,0,0,0,381,382,1,0,0,0,382,383,6,13,0,0,383,28,1,0,0,0, + 384,385,5,43,0,0,385,30,1,0,0,0,386,387,5,45,0,0,387,32,1,0,0,0, + 388,389,5,42,0,0,389,34,1,0,0,0,390,391,5,47,0,0,391,36,1,0,0,0, + 392,393,5,40,0,0,393,38,1,0,0,0,394,395,5,41,0,0,395,40,1,0,0,0, + 396,397,5,123,0,0,397,42,1,0,0,0,398,399,5,125,0,0,399,44,1,0,0, + 0,400,401,5,92,0,0,401,402,5,123,0,0,402,46,1,0,0,0,403,404,5,92, + 0,0,404,405,5,125,0,0,405,48,1,0,0,0,406,407,5,91,0,0,407,50,1,0, + 0,0,408,409,5,93,0,0,409,52,1,0,0,0,410,411,5,124,0,0,411,54,1,0, + 0,0,412,413,5,92,0,0,413,414,5,114,0,0,414,415,5,105,0,0,415,416, + 5,103,0,0,416,417,5,104,0,0,417,418,5,116,0,0,418,419,5,124,0,0, + 419,56,1,0,0,0,420,421,5,92,0,0,421,422,5,108,0,0,422,423,5,101, + 0,0,423,424,5,102,0,0,424,425,5,116,0,0,425,426,5,124,0,0,426,58, + 1,0,0,0,427,428,5,92,0,0,428,429,5,108,0,0,429,430,5,97,0,0,430, + 431,5,110,0,0,431,432,5,103,0,0,432,433,5,108,0,0,433,434,5,101, + 0,0,434,60,1,0,0,0,435,436,5,92,0,0,436,437,5,114,0,0,437,438,5, + 97,0,0,438,439,5,110,0,0,439,440,5,103,0,0,440,441,5,108,0,0,441, + 442,5,101,0,0,442,62,1,0,0,0,443,444,5,92,0,0,444,445,5,108,0,0, + 445,446,5,105,0,0,446,447,5,109,0,0,447,64,1,0,0,0,448,449,5,92, + 0,0,449,450,5,116,0,0,450,504,5,111,0,0,451,452,5,92,0,0,452,453, + 5,114,0,0,453,454,5,105,0,0,454,455,5,103,0,0,455,456,5,104,0,0, + 456,457,5,116,0,0,457,458,5,97,0,0,458,459,5,114,0,0,459,460,5,114, + 0,0,460,461,5,111,0,0,461,504,5,119,0,0,462,463,5,92,0,0,463,464, + 5,82,0,0,464,465,5,105,0,0,465,466,5,103,0,0,466,467,5,104,0,0,467, + 468,5,116,0,0,468,469,5,97,0,0,469,470,5,114,0,0,470,471,5,114,0, + 0,471,472,5,111,0,0,472,504,5,119,0,0,473,474,5,92,0,0,474,475,5, + 108,0,0,475,476,5,111,0,0,476,477,5,110,0,0,477,478,5,103,0,0,478, + 479,5,114,0,0,479,480,5,105,0,0,480,481,5,103,0,0,481,482,5,104, + 0,0,482,483,5,116,0,0,483,484,5,97,0,0,484,485,5,114,0,0,485,486, + 5,114,0,0,486,487,5,111,0,0,487,504,5,119,0,0,488,489,5,92,0,0,489, + 490,5,76,0,0,490,491,5,111,0,0,491,492,5,110,0,0,492,493,5,103,0, + 0,493,494,5,114,0,0,494,495,5,105,0,0,495,496,5,103,0,0,496,497, + 5,104,0,0,497,498,5,116,0,0,498,499,5,97,0,0,499,500,5,114,0,0,500, + 501,5,114,0,0,501,502,5,111,0,0,502,504,5,119,0,0,503,448,1,0,0, + 0,503,451,1,0,0,0,503,462,1,0,0,0,503,473,1,0,0,0,503,488,1,0,0, + 0,504,66,1,0,0,0,505,506,5,92,0,0,506,507,5,105,0,0,507,508,5,110, + 0,0,508,521,5,116,0,0,509,510,5,92,0,0,510,511,5,105,0,0,511,512, + 5,110,0,0,512,513,5,116,0,0,513,514,5,92,0,0,514,515,5,108,0,0,515, + 516,5,105,0,0,516,517,5,109,0,0,517,518,5,105,0,0,518,519,5,116, + 0,0,519,521,5,115,0,0,520,505,1,0,0,0,520,509,1,0,0,0,521,68,1,0, + 0,0,522,523,5,92,0,0,523,524,5,115,0,0,524,525,5,117,0,0,525,526, + 5,109,0,0,526,70,1,0,0,0,527,528,5,92,0,0,528,529,5,112,0,0,529, + 530,5,114,0,0,530,531,5,111,0,0,531,532,5,100,0,0,532,72,1,0,0,0, + 533,534,5,92,0,0,534,535,5,101,0,0,535,536,5,120,0,0,536,537,5,112, + 0,0,537,74,1,0,0,0,538,539,5,92,0,0,539,540,5,108,0,0,540,541,5, + 111,0,0,541,542,5,103,0,0,542,76,1,0,0,0,543,544,5,92,0,0,544,545, + 5,108,0,0,545,546,5,103,0,0,546,78,1,0,0,0,547,548,5,92,0,0,548, + 549,5,108,0,0,549,550,5,110,0,0,550,80,1,0,0,0,551,552,5,92,0,0, + 552,553,5,115,0,0,553,554,5,105,0,0,554,555,5,110,0,0,555,82,1,0, + 0,0,556,557,5,92,0,0,557,558,5,99,0,0,558,559,5,111,0,0,559,560, + 5,115,0,0,560,84,1,0,0,0,561,562,5,92,0,0,562,563,5,116,0,0,563, + 564,5,97,0,0,564,565,5,110,0,0,565,86,1,0,0,0,566,567,5,92,0,0,567, + 568,5,99,0,0,568,569,5,115,0,0,569,570,5,99,0,0,570,88,1,0,0,0,571, + 572,5,92,0,0,572,573,5,115,0,0,573,574,5,101,0,0,574,575,5,99,0, + 0,575,90,1,0,0,0,576,577,5,92,0,0,577,578,5,99,0,0,578,579,5,111, + 0,0,579,580,5,116,0,0,580,92,1,0,0,0,581,582,5,92,0,0,582,583,5, + 97,0,0,583,584,5,114,0,0,584,585,5,99,0,0,585,586,5,115,0,0,586, + 587,5,105,0,0,587,588,5,110,0,0,588,94,1,0,0,0,589,590,5,92,0,0, + 590,591,5,97,0,0,591,592,5,114,0,0,592,593,5,99,0,0,593,594,5,99, + 0,0,594,595,5,111,0,0,595,596,5,115,0,0,596,96,1,0,0,0,597,598,5, + 92,0,0,598,599,5,97,0,0,599,600,5,114,0,0,600,601,5,99,0,0,601,602, + 5,116,0,0,602,603,5,97,0,0,603,604,5,110,0,0,604,98,1,0,0,0,605, + 606,5,92,0,0,606,607,5,97,0,0,607,608,5,114,0,0,608,609,5,99,0,0, + 609,610,5,99,0,0,610,611,5,115,0,0,611,612,5,99,0,0,612,100,1,0, + 0,0,613,614,5,92,0,0,614,615,5,97,0,0,615,616,5,114,0,0,616,617, + 5,99,0,0,617,618,5,115,0,0,618,619,5,101,0,0,619,620,5,99,0,0,620, + 102,1,0,0,0,621,622,5,92,0,0,622,623,5,97,0,0,623,624,5,114,0,0, + 624,625,5,99,0,0,625,626,5,99,0,0,626,627,5,111,0,0,627,628,5,116, + 0,0,628,104,1,0,0,0,629,630,5,92,0,0,630,631,5,115,0,0,631,632,5, + 105,0,0,632,633,5,110,0,0,633,634,5,104,0,0,634,106,1,0,0,0,635, + 636,5,92,0,0,636,637,5,99,0,0,637,638,5,111,0,0,638,639,5,115,0, + 0,639,640,5,104,0,0,640,108,1,0,0,0,641,642,5,92,0,0,642,643,5,116, + 0,0,643,644,5,97,0,0,644,645,5,110,0,0,645,646,5,104,0,0,646,110, + 1,0,0,0,647,648,5,92,0,0,648,649,5,97,0,0,649,650,5,114,0,0,650, + 651,5,115,0,0,651,652,5,105,0,0,652,653,5,110,0,0,653,654,5,104, + 0,0,654,112,1,0,0,0,655,656,5,92,0,0,656,657,5,97,0,0,657,658,5, + 114,0,0,658,659,5,99,0,0,659,660,5,111,0,0,660,661,5,115,0,0,661, + 662,5,104,0,0,662,114,1,0,0,0,663,664,5,92,0,0,664,665,5,97,0,0, + 665,666,5,114,0,0,666,667,5,116,0,0,667,668,5,97,0,0,668,669,5,110, + 0,0,669,670,5,104,0,0,670,116,1,0,0,0,671,672,5,92,0,0,672,673,5, + 108,0,0,673,674,5,102,0,0,674,675,5,108,0,0,675,676,5,111,0,0,676, + 677,5,111,0,0,677,678,5,114,0,0,678,118,1,0,0,0,679,680,5,92,0,0, + 680,681,5,114,0,0,681,682,5,102,0,0,682,683,5,108,0,0,683,684,5, + 111,0,0,684,685,5,111,0,0,685,686,5,114,0,0,686,120,1,0,0,0,687, + 688,5,92,0,0,688,689,5,108,0,0,689,690,5,99,0,0,690,691,5,101,0, + 0,691,692,5,105,0,0,692,693,5,108,0,0,693,122,1,0,0,0,694,695,5, + 92,0,0,695,696,5,114,0,0,696,697,5,99,0,0,697,698,5,101,0,0,698, + 699,5,105,0,0,699,700,5,108,0,0,700,124,1,0,0,0,701,702,5,92,0,0, + 702,703,5,115,0,0,703,704,5,113,0,0,704,705,5,114,0,0,705,706,5, + 116,0,0,706,126,1,0,0,0,707,708,5,92,0,0,708,709,5,111,0,0,709,710, + 5,118,0,0,710,711,5,101,0,0,711,712,5,114,0,0,712,713,5,108,0,0, + 713,714,5,105,0,0,714,715,5,110,0,0,715,716,5,101,0,0,716,128,1, + 0,0,0,717,718,5,92,0,0,718,719,5,116,0,0,719,720,5,105,0,0,720,721, + 5,109,0,0,721,722,5,101,0,0,722,723,5,115,0,0,723,130,1,0,0,0,724, + 725,5,92,0,0,725,726,5,99,0,0,726,727,5,100,0,0,727,728,5,111,0, + 0,728,729,5,116,0,0,729,132,1,0,0,0,730,731,5,92,0,0,731,732,5,100, + 0,0,732,733,5,105,0,0,733,734,5,118,0,0,734,134,1,0,0,0,735,736, + 5,92,0,0,736,737,5,102,0,0,737,738,5,114,0,0,738,739,5,97,0,0,739, + 753,5,99,0,0,740,741,5,92,0,0,741,742,5,100,0,0,742,743,5,102,0, + 0,743,744,5,114,0,0,744,745,5,97,0,0,745,753,5,99,0,0,746,747,5, + 92,0,0,747,748,5,116,0,0,748,749,5,102,0,0,749,750,5,114,0,0,750, + 751,5,97,0,0,751,753,5,99,0,0,752,735,1,0,0,0,752,740,1,0,0,0,752, + 746,1,0,0,0,753,136,1,0,0,0,754,755,5,92,0,0,755,756,5,98,0,0,756, + 757,5,105,0,0,757,758,5,110,0,0,758,759,5,111,0,0,759,760,5,109, + 0,0,760,138,1,0,0,0,761,762,5,92,0,0,762,763,5,100,0,0,763,764,5, + 98,0,0,764,765,5,105,0,0,765,766,5,110,0,0,766,767,5,111,0,0,767, + 768,5,109,0,0,768,140,1,0,0,0,769,770,5,92,0,0,770,771,5,116,0,0, + 771,772,5,98,0,0,772,773,5,105,0,0,773,774,5,110,0,0,774,775,5,111, + 0,0,775,776,5,109,0,0,776,142,1,0,0,0,777,778,5,92,0,0,778,779,5, + 109,0,0,779,780,5,97,0,0,780,781,5,116,0,0,781,782,5,104,0,0,782, + 783,5,105,0,0,783,784,5,116,0,0,784,144,1,0,0,0,785,786,5,95,0,0, + 786,146,1,0,0,0,787,788,5,94,0,0,788,148,1,0,0,0,789,790,5,58,0, + 0,790,150,1,0,0,0,791,792,7,0,0,0,792,152,1,0,0,0,793,797,5,100, + 0,0,794,796,3,151,75,0,795,794,1,0,0,0,796,799,1,0,0,0,797,798,1, + 0,0,0,797,795,1,0,0,0,798,807,1,0,0,0,799,797,1,0,0,0,800,808,7, + 1,0,0,801,803,5,92,0,0,802,804,7,1,0,0,803,802,1,0,0,0,804,805,1, + 0,0,0,805,803,1,0,0,0,805,806,1,0,0,0,806,808,1,0,0,0,807,800,1, + 0,0,0,807,801,1,0,0,0,808,154,1,0,0,0,809,810,7,1,0,0,810,156,1, + 0,0,0,811,812,7,2,0,0,812,158,1,0,0,0,813,817,5,38,0,0,814,816,3, + 151,75,0,815,814,1,0,0,0,816,819,1,0,0,0,817,818,1,0,0,0,817,815, + 1,0,0,0,818,821,1,0,0,0,819,817,1,0,0,0,820,813,1,0,0,0,820,821, + 1,0,0,0,821,822,1,0,0,0,822,834,5,61,0,0,823,831,5,61,0,0,824,826, + 3,151,75,0,825,824,1,0,0,0,826,829,1,0,0,0,827,828,1,0,0,0,827,825, + 1,0,0,0,828,830,1,0,0,0,829,827,1,0,0,0,830,832,5,38,0,0,831,827, + 1,0,0,0,831,832,1,0,0,0,832,834,1,0,0,0,833,820,1,0,0,0,833,823, + 1,0,0,0,834,160,1,0,0,0,835,836,5,92,0,0,836,837,5,110,0,0,837,838, + 5,101,0,0,838,839,5,113,0,0,839,162,1,0,0,0,840,841,5,60,0,0,841, + 164,1,0,0,0,842,843,5,92,0,0,843,844,5,108,0,0,844,845,5,101,0,0, + 845,852,5,113,0,0,846,847,5,92,0,0,847,848,5,108,0,0,848,852,5,101, + 0,0,849,852,3,167,83,0,850,852,3,169,84,0,851,842,1,0,0,0,851,846, + 1,0,0,0,851,849,1,0,0,0,851,850,1,0,0,0,852,166,1,0,0,0,853,854, + 5,92,0,0,854,855,5,108,0,0,855,856,5,101,0,0,856,857,5,113,0,0,857, + 858,5,113,0,0,858,168,1,0,0,0,859,860,5,92,0,0,860,861,5,108,0,0, + 861,862,5,101,0,0,862,863,5,113,0,0,863,864,5,115,0,0,864,865,5, + 108,0,0,865,866,5,97,0,0,866,867,5,110,0,0,867,868,5,116,0,0,868, + 170,1,0,0,0,869,870,5,62,0,0,870,172,1,0,0,0,871,872,5,92,0,0,872, + 873,5,103,0,0,873,874,5,101,0,0,874,881,5,113,0,0,875,876,5,92,0, + 0,876,877,5,103,0,0,877,881,5,101,0,0,878,881,3,175,87,0,879,881, + 3,177,88,0,880,871,1,0,0,0,880,875,1,0,0,0,880,878,1,0,0,0,880,879, + 1,0,0,0,881,174,1,0,0,0,882,883,5,92,0,0,883,884,5,103,0,0,884,885, + 5,101,0,0,885,886,5,113,0,0,886,887,5,113,0,0,887,176,1,0,0,0,888, + 889,5,92,0,0,889,890,5,103,0,0,890,891,5,101,0,0,891,892,5,113,0, + 0,892,893,5,115,0,0,893,894,5,108,0,0,894,895,5,97,0,0,895,896,5, + 110,0,0,896,897,5,116,0,0,897,178,1,0,0,0,898,899,5,33,0,0,899,180, + 1,0,0,0,900,902,5,39,0,0,901,900,1,0,0,0,902,903,1,0,0,0,903,901, + 1,0,0,0,903,904,1,0,0,0,904,182,1,0,0,0,905,907,5,92,0,0,906,908, + 7,1,0,0,907,906,1,0,0,0,908,909,1,0,0,0,909,907,1,0,0,0,909,910, + 1,0,0,0,910,184,1,0,0,0,22,0,192,208,223,240,276,380,503,520,752, + 797,805,807,817,820,827,831,833,851,880,903,909,1,6,0,0 + ] + +class LaTeXLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + WS = 3 + THINSPACE = 4 + MEDSPACE = 5 + THICKSPACE = 6 + QUAD = 7 + QQUAD = 8 + NEGTHINSPACE = 9 + NEGMEDSPACE = 10 + NEGTHICKSPACE = 11 + CMD_LEFT = 12 + CMD_RIGHT = 13 + IGNORE = 14 + ADD = 15 + SUB = 16 + MUL = 17 + DIV = 18 + L_PAREN = 19 + R_PAREN = 20 + L_BRACE = 21 + R_BRACE = 22 + L_BRACE_LITERAL = 23 + R_BRACE_LITERAL = 24 + L_BRACKET = 25 + R_BRACKET = 26 + BAR = 27 + R_BAR = 28 + L_BAR = 29 + L_ANGLE = 30 + R_ANGLE = 31 + FUNC_LIM = 32 + LIM_APPROACH_SYM = 33 + FUNC_INT = 34 + FUNC_SUM = 35 + FUNC_PROD = 36 + FUNC_EXP = 37 + FUNC_LOG = 38 + FUNC_LG = 39 + FUNC_LN = 40 + FUNC_SIN = 41 + FUNC_COS = 42 + FUNC_TAN = 43 + FUNC_CSC = 44 + FUNC_SEC = 45 + FUNC_COT = 46 + FUNC_ARCSIN = 47 + FUNC_ARCCOS = 48 + FUNC_ARCTAN = 49 + FUNC_ARCCSC = 50 + FUNC_ARCSEC = 51 + FUNC_ARCCOT = 52 + FUNC_SINH = 53 + FUNC_COSH = 54 + FUNC_TANH = 55 + FUNC_ARSINH = 56 + FUNC_ARCOSH = 57 + FUNC_ARTANH = 58 + L_FLOOR = 59 + R_FLOOR = 60 + L_CEIL = 61 + R_CEIL = 62 + FUNC_SQRT = 63 + FUNC_OVERLINE = 64 + CMD_TIMES = 65 + CMD_CDOT = 66 + CMD_DIV = 67 + CMD_FRAC = 68 + CMD_BINOM = 69 + CMD_DBINOM = 70 + CMD_TBINOM = 71 + CMD_MATHIT = 72 + UNDERSCORE = 73 + CARET = 74 + COLON = 75 + DIFFERENTIAL = 76 + LETTER = 77 + DIGIT = 78 + EQUAL = 79 + NEQ = 80 + LT = 81 + LTE = 82 + LTE_Q = 83 + LTE_S = 84 + GT = 85 + GTE = 86 + GTE_Q = 87 + GTE_S = 88 + BANG = 89 + SINGLE_QUOTES = 90 + SYMBOL = 91 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "','", "'.'", "'\\quad'", "'\\qquad'", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "'+'", "'-'", "'*'", "'/'", "'('", + "')'", "'{'", "'}'", "'\\{'", "'\\}'", "'['", "']'", "'|'", + "'\\right|'", "'\\left|'", "'\\langle'", "'\\rangle'", "'\\lim'", + "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", "'\\lg'", "'\\ln'", + "'\\sin'", "'\\cos'", "'\\tan'", "'\\csc'", "'\\sec'", "'\\cot'", + "'\\arcsin'", "'\\arccos'", "'\\arctan'", "'\\arccsc'", "'\\arcsec'", + "'\\arccot'", "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", + "'\\arcosh'", "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", "'\\cdot'", + "'\\div'", "'\\binom'", "'\\dbinom'", "'\\tbinom'", "'\\mathit'", + "'_'", "'^'", "':'", "'\\neq'", "'<'", "'\\leqq'", "'\\leqslant'", + "'>'", "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", + "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", + "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", + "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", + "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", "L_ANGLE", + "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", "FUNC_INT", "FUNC_SUM", + "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", "FUNC_LG", "FUNC_LN", "FUNC_SIN", + "FUNC_COS", "FUNC_TAN", "FUNC_CSC", "FUNC_SEC", "FUNC_COT", + "FUNC_ARCSIN", "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", + "FUNC_ARCSEC", "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", + "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", + "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", + "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", "LTE", + "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + ruleNames = [ "T__0", "T__1", "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", + "QUAD", "QQUAD", "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", + "CMD_LEFT", "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", + "DIV", "L_PAREN", "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", + "R_BRACE_LITERAL", "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", + "L_BAR", "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", "FUNC_ARCCOS", + "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", "FUNC_ARCCOT", + "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", "FUNC_ARSINH", + "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", "L_CEIL", + "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", "CMD_CDOT", + "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", "CMD_TBINOM", + "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", "WS_CHAR", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", + "LTE", "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", + "BANG", "SINGLE_QUOTES", "SYMBOL" ] + + grammarFileName = "LaTeX.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexparser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexparser.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f58119055ded8f77380bbef52c77ddd6a01cfe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_antlr/latexparser.py @@ -0,0 +1,3652 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,91,522,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33, + 7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39, + 2,40,7,40,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,5,1,91,8,1,10,1,12,1,94, + 9,1,1,2,1,2,1,2,1,2,1,3,1,3,1,4,1,4,1,4,1,4,1,4,1,4,5,4,108,8,4, + 10,4,12,4,111,9,4,1,5,1,5,1,5,1,5,1,5,1,5,5,5,119,8,5,10,5,12,5, + 122,9,5,1,6,1,6,1,6,1,6,1,6,1,6,5,6,130,8,6,10,6,12,6,133,9,6,1, + 7,1,7,1,7,4,7,138,8,7,11,7,12,7,139,3,7,142,8,7,1,8,1,8,1,8,1,8, + 5,8,148,8,8,10,8,12,8,151,9,8,3,8,153,8,8,1,9,1,9,5,9,157,8,9,10, + 9,12,9,160,9,9,1,10,1,10,5,10,164,8,10,10,10,12,10,167,9,10,1,11, + 1,11,3,11,171,8,11,1,12,1,12,1,12,1,12,1,12,1,12,3,12,179,8,12,1, + 13,1,13,1,13,1,13,3,13,185,8,13,1,13,1,13,1,14,1,14,1,14,1,14,3, + 14,193,8,14,1,14,1,14,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1, + 15,1,15,3,15,207,8,15,1,15,3,15,210,8,15,5,15,212,8,15,10,15,12, + 15,215,9,15,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,3, + 16,227,8,16,1,16,3,16,230,8,16,5,16,232,8,16,10,16,12,16,235,9,16, + 1,17,1,17,1,17,1,17,1,17,1,17,3,17,243,8,17,1,18,1,18,1,18,1,18, + 1,18,3,18,250,8,18,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19, + 1,19,1,19,1,19,1,19,1,19,1,19,1,19,3,19,268,8,19,1,20,1,20,1,20, + 1,20,1,21,4,21,275,8,21,11,21,12,21,276,1,21,1,21,1,21,1,21,5,21, + 283,8,21,10,21,12,21,286,9,21,1,21,1,21,4,21,290,8,21,11,21,12,21, + 291,3,21,294,8,21,1,22,1,22,3,22,298,8,22,1,22,3,22,301,8,22,1,22, + 3,22,304,8,22,1,22,3,22,307,8,22,3,22,309,8,22,1,22,1,22,1,22,1, + 22,1,22,1,22,1,22,3,22,318,8,22,1,23,1,23,1,23,1,23,1,24,1,24,1, + 24,1,24,1,25,1,25,1,25,1,25,1,25,1,26,5,26,334,8,26,10,26,12,26, + 337,9,26,1,27,1,27,1,27,1,27,1,27,1,27,3,27,345,8,27,1,27,1,27,1, + 27,1,27,1,27,3,27,352,8,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1, + 28,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,31,1,31,1,32,1,32,3, + 32,374,8,32,1,32,3,32,377,8,32,1,32,3,32,380,8,32,1,32,3,32,383, + 8,32,3,32,385,8,32,1,32,1,32,1,32,1,32,1,32,3,32,392,8,32,1,32,1, + 32,3,32,396,8,32,1,32,3,32,399,8,32,1,32,3,32,402,8,32,1,32,3,32, + 405,8,32,3,32,407,8,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,32,3,32,420,8,32,1,32,3,32,423,8,32,1,32,1,32,1,32,3,32, + 428,8,32,1,32,1,32,1,32,1,32,1,32,3,32,435,8,32,1,32,1,32,1,32,1, + 32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,3, + 32,453,8,32,1,32,1,32,1,32,1,32,1,32,1,32,3,32,461,8,32,1,33,1,33, + 1,33,1,33,1,33,3,33,468,8,33,1,34,1,34,1,34,1,34,1,34,1,34,1,34, + 1,34,1,34,1,34,1,34,3,34,481,8,34,3,34,483,8,34,1,34,1,34,1,35,1, + 35,1,35,1,35,1,35,3,35,492,8,35,1,36,1,36,1,37,1,37,1,37,1,37,1, + 37,1,37,3,37,502,8,37,1,38,1,38,1,38,1,38,1,38,1,38,3,38,510,8,38, + 1,39,1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,40,0,6,2,8,10, + 12,30,32,41,0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36, + 38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80, + 0,9,2,0,79,82,85,86,1,0,15,16,3,0,17,18,65,67,75,75,2,0,77,77,91, + 91,1,0,27,28,2,0,27,27,29,29,1,0,69,71,1,0,37,58,1,0,35,36,563,0, + 82,1,0,0,0,2,84,1,0,0,0,4,95,1,0,0,0,6,99,1,0,0,0,8,101,1,0,0,0, + 10,112,1,0,0,0,12,123,1,0,0,0,14,141,1,0,0,0,16,152,1,0,0,0,18,154, + 1,0,0,0,20,161,1,0,0,0,22,170,1,0,0,0,24,172,1,0,0,0,26,180,1,0, + 0,0,28,188,1,0,0,0,30,196,1,0,0,0,32,216,1,0,0,0,34,242,1,0,0,0, + 36,249,1,0,0,0,38,267,1,0,0,0,40,269,1,0,0,0,42,274,1,0,0,0,44,317, + 1,0,0,0,46,319,1,0,0,0,48,323,1,0,0,0,50,327,1,0,0,0,52,335,1,0, + 0,0,54,338,1,0,0,0,56,353,1,0,0,0,58,361,1,0,0,0,60,365,1,0,0,0, + 62,369,1,0,0,0,64,460,1,0,0,0,66,467,1,0,0,0,68,469,1,0,0,0,70,491, + 1,0,0,0,72,493,1,0,0,0,74,495,1,0,0,0,76,503,1,0,0,0,78,511,1,0, + 0,0,80,516,1,0,0,0,82,83,3,2,1,0,83,1,1,0,0,0,84,85,6,1,-1,0,85, + 86,3,6,3,0,86,92,1,0,0,0,87,88,10,2,0,0,88,89,7,0,0,0,89,91,3,2, + 1,3,90,87,1,0,0,0,91,94,1,0,0,0,92,90,1,0,0,0,92,93,1,0,0,0,93,3, + 1,0,0,0,94,92,1,0,0,0,95,96,3,6,3,0,96,97,5,79,0,0,97,98,3,6,3,0, + 98,5,1,0,0,0,99,100,3,8,4,0,100,7,1,0,0,0,101,102,6,4,-1,0,102,103, + 3,10,5,0,103,109,1,0,0,0,104,105,10,2,0,0,105,106,7,1,0,0,106,108, + 3,8,4,3,107,104,1,0,0,0,108,111,1,0,0,0,109,107,1,0,0,0,109,110, + 1,0,0,0,110,9,1,0,0,0,111,109,1,0,0,0,112,113,6,5,-1,0,113,114,3, + 14,7,0,114,120,1,0,0,0,115,116,10,2,0,0,116,117,7,2,0,0,117,119, + 3,10,5,3,118,115,1,0,0,0,119,122,1,0,0,0,120,118,1,0,0,0,120,121, + 1,0,0,0,121,11,1,0,0,0,122,120,1,0,0,0,123,124,6,6,-1,0,124,125, + 3,16,8,0,125,131,1,0,0,0,126,127,10,2,0,0,127,128,7,2,0,0,128,130, + 3,12,6,3,129,126,1,0,0,0,130,133,1,0,0,0,131,129,1,0,0,0,131,132, + 1,0,0,0,132,13,1,0,0,0,133,131,1,0,0,0,134,135,7,1,0,0,135,142,3, + 14,7,0,136,138,3,18,9,0,137,136,1,0,0,0,138,139,1,0,0,0,139,137, + 1,0,0,0,139,140,1,0,0,0,140,142,1,0,0,0,141,134,1,0,0,0,141,137, + 1,0,0,0,142,15,1,0,0,0,143,144,7,1,0,0,144,153,3,16,8,0,145,149, + 3,18,9,0,146,148,3,20,10,0,147,146,1,0,0,0,148,151,1,0,0,0,149,147, + 1,0,0,0,149,150,1,0,0,0,150,153,1,0,0,0,151,149,1,0,0,0,152,143, + 1,0,0,0,152,145,1,0,0,0,153,17,1,0,0,0,154,158,3,30,15,0,155,157, + 3,22,11,0,156,155,1,0,0,0,157,160,1,0,0,0,158,156,1,0,0,0,158,159, + 1,0,0,0,159,19,1,0,0,0,160,158,1,0,0,0,161,165,3,32,16,0,162,164, + 3,22,11,0,163,162,1,0,0,0,164,167,1,0,0,0,165,163,1,0,0,0,165,166, + 1,0,0,0,166,21,1,0,0,0,167,165,1,0,0,0,168,171,5,89,0,0,169,171, + 3,24,12,0,170,168,1,0,0,0,170,169,1,0,0,0,171,23,1,0,0,0,172,178, + 5,27,0,0,173,179,3,28,14,0,174,179,3,26,13,0,175,176,3,28,14,0,176, + 177,3,26,13,0,177,179,1,0,0,0,178,173,1,0,0,0,178,174,1,0,0,0,178, + 175,1,0,0,0,179,25,1,0,0,0,180,181,5,73,0,0,181,184,5,21,0,0,182, + 185,3,6,3,0,183,185,3,4,2,0,184,182,1,0,0,0,184,183,1,0,0,0,185, + 186,1,0,0,0,186,187,5,22,0,0,187,27,1,0,0,0,188,189,5,74,0,0,189, + 192,5,21,0,0,190,193,3,6,3,0,191,193,3,4,2,0,192,190,1,0,0,0,192, + 191,1,0,0,0,193,194,1,0,0,0,194,195,5,22,0,0,195,29,1,0,0,0,196, + 197,6,15,-1,0,197,198,3,34,17,0,198,213,1,0,0,0,199,200,10,2,0,0, + 200,206,5,74,0,0,201,207,3,44,22,0,202,203,5,21,0,0,203,204,3,6, + 3,0,204,205,5,22,0,0,205,207,1,0,0,0,206,201,1,0,0,0,206,202,1,0, + 0,0,207,209,1,0,0,0,208,210,3,74,37,0,209,208,1,0,0,0,209,210,1, + 0,0,0,210,212,1,0,0,0,211,199,1,0,0,0,212,215,1,0,0,0,213,211,1, + 0,0,0,213,214,1,0,0,0,214,31,1,0,0,0,215,213,1,0,0,0,216,217,6,16, + -1,0,217,218,3,36,18,0,218,233,1,0,0,0,219,220,10,2,0,0,220,226, + 5,74,0,0,221,227,3,44,22,0,222,223,5,21,0,0,223,224,3,6,3,0,224, + 225,5,22,0,0,225,227,1,0,0,0,226,221,1,0,0,0,226,222,1,0,0,0,227, + 229,1,0,0,0,228,230,3,74,37,0,229,228,1,0,0,0,229,230,1,0,0,0,230, + 232,1,0,0,0,231,219,1,0,0,0,232,235,1,0,0,0,233,231,1,0,0,0,233, + 234,1,0,0,0,234,33,1,0,0,0,235,233,1,0,0,0,236,243,3,38,19,0,237, + 243,3,40,20,0,238,243,3,64,32,0,239,243,3,44,22,0,240,243,3,58,29, + 0,241,243,3,60,30,0,242,236,1,0,0,0,242,237,1,0,0,0,242,238,1,0, + 0,0,242,239,1,0,0,0,242,240,1,0,0,0,242,241,1,0,0,0,243,35,1,0,0, + 0,244,250,3,38,19,0,245,250,3,40,20,0,246,250,3,44,22,0,247,250, + 3,58,29,0,248,250,3,60,30,0,249,244,1,0,0,0,249,245,1,0,0,0,249, + 246,1,0,0,0,249,247,1,0,0,0,249,248,1,0,0,0,250,37,1,0,0,0,251,252, + 5,19,0,0,252,253,3,6,3,0,253,254,5,20,0,0,254,268,1,0,0,0,255,256, + 5,25,0,0,256,257,3,6,3,0,257,258,5,26,0,0,258,268,1,0,0,0,259,260, + 5,21,0,0,260,261,3,6,3,0,261,262,5,22,0,0,262,268,1,0,0,0,263,264, + 5,23,0,0,264,265,3,6,3,0,265,266,5,24,0,0,266,268,1,0,0,0,267,251, + 1,0,0,0,267,255,1,0,0,0,267,259,1,0,0,0,267,263,1,0,0,0,268,39,1, + 0,0,0,269,270,5,27,0,0,270,271,3,6,3,0,271,272,5,27,0,0,272,41,1, + 0,0,0,273,275,5,78,0,0,274,273,1,0,0,0,275,276,1,0,0,0,276,274,1, + 0,0,0,276,277,1,0,0,0,277,284,1,0,0,0,278,279,5,1,0,0,279,280,5, + 78,0,0,280,281,5,78,0,0,281,283,5,78,0,0,282,278,1,0,0,0,283,286, + 1,0,0,0,284,282,1,0,0,0,284,285,1,0,0,0,285,293,1,0,0,0,286,284, + 1,0,0,0,287,289,5,2,0,0,288,290,5,78,0,0,289,288,1,0,0,0,290,291, + 1,0,0,0,291,289,1,0,0,0,291,292,1,0,0,0,292,294,1,0,0,0,293,287, + 1,0,0,0,293,294,1,0,0,0,294,43,1,0,0,0,295,308,7,3,0,0,296,298,3, + 74,37,0,297,296,1,0,0,0,297,298,1,0,0,0,298,300,1,0,0,0,299,301, + 5,90,0,0,300,299,1,0,0,0,300,301,1,0,0,0,301,309,1,0,0,0,302,304, + 5,90,0,0,303,302,1,0,0,0,303,304,1,0,0,0,304,306,1,0,0,0,305,307, + 3,74,37,0,306,305,1,0,0,0,306,307,1,0,0,0,307,309,1,0,0,0,308,297, + 1,0,0,0,308,303,1,0,0,0,309,318,1,0,0,0,310,318,3,42,21,0,311,318, + 5,76,0,0,312,318,3,50,25,0,313,318,3,54,27,0,314,318,3,56,28,0,315, + 318,3,46,23,0,316,318,3,48,24,0,317,295,1,0,0,0,317,310,1,0,0,0, + 317,311,1,0,0,0,317,312,1,0,0,0,317,313,1,0,0,0,317,314,1,0,0,0, + 317,315,1,0,0,0,317,316,1,0,0,0,318,45,1,0,0,0,319,320,5,30,0,0, + 320,321,3,6,3,0,321,322,7,4,0,0,322,47,1,0,0,0,323,324,7,5,0,0,324, + 325,3,6,3,0,325,326,5,31,0,0,326,49,1,0,0,0,327,328,5,72,0,0,328, + 329,5,21,0,0,329,330,3,52,26,0,330,331,5,22,0,0,331,51,1,0,0,0,332, + 334,5,77,0,0,333,332,1,0,0,0,334,337,1,0,0,0,335,333,1,0,0,0,335, + 336,1,0,0,0,336,53,1,0,0,0,337,335,1,0,0,0,338,344,5,68,0,0,339, + 345,5,78,0,0,340,341,5,21,0,0,341,342,3,6,3,0,342,343,5,22,0,0,343, + 345,1,0,0,0,344,339,1,0,0,0,344,340,1,0,0,0,345,351,1,0,0,0,346, + 352,5,78,0,0,347,348,5,21,0,0,348,349,3,6,3,0,349,350,5,22,0,0,350, + 352,1,0,0,0,351,346,1,0,0,0,351,347,1,0,0,0,352,55,1,0,0,0,353,354, + 7,6,0,0,354,355,5,21,0,0,355,356,3,6,3,0,356,357,5,22,0,0,357,358, + 5,21,0,0,358,359,3,6,3,0,359,360,5,22,0,0,360,57,1,0,0,0,361,362, + 5,59,0,0,362,363,3,6,3,0,363,364,5,60,0,0,364,59,1,0,0,0,365,366, + 5,61,0,0,366,367,3,6,3,0,367,368,5,62,0,0,368,61,1,0,0,0,369,370, + 7,7,0,0,370,63,1,0,0,0,371,384,3,62,31,0,372,374,3,74,37,0,373,372, + 1,0,0,0,373,374,1,0,0,0,374,376,1,0,0,0,375,377,3,76,38,0,376,375, + 1,0,0,0,376,377,1,0,0,0,377,385,1,0,0,0,378,380,3,76,38,0,379,378, + 1,0,0,0,379,380,1,0,0,0,380,382,1,0,0,0,381,383,3,74,37,0,382,381, + 1,0,0,0,382,383,1,0,0,0,383,385,1,0,0,0,384,373,1,0,0,0,384,379, + 1,0,0,0,385,391,1,0,0,0,386,387,5,19,0,0,387,388,3,70,35,0,388,389, + 5,20,0,0,389,392,1,0,0,0,390,392,3,72,36,0,391,386,1,0,0,0,391,390, + 1,0,0,0,392,461,1,0,0,0,393,406,7,3,0,0,394,396,3,74,37,0,395,394, + 1,0,0,0,395,396,1,0,0,0,396,398,1,0,0,0,397,399,5,90,0,0,398,397, + 1,0,0,0,398,399,1,0,0,0,399,407,1,0,0,0,400,402,5,90,0,0,401,400, + 1,0,0,0,401,402,1,0,0,0,402,404,1,0,0,0,403,405,3,74,37,0,404,403, + 1,0,0,0,404,405,1,0,0,0,405,407,1,0,0,0,406,395,1,0,0,0,406,401, + 1,0,0,0,407,408,1,0,0,0,408,409,5,19,0,0,409,410,3,66,33,0,410,411, + 5,20,0,0,411,461,1,0,0,0,412,419,5,34,0,0,413,414,3,74,37,0,414, + 415,3,76,38,0,415,420,1,0,0,0,416,417,3,76,38,0,417,418,3,74,37, + 0,418,420,1,0,0,0,419,413,1,0,0,0,419,416,1,0,0,0,419,420,1,0,0, + 0,420,427,1,0,0,0,421,423,3,8,4,0,422,421,1,0,0,0,422,423,1,0,0, + 0,423,424,1,0,0,0,424,428,5,76,0,0,425,428,3,54,27,0,426,428,3,8, + 4,0,427,422,1,0,0,0,427,425,1,0,0,0,427,426,1,0,0,0,428,461,1,0, + 0,0,429,434,5,63,0,0,430,431,5,25,0,0,431,432,3,6,3,0,432,433,5, + 26,0,0,433,435,1,0,0,0,434,430,1,0,0,0,434,435,1,0,0,0,435,436,1, + 0,0,0,436,437,5,21,0,0,437,438,3,6,3,0,438,439,5,22,0,0,439,461, + 1,0,0,0,440,441,5,64,0,0,441,442,5,21,0,0,442,443,3,6,3,0,443,444, + 5,22,0,0,444,461,1,0,0,0,445,452,7,8,0,0,446,447,3,78,39,0,447,448, + 3,76,38,0,448,453,1,0,0,0,449,450,3,76,38,0,450,451,3,78,39,0,451, + 453,1,0,0,0,452,446,1,0,0,0,452,449,1,0,0,0,453,454,1,0,0,0,454, + 455,3,10,5,0,455,461,1,0,0,0,456,457,5,32,0,0,457,458,3,68,34,0, + 458,459,3,10,5,0,459,461,1,0,0,0,460,371,1,0,0,0,460,393,1,0,0,0, + 460,412,1,0,0,0,460,429,1,0,0,0,460,440,1,0,0,0,460,445,1,0,0,0, + 460,456,1,0,0,0,461,65,1,0,0,0,462,463,3,6,3,0,463,464,5,1,0,0,464, + 465,3,66,33,0,465,468,1,0,0,0,466,468,3,6,3,0,467,462,1,0,0,0,467, + 466,1,0,0,0,468,67,1,0,0,0,469,470,5,73,0,0,470,471,5,21,0,0,471, + 472,7,3,0,0,472,473,5,33,0,0,473,482,3,6,3,0,474,480,5,74,0,0,475, + 476,5,21,0,0,476,477,7,1,0,0,477,481,5,22,0,0,478,481,5,15,0,0,479, + 481,5,16,0,0,480,475,1,0,0,0,480,478,1,0,0,0,480,479,1,0,0,0,481, + 483,1,0,0,0,482,474,1,0,0,0,482,483,1,0,0,0,483,484,1,0,0,0,484, + 485,5,22,0,0,485,69,1,0,0,0,486,492,3,6,3,0,487,488,3,6,3,0,488, + 489,5,1,0,0,489,490,3,70,35,0,490,492,1,0,0,0,491,486,1,0,0,0,491, + 487,1,0,0,0,492,71,1,0,0,0,493,494,3,12,6,0,494,73,1,0,0,0,495,501, + 5,73,0,0,496,502,3,44,22,0,497,498,5,21,0,0,498,499,3,6,3,0,499, + 500,5,22,0,0,500,502,1,0,0,0,501,496,1,0,0,0,501,497,1,0,0,0,502, + 75,1,0,0,0,503,509,5,74,0,0,504,510,3,44,22,0,505,506,5,21,0,0,506, + 507,3,6,3,0,507,508,5,22,0,0,508,510,1,0,0,0,509,504,1,0,0,0,509, + 505,1,0,0,0,510,77,1,0,0,0,511,512,5,73,0,0,512,513,5,21,0,0,513, + 514,3,4,2,0,514,515,5,22,0,0,515,79,1,0,0,0,516,517,5,73,0,0,517, + 518,5,21,0,0,518,519,3,4,2,0,519,520,5,22,0,0,520,81,1,0,0,0,59, + 92,109,120,131,139,141,149,152,158,165,170,178,184,192,206,209,213, + 226,229,233,242,249,267,276,284,291,293,297,300,303,306,308,317, + 335,344,351,373,376,379,382,384,391,395,398,401,404,406,419,422, + 427,434,452,460,467,480,482,491,501,509 + ] + +class LaTeXParser ( Parser ): + + grammarFileName = "LaTeX.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "','", "'.'", "", "", + "", "", "'\\quad'", "'\\qquad'", + "", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "", "'+'", "'-'", + "'*'", "'/'", "'('", "')'", "'{'", "'}'", "'\\{'", + "'\\}'", "'['", "']'", "'|'", "'\\right|'", "'\\left|'", + "'\\langle'", "'\\rangle'", "'\\lim'", "", + "", "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", + "'\\lg'", "'\\ln'", "'\\sin'", "'\\cos'", "'\\tan'", + "'\\csc'", "'\\sec'", "'\\cot'", "'\\arcsin'", "'\\arccos'", + "'\\arctan'", "'\\arccsc'", "'\\arcsec'", "'\\arccot'", + "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", "'\\arcosh'", + "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", + "'\\cdot'", "'\\div'", "", "'\\binom'", "'\\dbinom'", + "'\\tbinom'", "'\\mathit'", "'_'", "'^'", "':'", "", + "", "", "", "'\\neq'", "'<'", + "", "'\\leqq'", "'\\leqslant'", "'>'", "", + "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", "", "", "WS", "THINSPACE", + "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", "NEGTHINSPACE", + "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", "CMD_RIGHT", + "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", "R_PAREN", + "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", + "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", + "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", + "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", + "R_FLOOR", "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", + "CMD_TIMES", "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", + "CMD_DBINOM", "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", + "CARET", "COLON", "DIFFERENTIAL", "LETTER", "DIGIT", + "EQUAL", "NEQ", "LT", "LTE", "LTE_Q", "LTE_S", "GT", + "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + RULE_math = 0 + RULE_relation = 1 + RULE_equality = 2 + RULE_expr = 3 + RULE_additive = 4 + RULE_mp = 5 + RULE_mp_nofunc = 6 + RULE_unary = 7 + RULE_unary_nofunc = 8 + RULE_postfix = 9 + RULE_postfix_nofunc = 10 + RULE_postfix_op = 11 + RULE_eval_at = 12 + RULE_eval_at_sub = 13 + RULE_eval_at_sup = 14 + RULE_exp = 15 + RULE_exp_nofunc = 16 + RULE_comp = 17 + RULE_comp_nofunc = 18 + RULE_group = 19 + RULE_abs_group = 20 + RULE_number = 21 + RULE_atom = 22 + RULE_bra = 23 + RULE_ket = 24 + RULE_mathit = 25 + RULE_mathit_text = 26 + RULE_frac = 27 + RULE_binom = 28 + RULE_floor = 29 + RULE_ceil = 30 + RULE_func_normal = 31 + RULE_func = 32 + RULE_args = 33 + RULE_limit_sub = 34 + RULE_func_arg = 35 + RULE_func_arg_noparens = 36 + RULE_subexpr = 37 + RULE_supexpr = 38 + RULE_subeq = 39 + RULE_supeq = 40 + + ruleNames = [ "math", "relation", "equality", "expr", "additive", "mp", + "mp_nofunc", "unary", "unary_nofunc", "postfix", "postfix_nofunc", + "postfix_op", "eval_at", "eval_at_sub", "eval_at_sup", + "exp", "exp_nofunc", "comp", "comp_nofunc", "group", + "abs_group", "number", "atom", "bra", "ket", "mathit", + "mathit_text", "frac", "binom", "floor", "ceil", "func_normal", + "func", "args", "limit_sub", "func_arg", "func_arg_noparens", + "subexpr", "supexpr", "subeq", "supeq" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + WS=3 + THINSPACE=4 + MEDSPACE=5 + THICKSPACE=6 + QUAD=7 + QQUAD=8 + NEGTHINSPACE=9 + NEGMEDSPACE=10 + NEGTHICKSPACE=11 + CMD_LEFT=12 + CMD_RIGHT=13 + IGNORE=14 + ADD=15 + SUB=16 + MUL=17 + DIV=18 + L_PAREN=19 + R_PAREN=20 + L_BRACE=21 + R_BRACE=22 + L_BRACE_LITERAL=23 + R_BRACE_LITERAL=24 + L_BRACKET=25 + R_BRACKET=26 + BAR=27 + R_BAR=28 + L_BAR=29 + L_ANGLE=30 + R_ANGLE=31 + FUNC_LIM=32 + LIM_APPROACH_SYM=33 + FUNC_INT=34 + FUNC_SUM=35 + FUNC_PROD=36 + FUNC_EXP=37 + FUNC_LOG=38 + FUNC_LG=39 + FUNC_LN=40 + FUNC_SIN=41 + FUNC_COS=42 + FUNC_TAN=43 + FUNC_CSC=44 + FUNC_SEC=45 + FUNC_COT=46 + FUNC_ARCSIN=47 + FUNC_ARCCOS=48 + FUNC_ARCTAN=49 + FUNC_ARCCSC=50 + FUNC_ARCSEC=51 + FUNC_ARCCOT=52 + FUNC_SINH=53 + FUNC_COSH=54 + FUNC_TANH=55 + FUNC_ARSINH=56 + FUNC_ARCOSH=57 + FUNC_ARTANH=58 + L_FLOOR=59 + R_FLOOR=60 + L_CEIL=61 + R_CEIL=62 + FUNC_SQRT=63 + FUNC_OVERLINE=64 + CMD_TIMES=65 + CMD_CDOT=66 + CMD_DIV=67 + CMD_FRAC=68 + CMD_BINOM=69 + CMD_DBINOM=70 + CMD_TBINOM=71 + CMD_MATHIT=72 + UNDERSCORE=73 + CARET=74 + COLON=75 + DIFFERENTIAL=76 + LETTER=77 + DIGIT=78 + EQUAL=79 + NEQ=80 + LT=81 + LTE=82 + LTE_Q=83 + LTE_S=84 + GT=85 + GTE=86 + GTE_Q=87 + GTE_S=88 + BANG=89 + SINGLE_QUOTES=90 + SYMBOL=91 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class MathContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def relation(self): + return self.getTypedRuleContext(LaTeXParser.RelationContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_math + + + + + def math(self): + + localctx = LaTeXParser.MathContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_math) + try: + self.enterOuterAlt(localctx, 1) + self.state = 82 + self.relation(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RelationContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def relation(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.RelationContext) + else: + return self.getTypedRuleContext(LaTeXParser.RelationContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def LT(self): + return self.getToken(LaTeXParser.LT, 0) + + def LTE(self): + return self.getToken(LaTeXParser.LTE, 0) + + def GT(self): + return self.getToken(LaTeXParser.GT, 0) + + def GTE(self): + return self.getToken(LaTeXParser.GTE, 0) + + def NEQ(self): + return self.getToken(LaTeXParser.NEQ, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_relation + + + + def relation(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.RelationContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 2 + self.enterRecursionRule(localctx, 2, self.RULE_relation, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 85 + self.expr() + self._ctx.stop = self._input.LT(-1) + self.state = 92 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.RelationContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_relation) + self.state = 87 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 88 + _la = self._input.LA(1) + if not((((_la - 79)) & ~0x3f) == 0 and ((1 << (_la - 79)) & 207) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 89 + self.relation(3) + self.state = 94 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class EqualityContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_equality + + + + + def equality(self): + + localctx = LaTeXParser.EqualityContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_equality) + try: + self.enterOuterAlt(localctx, 1) + self.state = 95 + self.expr() + self.state = 96 + self.match(LaTeXParser.EQUAL) + self.state = 97 + self.expr() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_expr + + + + + def expr(self): + + localctx = LaTeXParser.ExprContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_expr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 99 + self.additive(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AdditiveContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def additive(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.AdditiveContext) + else: + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,i) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_additive + + + + def additive(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.AdditiveContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 8 + self.enterRecursionRule(localctx, 8, self.RULE_additive, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 102 + self.mp(0) + self._ctx.stop = self._input.LT(-1) + self.state = 109 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.AdditiveContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_additive) + self.state = 104 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 105 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 106 + self.additive(3) + self.state = 111 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class MpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def mp(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.MpContext) + else: + return self.getTypedRuleContext(LaTeXParser.MpContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp + + + + def mp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.MpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 10 + self.enterRecursionRule(localctx, 10, self.RULE_mp, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 113 + self.unary() + self._ctx.stop = self._input.LT(-1) + self.state = 120 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.MpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp) + self.state = 115 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 116 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 117 + self.mp(3) + self.state = 122 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Mp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def mp_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Mp_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp_nofunc + + + + def mp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Mp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_mp_nofunc, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 124 + self.unary_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 131 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Mp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp_nofunc) + self.state = 126 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 127 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 128 + self.mp_nofunc(3) + self.state = 133 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class UnaryContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.PostfixContext) + else: + return self.getTypedRuleContext(LaTeXParser.PostfixContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary + + + + + def unary(self): + + localctx = LaTeXParser.UnaryContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_unary) + self._la = 0 # Token type + try: + self.state = 141 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 134 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.unary() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 137 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 136 + self.postfix() + + else: + raise NoViableAltException(self) + self.state = 139 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Unary_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self): + return self.getTypedRuleContext(LaTeXParser.PostfixContext,0) + + + def postfix_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_nofuncContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary_nofunc + + + + + def unary_nofunc(self): + + localctx = LaTeXParser.Unary_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_unary_nofunc) + self._la = 0 # Token type + try: + self.state = 152 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 143 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 144 + self.unary_nofunc() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 145 + self.postfix() + self.state = 149 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 146 + self.postfix_nofunc() + self.state = 151 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class PostfixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix + + + + + def postfix(self): + + localctx = LaTeXParser.PostfixContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_postfix) + try: + self.enterOuterAlt(localctx, 1) + self.state = 154 + self.exp(0) + self.state = 158 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 155 + self.postfix_op() + self.state = 160 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_nofunc + + + + + def postfix_nofunc(self): + + localctx = LaTeXParser.Postfix_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_postfix_nofunc) + try: + self.enterOuterAlt(localctx, 1) + self.state = 161 + self.exp_nofunc(0) + self.state = 165 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 162 + self.postfix_op() + self.state = 167 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_opContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BANG(self): + return self.getToken(LaTeXParser.BANG, 0) + + def eval_at(self): + return self.getTypedRuleContext(LaTeXParser.Eval_atContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_op + + + + + def postfix_op(self): + + localctx = LaTeXParser.Postfix_opContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_postfix_op) + try: + self.state = 170 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [89]: + self.enterOuterAlt(localctx, 1) + self.state = 168 + self.match(LaTeXParser.BANG) + pass + elif token in [27]: + self.enterOuterAlt(localctx, 2) + self.state = 169 + self.eval_at() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_atContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def eval_at_sup(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_supContext,0) + + + def eval_at_sub(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at + + + + + def eval_at(self): + + localctx = LaTeXParser.Eval_atContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_eval_at) + try: + self.enterOuterAlt(localctx, 1) + self.state = 172 + self.match(LaTeXParser.BAR) + self.state = 178 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,11,self._ctx) + if la_ == 1: + self.state = 173 + self.eval_at_sup() + pass + + elif la_ == 2: + self.state = 174 + self.eval_at_sub() + pass + + elif la_ == 3: + self.state = 175 + self.eval_at_sup() + self.state = 176 + self.eval_at_sub() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sub + + + + + def eval_at_sub(self): + + localctx = LaTeXParser.Eval_at_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_eval_at_sub) + try: + self.enterOuterAlt(localctx, 1) + self.state = 180 + self.match(LaTeXParser.UNDERSCORE) + self.state = 181 + self.match(LaTeXParser.L_BRACE) + self.state = 184 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,12,self._ctx) + if la_ == 1: + self.state = 182 + self.expr() + pass + + elif la_ == 2: + self.state = 183 + self.equality() + pass + + + self.state = 186 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_supContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sup + + + + + def eval_at_sup(self): + + localctx = LaTeXParser.Eval_at_supContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_eval_at_sup) + try: + self.enterOuterAlt(localctx, 1) + self.state = 188 + self.match(LaTeXParser.CARET) + self.state = 189 + self.match(LaTeXParser.L_BRACE) + self.state = 192 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,13,self._ctx) + if la_ == 1: + self.state = 190 + self.expr() + pass + + elif la_ == 2: + self.state = 191 + self.equality() + pass + + + self.state = 194 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp(self): + return self.getTypedRuleContext(LaTeXParser.CompContext,0) + + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp + + + + def exp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.ExpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 30 + self.enterRecursionRule(localctx, 30, self.RULE_exp, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 197 + self.comp() + self._ctx.stop = self._input.LT(-1) + self.state = 213 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.ExpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp) + self.state = 199 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 200 + self.match(LaTeXParser.CARET) + self.state = 206 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 201 + self.atom() + pass + elif token in [21]: + self.state = 202 + self.match(LaTeXParser.L_BRACE) + self.state = 203 + self.expr() + self.state = 204 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 209 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 208 + self.subexpr() + + + self.state = 215 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Exp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Comp_nofuncContext,0) + + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp_nofunc + + + + def exp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Exp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 32 + self.enterRecursionRule(localctx, 32, self.RULE_exp_nofunc, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 217 + self.comp_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 233 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Exp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp_nofunc) + self.state = 219 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 220 + self.match(LaTeXParser.CARET) + self.state = 226 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 221 + self.atom() + pass + elif token in [21]: + self.state = 222 + self.match(LaTeXParser.L_BRACE) + self.state = 223 + self.expr() + self.state = 224 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 229 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,18,self._ctx) + if la_ == 1: + self.state = 228 + self.subexpr() + + + self.state = 235 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class CompContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def func(self): + return self.getTypedRuleContext(LaTeXParser.FuncContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp + + + + + def comp(self): + + localctx = LaTeXParser.CompContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_comp) + try: + self.state = 242 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,20,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 237 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 238 + self.func() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 239 + self.atom() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 240 + self.floor() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 241 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Comp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp_nofunc + + + + + def comp_nofunc(self): + + localctx = LaTeXParser.Comp_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_comp_nofunc) + try: + self.state = 249 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,21,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 244 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 245 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.atom() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 247 + self.floor() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 248 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class GroupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def L_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.L_BRACE_LITERAL, 0) + + def R_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.R_BRACE_LITERAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_group + + + + + def group(self): + + localctx = LaTeXParser.GroupContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_group) + try: + self.state = 267 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [19]: + self.enterOuterAlt(localctx, 1) + self.state = 251 + self.match(LaTeXParser.L_PAREN) + self.state = 252 + self.expr() + self.state = 253 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [25]: + self.enterOuterAlt(localctx, 2) + self.state = 255 + self.match(LaTeXParser.L_BRACKET) + self.state = 256 + self.expr() + self.state = 257 + self.match(LaTeXParser.R_BRACKET) + pass + elif token in [21]: + self.enterOuterAlt(localctx, 3) + self.state = 259 + self.match(LaTeXParser.L_BRACE) + self.state = 260 + self.expr() + self.state = 261 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 4) + self.state = 263 + self.match(LaTeXParser.L_BRACE_LITERAL) + self.state = 264 + self.expr() + self.state = 265 + self.match(LaTeXParser.R_BRACE_LITERAL) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Abs_groupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.BAR) + else: + return self.getToken(LaTeXParser.BAR, i) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_abs_group + + + + + def abs_group(self): + + localctx = LaTeXParser.Abs_groupContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_abs_group) + try: + self.enterOuterAlt(localctx, 1) + self.state = 269 + self.match(LaTeXParser.BAR) + self.state = 270 + self.expr() + self.state = 271 + self.match(LaTeXParser.BAR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class NumberContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_number + + + + + def number(self): + + localctx = LaTeXParser.NumberContext(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_number) + try: + self.enterOuterAlt(localctx, 1) + self.state = 274 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 273 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 276 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,23,self._ctx) + + self.state = 284 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 278 + self.match(LaTeXParser.T__0) + self.state = 279 + self.match(LaTeXParser.DIGIT) + self.state = 280 + self.match(LaTeXParser.DIGIT) + self.state = 281 + self.match(LaTeXParser.DIGIT) + self.state = 286 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + + self.state = 293 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,26,self._ctx) + if la_ == 1: + self.state = 287 + self.match(LaTeXParser.T__1) + self.state = 289 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 288 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 291 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,25,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AtomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def number(self): + return self.getTypedRuleContext(LaTeXParser.NumberContext,0) + + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def mathit(self): + return self.getTypedRuleContext(LaTeXParser.MathitContext,0) + + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def binom(self): + return self.getTypedRuleContext(LaTeXParser.BinomContext,0) + + + def bra(self): + return self.getTypedRuleContext(LaTeXParser.BraContext,0) + + + def ket(self): + return self.getTypedRuleContext(LaTeXParser.KetContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_atom + + + + + def atom(self): + + localctx = LaTeXParser.AtomContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_atom) + self._la = 0 # Token type + try: + self.state = 317 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [77, 91]: + self.enterOuterAlt(localctx, 1) + self.state = 295 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 308 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,31,self._ctx) + if la_ == 1: + self.state = 297 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,27,self._ctx) + if la_ == 1: + self.state = 296 + self.subexpr() + + + self.state = 300 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + self.state = 299 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,29,self._ctx) + if la_ == 1: + self.state = 302 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 306 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 305 + self.subexpr() + + + pass + + + pass + elif token in [78]: + self.enterOuterAlt(localctx, 2) + self.state = 310 + self.number() + pass + elif token in [76]: + self.enterOuterAlt(localctx, 3) + self.state = 311 + self.match(LaTeXParser.DIFFERENTIAL) + pass + elif token in [72]: + self.enterOuterAlt(localctx, 4) + self.state = 312 + self.mathit() + pass + elif token in [68]: + self.enterOuterAlt(localctx, 5) + self.state = 313 + self.frac() + pass + elif token in [69, 70, 71]: + self.enterOuterAlt(localctx, 6) + self.state = 314 + self.binom() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 7) + self.state = 315 + self.bra() + pass + elif token in [27, 29]: + self.enterOuterAlt(localctx, 8) + self.state = 316 + self.ket() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BraContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_ANGLE(self): + return self.getToken(LaTeXParser.L_ANGLE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BAR(self): + return self.getToken(LaTeXParser.R_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_bra + + + + + def bra(self): + + localctx = LaTeXParser.BraContext(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_bra) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 319 + self.match(LaTeXParser.L_ANGLE) + self.state = 320 + self.expr() + self.state = 321 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class KetContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_ANGLE(self): + return self.getToken(LaTeXParser.R_ANGLE, 0) + + def L_BAR(self): + return self.getToken(LaTeXParser.L_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_ket + + + + + def ket(self): + + localctx = LaTeXParser.KetContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_ket) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 323 + _la = self._input.LA(1) + if not(_la==27 or _la==29): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 324 + self.expr() + self.state = 325 + self.match(LaTeXParser.R_ANGLE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MathitContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CMD_MATHIT(self): + return self.getToken(LaTeXParser.CMD_MATHIT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def mathit_text(self): + return self.getTypedRuleContext(LaTeXParser.Mathit_textContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit + + + + + def mathit(self): + + localctx = LaTeXParser.MathitContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_mathit) + try: + self.enterOuterAlt(localctx, 1) + self.state = 327 + self.match(LaTeXParser.CMD_MATHIT) + self.state = 328 + self.match(LaTeXParser.L_BRACE) + self.state = 329 + self.mathit_text() + self.state = 330 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Mathit_textContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.LETTER) + else: + return self.getToken(LaTeXParser.LETTER, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit_text + + + + + def mathit_text(self): + + localctx = LaTeXParser.Mathit_textContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_mathit_text) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==77: + self.state = 332 + self.match(LaTeXParser.LETTER) + self.state = 337 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FracContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.upperd = None # Token + self.upper = None # ExprContext + self.lowerd = None # Token + self.lower = None # ExprContext + + def CMD_FRAC(self): + return self.getToken(LaTeXParser.CMD_FRAC, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_frac + + + + + def frac(self): + + localctx = LaTeXParser.FracContext(self, self._ctx, self.state) + self.enterRule(localctx, 54, self.RULE_frac) + try: + self.enterOuterAlt(localctx, 1) + self.state = 338 + self.match(LaTeXParser.CMD_FRAC) + self.state = 344 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 339 + localctx.upperd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 340 + self.match(LaTeXParser.L_BRACE) + self.state = 341 + localctx.upper = self.expr() + self.state = 342 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 351 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 346 + localctx.lowerd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 347 + self.match(LaTeXParser.L_BRACE) + self.state = 348 + localctx.lower = self.expr() + self.state = 349 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BinomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.n = None # ExprContext + self.k = None # ExprContext + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def CMD_BINOM(self): + return self.getToken(LaTeXParser.CMD_BINOM, 0) + + def CMD_DBINOM(self): + return self.getToken(LaTeXParser.CMD_DBINOM, 0) + + def CMD_TBINOM(self): + return self.getToken(LaTeXParser.CMD_TBINOM, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_binom + + + + + def binom(self): + + localctx = LaTeXParser.BinomContext(self, self._ctx, self.state) + self.enterRule(localctx, 56, self.RULE_binom) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 353 + _la = self._input.LA(1) + if not((((_la - 69)) & ~0x3f) == 0 and ((1 << (_la - 69)) & 7) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 354 + self.match(LaTeXParser.L_BRACE) + self.state = 355 + localctx.n = self.expr() + self.state = 356 + self.match(LaTeXParser.R_BRACE) + self.state = 357 + self.match(LaTeXParser.L_BRACE) + self.state = 358 + localctx.k = self.expr() + self.state = 359 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FloorContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_FLOOR(self): + return self.getToken(LaTeXParser.L_FLOOR, 0) + + def R_FLOOR(self): + return self.getToken(LaTeXParser.R_FLOOR, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_floor + + + + + def floor(self): + + localctx = LaTeXParser.FloorContext(self, self._ctx, self.state) + self.enterRule(localctx, 58, self.RULE_floor) + try: + self.enterOuterAlt(localctx, 1) + self.state = 361 + self.match(LaTeXParser.L_FLOOR) + self.state = 362 + localctx.val = self.expr() + self.state = 363 + self.match(LaTeXParser.R_FLOOR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CeilContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_CEIL(self): + return self.getToken(LaTeXParser.L_CEIL, 0) + + def R_CEIL(self): + return self.getToken(LaTeXParser.R_CEIL, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_ceil + + + + + def ceil(self): + + localctx = LaTeXParser.CeilContext(self, self._ctx, self.state) + self.enterRule(localctx, 60, self.RULE_ceil) + try: + self.enterOuterAlt(localctx, 1) + self.state = 365 + self.match(LaTeXParser.L_CEIL) + self.state = 366 + localctx.val = self.expr() + self.state = 367 + self.match(LaTeXParser.R_CEIL) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_normalContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def FUNC_EXP(self): + return self.getToken(LaTeXParser.FUNC_EXP, 0) + + def FUNC_LOG(self): + return self.getToken(LaTeXParser.FUNC_LOG, 0) + + def FUNC_LG(self): + return self.getToken(LaTeXParser.FUNC_LG, 0) + + def FUNC_LN(self): + return self.getToken(LaTeXParser.FUNC_LN, 0) + + def FUNC_SIN(self): + return self.getToken(LaTeXParser.FUNC_SIN, 0) + + def FUNC_COS(self): + return self.getToken(LaTeXParser.FUNC_COS, 0) + + def FUNC_TAN(self): + return self.getToken(LaTeXParser.FUNC_TAN, 0) + + def FUNC_CSC(self): + return self.getToken(LaTeXParser.FUNC_CSC, 0) + + def FUNC_SEC(self): + return self.getToken(LaTeXParser.FUNC_SEC, 0) + + def FUNC_COT(self): + return self.getToken(LaTeXParser.FUNC_COT, 0) + + def FUNC_ARCSIN(self): + return self.getToken(LaTeXParser.FUNC_ARCSIN, 0) + + def FUNC_ARCCOS(self): + return self.getToken(LaTeXParser.FUNC_ARCCOS, 0) + + def FUNC_ARCTAN(self): + return self.getToken(LaTeXParser.FUNC_ARCTAN, 0) + + def FUNC_ARCCSC(self): + return self.getToken(LaTeXParser.FUNC_ARCCSC, 0) + + def FUNC_ARCSEC(self): + return self.getToken(LaTeXParser.FUNC_ARCSEC, 0) + + def FUNC_ARCCOT(self): + return self.getToken(LaTeXParser.FUNC_ARCCOT, 0) + + def FUNC_SINH(self): + return self.getToken(LaTeXParser.FUNC_SINH, 0) + + def FUNC_COSH(self): + return self.getToken(LaTeXParser.FUNC_COSH, 0) + + def FUNC_TANH(self): + return self.getToken(LaTeXParser.FUNC_TANH, 0) + + def FUNC_ARSINH(self): + return self.getToken(LaTeXParser.FUNC_ARSINH, 0) + + def FUNC_ARCOSH(self): + return self.getToken(LaTeXParser.FUNC_ARCOSH, 0) + + def FUNC_ARTANH(self): + return self.getToken(LaTeXParser.FUNC_ARTANH, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_func_normal + + + + + def func_normal(self): + + localctx = LaTeXParser.Func_normalContext(self, self._ctx, self.state) + self.enterRule(localctx, 62, self.RULE_func_normal) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 369 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 576460614864470016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.root = None # ExprContext + self.base = None # ExprContext + + def func_normal(self): + return self.getTypedRuleContext(LaTeXParser.Func_normalContext,0) + + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def func_arg_noparens(self): + return self.getTypedRuleContext(LaTeXParser.Func_arg_noparensContext,0) + + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def supexpr(self): + return self.getTypedRuleContext(LaTeXParser.SupexprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def FUNC_INT(self): + return self.getToken(LaTeXParser.FUNC_INT, 0) + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def FUNC_SQRT(self): + return self.getToken(LaTeXParser.FUNC_SQRT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def FUNC_OVERLINE(self): + return self.getToken(LaTeXParser.FUNC_OVERLINE, 0) + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def FUNC_SUM(self): + return self.getToken(LaTeXParser.FUNC_SUM, 0) + + def FUNC_PROD(self): + return self.getToken(LaTeXParser.FUNC_PROD, 0) + + def subeq(self): + return self.getTypedRuleContext(LaTeXParser.SubeqContext,0) + + + def FUNC_LIM(self): + return self.getToken(LaTeXParser.FUNC_LIM, 0) + + def limit_sub(self): + return self.getTypedRuleContext(LaTeXParser.Limit_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func + + + + + def func(self): + + localctx = LaTeXParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 64, self.RULE_func) + self._la = 0 # Token type + try: + self.state = 460 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58]: + self.enterOuterAlt(localctx, 1) + self.state = 371 + self.func_normal() + self.state = 384 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,40,self._ctx) + if la_ == 1: + self.state = 373 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 372 + self.subexpr() + + + self.state = 376 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 375 + self.supexpr() + + + pass + + elif la_ == 2: + self.state = 379 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 378 + self.supexpr() + + + self.state = 382 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 381 + self.subexpr() + + + pass + + + self.state = 391 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,41,self._ctx) + if la_ == 1: + self.state = 386 + self.match(LaTeXParser.L_PAREN) + self.state = 387 + self.func_arg() + self.state = 388 + self.match(LaTeXParser.R_PAREN) + pass + + elif la_ == 2: + self.state = 390 + self.func_arg_noparens() + pass + + + pass + elif token in [77, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 393 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 406 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,46,self._ctx) + if la_ == 1: + self.state = 395 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 394 + self.subexpr() + + + self.state = 398 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 397 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 401 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 400 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 404 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 403 + self.subexpr() + + + pass + + + self.state = 408 + self.match(LaTeXParser.L_PAREN) + self.state = 409 + self.args() + self.state = 410 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [34]: + self.enterOuterAlt(localctx, 3) + self.state = 412 + self.match(LaTeXParser.FUNC_INT) + self.state = 419 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 413 + self.subexpr() + self.state = 414 + self.supexpr() + pass + elif token in [74]: + self.state = 416 + self.supexpr() + self.state = 417 + self.subexpr() + pass + elif token in [15, 16, 19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + pass + else: + pass + self.state = 427 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,49,self._ctx) + if la_ == 1: + self.state = 422 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + self.state = 421 + self.additive(0) + + + self.state = 424 + self.match(LaTeXParser.DIFFERENTIAL) + pass + + elif la_ == 2: + self.state = 425 + self.frac() + pass + + elif la_ == 3: + self.state = 426 + self.additive(0) + pass + + + pass + elif token in [63]: + self.enterOuterAlt(localctx, 4) + self.state = 429 + self.match(LaTeXParser.FUNC_SQRT) + self.state = 434 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==25: + self.state = 430 + self.match(LaTeXParser.L_BRACKET) + self.state = 431 + localctx.root = self.expr() + self.state = 432 + self.match(LaTeXParser.R_BRACKET) + + + self.state = 436 + self.match(LaTeXParser.L_BRACE) + self.state = 437 + localctx.base = self.expr() + self.state = 438 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [64]: + self.enterOuterAlt(localctx, 5) + self.state = 440 + self.match(LaTeXParser.FUNC_OVERLINE) + self.state = 441 + self.match(LaTeXParser.L_BRACE) + self.state = 442 + localctx.base = self.expr() + self.state = 443 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [35, 36]: + self.enterOuterAlt(localctx, 6) + self.state = 445 + _la = self._input.LA(1) + if not(_la==35 or _la==36): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 452 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 446 + self.subeq() + self.state = 447 + self.supexpr() + pass + elif token in [74]: + self.state = 449 + self.supexpr() + self.state = 450 + self.subeq() + pass + else: + raise NoViableAltException(self) + + self.state = 454 + self.mp(0) + pass + elif token in [32]: + self.enterOuterAlt(localctx, 7) + self.state = 456 + self.match(LaTeXParser.FUNC_LIM) + self.state = 457 + self.limit_sub() + self.state = 458 + self.mp(0) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_args + + + + + def args(self): + + localctx = LaTeXParser.ArgsContext(self, self._ctx, self.state) + self.enterRule(localctx, 66, self.RULE_args) + try: + self.state = 467 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,53,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 462 + self.expr() + self.state = 463 + self.match(LaTeXParser.T__0) + self.state = 464 + self.args() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 466 + self.expr() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Limit_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def LIM_APPROACH_SYM(self): + return self.getToken(LaTeXParser.LIM_APPROACH_SYM, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_limit_sub + + + + + def limit_sub(self): + + localctx = LaTeXParser.Limit_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 68, self.RULE_limit_sub) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 469 + self.match(LaTeXParser.UNDERSCORE) + self.state = 470 + self.match(LaTeXParser.L_BRACE) + self.state = 471 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 472 + self.match(LaTeXParser.LIM_APPROACH_SYM) + self.state = 473 + self.expr() + self.state = 482 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 474 + self.match(LaTeXParser.CARET) + self.state = 480 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [21]: + self.state = 475 + self.match(LaTeXParser.L_BRACE) + self.state = 476 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 477 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [15]: + self.state = 478 + self.match(LaTeXParser.ADD) + pass + elif token in [16]: + self.state = 479 + self.match(LaTeXParser.SUB) + pass + else: + raise NoViableAltException(self) + + + + self.state = 484 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_argContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg + + + + + def func_arg(self): + + localctx = LaTeXParser.Func_argContext(self, self._ctx, self.state) + self.enterRule(localctx, 70, self.RULE_func_arg) + try: + self.state = 491 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,56,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 486 + self.expr() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 487 + self.expr() + self.state = 488 + self.match(LaTeXParser.T__0) + self.state = 489 + self.func_arg() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_arg_noparensContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg_noparens + + + + + def func_arg_noparens(self): + + localctx = LaTeXParser.Func_arg_noparensContext(self, self._ctx, self.state) + self.enterRule(localctx, 72, self.RULE_func_arg_noparens) + try: + self.enterOuterAlt(localctx, 1) + self.state = 493 + self.mp_nofunc(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subexpr + + + + + def subexpr(self): + + localctx = LaTeXParser.SubexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 74, self.RULE_subexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 495 + self.match(LaTeXParser.UNDERSCORE) + self.state = 501 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 496 + self.atom() + pass + elif token in [21]: + self.state = 497 + self.match(LaTeXParser.L_BRACE) + self.state = 498 + self.expr() + self.state = 499 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supexpr + + + + + def supexpr(self): + + localctx = LaTeXParser.SupexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 76, self.RULE_supexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 503 + self.match(LaTeXParser.CARET) + self.state = 509 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 504 + self.atom() + pass + elif token in [21]: + self.state = 505 + self.match(LaTeXParser.L_BRACE) + self.state = 506 + self.expr() + self.state = 507 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subeq + + + + + def subeq(self): + + localctx = LaTeXParser.SubeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 78, self.RULE_subeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 511 + self.match(LaTeXParser.UNDERSCORE) + self.state = 512 + self.match(LaTeXParser.L_BRACE) + self.state = 513 + self.equality() + self.state = 514 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supeq + + + + + def supeq(self): + + localctx = LaTeXParser.SupeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 80, self.RULE_supeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 516 + self.match(LaTeXParser.UNDERSCORE) + self.state = 517 + self.match(LaTeXParser.L_BRACE) + self.state = 518 + self.equality() + self.state = 519 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[1] = self.relation_sempred + self._predicates[4] = self.additive_sempred + self._predicates[5] = self.mp_sempred + self._predicates[6] = self.mp_nofunc_sempred + self._predicates[15] = self.exp_sempred + self._predicates[16] = self.exp_nofunc_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def relation_sempred(self, localctx:RelationContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 2) + + + def additive_sempred(self, localctx:AdditiveContext, predIndex:int): + if predIndex == 1: + return self.precpred(self._ctx, 2) + + + def mp_sempred(self, localctx:MpContext, predIndex:int): + if predIndex == 2: + return self.precpred(self._ctx, 2) + + + def mp_nofunc_sempred(self, localctx:Mp_nofuncContext, predIndex:int): + if predIndex == 3: + return self.precpred(self._ctx, 2) + + + def exp_sempred(self, localctx:ExpContext, predIndex:int): + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + def exp_nofunc_sempred(self, localctx:Exp_nofuncContext, predIndex:int): + if predIndex == 5: + return self.precpred(self._ctx, 2) + + + + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_build_latex_antlr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_build_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..ee50da5b7861154823812c7773360b53dfd29ff6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_build_latex_antlr.py @@ -0,0 +1,91 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "LaTeX.g4")) +dir_latex_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_latex_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + # for now, not generating these as latex2sympy did not use them + "-no-visitor", + "-no-listener", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to latex* but LaTeX*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "LaTeX*.*")) or + glob.glob(os.path.join(output_dir, "latex*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip() + '\n' for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_parse_latex_antlr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_parse_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..26604375b3a9622f8c1dacdb1d678d09c2c3ad41 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/_parse_latex_antlr.py @@ -0,0 +1,607 @@ +# Ported from latex2sympy by @augustt198 +# https://github.com/augustt198/latex2sympy +# See license in LICENSE.txt +from importlib.metadata import version +import sympy +from sympy.external import import_module +from sympy.printing.str import StrPrinter +from sympy.physics.quantum.state import Bra, Ket + +from .errors import LaTeXParsingError + + +LaTeXParser = LaTeXLexer = MathErrorListener = None + +try: + LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser', + import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser + LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer', + import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer +except Exception: + pass + +ErrorListener = import_module('antlr4.error.ErrorListener', + warn_not_installed=True, + import_kwargs={'fromlist': ['ErrorListener']} + ) + + + +if ErrorListener: + class MathErrorListener(ErrorListener.ErrorListener): # type:ignore # noqa:F811 + def __init__(self, src): + super(ErrorListener.ErrorListener, self).__init__() + self.src = src + + def syntaxError(self, recog, symbol, line, col, msg, e): + fmt = "%s\n%s\n%s" + marker = "~" * col + "^" + + if msg.startswith("missing"): + err = fmt % (msg, self.src, marker) + elif msg.startswith("no viable"): + err = fmt % ("I expected something else here", self.src, marker) + elif msg.startswith("mismatched"): + names = LaTeXParser.literalNames + expected = [ + names[i] for i in e.getExpectedTokens() if i < len(names) + ] + if len(expected) < 10: + expected = " ".join(expected) + err = (fmt % ("I expected one of these: " + expected, self.src, + marker)) + else: + err = (fmt % ("I expected something else here", self.src, + marker)) + else: + err = fmt % ("I don't understand this", self.src, marker) + raise LaTeXParsingError(err) + + +def parse_latex(sympy, strict=False): + antlr4 = import_module('antlr4') + + if None in [antlr4, MathErrorListener] or \ + not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("LaTeX parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime) or" + " conda (antlr-python-runtime), version 4.11") + + sympy = sympy.strip() + matherror = MathErrorListener(sympy) + + stream = antlr4.InputStream(sympy) + lex = LaTeXLexer(stream) + lex.removeErrorListeners() + lex.addErrorListener(matherror) + + tokens = antlr4.CommonTokenStream(lex) + parser = LaTeXParser(tokens) + + # remove default console error listener + parser.removeErrorListeners() + parser.addErrorListener(matherror) + + relation = parser.math().relation() + if strict and (relation.start.start != 0 or relation.stop.stop != len(sympy) - 1): + raise LaTeXParsingError("Invalid LaTeX") + expr = convert_relation(relation) + + return expr + + +def convert_relation(rel): + if rel.expr(): + return convert_expr(rel.expr()) + + lh = convert_relation(rel.relation(0)) + rh = convert_relation(rel.relation(1)) + if rel.LT(): + return sympy.StrictLessThan(lh, rh) + elif rel.LTE(): + return sympy.LessThan(lh, rh) + elif rel.GT(): + return sympy.StrictGreaterThan(lh, rh) + elif rel.GTE(): + return sympy.GreaterThan(lh, rh) + elif rel.EQUAL(): + return sympy.Eq(lh, rh) + elif rel.NEQ(): + return sympy.Ne(lh, rh) + + +def convert_expr(expr): + return convert_add(expr.additive()) + + +def convert_add(add): + if add.ADD(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + return sympy.Add(lh, rh, evaluate=False) + elif add.SUB(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + if hasattr(rh, "is_Atom") and rh.is_Atom: + return sympy.Add(lh, -1 * rh, evaluate=False) + return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False) + else: + return convert_mp(add.mp()) + + +def convert_mp(mp): + if hasattr(mp, 'mp'): + mp_left = mp.mp(0) + mp_right = mp.mp(1) + else: + mp_left = mp.mp_nofunc(0) + mp_right = mp.mp_nofunc(1) + + if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, rh, evaluate=False) + elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + else: + if hasattr(mp, 'unary'): + return convert_unary(mp.unary()) + else: + return convert_unary(mp.unary_nofunc()) + + +def convert_unary(unary): + if hasattr(unary, 'unary'): + nested_unary = unary.unary() + else: + nested_unary = unary.unary_nofunc() + if hasattr(unary, 'postfix_nofunc'): + first = unary.postfix() + tail = unary.postfix_nofunc() + postfix = [first] + tail + else: + postfix = unary.postfix() + + if unary.ADD(): + return convert_unary(nested_unary) + elif unary.SUB(): + numabs = convert_unary(nested_unary) + # Use Integer(-n) instead of Mul(-1, n) + return -numabs + elif postfix: + return convert_postfix_list(postfix) + + +def convert_postfix_list(arr, i=0): + if i >= len(arr): + raise LaTeXParsingError("Index out of bounds") + + res = convert_postfix(arr[i]) + if isinstance(res, sympy.Expr): + if i == len(arr) - 1: + return res # nothing to multiply by + else: + if i > 0: + left = convert_postfix(arr[i - 1]) + right = convert_postfix(arr[i + 1]) + if isinstance(left, sympy.Expr) and isinstance( + right, sympy.Expr): + left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol) + right_syms = convert_postfix(arr[i + 1]).atoms( + sympy.Symbol) + # if the left and right sides contain no variables and the + # symbol in between is 'x', treat as multiplication. + if not (left_syms or right_syms) and str(res) == 'x': + return convert_postfix_list(arr, i + 1) + # multiply by next + return sympy.Mul( + res, convert_postfix_list(arr, i + 1), evaluate=False) + else: # must be derivative + wrt = res[0] + if i == len(arr) - 1: + raise LaTeXParsingError("Expected expression for derivative") + else: + expr = convert_postfix_list(arr, i + 1) + return sympy.Derivative(expr, wrt) + + +def do_subs(expr, at): + if at.expr(): + at_expr = convert_expr(at.expr()) + syms = at_expr.atoms(sympy.Symbol) + if len(syms) == 0: + return expr + elif len(syms) > 0: + sym = next(iter(syms)) + return expr.subs(sym, at_expr) + elif at.equality(): + lh = convert_expr(at.equality().expr(0)) + rh = convert_expr(at.equality().expr(1)) + return expr.subs(lh, rh) + + +def convert_postfix(postfix): + if hasattr(postfix, 'exp'): + exp_nested = postfix.exp() + else: + exp_nested = postfix.exp_nofunc() + + exp = convert_exp(exp_nested) + for op in postfix.postfix_op(): + if op.BANG(): + if isinstance(exp, list): + raise LaTeXParsingError("Cannot apply postfix to derivative") + exp = sympy.factorial(exp, evaluate=False) + elif op.eval_at(): + ev = op.eval_at() + at_b = None + at_a = None + if ev.eval_at_sup(): + at_b = do_subs(exp, ev.eval_at_sup()) + if ev.eval_at_sub(): + at_a = do_subs(exp, ev.eval_at_sub()) + if at_b is not None and at_a is not None: + exp = sympy.Add(at_b, -1 * at_a, evaluate=False) + elif at_b is not None: + exp = at_b + elif at_a is not None: + exp = at_a + + return exp + + +def convert_exp(exp): + if hasattr(exp, 'exp'): + exp_nested = exp.exp() + else: + exp_nested = exp.exp_nofunc() + + if exp_nested: + base = convert_exp(exp_nested) + if isinstance(base, list): + raise LaTeXParsingError("Cannot raise derivative to power") + if exp.atom(): + exponent = convert_atom(exp.atom()) + elif exp.expr(): + exponent = convert_expr(exp.expr()) + return sympy.Pow(base, exponent, evaluate=False) + else: + if hasattr(exp, 'comp'): + return convert_comp(exp.comp()) + else: + return convert_comp(exp.comp_nofunc()) + + +def convert_comp(comp): + if comp.group(): + return convert_expr(comp.group().expr()) + elif comp.abs_group(): + return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) + elif comp.atom(): + return convert_atom(comp.atom()) + elif comp.floor(): + return convert_floor(comp.floor()) + elif comp.ceil(): + return convert_ceil(comp.ceil()) + elif comp.func(): + return convert_func(comp.func()) + + +def convert_atom(atom): + if atom.LETTER(): + sname = atom.LETTER().getText() + if atom.subexpr(): + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + sname += '_{' + StrPrinter().doprint(subscript) + '}' + if atom.SINGLE_QUOTES(): + sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify + return sympy.Symbol(sname) + elif atom.SYMBOL(): + s = atom.SYMBOL().getText()[1:] + if s == "infty": + return sympy.oo + else: + if atom.subexpr(): + subscript = None + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + s += '_{' + subscriptName + '}' + return sympy.Symbol(s) + elif atom.number(): + s = atom.number().getText().replace(",", "") + return sympy.Number(s) + elif atom.DIFFERENTIAL(): + var = get_differential_var(atom.DIFFERENTIAL()) + return sympy.Symbol('d' + var.name) + elif atom.mathit(): + text = rule2text(atom.mathit().mathit_text()) + return sympy.Symbol(text) + elif atom.frac(): + return convert_frac(atom.frac()) + elif atom.binom(): + return convert_binom(atom.binom()) + elif atom.bra(): + val = convert_expr(atom.bra().expr()) + return Bra(val) + elif atom.ket(): + val = convert_expr(atom.ket().expr()) + return Ket(val) + + +def rule2text(ctx): + stream = ctx.start.getInputStream() + # starting index of starting token + startIdx = ctx.start.start + # stopping index of stopping token + stopIdx = ctx.stop.stop + + return stream.getText(startIdx, stopIdx) + + +def convert_frac(frac): + diff_op = False + partial_op = False + if frac.lower and frac.upper: + lower_itv = frac.lower.getSourceInterval() + lower_itv_len = lower_itv[1] - lower_itv[0] + 1 + if (frac.lower.start == frac.lower.stop + and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL): + wrt = get_differential_var_str(frac.lower.start.text) + diff_op = True + elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL + and frac.lower.start.text == '\\partial' + and (frac.lower.stop.type == LaTeXLexer.LETTER + or frac.lower.stop.type == LaTeXLexer.SYMBOL)): + partial_op = True + wrt = frac.lower.stop.text + if frac.lower.stop.type == LaTeXLexer.SYMBOL: + wrt = wrt[1:] + + if diff_op or partial_op: + wrt = sympy.Symbol(wrt) + if (diff_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.LETTER + and frac.upper.start.text == 'd'): + return [wrt] + elif (partial_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.SYMBOL + and frac.upper.start.text == '\\partial'): + return [wrt] + upper_text = rule2text(frac.upper) + + expr_top = None + if diff_op and upper_text.startswith('d'): + expr_top = parse_latex(upper_text[1:]) + elif partial_op and frac.upper.start.text == '\\partial': + expr_top = parse_latex(upper_text[len('\\partial'):]) + if expr_top: + return sympy.Derivative(expr_top, wrt) + if frac.upper: + expr_top = convert_expr(frac.upper) + else: + expr_top = sympy.Number(frac.upperd.text) + if frac.lower: + expr_bot = convert_expr(frac.lower) + else: + expr_bot = sympy.Number(frac.lowerd.text) + inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False) + if expr_top == 1: + return inverse_denom + else: + return sympy.Mul(expr_top, inverse_denom, evaluate=False) + +def convert_binom(binom): + expr_n = convert_expr(binom.n) + expr_k = convert_expr(binom.k) + return sympy.binomial(expr_n, expr_k, evaluate=False) + +def convert_floor(floor): + val = convert_expr(floor.val) + return sympy.floor(val, evaluate=False) + +def convert_ceil(ceil): + val = convert_expr(ceil.val) + return sympy.ceiling(val, evaluate=False) + +def convert_func(func): + if func.func_normal(): + if func.L_PAREN(): # function called with parenthesis + arg = convert_func_arg(func.func_arg()) + else: + arg = convert_func_arg(func.func_arg_noparens()) + + name = func.func_normal().start.text[1:] + + # change arc -> a + if name in [ + "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot" + ]: + name = "a" + name[3:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + if name in ["arsinh", "arcosh", "artanh"]: + name = "a" + name[2:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if name == "exp": + expr = sympy.exp(arg, evaluate=False) + + if name in ("log", "lg", "ln"): + if func.subexpr(): + if func.subexpr().expr(): + base = convert_expr(func.subexpr().expr()) + else: + base = convert_atom(func.subexpr().atom()) + elif name == "lg": # ISO 80000-2:2019 + base = 10 + elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default + base = sympy.E + expr = sympy.log(arg, base, evaluate=False) + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = convert_expr(func.supexpr().expr()) + else: + func_pow = convert_atom(func.supexpr().atom()) + + if name in [ + "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", + "tanh" + ]: + if func_pow == -1: + name = "a" + name + should_pow = False + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + elif func.LETTER() or func.SYMBOL(): + if func.LETTER(): + fname = func.LETTER().getText() + elif func.SYMBOL(): + fname = func.SYMBOL().getText()[1:] + fname = str(fname) # can't be unicode + if func.subexpr(): + if func.subexpr().expr(): # subscript is expr + subscript = convert_expr(func.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(func.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + fname += '_{' + subscriptName + '}' + if func.SINGLE_QUOTES(): + fname += func.SINGLE_QUOTES().getText() + input_args = func.args() + output_args = [] + while input_args.args(): # handle multiple arguments to function + output_args.append(convert_expr(input_args.expr())) + input_args = input_args.args() + output_args.append(convert_expr(input_args.expr())) + return sympy.Function(fname)(*output_args) + elif func.FUNC_INT(): + return handle_integral(func) + elif func.FUNC_SQRT(): + expr = convert_expr(func.base) + if func.root: + r = convert_expr(func.root) + return sympy.root(expr, r, evaluate=False) + else: + return sympy.sqrt(expr, evaluate=False) + elif func.FUNC_OVERLINE(): + expr = convert_expr(func.base) + return sympy.conjugate(expr, evaluate=False) + elif func.FUNC_SUM(): + return handle_sum_or_prod(func, "summation") + elif func.FUNC_PROD(): + return handle_sum_or_prod(func, "product") + elif func.FUNC_LIM(): + return handle_limit(func) + + +def convert_func_arg(arg): + if hasattr(arg, 'expr'): + return convert_expr(arg.expr()) + else: + return convert_mp(arg.mp_nofunc()) + + +def handle_integral(func): + if func.additive(): + integrand = convert_add(func.additive()) + elif func.frac(): + integrand = convert_frac(func.frac()) + else: + integrand = 1 + + int_var = None + if func.DIFFERENTIAL(): + int_var = get_differential_var(func.DIFFERENTIAL()) + else: + for sym in integrand.atoms(sympy.Symbol): + s = str(sym) + if len(s) > 1 and s[0] == 'd': + if s[1] == '\\': + int_var = sympy.Symbol(s[2:]) + else: + int_var = sympy.Symbol(s[1:]) + int_sym = sym + if int_var: + integrand = integrand.subs(int_sym, 1) + else: + # Assume dx by default + int_var = sympy.Symbol('x') + + if func.subexpr(): + if func.subexpr().atom(): + lower = convert_atom(func.subexpr().atom()) + else: + lower = convert_expr(func.subexpr().expr()) + if func.supexpr().atom(): + upper = convert_atom(func.supexpr().atom()) + else: + upper = convert_expr(func.supexpr().expr()) + return sympy.Integral(integrand, (int_var, lower, upper)) + else: + return sympy.Integral(integrand, int_var) + + +def handle_sum_or_prod(func, name): + val = convert_mp(func.mp()) + iter_var = convert_expr(func.subeq().equality().expr(0)) + start = convert_expr(func.subeq().equality().expr(1)) + if func.supexpr().expr(): # ^{expr} + end = convert_expr(func.supexpr().expr()) + else: # ^atom + end = convert_atom(func.supexpr().atom()) + + if name == "summation": + return sympy.Sum(val, (iter_var, start, end)) + elif name == "product": + return sympy.Product(val, (iter_var, start, end)) + + +def handle_limit(func): + sub = func.limit_sub() + if sub.LETTER(): + var = sympy.Symbol(sub.LETTER().getText()) + elif sub.SYMBOL(): + var = sympy.Symbol(sub.SYMBOL().getText()[1:]) + else: + var = sympy.Symbol('x') + if sub.SUB(): + direction = "-" + elif sub.ADD(): + direction = "+" + else: + direction = "+-" + approaching = convert_expr(sub.expr()) + content = convert_mp(func.mp()) + + return sympy.Limit(content, var, approaching, direction) + + +def get_differential_var(d): + text = get_differential_var_str(d.getText()) + return sympy.Symbol(text) + + +def get_differential_var_str(text): + for i in range(1, len(text)): + c = text[i] + if not (c == " " or c == "\r" or c == "\n" or c == "\t"): + idx = i + break + text = text[idx:] + if text[0] == "\\": + text = text[1:] + return text diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/errors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c3ef9f06279df42d4b2054acc4cfe39b6682a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/errors.py @@ -0,0 +1,2 @@ +class LaTeXParsingError(Exception): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92e58d3172e100cc376d0b416b3835d164bd5647 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__init__.py @@ -0,0 +1,2 @@ +from .latex_parser import parse_latex_lark, LarkLaTeXParser # noqa +from .transformer import TransformToSymPyExpr # noqa diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82c8d6119607011734bca61c0505e6df59b34322 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df981b3fc7e01820d1487947743235999b40d2f9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61ab2e90c508168598e26c66affa334da4794cdf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark new file mode 100644 index 0000000000000000000000000000000000000000..7439fab9dcac284dc3c9b5fbfa4fc6db8b29dfd2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark @@ -0,0 +1,28 @@ +// Greek symbols +// TODO: Shouold we include the uppercase variants for the symbols where the uppercase variant doesn't have a separate meaning? +ALPHA: "\\alpha" +BETA: "\\beta" +GAMMA: "\\gamma" +DELTA: "\\delta" // TODO: Should this be included? Delta usually denotes other things. +EPSILON: "\\epsilon" | "\\varepsilon" +ZETA: "\\zeta" +ETA: "\\eta" +THETA: "\\theta" | "\\vartheta" +// TODO: Should I add iota to the list? +KAPPA: "\\kappa" +LAMBDA: "\\lambda" // TODO: What about the uppercase variant? +MU: "\\mu" +NU: "\\nu" +XI: "\\xi" +// TODO: Should there be a separate note for transforming \pi into sympy.pi? +RHO: "\\rho" | "\\varrho" +// TODO: What should we do about sigma? +TAU: "\\tau" +UPSILON: "\\upsilon" +PHI: "\\phi" | "\\varphi" +CHI: "\\chi" +PSI: "\\psi" +OMEGA: "\\omega" + +GREEK_SYMBOL: ALPHA | BETA | GAMMA | DELTA | EPSILON | ZETA | ETA | THETA | KAPPA + | LAMBDA | MU | NU | XI | RHO | TAU | UPSILON | PHI | CHI | PSI | OMEGA diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/latex.lark b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/latex.lark new file mode 100644 index 0000000000000000000000000000000000000000..43e8d0e9105fa4da9bcdd2c0fa6111f6d523c9a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/grammar/latex.lark @@ -0,0 +1,403 @@ +%ignore /[ \t\n\r]+/ + +%ignore "\\," | "\\thinspace" | "\\:" | "\\medspace" | "\\;" | "\\thickspace" +%ignore "\\quad" | "\\qquad" +%ignore "\\!" | "\\negthinspace" | "\\negmedspace" | "\\negthickspace" +%ignore "\\vrule" | "\\vcenter" | "\\vbox" | "\\vskip" | "\\vspace" | "\\hfill" +%ignore "\\*" | "\\-" | "\\." | "\\/" | "\\(" | "\\=" + +%ignore "\\left" | "\\right" +%ignore "\\limits" | "\\nolimits" +%ignore "\\displaystyle" + +///////////////////// tokens /////////////////////// + +// basic binary operators +ADD: "+" +SUB: "-" +MUL: "*" +DIV: "/" + +// tokens with distinct left and right symbols +L_BRACE: "{" +R_BRACE: "}" +L_BRACE_LITERAL: "\\{" +R_BRACE_LITERAL: "\\}" +L_BRACKET: "[" +R_BRACKET: "]" +L_CEIL: "\\lceil" +R_CEIL: "\\rceil" +L_FLOOR: "\\lfloor" +R_FLOOR: "\\rfloor" +L_PAREN: "(" +R_PAREN: ")" + +// limit, integral, sum, and product symbols +FUNC_LIM: "\\lim" +LIM_APPROACH_SYM: "\\to" | "\\rightarrow" | "\\Rightarrow" | "\\longrightarrow" | "\\Longrightarrow" +FUNC_INT: "\\int" | "\\intop" +FUNC_SUM: "\\sum" +FUNC_PROD: "\\prod" + +// common functions +FUNC_EXP: "\\exp" +FUNC_LOG: "\\log" +FUNC_LN: "\\ln" +FUNC_LG: "\\lg" +FUNC_MIN: "\\min" +FUNC_MAX: "\\max" + +// trigonometric functions +FUNC_SIN: "\\sin" +FUNC_COS: "\\cos" +FUNC_TAN: "\\tan" +FUNC_CSC: "\\csc" +FUNC_SEC: "\\sec" +FUNC_COT: "\\cot" + +// inverse trigonometric functions +FUNC_ARCSIN: "\\arcsin" +FUNC_ARCCOS: "\\arccos" +FUNC_ARCTAN: "\\arctan" +FUNC_ARCCSC: "\\arccsc" +FUNC_ARCSEC: "\\arcsec" +FUNC_ARCCOT: "\\arccot" + +// hyperbolic trigonometric functions +FUNC_SINH: "\\sinh" +FUNC_COSH: "\\cosh" +FUNC_TANH: "\\tanh" +FUNC_ARSINH: "\\arsinh" +FUNC_ARCOSH: "\\arcosh" +FUNC_ARTANH: "\\artanh" + +FUNC_SQRT: "\\sqrt" + +// miscellaneous symbols +CMD_TIMES: "\\times" +CMD_CDOT: "\\cdot" +CMD_DIV: "\\div" +CMD_FRAC: "\\frac" | "\\dfrac" | "\\tfrac" | "\\nicefrac" +CMD_BINOM: "\\binom" | "\\dbinom" | "\\tbinom" +CMD_OVERLINE: "\\overline" +CMD_LANGLE: "\\langle" +CMD_RANGLE: "\\rangle" + +CMD_MATHIT: "\\mathit" + +CMD_INFTY: "\\infty" + +BANG: "!" +BAR: "|" +CARET: "^" +COLON: ":" +UNDERSCORE: "_" + +// relational symbols +EQUAL: "=" +NOT_EQUAL: "\\neq" | "\\ne" +LT: "<" +LTE: "\\leq" | "\\le" | "\\leqslant" +GT: ">" +GTE: "\\geq" | "\\ge" | "\\geqslant" + +DIV_SYMBOL: CMD_DIV | DIV +MUL_SYMBOL: MUL | CMD_TIMES | CMD_CDOT + +%import .greek_symbols.GREEK_SYMBOL + +UPRIGHT_DIFFERENTIAL_SYMBOL: "\\text{d}" | "\\mathrm{d}" +DIFFERENTIAL_SYMBOL: "d" | UPRIGHT_DIFFERENTIAL_SYMBOL + +// disallow "d" as a variable name because we want to parse "d" as a differential symbol. +SYMBOL: /[a-zA-Z]'*/ +GREEK_SYMBOL_WITH_PRIMES: GREEK_SYMBOL "'"* +LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT: /([a-zA-Z]'*)_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT: /([a-zA-Z]'*)_/ GREEK_SYMBOL | /([a-zA-Z]'*)_/ L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE +// best to define the variant with braces like that instead of shoving it all into one case like in +// /([a-zA-Z])_/ L_BRACE? GREEK_SYMBOL R_BRACE? because then we can easily error out on input like +// r"h_{\theta" +GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_/ (GREEK_SYMBOL | L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE) +MULTI_LETTER_SYMBOL: /[a-zA-Z]+(\s+[a-zA-Z]+)*'*/ + +%import common.DIGIT -> DIGIT + +CMD_PRIME: "\\prime" +CMD_ASTERISK: "\\ast" + +PRIMES: "'"+ +STARS: "*"+ +PRIMES_VIA_CMD: CMD_PRIME+ +STARS_VIA_CMD: CMD_ASTERISK+ + +CMD_IMAGINARY_UNIT: "\\imaginaryunit" + +CMD_BEGIN: "\\begin" +CMD_END: "\\end" + +// matrices +IGNORE_L: /[ \t\n\r]*/ L_BRACE* /[ \t\n\r]*/ +IGNORE_R: /[ \t\n\r]*/ R_BRACE* /[ \t\n\r]*/ +ARRAY_MATRIX_BEGIN: L_BRACE "array" R_BRACE L_BRACE /[^}]*/ R_BRACE +ARRAY_MATRIX_END: L_BRACE "array" R_BRACE +AMSMATH_MATRIX: L_BRACE "matrix" R_BRACE +AMSMATH_PMATRIX: L_BRACE "pmatrix" R_BRACE +AMSMATH_BMATRIX: L_BRACE "bmatrix" R_BRACE +// Without the (L|R)_PARENs and (L|R)_BRACKETs, a matrix defined using +// \begin{array}...\end{array} or \begin{matrix}...\end{matrix} must +// not qualify as a complete matrix expression; this is done so that +// if we have \begin{array}...\end{array} or \begin{matrix}...\end{matrix} +// between BAR pairs, then they should be interpreted as determinants as +// opposed to sympy.Abs (absolute value) applied to a matrix. +CMD_BEGIN_AMSPMATRIX_AMSBMATRIX: CMD_BEGIN (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_BEGIN_ARRAY_AMSMATRIX: (L_PAREN | L_BRACKET) IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_MATRIX_BEGIN: CMD_BEGIN_AMSPMATRIX_AMSBMATRIX | CMD_BEGIN_ARRAY_AMSMATRIX +CMD_END_AMSPMATRIX_AMSBMATRIX: CMD_END (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_END_ARRAY_AMSMATRIX: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? (R_PAREN | R_BRACKET) +CMD_MATRIX_END: CMD_END_AMSPMATRIX_AMSBMATRIX | CMD_END_ARRAY_AMSMATRIX +MATRIX_COL_DELIM: "&" +MATRIX_ROW_DELIM: "\\\\" +FUNC_MATRIX_TRACE: "\\trace" +FUNC_MATRIX_ADJUGATE: "\\adjugate" + +// determinants +AMSMATH_VMATRIX: L_BRACE "vmatrix" R_BRACE +CMD_DETERMINANT_BEGIN_SIMPLE: CMD_BEGIN AMSMATH_VMATRIX +CMD_DETERMINANT_BEGIN_VARIANT: BAR IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_DETERMINANT_BEGIN: CMD_DETERMINANT_BEGIN_SIMPLE | CMD_DETERMINANT_BEGIN_VARIANT +CMD_DETERMINANT_END_SIMPLE: CMD_END AMSMATH_VMATRIX +CMD_DETERMINANT_END_VARIANT: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? BAR +CMD_DETERMINANT_END: CMD_DETERMINANT_END_SIMPLE | CMD_DETERMINANT_END_VARIANT +FUNC_DETERMINANT: "\\det" + +//////////////////// grammar ////////////////////// + +latex_string: _relation | _expression + +_one_letter_symbol: SYMBOL + | LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_PRIMES +// LuaTeX-generated outputs of \mathit{foo'} and \mathit{foo}' +// seem to be the same on the surface. We allow both styles. +multi_letter_symbol: CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE + | CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE /'+/ +number: /\d+(\.\d*)?/ | CMD_IMAGINARY_UNIT + +_atomic_expr: _one_letter_symbol + | multi_letter_symbol + | number + | CMD_INFTY + +group_round_parentheses: L_PAREN _expression R_PAREN +group_square_brackets: L_BRACKET _expression R_BRACKET +group_curly_parentheses: L_BRACE _expression R_BRACE + +_relation: eq | ne | lt | lte | gt | gte + +eq: _expression EQUAL _expression +ne: _expression NOT_EQUAL _expression +lt: _expression LT _expression +lte: _expression LTE _expression +gt: _expression GT _expression +gte: _expression GTE _expression + +_expression_core: _atomic_expr | group_curly_parentheses + +add: _expression ADD _expression_mul + | ADD _expression_mul +sub: _expression SUB _expression_mul + | SUB _expression_mul +mul: _expression_mul MUL_SYMBOL _expression_power +div: _expression_mul DIV_SYMBOL _expression_power + +adjacent_expressions: (_one_letter_symbol | number) _expression_mul + | group_round_parentheses (group_round_parentheses | _one_letter_symbol) + | _function _function + | fraction _expression_mul + +_expression_func: _expression_core + | group_round_parentheses + | fraction + | binomial + | _function + | _integral// | derivative + | limit + | matrix + +_expression_power: _expression_func | superscript | matrix_prime | symbol_prime + +_expression_mul: _expression_power + | mul | div | adjacent_expressions + | summation | product + +_expression: _expression_mul | add | sub + +_limit_dir: "+" | "-" | L_BRACE ("+" | "-") R_BRACE + +limit_dir_expr: _expression CARET _limit_dir + +group_curly_parentheses_lim: L_BRACE _expression LIM_APPROACH_SYM (limit_dir_expr | _expression) R_BRACE + +limit: FUNC_LIM UNDERSCORE group_curly_parentheses_lim _expression + +differential: DIFFERENTIAL_SYMBOL _one_letter_symbol + +//_derivative_operator: CMD_FRAC L_BRACE DIFFERENTIAL_SYMBOL R_BRACE L_BRACE differential R_BRACE + +//derivative: _derivative_operator _expression + +_integral: normal_integral | integral_with_special_fraction + +normal_integral: FUNC_INT _expression DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + +group_curly_parentheses_int: L_BRACE _expression? differential R_BRACE + +special_fraction: CMD_FRAC group_curly_parentheses_int group_curly_parentheses + +integral_with_special_fraction: FUNC_INT special_fraction + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? special_fraction + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? special_fraction + +group_curly_parentheses_special: UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE CARET _expression_core + | CARET _expression_core UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE + +summation: FUNC_SUM group_curly_parentheses_special _expression + | FUNC_SUM group_curly_parentheses_special _expression + +product: FUNC_PROD group_curly_parentheses_special _expression + | FUNC_PROD group_curly_parentheses_special _expression + +superscript: _expression_func CARET (_expression_power | CMD_PRIME | CMD_ASTERISK) + | _expression_func CARET L_BRACE (PRIMES | STARS | PRIMES_VIA_CMD | STARS_VIA_CMD) R_BRACE + +matrix_prime: (matrix | group_round_parentheses) PRIMES + +symbol_prime: (LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT) PRIMES + +fraction: _basic_fraction + | _simple_fraction + | _general_fraction + +_basic_fraction: CMD_FRAC DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_fraction: CMD_FRAC DIGIT group_curly_parentheses + | CMD_FRAC group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_fraction: CMD_FRAC group_curly_parentheses group_curly_parentheses + +binomial: _basic_binomial + | _simple_binomial + | _general_binomial + +_basic_binomial: CMD_BINOM DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_binomial: CMD_BINOM DIGIT group_curly_parentheses + | CMD_BINOM group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_binomial: CMD_BINOM group_curly_parentheses group_curly_parentheses + +list_of_expressions: _expression ("," _expression)* + +function_applied: _one_letter_symbol L_PAREN list_of_expressions R_PAREN + +min: FUNC_MIN L_PAREN list_of_expressions R_PAREN + +max: FUNC_MAX L_PAREN list_of_expressions R_PAREN + +bra: CMD_LANGLE _expression BAR + +ket: BAR _expression CMD_RANGLE + +inner_product: CMD_LANGLE _expression BAR _expression CMD_RANGLE + +_function: function_applied + | abs | floor | ceil + | _trigonometric_function | _inverse_trigonometric_function + | _trigonometric_function_power + | _hyperbolic_trigonometric_function | _inverse_hyperbolic_trigonometric_function + | exponential + | log + | square_root + | factorial + | conjugate + | max | min + | bra | ket | inner_product + | determinant + | trace + | adjugate + +exponential: FUNC_EXP _expression + +log: FUNC_LOG _expression + | FUNC_LN _expression + | FUNC_LG _expression + | FUNC_LOG UNDERSCORE (DIGIT | _one_letter_symbol) _expression + | FUNC_LOG UNDERSCORE group_curly_parentheses _expression + +square_root: FUNC_SQRT group_curly_parentheses + | FUNC_SQRT group_square_brackets group_curly_parentheses + +factorial: _expression_func BANG + +conjugate: CMD_OVERLINE group_curly_parentheses + | CMD_OVERLINE DIGIT + +_trigonometric_function: sin | cos | tan | csc | sec | cot + +sin: FUNC_SIN _expression +cos: FUNC_COS _expression +tan: FUNC_TAN _expression +csc: FUNC_CSC _expression +sec: FUNC_SEC _expression +cot: FUNC_COT _expression + +_trigonometric_function_power: sin_power | cos_power | tan_power | csc_power | sec_power | cot_power + +sin_power: FUNC_SIN CARET _expression_core _expression +cos_power: FUNC_COS CARET _expression_core _expression +tan_power: FUNC_TAN CARET _expression_core _expression +csc_power: FUNC_CSC CARET _expression_core _expression +sec_power: FUNC_SEC CARET _expression_core _expression +cot_power: FUNC_COT CARET _expression_core _expression + +_hyperbolic_trigonometric_function: sinh | cosh | tanh + +sinh: FUNC_SINH _expression +cosh: FUNC_COSH _expression +tanh: FUNC_TANH _expression + +_inverse_trigonometric_function: arcsin | arccos | arctan | arccsc | arcsec | arccot + +arcsin: FUNC_ARCSIN _expression +arccos: FUNC_ARCCOS _expression +arctan: FUNC_ARCTAN _expression +arccsc: FUNC_ARCCSC _expression +arcsec: FUNC_ARCSEC _expression +arccot: FUNC_ARCCOT _expression + +_inverse_hyperbolic_trigonometric_function: asinh | acosh | atanh + +asinh: FUNC_ARSINH _expression +acosh: FUNC_ARCOSH _expression +atanh: FUNC_ARTANH _expression + +abs: BAR _expression BAR +floor: L_FLOOR _expression R_FLOOR +ceil: L_CEIL _expression R_CEIL + +matrix: CMD_MATRIX_BEGIN matrix_body CMD_MATRIX_END +matrix_body: matrix_row (MATRIX_ROW_DELIM matrix_row)* (MATRIX_ROW_DELIM)? +matrix_row: _expression (MATRIX_COL_DELIM _expression)* +determinant: (CMD_DETERMINANT_BEGIN matrix_body CMD_DETERMINANT_END) + | FUNC_DETERMINANT _expression +trace: FUNC_MATRIX_TRACE _expression +adjugate: FUNC_MATRIX_ADJUGATE _expression diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/latex_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/latex_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..29f594b0de4bfd4648df1554d5863a37afff035f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/latex_parser.py @@ -0,0 +1,145 @@ +import os +import logging +import re +from pathlib import Path + +from sympy.external import import_module +from sympy.parsing.latex.lark.transformer import TransformToSymPyExpr + +_lark = import_module("lark") + + +class LarkLaTeXParser: + r"""Class for converting input `\mathrm{\LaTeX}` strings into SymPy Expressions. + It holds all the necessary internal data for doing so, and exposes hooks for + customizing its behavior. + + Parameters + ========== + + print_debug_output : bool, optional + + If set to ``True``, prints debug output to the logger. Defaults to ``False``. + + transform : bool, optional + + If set to ``True``, the class runs the Transformer class on the parse tree + generated by running ``Lark.parse`` on the input string. Defaults to ``True``. + + Setting it to ``False`` can help with debugging the `\mathrm{\LaTeX}` grammar. + + grammar_file : str, optional + + The path to the grammar file that the parser should use. If set to ``None``, + it uses the default grammar, which is in ``grammar/latex.lark``, relative to + the ``sympy/parsing/latex/lark/`` directory. + + transformer : str, optional + + The name of the Transformer class to use. If set to ``None``, it uses the + default transformer class, which is :py:func:`TransformToSymPyExpr`. + + """ + def __init__(self, print_debug_output=False, transform=True, grammar_file=None, transformer=None): + grammar_dir_path = os.path.join(os.path.dirname(__file__), "grammar/") + + if grammar_file is None: + latex_grammar = Path(os.path.join(grammar_dir_path, "latex.lark")).read_text(encoding="utf-8") + else: + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + + self.parser = _lark.Lark( + latex_grammar, + source_path=grammar_dir_path, + parser="earley", + start="latex_string", + lexer="auto", + ambiguity="explicit", + propagate_positions=False, + maybe_placeholders=False, + keep_all_tokens=True) + + self.print_debug_output = print_debug_output + self.transform_expr = transform + + if transformer is None: + self.transformer = TransformToSymPyExpr() + else: + self.transformer = transformer() + + def doparse(self, s: str): + if self.print_debug_output: + _lark.logger.setLevel(logging.DEBUG) + + parse_tree = self.parser.parse(s) + + if not self.transform_expr: + # exit early and return the parse tree + _lark.logger.debug("expression = %s", s) + _lark.logger.debug(parse_tree) + _lark.logger.debug(parse_tree.pretty()) + return parse_tree + + if self.print_debug_output: + # print this stuff before attempting to run the transformer + _lark.logger.debug("expression = %s", s) + # print the `parse_tree` variable + _lark.logger.debug(parse_tree.pretty()) + + sympy_expression = self.transformer.transform(parse_tree) + + if self.print_debug_output: + _lark.logger.debug("SymPy expression = %s", sympy_expression) + + return sympy_expression + + +if _lark is not None: + _lark_latex_parser = LarkLaTeXParser() + + +def parse_latex_lark(s: str): + """ + Experimental LaTeX parser using Lark. + + This function is still under development and its API may change with the + next releases of SymPy. + """ + if _lark is None: + raise ImportError("Lark is probably not installed") + return _lark_latex_parser.doparse(s) + + +def _pretty_print_lark_trees(tree, indent=0, show_expr=True): + if isinstance(tree, _lark.Token): + return tree.value + + data = str(tree.data) + + is_expr = data.startswith("expression") + + if is_expr: + data = re.sub(r"^expression", "E", data) + + is_ambig = (data == "_ambig") + + if is_ambig: + new_indent = indent + 2 + else: + new_indent = indent + + output = "" + show_node = not is_expr or show_expr + + if show_node: + output += str(data) + "(" + + if is_ambig: + output += "\n" + "\n".join([" " * new_indent + _pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + else: + output += ",".join([_pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + + if show_node: + output += ")" + + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/transformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd514b6517336207a57de6d28bcce25858071dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/latex/lark/transformer.py @@ -0,0 +1,730 @@ +import re + +import sympy +from sympy.external import import_module +from sympy.parsing.latex.errors import LaTeXParsingError + +lark = import_module("lark") + +if lark: + from lark import Transformer, Token, Tree # type: ignore +else: + class Transformer: # type: ignore + def transform(self, *args): + pass + + + class Token: # type: ignore + pass + + + class Tree: # type: ignore + pass + + +# noinspection PyPep8Naming,PyMethodMayBeStatic +class TransformToSymPyExpr(Transformer): + """Returns a SymPy expression that is generated by traversing the ``lark.Tree`` + passed to the ``.transform()`` function. + + Notes + ===== + + **This class is never supposed to be used directly.** + + In order to tweak the behavior of this class, it has to be subclassed and then after + the required modifications are made, the name of the new class should be passed to + the :py:class:`LarkLaTeXParser` class by using the ``transformer`` argument in the + constructor. + + Parameters + ========== + + visit_tokens : bool, optional + For information about what this option does, see `here + `_. + + Note that the option must be set to ``True`` for the default parser to work. + """ + + SYMBOL = sympy.Symbol + DIGIT = sympy.core.numbers.Integer + + def CMD_INFTY(self, tokens): + return sympy.oo + + def GREEK_SYMBOL_WITH_PRIMES(self, tokens): + # we omit the first character because it is a backslash. Also, if the variable name has "var" in it, + # like "varphi" or "varepsilon", we remove that too + variable_name = re.sub("var", "", tokens[1:]) + + return sympy.Symbol(variable_name) + + def LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (base, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (base, sub)) + + def GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_letter = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (greek_letter, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (greek_letter, sub)) + + def LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + greek_letter = sub[2:-1] + else: + greek_letter = sub[1:] + + greek_letter = re.sub("var", "", greek_letter) + return sympy.Symbol("%s_{%s}" % (base, greek_letter)) + + + def GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_base = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + greek_sub = sub[2:-1] + else: + greek_sub = sub[1:] + + greek_sub = re.sub("var", "", greek_sub) + return sympy.Symbol("%s_{%s}" % (greek_base, greek_sub)) + + def multi_letter_symbol(self, tokens): + if len(tokens) == 4: # no primes (single quotes) on symbol + return sympy.Symbol(tokens[2]) + if len(tokens) == 5: # there are primes on the symbol + return sympy.Symbol(tokens[2] + tokens[4]) + + def number(self, tokens): + if tokens[0].type == "CMD_IMAGINARY_UNIT": + return sympy.I + + if "." in tokens[0]: + return sympy.core.numbers.Float(tokens[0]) + else: + return sympy.core.numbers.Integer(tokens[0]) + + def latex_string(self, tokens): + return tokens[0] + + def group_round_parentheses(self, tokens): + return tokens[1] + + def group_square_brackets(self, tokens): + return tokens[1] + + def group_curly_parentheses(self, tokens): + return tokens[1] + + def eq(self, tokens): + return sympy.Eq(tokens[0], tokens[2]) + + def ne(self, tokens): + return sympy.Ne(tokens[0], tokens[2]) + + def lt(self, tokens): + return sympy.Lt(tokens[0], tokens[2]) + + def lte(self, tokens): + return sympy.Le(tokens[0], tokens[2]) + + def gt(self, tokens): + return sympy.Gt(tokens[0], tokens[2]) + + def gte(self, tokens): + return sympy.Ge(tokens[0], tokens[2]) + + def add(self, tokens): + if len(tokens) == 2: # +a + return tokens[1] + if len(tokens) == 3: # a + b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, rh) + + return sympy.Add(lh, rh) + + def sub(self, tokens): + if len(tokens) == 2: # -a + x = tokens[1] + + if self._obj_is_sympy_Matrix(x): + return sympy.MatMul(-1, x) + + return -x + if len(tokens) == 3: # a - b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, sympy.MatMul(-1, rh)) + + return sympy.Add(lh, -rh) + + def mul(self, tokens): + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatMul(lh, rh) + + return sympy.Mul(lh, rh) + + def div(self, tokens): + return self._handle_division(tokens[0], tokens[2]) + + def adjacent_expressions(self, tokens): + # Most of the time, if two expressions are next to each other, it means implicit multiplication, + # but not always + from sympy.physics.quantum import Bra, Ket + if isinstance(tokens[0], Ket) and isinstance(tokens[1], Bra): + from sympy.physics.quantum import OuterProduct + return OuterProduct(tokens[0], tokens[1]) + elif tokens[0] == sympy.Symbol("d"): + # If the leftmost token is a "d", then it is highly likely that this is a differential + return tokens[0], tokens[1] + elif isinstance(tokens[0], tuple): + # then we have a derivative + return sympy.Derivative(tokens[1], tokens[0][1]) + else: + return sympy.Mul(tokens[0], tokens[1]) + + def superscript(self, tokens): + def isprime(x): + return isinstance(x, Token) and x.type == "PRIMES" + + def iscmdprime(x): + return isinstance(x, Token) and (x.type == "PRIMES_VIA_CMD" + or x.type == "CMD_PRIME") + + def isstar(x): + return isinstance(x, Token) and x.type == "STARS" + + def iscmdstar(x): + return isinstance(x, Token) and (x.type == "STARS_VIA_CMD" + or x.type == "CMD_ASTERISK") + + base = tokens[0] + if len(tokens) == 3: # a^b OR a^\prime OR a^\ast + sup = tokens[2] + if len(tokens) == 5: + # a^{'}, a^{''}, ... OR + # a^{*}, a^{**}, ... OR + # a^{\prime}, a^{\prime\prime}, ... OR + # a^{\ast}, a^{\ast\ast}, ... + sup = tokens[3] + + if self._obj_is_sympy_Matrix(base): + if sup == sympy.Symbol("T"): + return sympy.Transpose(base) + if sup == sympy.Symbol("H"): + return sympy.adjoint(base) + if isprime(sup): + sup = sup.value + if len(sup) % 2 == 0: + return base + return sympy.Transpose(base) + if iscmdprime(sup): + sup = sup.value + if (len(sup)/len(r"\prime")) % 2 == 0: + return base + return sympy.Transpose(base) + if isstar(sup): + sup = sup.value + # need .doit() in order to be consistent with + # sympy.adjoint() which returns the evaluated adjoint + # of a matrix + if len(sup) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + if iscmdstar(sup): + sup = sup.value + # need .doit() for same reason as above + if (len(sup)/len(r"\ast")) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + + if isprime(sup) or iscmdprime(sup) or isstar(sup) or iscmdstar(sup): + raise LaTeXParsingError(f"{base} with superscript {sup} is not understood.") + + return sympy.Pow(base, sup) + + def matrix_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + if not self._obj_is_sympy_Matrix(base): + raise LaTeXParsingError(f"({base}){primes} is not understood.") + + if len(primes) % 2 == 0: + return base + + return sympy.Transpose(base) + + def symbol_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + return sympy.Symbol(f"{base.name}{primes}") + + def fraction(self, tokens): + numerator = tokens[1] + if isinstance(tokens[2], tuple): + # we only need the variable w.r.t. which we are differentiating + _, variable = tokens[2] + + # we will pass this information upwards + return "derivative", variable + else: + denominator = tokens[2] + return self._handle_division(numerator, denominator) + + def binomial(self, tokens): + return sympy.binomial(tokens[1], tokens[2]) + + def normal_integral(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + differential_symbol = self._extract_differential_symbol(tokens) + + if differential_symbol is None: + raise LaTeXParsingError("Differential symbol was not found in the expression." + "Valid differential symbols are \"d\", \"\\text{d}, and \"\\mathrm{d}\".") + + # else we can assume that a differential symbol was found + differential_variable_index = tokens.index(differential_symbol) + 1 + differential_variable = tokens[differential_variable_index] + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + # check if any expression was given or not. If it wasn't, then set the integrand to 1. + if underscore_index is not None and underscore_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the underscore, then that means that there _was_ no integrand. + # Example: \int^7_0 dx + integrand = 1 + elif caret_index is not None and caret_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the caret, then that means that there _was_ no integrand. + # Example: \int_0^7 dx + integrand = 1 + elif differential_variable_index == 2: + # this means we have something like "\int dx", because the "\int" symbol will always be + # at index 0 in `tokens` + integrand = 1 + else: + # The Token at differential_variable_index - 1 is the differential symbol itself, so we need to go one + # more step before that. + integrand = tokens[differential_variable_index - 2] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_int(self, tokens): + # return signature is a tuple consisting of the expression in the numerator, along with the variable of + # integration + if len(tokens) == 3: + return 1, tokens[1] + elif len(tokens) == 4: + return tokens[1], tokens[2] + # there are no other possibilities + + def special_fraction(self, tokens): + numerator, variable = tokens[1] + denominator = tokens[2] + + # We pass the integrand, along with information about the variable of integration, upw + return sympy.Mul(numerator, sympy.Pow(denominator, -1)), variable + + def integral_with_special_fraction(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + integrand, differential_variable = tokens[-1] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_special(self, tokens): + underscore_index = tokens.index("_") + caret_index = tokens.index("^") + + # given the type of expressions we are parsing, we can assume that the lower limit + # will always use braces around its arguments. This is because we don't support + # converting unconstrained sums into SymPy expressions. + + # first we isolate the bottom limit + left_brace_index = tokens.index("{", underscore_index) + right_brace_index = tokens.index("}", underscore_index) + + bottom_limit = tokens[left_brace_index + 1: right_brace_index] + + # next, we isolate the upper limit + top_limit = tokens[caret_index + 1:] + + # the code below will be useful for supporting things like `\sum_{n = 0}^{n = 5} n^2` + # if "{" in top_limit: + # left_brace_index = tokens.index("{", caret_index) + # if left_brace_index != -1: + # # then there's a left brace in the string, and we need to find the closing right brace + # right_brace_index = tokens.index("}", caret_index) + # top_limit = tokens[left_brace_index + 1: right_brace_index] + + # print(f"top limit = {top_limit}") + + index_variable = bottom_limit[0] + lower_limit = bottom_limit[-1] + upper_limit = top_limit[0] # for now, the index will always be 0 + + # print(f"return value = ({index_variable}, {lower_limit}, {upper_limit})") + + return index_variable, lower_limit, upper_limit + + def summation(self, tokens): + return sympy.Sum(tokens[2], tokens[1]) + + def product(self, tokens): + return sympy.Product(tokens[2], tokens[1]) + + def limit_dir_expr(self, tokens): + caret_index = tokens.index("^") + + if "{" in tokens: + left_curly_brace_index = tokens.index("{", caret_index) + direction = tokens[left_curly_brace_index + 1] + else: + direction = tokens[caret_index + 1] + + if direction == "+": + return tokens[0], "+" + elif direction == "-": + return tokens[0], "-" + else: + return tokens[0], "+-" + + def group_curly_parentheses_lim(self, tokens): + limit_variable = tokens[1] + if isinstance(tokens[3], tuple): + destination, direction = tokens[3] + else: + destination = tokens[3] + direction = "+-" + + return limit_variable, destination, direction + + def limit(self, tokens): + limit_variable, destination, direction = tokens[2] + + return sympy.Limit(tokens[-1], limit_variable, destination, direction) + + def differential(self, tokens): + return tokens[1] + + def derivative(self, tokens): + return sympy.Derivative(tokens[-1], tokens[5]) + + def list_of_expressions(self, tokens): + if len(tokens) == 1: + # we return it verbatim because the function_applied node expects + # a list + return tokens + else: + def remove_tokens(args): + if isinstance(args, Token): + if args.type != "COMMA": + # An unexpected token was encountered + raise LaTeXParsingError("A comma token was expected, but some other token was encountered.") + return False + return True + + return filter(remove_tokens, tokens) + + def function_applied(self, tokens): + return sympy.Function(tokens[0])(*tokens[2]) + + def min(self, tokens): + return sympy.Min(*tokens[2]) + + def max(self, tokens): + return sympy.Max(*tokens[2]) + + def bra(self, tokens): + from sympy.physics.quantum import Bra + return Bra(tokens[1]) + + def ket(self, tokens): + from sympy.physics.quantum import Ket + return Ket(tokens[1]) + + def inner_product(self, tokens): + from sympy.physics.quantum import Bra, Ket, InnerProduct + return InnerProduct(Bra(tokens[1]), Ket(tokens[3])) + + def sin(self, tokens): + return sympy.sin(tokens[1]) + + def cos(self, tokens): + return sympy.cos(tokens[1]) + + def tan(self, tokens): + return sympy.tan(tokens[1]) + + def csc(self, tokens): + return sympy.csc(tokens[1]) + + def sec(self, tokens): + return sympy.sec(tokens[1]) + + def cot(self, tokens): + return sympy.cot(tokens[1]) + + def sin_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asin(tokens[-1]) + else: + return sympy.Pow(sympy.sin(tokens[-1]), exponent) + + def cos_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acos(tokens[-1]) + else: + return sympy.Pow(sympy.cos(tokens[-1]), exponent) + + def tan_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.atan(tokens[-1]) + else: + return sympy.Pow(sympy.tan(tokens[-1]), exponent) + + def csc_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acsc(tokens[-1]) + else: + return sympy.Pow(sympy.csc(tokens[-1]), exponent) + + def sec_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asec(tokens[-1]) + else: + return sympy.Pow(sympy.sec(tokens[-1]), exponent) + + def cot_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acot(tokens[-1]) + else: + return sympy.Pow(sympy.cot(tokens[-1]), exponent) + + def arcsin(self, tokens): + return sympy.asin(tokens[1]) + + def arccos(self, tokens): + return sympy.acos(tokens[1]) + + def arctan(self, tokens): + return sympy.atan(tokens[1]) + + def arccsc(self, tokens): + return sympy.acsc(tokens[1]) + + def arcsec(self, tokens): + return sympy.asec(tokens[1]) + + def arccot(self, tokens): + return sympy.acot(tokens[1]) + + def sinh(self, tokens): + return sympy.sinh(tokens[1]) + + def cosh(self, tokens): + return sympy.cosh(tokens[1]) + + def tanh(self, tokens): + return sympy.tanh(tokens[1]) + + def asinh(self, tokens): + return sympy.asinh(tokens[1]) + + def acosh(self, tokens): + return sympy.acosh(tokens[1]) + + def atanh(self, tokens): + return sympy.atanh(tokens[1]) + + def abs(self, tokens): + return sympy.Abs(tokens[1]) + + def floor(self, tokens): + return sympy.floor(tokens[1]) + + def ceil(self, tokens): + return sympy.ceiling(tokens[1]) + + def factorial(self, tokens): + return sympy.factorial(tokens[0]) + + def conjugate(self, tokens): + return sympy.conjugate(tokens[1]) + + def square_root(self, tokens): + if len(tokens) == 2: + # then there was no square bracket argument + return sympy.sqrt(tokens[1]) + elif len(tokens) == 3: + # then there _was_ a square bracket argument + return sympy.root(tokens[2], tokens[1]) + + def exponential(self, tokens): + return sympy.exp(tokens[1]) + + def log(self, tokens): + if tokens[0].type == "FUNC_LG": + # we don't need to check if there's an underscore or not because having one + # in this case would be meaningless + # TODO: ANTLR refers to ISO 80000-2:2019. should we keep base 10 or base 2? + return sympy.log(tokens[1], 10) + elif tokens[0].type == "FUNC_LN": + return sympy.log(tokens[1]) + elif tokens[0].type == "FUNC_LOG": + # we check if a base was specified or not + if "_" in tokens: + # then a base was specified + return sympy.log(tokens[3], tokens[2]) + else: + # a base was not specified + return sympy.log(tokens[1]) + + def _extract_differential_symbol(self, s: str): + differential_symbols = {"d", r"\text{d}", r"\mathrm{d}"} + + differential_symbol = next((symbol for symbol in differential_symbols if symbol in s), None) + + return differential_symbol + + def matrix(self, tokens): + def is_matrix_row(x): + return (isinstance(x, Tree) and x.data == "matrix_row") + + def is_not_col_delim(y): + return (not isinstance(y, Token) or y.type != "MATRIX_COL_DELIM") + + matrix_body = tokens[1].children + return sympy.Matrix([[y for y in x.children if is_not_col_delim(y)] + for x in matrix_body if is_matrix_row(x)]) + + def determinant(self, tokens): + if len(tokens) == 2: # \det A + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take determinant of non-matrix.") + + return tokens[1].det() + + if len(tokens) == 3: # | A | + return self.matrix(tokens).det() + + def trace(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take trace of non-matrix.") + + return sympy.Trace(tokens[1]) + + def adjugate(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take adjugate of non-matrix.") + + # need .doit() since MatAdd does not support .adjugate() method + return tokens[1].doit().adjugate() + + def _obj_is_sympy_Matrix(self, obj): + if hasattr(obj, "is_Matrix"): + return obj.is_Matrix + + return isinstance(obj, sympy.Matrix) + + def _handle_division(self, numerator, denominator): + if self._obj_is_sympy_Matrix(denominator): + raise LaTeXParsingError("Cannot divide by matrices like this since " + "it is not clear if left or right multiplication " + "by the inverse is intended. Try explicitly " + "multiplying by the inverse instead.") + + if self._obj_is_sympy_Matrix(numerator): + return sympy.MatMul(numerator, sympy.Pow(denominator, -1)) + + return sympy.Mul(numerator, sympy.Pow(denominator, -1)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/mathematica.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..b5824a8c33ee402d03e6c5617eeeea21d4a457d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/mathematica.py @@ -0,0 +1,1085 @@ +from __future__ import annotations +import re +import typing +from itertools import product +from typing import Any, Callable + +import sympy +from sympy import Mul, Add, Pow, Rational, log, exp, sqrt, cos, sin, tan, asin, acos, acot, asec, acsc, sinh, cosh, tanh, asinh, \ + acosh, atanh, acoth, asech, acsch, expand, im, flatten, polylog, cancel, expand_trig, sign, simplify, \ + UnevaluatedExpr, S, atan, atan2, Mod, Max, Min, rf, Ei, Si, Ci, airyai, airyaiprime, airybi, primepi, prime, \ + isprime, cot, sec, csc, csch, sech, coth, Function, I, pi, Tuple, GreaterThan, StrictGreaterThan, StrictLessThan, \ + LessThan, Equality, Or, And, Lambda, Integer, Dummy, symbols +from sympy.core.sympify import sympify, _sympify +from sympy.functions.special.bessel import airybiprime +from sympy.functions.special.error_functions import li +from sympy.utilities.exceptions import sympy_deprecation_warning + + +def mathematica(s, additional_translations=None): + sympy_deprecation_warning( + """The ``mathematica`` function for the Mathematica parser is now +deprecated. Use ``parse_mathematica`` instead. +The parameter ``additional_translation`` can be replaced by SymPy's +.replace( ) or .subs( ) methods on the output expression instead.""", + deprecated_since_version="1.11", + active_deprecations_target="mathematica-parser-new", + ) + parser = MathematicaParser(additional_translations) + return sympify(parser._parse_old(s)) + + +def parse_mathematica(s): + """ + Translate a string containing a Wolfram Mathematica expression to a SymPy + expression. + + If the translator is unable to find a suitable SymPy expression, the + ``FullForm`` of the Mathematica expression will be output, using SymPy + ``Function`` objects as nodes of the syntax tree. + + Examples + ======== + + >>> from sympy.parsing.mathematica import parse_mathematica + >>> parse_mathematica("Sin[x]^2 Tan[y]") + sin(x)**2*tan(y) + >>> e = parse_mathematica("F[7,5,3]") + >>> e + F(7, 5, 3) + >>> from sympy import Function, Max, Min + >>> e.replace(Function("F"), lambda *x: Max(*x)*Min(*x)) + 21 + + Both standard input form and Mathematica full form are supported: + + >>> parse_mathematica("x*(a + b)") + x*(a + b) + >>> parse_mathematica("Times[x, Plus[a, b]]") + x*(a + b) + + To get a matrix from Wolfram's code: + + >>> m = parse_mathematica("{{a, b}, {c, d}}") + >>> m + ((a, b), (c, d)) + >>> from sympy import Matrix + >>> Matrix(m) + Matrix([ + [a, b], + [c, d]]) + + If the translation into equivalent SymPy expressions fails, an SymPy + expression equivalent to Wolfram Mathematica's "FullForm" will be created: + + >>> parse_mathematica("x_.") + Optional(Pattern(x, Blank())) + >>> parse_mathematica("Plus @@ {x, y, z}") + Apply(Plus, (x, y, z)) + >>> parse_mathematica("f[x_, 3] := x^3 /; x > 0") + SetDelayed(f(Pattern(x, Blank()), 3), Condition(x**3, x > 0)) + """ + parser = MathematicaParser() + return parser.parse(s) + + +def _parse_Function(*args): + if len(args) == 1: + arg = args[0] + Slot = Function("Slot") + slots = arg.atoms(Slot) + numbers = [a.args[0] for a in slots] + number_of_arguments = max(numbers) + if isinstance(number_of_arguments, Integer): + variables = symbols(f"dummy0:{number_of_arguments}", cls=Dummy) + return Lambda(variables, arg.xreplace({Slot(i+1): v for i, v in enumerate(variables)})) + return Lambda((), arg) + elif len(args) == 2: + variables = args[0] + body = args[1] + return Lambda(variables, body) + else: + raise SyntaxError("Function node expects 1 or 2 arguments") + + +def _deco(cls): + cls._initialize_class() + return cls + + +@_deco +class MathematicaParser: + """ + An instance of this class converts a string of a Wolfram Mathematica + expression to a SymPy expression. + + The main parser acts internally in three stages: + + 1. tokenizer: tokenizes the Mathematica expression and adds the missing * + operators. Handled by ``_from_mathematica_to_tokens(...)`` + 2. full form list: sort the list of strings output by the tokenizer into a + syntax tree of nested lists and strings, equivalent to Mathematica's + ``FullForm`` expression output. This is handled by the function + ``_from_tokens_to_fullformlist(...)``. + 3. SymPy expression: the syntax tree expressed as full form list is visited + and the nodes with equivalent classes in SymPy are replaced. Unknown + syntax tree nodes are cast to SymPy ``Function`` objects. This is + handled by ``_from_fullformlist_to_sympy(...)``. + + """ + + # left: Mathematica, right: SymPy + CORRESPONDENCES = { + 'Sqrt[x]': 'sqrt(x)', + 'Rational[x,y]': 'Rational(x,y)', + 'Exp[x]': 'exp(x)', + 'Log[x]': 'log(x)', + 'Log[x,y]': 'log(y,x)', + 'Log2[x]': 'log(x,2)', + 'Log10[x]': 'log(x,10)', + 'Mod[x,y]': 'Mod(x,y)', + 'Max[*x]': 'Max(*x)', + 'Min[*x]': 'Min(*x)', + 'Pochhammer[x,y]':'rf(x,y)', + 'ArcTan[x,y]':'atan2(y,x)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[x]': 'airyaiprime(x)', + 'AiryBi[x]' :'airybi(x)', + 'AiryBiPrime[x]' :'airybiprime(x)', + 'LogIntegral[x]':' li(x)', + 'PrimePi[x]': 'primepi(x)', + 'Prime[x]': 'prime(x)', + 'PrimeQ[x]': 'isprime(x)' + } + + # trigonometric, e.t.c. + for arc, tri, h in product(('', 'Arc'), ( + 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')): + fm = arc + tri + h + '[x]' + if arc: # arc func + fs = 'a' + tri.lower() + h + '(x)' + else: # non-arc func + fs = tri.lower() + h + '(x)' + CORRESPONDENCES.update({fm: fs}) + + REPLACEMENTS = { + ' ': '', + '^': '**', + '{': '[', + '}': ']', + } + + RULES = { + # a single whitespace to '*' + 'whitespace': ( + re.compile(r''' + (?:(?<=[a-zA-Z\d])|(?<=\d\.)) # a letter or a number + \s+ # any number of whitespaces + (?:(?=[a-zA-Z\d])|(?=\.\d)) # a letter or a number + ''', re.VERBOSE), + '*'), + + # add omitted '*' character + 'add*_1': ( + re.compile(r''' + (?:(?<=[])\d])|(?<=\d\.)) # ], ) or a number + # '' + (?=[(a-zA-Z]) # ( or a single letter + ''', re.VERBOSE), + '*'), + + # add omitted '*' character (variable letter preceding) + 'add*_2': ( + re.compile(r''' + (?<=[a-zA-Z]) # a letter + \( # ( as a character + (?=.) # any characters + ''', re.VERBOSE), + '*('), + + # convert 'Pi' to 'pi' + 'Pi': ( + re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + Pi # 'Pi' is 3.14159... in Mathematica + (?=[^a-zA-Z]) + ''', re.VERBOSE), + 'pi'), + } + + # Mathematica function name pattern + FM_PATTERN = re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) # at the top or a non-letter + ) + [A-Z][a-zA-Z\d]* # Function + (?=\[) # [ as a character + ''', re.VERBOSE) + + # list or matrix pattern (for future usage) + ARG_MTRX_PATTERN = re.compile(r''' + \{.*\} + ''', re.VERBOSE) + + # regex string for function argument pattern + ARGS_PATTERN_TEMPLATE = r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + {arguments} # model argument like x, y,... + (?=[^a-zA-Z]) + ''' + + # will contain transformed CORRESPONDENCES dictionary + TRANSLATIONS: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a raw users' translation dictionary + cache_original: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a compiled users' translation dictionary + cache_compiled: dict[tuple[str, int], dict[str, Any]] = {} + + @classmethod + def _initialize_class(cls): + # get a transformed CORRESPONDENCES dictionary + d = cls._compile_dictionary(cls.CORRESPONDENCES) + cls.TRANSLATIONS.update(d) + + def __init__(self, additional_translations=None): + self.translations = {} + + # update with TRANSLATIONS (class constant) + self.translations.update(self.TRANSLATIONS) + + if additional_translations is None: + additional_translations = {} + + # check the latest added translations + if self.__class__.cache_original != additional_translations: + if not isinstance(additional_translations, dict): + raise ValueError('The argument must be dict type') + + # get a transformed additional_translations dictionary + d = self._compile_dictionary(additional_translations) + + # update cache + self.__class__.cache_original = additional_translations + self.__class__.cache_compiled = d + + # merge user's own translations + self.translations.update(self.__class__.cache_compiled) + + @classmethod + def _compile_dictionary(cls, dic): + # for return + d = {} + + for fm, fs in dic.items(): + # check function form + cls._check_input(fm) + cls._check_input(fs) + + # uncover '*' hiding behind a whitespace + fm = cls._apply_rules(fm, 'whitespace') + fs = cls._apply_rules(fs, 'whitespace') + + # remove whitespace(s) + fm = cls._replace(fm, ' ') + fs = cls._replace(fs, ' ') + + # search Mathematica function name + m = cls.FM_PATTERN.search(fm) + + # if no-hit + if m is None: + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # get Mathematica function name like 'Log' + fm_name = m.group() + + # get arguments of Mathematica function + args, end = cls._get_args(m) + + # function side check. (e.g.) '2*Func[x]' is invalid. + if m.start() != 0 or end != len(fm): + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # check the last argument's 1st character + if args[-1][0] == '*': + key_arg = '*' + else: + key_arg = len(args) + + key = (fm_name, key_arg) + + # convert '*x' to '\\*x' for regex + re_args = [x if x[0] != '*' else '\\' + x for x in args] + + # for regex. Example: (?:(x|y|z)) + xyz = '(?:(' + '|'.join(re_args) + '))' + + # string for regex compile + patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz) + + pat = re.compile(patStr, re.VERBOSE) + + # update dictionary + d[key] = {} + d[key]['fs'] = fs # SymPy function template + d[key]['args'] = args # args are ['x', 'y'] for example + d[key]['pat'] = pat + + return d + + def _convert_function(self, s): + '''Parse Mathematica function to SymPy one''' + + # compiled regex object + pat = self.FM_PATTERN + + scanned = '' # converted string + cur = 0 # position cursor + while True: + m = pat.search(s) + + if m is None: + # append the rest of string + scanned += s + break + + # get Mathematica function name + fm = m.group() + + # get arguments, and the end position of fm function + args, end = self._get_args(m) + + # the start position of fm function + bgn = m.start() + + # convert Mathematica function to SymPy one + s = self._convert_one_function(s, fm, args, bgn, end) + + # update cursor + cur = bgn + + # append converted part + scanned += s[:cur] + + # shrink s + s = s[cur:] + + return scanned + + def _convert_one_function(self, s, fm, args, bgn, end): + # no variable-length argument + if (fm, len(args)) in self.translations: + key = (fm, len(args)) + + # x, y,... model arguments + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = dict(zip(x_args, args)) + + # with variable-length argument + elif (fm, '*') in self.translations: + key = (fm, '*') + + # x, y,..*args (model arguments) + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = {} + for i, x in enumerate(x_args): + if x[0] == '*': + d[x] = ','.join(args[i:]) + break + d[x] = args[i] + + # out of self.translations + else: + err = "'{f}' is out of the whitelist.".format(f=fm) + raise ValueError(err) + + # template string of converted function + template = self.translations[key]['fs'] + + # regex pattern for x_args + pat = self.translations[key]['pat'] + + scanned = '' + cur = 0 + while True: + m = pat.search(template) + + if m is None: + scanned += template + break + + # get model argument + x = m.group() + + # get a start position of the model argument + xbgn = m.start() + + # add the corresponding actual argument + scanned += template[:xbgn] + d[x] + + # update cursor to the end of the model argument + cur = m.end() + + # shrink template + template = template[cur:] + + # update to swapped string + s = s[:bgn] + scanned + s[end:] + + return s + + @classmethod + def _get_args(cls, m): + '''Get arguments of a Mathematica function''' + + s = m.string # whole string + anc = m.end() + 1 # pointing the first letter of arguments + square, curly = [], [] # stack for brackets + args = [] + + # current cursor + cur = anc + for i, c in enumerate(s[anc:], anc): + # extract one argument + if c == ',' and (not square) and (not curly): + args.append(s[cur:i]) # add an argument + cur = i + 1 # move cursor + + # handle list or matrix (for future usage) + if c == '{': + curly.append(c) + elif c == '}': + curly.pop() + + # seek corresponding ']' with skipping irrevant ones + if c == '[': + square.append(c) + elif c == ']': + if square: + square.pop() + else: # empty stack + args.append(s[cur:i]) + break + + # the next position to ']' bracket (the function end) + func_end = i + 1 + + return args, func_end + + @classmethod + def _replace(cls, s, bef): + aft = cls.REPLACEMENTS[bef] + s = s.replace(bef, aft) + return s + + @classmethod + def _apply_rules(cls, s, bef): + pat, aft = cls.RULES[bef] + return pat.sub(aft, s) + + @classmethod + def _check_input(cls, s): + for bracket in (('[', ']'), ('{', '}'), ('(', ')')): + if s.count(bracket[0]) != s.count(bracket[1]): + err = "'{f}' function form is invalid.".format(f=s) + raise ValueError(err) + + if '{' in s: + err = "Currently list is not supported." + raise ValueError(err) + + def _parse_old(self, s): + # input check + self._check_input(s) + + # uncover '*' hiding behind a whitespace + s = self._apply_rules(s, 'whitespace') + + # remove whitespace(s) + s = self._replace(s, ' ') + + # add omitted '*' character + s = self._apply_rules(s, 'add*_1') + s = self._apply_rules(s, 'add*_2') + + # translate function + s = self._convert_function(s) + + # '^' to '**' + s = self._replace(s, '^') + + # 'Pi' to 'pi' + s = self._apply_rules(s, 'Pi') + + # '{', '}' to '[', ']', respectively +# s = cls._replace(s, '{') # currently list is not taken into account +# s = cls._replace(s, '}') + + return s + + def parse(self, s): + s2 = self._from_mathematica_to_tokens(s) + s3 = self._from_tokens_to_fullformlist(s2) + s4 = self._from_fullformlist_to_sympy(s3) + return s4 + + INFIX = "Infix" + PREFIX = "Prefix" + POSTFIX = "Postfix" + FLAT = "Flat" + RIGHT = "Right" + LEFT = "Left" + + _mathematica_op_precedence: list[tuple[str, str | None, dict[str, str | Callable]]] = [ + (POSTFIX, None, {";": lambda x: x + ["Null"] if isinstance(x, list) and x and x[0] == "CompoundExpression" else ["CompoundExpression", x, "Null"]}), + (INFIX, FLAT, {";": "CompoundExpression"}), + (INFIX, RIGHT, {"=": "Set", ":=": "SetDelayed", "+=": "AddTo", "-=": "SubtractFrom", "*=": "TimesBy", "/=": "DivideBy"}), + (INFIX, LEFT, {"//": lambda x, y: [x, y]}), + (POSTFIX, None, {"&": "Function"}), + (INFIX, LEFT, {"/.": "ReplaceAll"}), + (INFIX, RIGHT, {"->": "Rule", ":>": "RuleDelayed"}), + (INFIX, LEFT, {"/;": "Condition"}), + (INFIX, FLAT, {"|": "Alternatives"}), + (POSTFIX, None, {"..": "Repeated", "...": "RepeatedNull"}), + (INFIX, FLAT, {"||": "Or"}), + (INFIX, FLAT, {"&&": "And"}), + (PREFIX, None, {"!": "Not"}), + (INFIX, FLAT, {"===": "SameQ", "=!=": "UnsameQ"}), + (INFIX, FLAT, {"==": "Equal", "!=": "Unequal", "<=": "LessEqual", "<": "Less", ">=": "GreaterEqual", ">": "Greater"}), + (INFIX, None, {";;": "Span"}), + (INFIX, FLAT, {"+": "Plus", "-": "Plus"}), + (INFIX, FLAT, {"*": "Times", "/": "Times"}), + (INFIX, FLAT, {".": "Dot"}), + (PREFIX, None, {"-": lambda x: MathematicaParser._get_neg(x), + "+": lambda x: x}), + (INFIX, RIGHT, {"^": "Power"}), + (INFIX, RIGHT, {"@@": "Apply", "/@": "Map", "//@": "MapAll", "@@@": lambda x, y: ["Apply", x, y, ["List", "1"]]}), + (POSTFIX, None, {"'": "Derivative", "!": "Factorial", "!!": "Factorial2", "--": "Decrement"}), + (INFIX, None, {"[": lambda x, y: [x, *y], "[[": lambda x, y: ["Part", x, *y]}), + (PREFIX, None, {"{": lambda x: ["List", *x], "(": lambda x: x[0]}), + (INFIX, None, {"?": "PatternTest"}), + (POSTFIX, None, { + "_": lambda x: ["Pattern", x, ["Blank"]], + "_.": lambda x: ["Optional", ["Pattern", x, ["Blank"]]], + "__": lambda x: ["Pattern", x, ["BlankSequence"]], + "___": lambda x: ["Pattern", x, ["BlankNullSequence"]], + }), + (INFIX, None, {"_": lambda x, y: ["Pattern", x, ["Blank", y]]}), + (PREFIX, None, {"#": "Slot", "##": "SlotSequence"}), + ] + + _missing_arguments_default = { + "#": lambda: ["Slot", "1"], + "##": lambda: ["SlotSequence", "1"], + } + + _literal = r"[A-Za-z][A-Za-z0-9]*" + _number = r"(?:[0-9]+(?:\.[0-9]*)?|\.[0-9]+)" + + _enclosure_open = ["(", "[", "[[", "{"] + _enclosure_close = [")", "]", "]]", "}"] + + @classmethod + def _get_neg(cls, x): + return f"-{x}" if isinstance(x, str) and re.match(MathematicaParser._number, x) else ["Times", "-1", x] + + @classmethod + def _get_inv(cls, x): + return ["Power", x, "-1"] + + _regex_tokenizer = None + + def _get_tokenizer(self): + if self._regex_tokenizer is not None: + # Check if the regular expression has already been compiled: + return self._regex_tokenizer + tokens = [self._literal, self._number] + tokens_escape = self._enclosure_open[:] + self._enclosure_close[:] + for typ, strat, symdict in self._mathematica_op_precedence: + for k in symdict: + tokens_escape.append(k) + tokens_escape.sort(key=lambda x: -len(x)) + tokens.extend(map(re.escape, tokens_escape)) + tokens.append(",") + tokens.append("\n") + tokenizer = re.compile("(" + "|".join(tokens) + ")") + self._regex_tokenizer = tokenizer + return self._regex_tokenizer + + def _from_mathematica_to_tokens(self, code: str): + tokenizer = self._get_tokenizer() + + # Find strings: + code_splits: list[str | list] = [] + while True: + string_start = code.find("\"") + if string_start == -1: + if len(code) > 0: + code_splits.append(code) + break + match_end = re.search(r'(? 0: + code_splits.append(code[:string_start]) + code_splits.append(["_Str", code[string_start+1:string_end].replace('\\"', '"')]) + code = code[string_end+1:] + + # Remove comments: + for i, code_split in enumerate(code_splits): + if isinstance(code_split, list): + continue + while True: + pos_comment_start = code_split.find("(*") + if pos_comment_start == -1: + break + pos_comment_end = code_split.find("*)") + if pos_comment_end == -1 or pos_comment_end < pos_comment_start: + raise SyntaxError("mismatch in comment (* *) code") + code_split = code_split[:pos_comment_start] + code_split[pos_comment_end+2:] + code_splits[i] = code_split + + # Tokenize the input strings with a regular expression: + token_lists = [tokenizer.findall(i) if isinstance(i, str) and i.isascii() else [i] for i in code_splits] + tokens = [j for i in token_lists for j in i] + + # Remove newlines at the beginning + while tokens and tokens[0] == "\n": + tokens.pop(0) + # Remove newlines at the end + while tokens and tokens[-1] == "\n": + tokens.pop(-1) + + return tokens + + def _is_op(self, token: str | list) -> bool: + if isinstance(token, list): + return False + if re.match(self._literal, token): + return False + if re.match("-?" + self._number, token): + return False + return True + + def _is_valid_star1(self, token: str | list) -> bool: + if token in (")", "}"): + return True + return not self._is_op(token) + + def _is_valid_star2(self, token: str | list) -> bool: + if token in ("(", "{"): + return True + return not self._is_op(token) + + def _from_tokens_to_fullformlist(self, tokens: list): + stack: list[list] = [[]] + open_seq = [] + pointer: int = 0 + while pointer < len(tokens): + token = tokens[pointer] + if token in self._enclosure_open: + stack[-1].append(token) + open_seq.append(token) + stack.append([]) + elif token == ",": + if len(stack[-1]) == 0 and stack[-2][-1] == open_seq[-1]: + raise SyntaxError("%s cannot be followed by comma ," % open_seq[-1]) + stack[-1] = self._parse_after_braces(stack[-1]) + stack.append([]) + elif token in self._enclosure_close: + ind = self._enclosure_close.index(token) + if self._enclosure_open[ind] != open_seq[-1]: + unmatched_enclosure = SyntaxError("unmatched enclosure") + if token == "]]" and open_seq[-1] == "[": + if open_seq[-2] == "[": + # These two lines would be logically correct, but are + # unnecessary: + # token = "]" + # tokens[pointer] = "]" + tokens.insert(pointer+1, "]") + elif open_seq[-2] == "[[": + if tokens[pointer+1] == "]": + tokens[pointer+1] = "]]" + elif tokens[pointer+1] == "]]": + tokens[pointer+1] = "]]" + tokens.insert(pointer+2, "]") + else: + raise unmatched_enclosure + else: + raise unmatched_enclosure + if len(stack[-1]) == 0 and stack[-2][-1] == "(": + raise SyntaxError("( ) not valid syntax") + last_stack = self._parse_after_braces(stack[-1], True) + stack[-1] = last_stack + new_stack_element = [] + while stack[-1][-1] != open_seq[-1]: + new_stack_element.append(stack.pop()) + new_stack_element.reverse() + if open_seq[-1] == "(" and len(new_stack_element) != 1: + raise SyntaxError("( must be followed by one expression, %i detected" % len(new_stack_element)) + stack[-1].append(new_stack_element) + open_seq.pop(-1) + else: + stack[-1].append(token) + pointer += 1 + if len(stack) != 1: + raise RuntimeError("Stack should have only one element") + return self._parse_after_braces(stack[0]) + + def _util_remove_newlines(self, lines: list, tokens: list, inside_enclosure: bool): + pointer = 0 + size = len(tokens) + while pointer < size: + token = tokens[pointer] + if token == "\n": + if inside_enclosure: + # Ignore newlines inside enclosures + tokens.pop(pointer) + size -= 1 + continue + if pointer == 0: + tokens.pop(0) + size -= 1 + continue + if pointer > 1: + try: + prev_expr = self._parse_after_braces(tokens[:pointer], inside_enclosure) + except SyntaxError: + tokens.pop(pointer) + size -= 1 + continue + else: + prev_expr = tokens[0] + if len(prev_expr) > 0 and prev_expr[0] == "CompoundExpression": + lines.extend(prev_expr[1:]) + else: + lines.append(prev_expr) + for i in range(pointer): + tokens.pop(0) + size -= pointer + pointer = 0 + continue + pointer += 1 + + def _util_add_missing_asterisks(self, tokens: list): + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + if (pointer > 0 and + self._is_valid_star1(tokens[pointer - 1]) and + self._is_valid_star2(tokens[pointer])): + # This is a trick to add missing * operators in the expression, + # `"*" in op_dict` makes sure the precedence level is the same as "*", + # while `not self._is_op( ... )` makes sure this and the previous + # expression are not operators. + if tokens[pointer] == "(": + # ( has already been processed by now, replace: + tokens[pointer] = "*" + tokens[pointer + 1] = tokens[pointer + 1][0] + else: + tokens.insert(pointer, "*") + pointer += 1 + size += 1 + pointer += 1 + + def _parse_after_braces(self, tokens: list, inside_enclosure: bool = False): + op_dict: dict + changed: bool = False + lines: list = [] + + self._util_remove_newlines(lines, tokens, inside_enclosure) + + for op_type, grouping_strat, op_dict in reversed(self._mathematica_op_precedence): + if "*" in op_dict: + self._util_add_missing_asterisks(tokens) + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + token = tokens[pointer] + if isinstance(token, str) and token in op_dict: + op_name: str | Callable = op_dict[token] + node: list + first_index: int + if isinstance(op_name, str): + node = [op_name] + first_index = 1 + else: + node = [] + first_index = 0 + if token in ("+", "-") and op_type == self.PREFIX and pointer > 0 and not self._is_op(tokens[pointer - 1]): + # Make sure that PREFIX + - don't match expressions like a + b or a - b, + # the INFIX + - are supposed to match that expression: + pointer += 1 + continue + if op_type == self.INFIX: + if pointer == 0 or pointer == size - 1 or self._is_op(tokens[pointer - 1]) or self._is_op(tokens[pointer + 1]): + pointer += 1 + continue + changed = True + tokens[pointer] = node + if op_type == self.INFIX: + arg1 = tokens.pop(pointer-1) + arg2 = tokens.pop(pointer) + if token == "/": + arg2 = self._get_inv(arg2) + elif token == "-": + arg2 = self._get_neg(arg2) + pointer -= 1 + size -= 2 + node.append(arg1) + node_p = node + if grouping_strat == self.FLAT: + while pointer + 2 < size and self._check_op_compatible(tokens[pointer+1], token): + node_p.append(arg2) + other_op = tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + if other_op == "/": + arg2 = self._get_inv(arg2) + elif other_op == "-": + arg2 = self._get_neg(arg2) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.RIGHT: + while pointer + 2 < size and tokens[pointer+1] == token: + node_p.append([op_name, arg2]) + node_p = node_p[-1] + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.LEFT: + while pointer + 1 < size and tokens[pointer+1] == token: + if isinstance(op_name, str): + node_p[first_index] = [op_name, node_p[first_index], arg2] + else: + node_p[first_index] = op_name(node_p[first_index], arg2) + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + else: + node.append(arg2) + elif op_type == self.PREFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == size - 1 or self._is_op(tokens[pointer + 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer+1)) + size -= 1 + elif op_type == self.POSTFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == 0 or self._is_op(tokens[pointer - 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer-1)) + pointer -= 1 + size -= 1 + if isinstance(op_name, Callable): # type: ignore + op_call: Callable = typing.cast(Callable, op_name) + new_node = op_call(*node) + node.clear() + if isinstance(new_node, list): + node.extend(new_node) + else: + tokens[pointer] = new_node + pointer += 1 + if len(tokens) > 1 or (len(lines) == 0 and len(tokens) == 0): + if changed: + # Trick to deal with cases in which an operator with lower + # precedence should be transformed before an operator of higher + # precedence. Such as in the case of `#&[x]` (that is + # equivalent to `Lambda(d_, d_)(x)` in SymPy). In this case the + # operator `&` has lower precedence than `[`, but needs to be + # evaluated first because otherwise `# (&[x])` is not a valid + # expression: + return self._parse_after_braces(tokens, inside_enclosure) + raise SyntaxError("unable to create a single AST for the expression") + if len(lines) > 0: + if tokens[0] and tokens[0][0] == "CompoundExpression": + tokens = tokens[0][1:] + compound_expression = ["CompoundExpression", *lines, *tokens] + return compound_expression + return tokens[0] + + def _check_op_compatible(self, op1: str, op2: str): + if op1 == op2: + return True + muldiv = {"*", "/"} + addsub = {"+", "-"} + if op1 in muldiv and op2 in muldiv: + return True + if op1 in addsub and op2 in addsub: + return True + return False + + def _from_fullform_to_fullformlist(self, wmexpr: str): + """ + Parses FullForm[Downvalues[]] generated by Mathematica + """ + out: list = [] + stack = [out] + generator = re.finditer(r'[\[\],]', wmexpr) + last_pos = 0 + for match in generator: + if match is None: + break + position = match.start() + last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip() + + if match.group() == ',': + if last_expr != '': + stack[-1].append(last_expr) + elif match.group() == ']': + if last_expr != '': + stack[-1].append(last_expr) + stack.pop() + elif match.group() == '[': + stack[-1].append([last_expr]) + stack.append(stack[-1][-1]) + last_pos = match.end() + return out[0] + + def _from_fullformlist_to_fullformsympy(self, pylist: list): + from sympy import Function, Symbol + + def converter(expr): + if isinstance(expr, list): + if len(expr) > 0: + head = expr[0] + args = [converter(arg) for arg in expr[1:]] + return Function(head)(*args) + else: + raise ValueError("Empty list of expressions") + elif isinstance(expr, str): + return Symbol(expr) + else: + return _sympify(expr) + + return converter(pylist) + + _node_conversions = { + "Times": Mul, + "Plus": Add, + "Power": Pow, + "Rational": Rational, + "Log": lambda *a: log(*reversed(a)), + "Log2": lambda x: log(x, 2), + "Log10": lambda x: log(x, 10), + "Exp": exp, + "Sqrt": sqrt, + + "Sin": sin, + "Cos": cos, + "Tan": tan, + "Cot": cot, + "Sec": sec, + "Csc": csc, + + "ArcSin": asin, + "ArcCos": acos, + "ArcTan": lambda *a: atan2(*reversed(a)) if len(a) == 2 else atan(*a), + "ArcCot": acot, + "ArcSec": asec, + "ArcCsc": acsc, + + "Sinh": sinh, + "Cosh": cosh, + "Tanh": tanh, + "Coth": coth, + "Sech": sech, + "Csch": csch, + + "ArcSinh": asinh, + "ArcCosh": acosh, + "ArcTanh": atanh, + "ArcCoth": acoth, + "ArcSech": asech, + "ArcCsch": acsch, + + "Expand": expand, + "Im": im, + "Re": sympy.re, + "Flatten": flatten, + "Polylog": polylog, + "Cancel": cancel, + # Gamma=gamma, + "TrigExpand": expand_trig, + "Sign": sign, + "Simplify": simplify, + "Defer": UnevaluatedExpr, + "Identity": S, + # Sum=Sum_doit, + # Module=With, + # Block=With, + "Null": lambda *a: S.Zero, + "Mod": Mod, + "Max": Max, + "Min": Min, + "Pochhammer": rf, + "ExpIntegralEi": Ei, + "SinIntegral": Si, + "CosIntegral": Ci, + "AiryAi": airyai, + "AiryAiPrime": airyaiprime, + "AiryBi": airybi, + "AiryBiPrime": airybiprime, + "LogIntegral": li, + "PrimePi": primepi, + "Prime": prime, + "PrimeQ": isprime, + + "List": Tuple, + "Greater": StrictGreaterThan, + "GreaterEqual": GreaterThan, + "Less": StrictLessThan, + "LessEqual": LessThan, + "Equal": Equality, + "Or": Or, + "And": And, + + "Function": _parse_Function, + } + + _atom_conversions = { + "I": I, + "Pi": pi, + } + + def _from_fullformlist_to_sympy(self, full_form_list): + + def recurse(expr): + if isinstance(expr, list): + if isinstance(expr[0], list): + head = recurse(expr[0]) + else: + head = self._node_conversions.get(expr[0], Function(expr[0])) + return head(*[recurse(arg) for arg in expr[1:]]) + else: + return self._atom_conversions.get(expr, sympify(expr)) + + return recurse(full_form_list) + + def _from_fullformsympy_to_sympy(self, mform): + + expr = mform + for mma_form, sympy_node in self._node_conversions.items(): + expr = expr.replace(Function(mma_form), sympy_node) + return expr diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/maxima.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8ee5b17bb03a36e338803cb10f9ebf22763c2c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/maxima.py @@ -0,0 +1,71 @@ +import re +from sympy.concrete.products import product +from sympy.concrete.summations import Sum +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import (cos, sin) + + +class MaximaHelpers: + def maxima_expand(expr): + return expr.expand() + + def maxima_float(expr): + return expr.evalf() + + def maxima_trigexpand(expr): + return expr.expand(trig=True) + + def maxima_sum(a1, a2, a3, a4): + return Sum(a1, (a2, a3, a4)).doit() + + def maxima_product(a1, a2, a3, a4): + return product(a1, (a2, a3, a4)) + + def maxima_csc(expr): + return 1/sin(expr) + + def maxima_sec(expr): + return 1/cos(expr) + +sub_dict = { + 'pi': re.compile(r'%pi'), + 'E': re.compile(r'%e'), + 'I': re.compile(r'%i'), + '**': re.compile(r'\^'), + 'oo': re.compile(r'\binf\b'), + '-oo': re.compile(r'\bminf\b'), + "'-'": re.compile(r'\bminus\b'), + 'maxima_expand': re.compile(r'\bexpand\b'), + 'maxima_float': re.compile(r'\bfloat\b'), + 'maxima_trigexpand': re.compile(r'\btrigexpand'), + 'maxima_sum': re.compile(r'\bsum\b'), + 'maxima_product': re.compile(r'\bproduct\b'), + 'cancel': re.compile(r'\bratsimp\b'), + 'maxima_csc': re.compile(r'\bcsc\b'), + 'maxima_sec': re.compile(r'\bsec\b') +} + +var_name = re.compile(r'^\s*(\w+)\s*:') + + +def parse_maxima(str, globals=None, name_dict={}): + str = str.strip() + str = str.rstrip('; ') + + for k, v in sub_dict.items(): + str = v.sub(k, str) + + assign_var = None + var_match = var_name.search(str) + if var_match: + assign_var = var_match.group(1) + str = str[var_match.end():].strip() + + dct = MaximaHelpers.__dict__.copy() + dct.update(name_dict) + obj = sympify(str, locals=dct) + + if assign_var and globals: + globals[assign_var] = obj + + return obj diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sym_expr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbd0e94eb51147b51825fcf15cbec5ae18bb1b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sym_expr.py @@ -0,0 +1,279 @@ +from sympy.printing import pycode, ccode, fcode +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran: + from sympy.parsing.fortran.fortran_parser import src_to_sympy +if cin: + from sympy.parsing.c.c_parser import parse_c + +@doctest_depends_on(modules=['lfortran', 'clang.cindex']) +class SymPyExpression: # type: ignore + """Class to store and handle SymPy expressions + + This class will hold SymPy Expressions and handle the API for the + conversion to and from different languages. + + It works with the C and the Fortran Parser to generate SymPy expressions + which are stored here and which can be converted to multiple language's + source code. + + Notes + ===== + + The module and its API are currently under development and experimental + and can be changed during development. + + The Fortran parser does not support numeric assignments, so all the + variables have been Initialized to zero. + + The module also depends on external dependencies: + + - LFortran which is required to use the Fortran parser + - Clang which is required for the C parser + + Examples + ======== + + Example of parsing C code: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... int a,b; + ... float c = 2, d =4; + ... ''' + >>> a = SymPyExpression(src, 'c') + >>> a.return_expr() + [Declaration(Variable(a, type=intc)), + Declaration(Variable(b, type=intc)), + Declaration(Variable(c, type=float32, value=2.0)), + Declaration(Variable(d, type=float32, value=4.0))] + + An example of variable definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0'] + + An example of Assignment: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer :: a, b, c, d, e + ... d = a + b - c + ... e = b * d + c * e / a + ... ''' + >>> p = SymPyExpression(src3, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'e = 0', 'd = a + b - c', 'e = b*d + c*e/a'] + + An example of function definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... end function + ... ''' + >>> a = SymPyExpression(src, 'f') + >>> a.convert_to_python() + ['def f(a, b):\\n f = 0\\n r = 0\\n return f'] + + """ + + def __init__(self, source_code = None, mode = None): + """Constructor for SymPyExpression class""" + super().__init__() + if not(mode or source_code): + self._expr = [] + elif mode: + if source_code: + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(source_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(source_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + 'Parser for specified language is not implemented' + ) + else: + raise ValueError('Source code not present') + else: + raise ValueError('Please specify a mode for conversion') + + def convert_to_expr(self, src_code, mode): + """Converts the given source code to SymPy Expressions + + Attributes + ========== + + src_code : String + the source code or filename of the source code that is to be + converted + + mode: String + the mode to determine which parser is to be used according to + the language of the source code + f or F for Fortran + c or C for C/C++ + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) result(r) + ... integer, intent(in) :: a, b + ... integer :: x + ... r = a + b -x + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(r, type=integer, value=0)), + Declaration(Variable(x, type=integer, value=0)), + Assignment(Variable(r), a + b - x), + Return(Variable(r)) + ))] + + + + + """ + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(src_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(src_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + "Parser for specified language has not been implemented" + ) + + def convert_to_python(self): + """Returns a list with Python code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'p = 0.0', 'q = 0.0', 'r = 0.0', 's = 0.0', 'c = a/b', 'd = c/a', 's = p/q', 'r = q/p'] + + """ + self._pycode = [] + for iter in self._expr: + self._pycode.append(pycode(iter)) + return self._pycode + + def convert_to_c(self): + """Returns a list with the c source code for the SymPy expressions + + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0', 'c = a/b;', 'd = c/a;', 's = p/q;', 'r = q/p;'] + + """ + self._ccode = [] + for iter in self._expr: + self._ccode.append(ccode(iter)) + return self._ccode + + def convert_to_fortran(self): + """Returns a list with the fortran source code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_fortran() + [' integer*4 a', ' integer*4 b', ' integer*4 c', ' integer*4 d', ' real*8 p', ' real*8 q', ' real*8 r', ' real*8 s', ' c = a/b', ' d = c/a', ' s = p/q', ' r = q/p'] + + """ + self._fcode = [] + for iter in self._expr: + self._fcode.append(fcode(iter)) + return self._fcode + + def return_expr(self): + """Returns the expression list + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... r = a+b + ... f = r + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(f, type=integer, value=0)), + Declaration(Variable(r, type=integer, value=0)), + Assignment(Variable(f), Variable(r)), + Return(Variable(f)) + ))] + + """ + return self._expr diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sympy_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfda9ce0f73ffa3773031c48b9e9c245f69fe0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/sympy_parser.py @@ -0,0 +1,1270 @@ +"""Transform a string with Python-like source code into SymPy expression. """ +from __future__ import annotations +from tokenize import (generate_tokens, untokenize, TokenError, + NUMBER, STRING, NAME, OP, ENDMARKER, ERRORTOKEN, NEWLINE) + +from keyword import iskeyword + +import ast +import unicodedata +from io import StringIO +import builtins +import types +from typing import Any, Callable +from functools import reduce +from sympy.assumptions.ask import AssumptionKeys +from sympy.core.basic import Basic +from sympy.core import Symbol +from sympy.core.function import Function +from sympy.utilities.misc import func_name +from sympy.functions.elementary.miscellaneous import Max, Min + + +null = '' + +TOKEN = tuple[int, str] +DICT = dict[str, Any] +TRANS = Callable[[list[TOKEN], DICT, DICT], list[TOKEN]] + +def _token_splittable(token_name: str) -> bool: + """ + Predicate for whether a token name can be split into multiple tokens. + + A token is splittable if it does not contain an underscore character and + it is not the name of a Greek letter. This is used to implicitly convert + expressions like 'xyz' into 'x*y*z'. + """ + if '_' in token_name: + return False + try: + return not unicodedata.lookup('GREEK SMALL LETTER ' + token_name) + except KeyError: + return len(token_name) > 1 + + +def _token_callable(token: TOKEN, local_dict: DICT, global_dict: DICT, nextToken=None): + """ + Predicate for whether a token name represents a callable function. + + Essentially wraps ``callable``, but looks up the token name in the + locals and globals. + """ + func = local_dict.get(token[1]) + if not func: + func = global_dict.get(token[1]) + return callable(func) and not isinstance(func, Symbol) + + +def _add_factorial_tokens(name: str, result: list[TOKEN]) -> list[TOKEN]: + if result == [] or result[-1][1] == '(': + raise TokenError() + + beginning = [(NAME, name), (OP, '(')] + end = [(OP, ')')] + + diff = 0 + length = len(result) + + for index, token in enumerate(result[::-1]): + toknum, tokval = token + i = length - index - 1 + + if tokval == ')': + diff += 1 + elif tokval == '(': + diff -= 1 + + if diff == 0: + if i - 1 >= 0 and result[i - 1][0] == NAME: + return result[:i - 1] + beginning + result[i - 1:] + end + else: + return result[:i] + beginning + result[i:] + end + + return result + + +class ParenthesisGroup(list[TOKEN]): + """List of tokens representing an expression in parentheses.""" + pass + + +class AppliedFunction: + """ + A group of tokens representing a function and its arguments. + + `exponent` is for handling the shorthand sin^2, ln^2, etc. + """ + def __init__(self, function: TOKEN, args: ParenthesisGroup, exponent=None): + if exponent is None: + exponent = [] + self.function = function + self.args = args + self.exponent = exponent + self.items = ['function', 'args', 'exponent'] + + def expand(self) -> list[TOKEN]: + """Return a list of tokens representing the function""" + return [self.function, *self.args] + + def __getitem__(self, index): + return getattr(self, self.items[index]) + + def __repr__(self): + return "AppliedFunction(%s, %s, %s)" % (self.function, self.args, + self.exponent) + + +def _flatten(result: list[TOKEN | AppliedFunction]): + result2: list[TOKEN] = [] + for tok in result: + if isinstance(tok, AppliedFunction): + result2.extend(tok.expand()) + else: + result2.append(tok) + return result2 + + +def _group_parentheses(recursor: TRANS): + def _inner(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Group tokens between parentheses with ParenthesisGroup. + + Also processes those tokens recursively. + + """ + result: list[TOKEN | ParenthesisGroup] = [] + stacks: list[ParenthesisGroup] = [] + stacklevel = 0 + for token in tokens: + if token[0] == OP: + if token[1] == '(': + stacks.append(ParenthesisGroup([])) + stacklevel += 1 + elif token[1] == ')': + stacks[-1].append(token) + stack = stacks.pop() + + if len(stacks) > 0: + # We don't recurse here since the upper-level stack + # would reprocess these tokens + stacks[-1].extend(stack) + else: + # Recurse here to handle nested parentheses + # Strip off the outer parentheses to avoid an infinite loop + inner = stack[1:-1] + inner = recursor(inner, + local_dict, + global_dict) + parenGroup = [stack[0]] + inner + [stack[-1]] + result.append(ParenthesisGroup(parenGroup)) + stacklevel -= 1 + continue + if stacklevel: + stacks[-1].append(token) + else: + result.append(token) + if stacklevel: + raise TokenError("Mismatched parentheses") + return result + return _inner + + +def _apply_functions(tokens: list[TOKEN | ParenthesisGroup], local_dict: DICT, global_dict: DICT): + """Convert a NAME token + ParenthesisGroup into an AppliedFunction. + + Note that ParenthesisGroups, if not applied to any function, are + converted back into lists of tokens. + + """ + result: list[TOKEN | AppliedFunction] = [] + symbol = None + for tok in tokens: + if isinstance(tok, ParenthesisGroup): + if symbol and _token_callable(symbol, local_dict, global_dict): + result[-1] = AppliedFunction(symbol, tok) + symbol = None + else: + result.extend(tok) + elif tok[0] == NAME: + symbol = tok + result.append(tok) + else: + symbol = None + result.append(tok) + return result + + +def _implicit_multiplication(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Implicitly adds '*' tokens. + + Cases: + + - Two AppliedFunctions next to each other ("sin(x)cos(x)") + + - AppliedFunction next to an open parenthesis ("sin x (cos x + 1)") + + - A close parenthesis next to an AppliedFunction ("(x+2)sin x")\ + + - A close parenthesis next to an open parenthesis ("(x+2)(x+3)") + + - AppliedFunction next to an implicitly applied function ("sin(x)cos x") + + """ + result: list[TOKEN | AppliedFunction] = [] + skip = False + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if skip: + skip = False + continue + if tok[0] == OP and tok[1] == '.' and nextTok[0] == NAME: + # Dotted name. Do not do implicit multiplication + skip = True + continue + if isinstance(tok, AppliedFunction): + if isinstance(nextTok, AppliedFunction): + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Applied function followed by an open parenthesis + if tok.function[1] == "Function": + tok.function = (tok.function[0], 'Symbol') + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Applied function followed by implicitly applied function + result.append((OP, '*')) + else: + if tok == (OP, ')'): + if isinstance(nextTok, AppliedFunction): + # Close parenthesis followed by an applied function + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Close parenthesis followed by an implicitly applied function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Close parenthesis followed by an open parenthesis + result.append((OP, '*')) + elif tok[0] == NAME and not _token_callable(tok, local_dict, global_dict): + if isinstance(nextTok, AppliedFunction) or \ + (nextTok[0] == NAME and _token_callable(nextTok, local_dict, global_dict)): + # Constant followed by (implicitly applied) function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Constant followed by parenthesis + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Constant followed by constant + result.append((OP, '*')) + if tokens: + result.append(tokens[-1]) + return result + + +def _implicit_application(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Adds parentheses as needed after functions.""" + result: list[TOKEN | AppliedFunction] = [] + appendParen = 0 # number of closing parentheses to add + skip = 0 # number of tokens to delay before adding a ')' (to + # capture **, ^, etc.) + exponentSkip = False # skipping tokens before inserting parentheses to + # work with function exponentiation + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if (tok[0] == NAME and nextTok[0] not in [OP, ENDMARKER, NEWLINE]): + if _token_callable(tok, local_dict, global_dict, nextTok): # type: ignore + result.append((OP, '(')) + appendParen += 1 + # name followed by exponent - function exponentiation + elif (tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**'): + if _token_callable(tok, local_dict, global_dict): # type: ignore + exponentSkip = True + elif exponentSkip: + # if the last token added was an applied function (i.e. the + # power of the function exponent) OR a multiplication (as + # implicit multiplication would have added an extraneous + # multiplication) + if (isinstance(tok, AppliedFunction) + or (tok[0] == OP and tok[1] == '*')): + # don't add anything if the next token is a multiplication + # or if there's already a parenthesis (if parenthesis, still + # stop skipping tokens) + if not (nextTok[0] == OP and nextTok[1] == '*'): + if not(nextTok[0] == OP and nextTok[1] == '('): + result.append((OP, '(')) + appendParen += 1 + exponentSkip = False + elif appendParen: + if nextTok[0] == OP and nextTok[1] in ('^', '**', '*'): + skip = 1 + continue + if skip: + skip -= 1 + continue + result.append((OP, ')')) + appendParen -= 1 + + if tokens: + result.append(tokens[-1]) + + if appendParen: + result.extend([(OP, ')')] * appendParen) + return result + + +def function_exponentiation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows functions to be exponentiated, e.g. ``cos**2(x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, function_exponentiation) + >>> transformations = standard_transformations + (function_exponentiation,) + >>> parse_expr('sin**4(x)', transformations=transformations) + sin(x)**4 + """ + result: list[TOKEN] = [] + exponent: list[TOKEN] = [] + consuming_exponent = False + level = 0 + for tok, nextTok in zip(tokens, tokens[1:]): + if tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**': + if _token_callable(tok, local_dict, global_dict): + consuming_exponent = True + elif consuming_exponent: + if tok[0] == NAME and tok[1] == 'Function': + tok = (NAME, 'Symbol') + exponent.append(tok) + + # only want to stop after hitting ) + if tok[0] == nextTok[0] == OP and tok[1] == ')' and nextTok[1] == '(': + consuming_exponent = False + # if implicit multiplication was used, we may have )*( instead + if tok[0] == nextTok[0] == OP and tok[1] == '*' and nextTok[1] == '(': + consuming_exponent = False + del exponent[-1] + continue + elif exponent and not consuming_exponent: + if tok[0] == OP: + if tok[1] == '(': + level += 1 + elif tok[1] == ')': + level -= 1 + if level == 0: + result.append(tok) + result.extend(exponent) + exponent = [] + continue + result.append(tok) + if tokens: + result.append(tokens[-1]) + if exponent: + result.extend(exponent) + return result + + +def split_symbols_custom(predicate: Callable[[str], bool]): + """Creates a transformation that splits symbol names. + + ``predicate`` should return True if the symbol name is to be split. + + For instance, to retain the default behavior but avoid splitting certain + symbol names, a predicate like this would work: + + + >>> from sympy.parsing.sympy_parser import (parse_expr, _token_splittable, + ... standard_transformations, implicit_multiplication, + ... split_symbols_custom) + >>> def can_split(symbol): + ... if symbol not in ('list', 'of', 'unsplittable', 'names'): + ... return _token_splittable(symbol) + ... return False + ... + >>> transformation = split_symbols_custom(can_split) + >>> parse_expr('unsplittable', transformations=standard_transformations + + ... (transformation, implicit_multiplication)) + unsplittable + """ + def _split_symbols(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + result: list[TOKEN] = [] + split = False + split_previous=False + + for tok in tokens: + if split_previous: + # throw out closing parenthesis of Symbol that was split + split_previous=False + continue + split_previous=False + + if tok[0] == NAME and tok[1] in ['Symbol', 'Function']: + split = True + + elif split and tok[0] == NAME: + symbol = tok[1][1:-1] + + if predicate(symbol): + tok_type = result[-2][1] # Symbol or Function + del result[-2:] # Get rid of the call to Symbol + + i = 0 + while i < len(symbol): + char = symbol[i] + if char in local_dict or char in global_dict: + result.append((NAME, "%s" % char)) + elif char.isdigit(): + chars = [char] + for i in range(i + 1, len(symbol)): + if not symbol[i].isdigit(): + i -= 1 + break + chars.append(symbol[i]) + char = ''.join(chars) + result.extend([(NAME, 'Number'), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + else: + use = tok_type if i == len(symbol) else 'Symbol' + result.extend([(NAME, use), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + i += 1 + + # Set split_previous=True so will skip + # the closing parenthesis of the original Symbol + split = False + split_previous = True + continue + + else: + split = False + + result.append(tok) + + return result + + return _split_symbols + + +#: Splits symbol names for implicit multiplication. +#: +#: Intended to let expressions like ``xyz`` be parsed as ``x*y*z``. Does not +#: split Greek character names, so ``theta`` will *not* become +#: ``t*h*e*t*a``. Generally this should be used with +#: ``implicit_multiplication``. +split_symbols = split_symbols_custom(_token_splittable) + + +def implicit_multiplication(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes the multiplication operator optional in most cases. + + Use this before :func:`implicit_application`, otherwise expressions like + ``sin 2x`` will be parsed as ``x * sin(2)`` rather than ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication) + >>> transformations = standard_transformations + (implicit_multiplication,) + >>> parse_expr('3 x y', transformations=transformations) + 3*x*y + """ + # These are interdependent steps, so we don't expose them separately + res1 = _group_parentheses(implicit_multiplication)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_multiplication(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_application(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes parentheses optional in some cases for function calls. + + Use this after :func:`implicit_multiplication`, otherwise expressions + like ``sin 2x`` will be parsed as ``x * sin(2)`` rather than + ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_application) + >>> transformations = standard_transformations + (implicit_application,) + >>> parse_expr('cot z + csc z', transformations=transformations) + cot(z) + csc(z) + """ + res1 = _group_parentheses(implicit_application)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_application(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_multiplication_application(result: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Allows a slightly relaxed syntax. + + - Parentheses for single-argument method calls are optional. + + - Multiplication is implicit. + + - Symbol names can be split (i.e. spaces are not needed between + symbols). + + - Functions can be exponentiated. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication_application) + >>> parse_expr("10sin**2 x**2 + 3xyz + tan theta", + ... transformations=(standard_transformations + + ... (implicit_multiplication_application,))) + 3*x*y*z + 10*sin(x**2)**2 + tan(theta) + + """ + for step in (split_symbols, implicit_multiplication, + implicit_application, function_exponentiation): + result = step(result, local_dict, global_dict) + + return result + + +def auto_symbol(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Inserts calls to ``Symbol``/``Function`` for undefined variables.""" + result: list[TOKEN] = [] + prevTok = (-1, '') + + tokens.append((-1, '')) # so zip traverses all tokens + for tok, nextTok in zip(tokens, tokens[1:]): + tokNum, tokVal = tok + nextTokNum, nextTokVal = nextTok + if tokNum == NAME: + name = tokVal + + if (name in ['True', 'False', 'None'] + or iskeyword(name) + # Don't convert attribute access + or (prevTok[0] == OP and prevTok[1] == '.') + # Don't convert keyword arguments + or (prevTok[0] == OP and prevTok[1] in ('(', ',') + and nextTokNum == OP and nextTokVal == '=') + # the name has already been defined + or name in local_dict and local_dict[name] is not null): + result.append((NAME, name)) + continue + elif name in local_dict: + local_dict.setdefault(null, set()).add(name) + if nextTokVal == '(': + local_dict[name] = Function(name) + else: + local_dict[name] = Symbol(name) + result.append((NAME, name)) + continue + elif name in global_dict: + obj = global_dict[name] + if isinstance(obj, (AssumptionKeys, Basic, type)) or callable(obj): + result.append((NAME, name)) + continue + + result.extend([ + (NAME, 'Symbol' if nextTokVal != '(' else 'Function'), + (OP, '('), + (NAME, repr(str(name))), + (OP, ')'), + ]) + else: + result.append((tokNum, tokVal)) + + prevTok = (tokNum, tokVal) + + return result + + +def lambda_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Substitutes "lambda" with its SymPy equivalent Lambda(). + However, the conversion does not take place if only "lambda" + is passed because that is a syntax error. + + """ + result: list[TOKEN] = [] + flag = False + toknum, tokval = tokens[0] + tokLen = len(tokens) + + if toknum == NAME and tokval == 'lambda': + if tokLen == 2 or tokLen == 3 and tokens[1][0] == NEWLINE: + # In Python 3.6.7+, inputs without a newline get NEWLINE added to + # the tokens + result.extend(tokens) + elif tokLen > 2: + result.extend([ + (NAME, 'Lambda'), + (OP, '('), + (OP, '('), + (OP, ')'), + (OP, ')'), + ]) + for tokNum, tokVal in tokens[1:]: + if tokNum == OP and tokVal == ':': + tokVal = ',' + flag = True + if not flag and tokNum == OP and tokVal in ('*', '**'): + raise TokenError("Starred arguments in lambda not supported") + if flag: + result.insert(-1, (tokNum, tokVal)) + else: + result.insert(-2, (tokNum, tokVal)) + else: + result.extend(tokens) + + return result + + +def factorial_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows standard notation for factorial.""" + result: list[TOKEN] = [] + nfactorial = 0 + for toknum, tokval in tokens: + if toknum == OP and tokval == "!": + # In Python 3.12 "!" are OP instead of ERRORTOKEN + nfactorial += 1 + elif toknum == ERRORTOKEN: + op = tokval + if op == '!': + nfactorial += 1 + else: + nfactorial = 0 + result.append((OP, op)) + else: + if nfactorial == 1: + result = _add_factorial_tokens('factorial', result) + elif nfactorial == 2: + result = _add_factorial_tokens('factorial2', result) + elif nfactorial > 2: + raise TokenError + nfactorial = 0 + result.append((toknum, tokval)) + return result + + +def convert_xor(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Treats XOR, ``^``, as exponentiation, ``**``.""" + result: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == OP: + if tokval == '^': + result.append((OP, '**')) + else: + result.append((toknum, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def repeated_decimals(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Allows 0.2[1] notation to represent the repeated decimal 0.2111... (19/90) + + Run this before auto_number. + + """ + result: list[TOKEN] = [] + + def is_digit(s): + return all(i in '0123456789_' for i in s) + + # num will running match any DECIMAL [ INTEGER ] + num: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == NUMBER: + if (not num and '.' in tokval and 'e' not in tokval.lower() and + 'j' not in tokval.lower()): + num.append((toknum, tokval)) + elif is_digit(tokval) and (len(num) == 2 or + len(num) == 3 and is_digit(num[-1][1])): + num.append((toknum, tokval)) + else: + num = [] + elif toknum == OP: + if tokval == '[' and len(num) == 1: + num.append((OP, tokval)) + elif tokval == ']' and len(num) >= 3: + num.append((OP, tokval)) + elif tokval == '.' and not num: + # handle .[1] + num.append((NUMBER, '0.')) + else: + num = [] + else: + num = [] + + result.append((toknum, tokval)) + + if num and num[-1][1] == ']': + # pre.post[repetend] = a + b/c + d/e where a = pre, b/c = post, + # and d/e = repetend + result = result[:-len(num)] + pre, post = num[0][1].split('.') + repetend = num[2][1] + if len(num) == 5: + repetend += num[3][1] + + pre = pre.replace('_', '') + post = post.replace('_', '') + repetend = repetend.replace('_', '') + + zeros = '0'*len(post) + post, repetends = [w.lstrip('0') for w in [post, repetend]] + # or else interpreted as octal + + a = pre or '0' + b, c = post or '0', '1' + zeros + d, e = repetends, ('9'*len(repetend)) + zeros + + seq = [ + (OP, '('), + (NAME, 'Integer'), + (OP, '('), + (NUMBER, a), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, b), + (OP, ','), + (NUMBER, c), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, d), + (OP, ','), + (NUMBER, e), + (OP, ')'), + (OP, ')'), + ] + result.extend(seq) + num = [] + + return result + + +def auto_number(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Converts numeric literals to use SymPy equivalents. + + Complex numbers use ``I``, integer literals use ``Integer``, and float + literals use ``Float``. + + """ + result: list[TOKEN] = [] + + for toknum, tokval in tokens: + if toknum == NUMBER: + number = tokval + postfix = [] + + if number.endswith(('j', 'J')): + number = number[:-1] + postfix = [(OP, '*'), (NAME, 'I')] + + if '.' in number or (('e' in number or 'E' in number) and + not (number.startswith(('0x', '0X')))): + seq = [(NAME, 'Float'), (OP, '('), + (NUMBER, repr(str(number))), (OP, ')')] + else: + seq = [(NAME, 'Integer'), (OP, '('), ( + NUMBER, number), (OP, ')')] + + result.extend(seq + postfix) + else: + result.append((toknum, tokval)) + + return result + + +def rationalize(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Converts floats into ``Rational``. Run AFTER ``auto_number``.""" + result: list[TOKEN] = [] + passed_float = False + for toknum, tokval in tokens: + if toknum == NAME: + if tokval == 'Float': + passed_float = True + tokval = 'Rational' + result.append((toknum, tokval)) + elif passed_float == True and toknum == NUMBER: + passed_float = False + result.append((STRING, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def _transform_equals_sign(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Transforms the equals sign ``=`` to instances of Eq. + + This is a helper function for ``convert_equals_signs``. + Works with expressions containing one equals sign and no + nesting. Expressions like ``(1=2)=False`` will not work with this + and should be used with ``convert_equals_signs``. + + Examples: 1=2 to Eq(1,2) + 1*2=x to Eq(1*2, x) + + This does not deal with function arguments yet. + + """ + result: list[TOKEN] = [] + if (OP, "=") in tokens: + result.append((NAME, "Eq")) + result.append((OP, "(")) + for token in tokens: + if token == (OP, "="): + result.append((OP, ",")) + continue + result.append(token) + result.append((OP, ")")) + else: + result = tokens + return result + + +def convert_equals_signs(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """ Transforms all the equals signs ``=`` to instances of Eq. + + Parses the equals signs in the expression and replaces them with + appropriate Eq instances. Also works with nested equals signs. + + Does not yet play well with function arguments. + For example, the expression ``(x=y)`` is ambiguous and can be interpreted + as x being an argument to a function and ``convert_equals_signs`` will not + work for this. + + See also + ======== + convert_equality_operators + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, convert_equals_signs) + >>> parse_expr("1*2=x", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(2, x) + >>> parse_expr("(1*2=x)=False", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(Eq(2, x), False) + + """ + res1 = _group_parentheses(convert_equals_signs)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _transform_equals_sign(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +#: Standard transformations for :func:`parse_expr`. +#: Inserts calls to :class:`~.Symbol`, :class:`~.Integer`, and other SymPy +#: datatypes and allows the use of standard factorial notation (e.g. ``x!``). +standard_transformations: tuple[TRANS, ...] \ + = (lambda_notation, auto_symbol, repeated_decimals, auto_number, + factorial_notation) + + +def stringify_expr(s: str, local_dict: DICT, global_dict: DICT, + transformations: tuple[TRANS, ...]) -> str: + """ + Converts the string ``s`` to Python code, in ``local_dict`` + + Generally, ``parse_expr`` should be used. + """ + + tokens = [] + input_code = StringIO(s.strip()) + for toknum, tokval, _, _, _ in generate_tokens(input_code.readline): + tokens.append((toknum, tokval)) + + for transform in transformations: + tokens = transform(tokens, local_dict, global_dict) + + return untokenize(tokens) + + +def eval_expr(code, local_dict: DICT, global_dict: DICT): + """ + Evaluate Python code generated by ``stringify_expr``. + + Generally, ``parse_expr`` should be used. + """ + expr = eval( + code, global_dict, local_dict) # take local objects in preference + return expr + + +def parse_expr(s: str, local_dict: DICT | None = None, + transformations: tuple[TRANS, ...] | str \ + = standard_transformations, + global_dict: DICT | None = None, evaluate=True): + """Converts the string ``s`` to a SymPy expression, in ``local_dict``. + + .. warning:: + Note that this function uses ``eval``, and thus shouldn't be used on + unsanitized input. + + Parameters + ========== + + s : str + The string to parse. + + local_dict : dict, optional + A dictionary of local variables to use when parsing. + + global_dict : dict, optional + A dictionary of global variables. By default, this is initialized + with ``from sympy import *``; provide this parameter to override + this behavior (for instance, to parse ``"Q & S"``). + + transformations : tuple or str + A tuple of transformation functions used to modify the tokens of the + parsed expression before evaluation. The default transformations + convert numeric literals into their SymPy equivalents, convert + undefined variables into SymPy symbols, and allow the use of standard + mathematical factorial notation (e.g. ``x!``). Selection via + string is available (see below). + + evaluate : bool, optional + When False, the order of the arguments will remain as they were in the + string and automatic simplification that would normally occur is + suppressed. (see examples) + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import parse_expr + >>> parse_expr("1/2") + 1/2 + >>> type(_) + + >>> from sympy.parsing.sympy_parser import standard_transformations,\\ + ... implicit_multiplication_application + >>> transformations = (standard_transformations + + ... (implicit_multiplication_application,)) + >>> parse_expr("2x", transformations=transformations) + 2*x + + When evaluate=False, some automatic simplifications will not occur: + + >>> parse_expr("2**3"), parse_expr("2**3", evaluate=False) + (8, 2**3) + + In addition the order of the arguments will not be made canonical. + This feature allows one to tell exactly how the expression was entered: + + >>> a = parse_expr('1 + x', evaluate=False) + >>> b = parse_expr('x + 1', evaluate=False) + >>> a == b + False + >>> a.args + (1, x) + >>> b.args + (x, 1) + + Note, however, that when these expressions are printed they will + appear the same: + + >>> assert str(a) == str(b) + + As a convenience, transformations can be seen by printing ``transformations``: + + >>> from sympy.parsing.sympy_parser import transformations + + >>> print(transformations) + 0: lambda_notation + 1: auto_symbol + 2: repeated_decimals + 3: auto_number + 4: factorial_notation + 5: implicit_multiplication_application + 6: convert_xor + 7: implicit_application + 8: implicit_multiplication + 9: convert_equals_signs + 10: function_exponentiation + 11: rationalize + + The ``T`` object provides a way to select these transformations: + + >>> from sympy.parsing.sympy_parser import T + + If you print it, you will see the same list as shown above. + + >>> str(T) == str(transformations) + True + + Standard slicing will return a tuple of transformations: + + >>> T[:5] == standard_transformations + True + + So ``T`` can be used to specify the parsing transformations: + + >>> parse_expr("2x", transformations=T[:5]) + Traceback (most recent call last): + ... + SyntaxError: invalid syntax + >>> parse_expr("2x", transformations=T[:6]) + 2*x + >>> parse_expr('.3', transformations=T[3, 11]) + 3/10 + >>> parse_expr('.3x', transformations=T[:]) + 3*x/10 + + As a further convenience, strings 'implicit' and 'all' can be used + to select 0-5 and all the transformations, respectively. + + >>> parse_expr('.3x', transformations='all') + 3*x/10 + + See Also + ======== + + stringify_expr, eval_expr, standard_transformations, + implicit_multiplication_application + + """ + + if local_dict is None: + local_dict = {} + elif not isinstance(local_dict, dict): + raise TypeError('expecting local_dict to be a dict') + elif null in local_dict: + raise ValueError('cannot use "" in local_dict') + + if global_dict is None: + global_dict = {} + exec('from sympy import *', global_dict) + + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + global_dict[name] = obj + global_dict['max'] = Max + global_dict['min'] = Min + + elif not isinstance(global_dict, dict): + raise TypeError('expecting global_dict to be a dict') + + transformations = transformations or () + if isinstance(transformations, str): + if transformations == 'all': + _transformations = T[:] + elif transformations == 'implicit': + _transformations = T[:6] + else: + raise ValueError('unknown transformation group name') + else: + _transformations = transformations + + code = stringify_expr(s, local_dict, global_dict, _transformations) + + if not evaluate: + code = compile(evaluateFalse(code), '', 'eval') # type: ignore + + try: + rv = eval_expr(code, local_dict, global_dict) + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + return rv + except Exception as e: + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + raise e from ValueError(f"Error from parse_expr with transformed code: {code!r}") + + +def evaluateFalse(s: str): + """ + Replaces operators with the SymPy equivalent and sets evaluate=False. + """ + node = ast.parse(s) + transformed_node = EvaluateFalseTransformer().visit(node) + # node is a Module, we want an Expression + transformed_node = ast.Expression(transformed_node.body[0].value) + + return ast.fix_missing_locations(transformed_node) + + +class EvaluateFalseTransformer(ast.NodeTransformer): + operators = { + ast.Add: 'Add', + ast.Mult: 'Mul', + ast.Pow: 'Pow', + ast.Sub: 'Add', + ast.Div: 'Mul', + ast.BitOr: 'Or', + ast.BitAnd: 'And', + ast.BitXor: 'Not', + } + functions = ( + 'Abs', 'im', 're', 'sign', 'arg', 'conjugate', + 'acos', 'acot', 'acsc', 'asec', 'asin', 'atan', + 'acosh', 'acoth', 'acsch', 'asech', 'asinh', 'atanh', + 'cos', 'cot', 'csc', 'sec', 'sin', 'tan', + 'cosh', 'coth', 'csch', 'sech', 'sinh', 'tanh', + 'exp', 'ln', 'log', 'sqrt', 'cbrt', + ) + + relational_operators = { + ast.NotEq: 'Ne', + ast.Lt: 'Lt', + ast.LtE: 'Le', + ast.Gt: 'Gt', + ast.GtE: 'Ge', + ast.Eq: 'Eq' + } + def visit_Compare(self, node): + def reducer(acc, op_right): + result, left = acc + op, right = op_right + if op.__class__ not in self.relational_operators: + raise ValueError("Only equation or inequality operators are supported") + new = ast.Call( + func=ast.Name( + id=self.relational_operators[op.__class__], ctx=ast.Load() + ), + args=[self.visit(left), self.visit(right)], + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + return result + [new], right + + args, _ = reduce( + reducer, zip(node.ops, node.comparators), ([], node.left) + ) + if len(args) == 1: + return args[0] + return ast.Call( + func=ast.Name(id=self.operators[ast.BitAnd], ctx=ast.Load()), + args=args, + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + + def flatten(self, args, func): + result = [] + for arg in args: + if isinstance(arg, ast.Call): + arg_func = arg.func + if isinstance(arg_func, ast.Call): + arg_func = arg_func.func + if arg_func.id == func: + result.extend(self.flatten(arg.args, func)) + else: + result.append(arg) + else: + result.append(arg) + return result + + def visit_BinOp(self, node): + if node.op.__class__ in self.operators: + sympy_class = self.operators[node.op.__class__] + right = self.visit(node.right) + left = self.visit(node.left) + + rev = False + if isinstance(node.op, ast.Sub): + right = ast.Call( + func=ast.Name(id='Mul', ctx=ast.Load()), + args=[ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1)), right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + elif isinstance(node.op, ast.Div): + if isinstance(node.left, ast.UnaryOp): + left, right = right, left + rev = True + left = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[left, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + else: + right = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[right, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if rev: # undo reversal + left, right = right, left + new_node = ast.Call( + func=ast.Name(id=sympy_class, ctx=ast.Load()), + args=[left, right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if sympy_class in ('Add', 'Mul'): + # Denest Add or Mul as appropriate + new_node.args = self.flatten(new_node.args, sympy_class) + + return new_node + return node + + def visit_Call(self, node): + new_node = self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id in self.functions: + new_node.keywords.append(ast.keyword(arg='evaluate', value=ast.Constant(value=False))) + return new_node + + +_transformation = { # items can be added but never re-ordered +0: lambda_notation, +1: auto_symbol, +2: repeated_decimals, +3: auto_number, +4: factorial_notation, +5: implicit_multiplication_application, +6: convert_xor, +7: implicit_application, +8: implicit_multiplication, +9: convert_equals_signs, +10: function_exponentiation, +11: rationalize} + +transformations = '\n'.join('%s: %s' % (i, func_name(f)) for i, f in _transformation.items()) + + +class _T(): + """class to retrieve transformations from a given slice + + EXAMPLES + ======== + + >>> from sympy.parsing.sympy_parser import T, standard_transformations + >>> assert T[:5] == standard_transformations + """ + def __init__(self): + self.N = len(_transformation) + + def __str__(self): + return transformations + + def __getitem__(self, t): + if not type(t) is tuple: + t = (t,) + i = [] + for ti in t: + if type(ti) is int: + i.append(range(self.N)[ti]) + elif type(ti) is slice: + i.extend(range(*ti.indices(self.N))) + else: + raise TypeError('unexpected slice arg') + return tuple([_transformation[_] for _ in i]) + +T = _T() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7d2130f3590671230d21398036c1d93ae14e85f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a5b83fe37d1d4e23ee68d7b1c8ff373fe08eb2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a18e26c583017af0e4df5170432eacc6d20bf69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a32a9e4fcb5bf7e16b97b28b0bd5d6465a73d4f1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e62a34521c67a2a365eeafefb7d54307ad0b0b47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bddaf40e611f088c8b2efa7ac20b063a0e519ac9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7b64affbc8289081d7fb4948097a490f3f5300c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e24264350b4d940f27ba600c5f3368d65e89f365 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43689a6c3d86134eadbc3525fd4628444597878 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27548858e112d76a0fdac156746b510dbf2b682d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f9b8827a9834c0117c99375d574f16cff474cc0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..799c56f488a502695b9c57a7f042ba8d3fd7fcd3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796886ec7365877a3bb1c9fb884e3d937ef041e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_ast_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..24572190df72f9be11b5830355b0d6b9e3bb53ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_ast_parser.py @@ -0,0 +1,25 @@ +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.parsing.ast_parser import parse_expr +from sympy.testing.pytest import raises +from sympy.core.sympify import SympifyError +import warnings + +def test_parse_expr(): + a, b = symbols('a, b') + # tests issue_16393 + assert parse_expr('a + b', {}) == a + b + raises(SympifyError, lambda: parse_expr('a + ', {})) + + # tests Transform.visit_Constant + assert parse_expr('1 + 2', {}) == S(3) + assert parse_expr('1 + 2.0', {}) == S(3.0) + + # tests Transform.visit_Name + assert parse_expr('Rational(1, 2)', {}) == S(1)/2 + assert parse_expr('a', {'a': a}) == a + + # tests issue_23092 + with warnings.catch_warnings(): + warnings.simplefilter('error') + assert parse_expr('6 * 7', {}) == S(42) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_autolev.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_autolev.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcaef13565c5e2187dc6e90113b407a7967c331 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_autolev.py @@ -0,0 +1,178 @@ +import os + +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.parsing.autolev import parse_autolev + +antlr4 = import_module("antlr4") + +if not antlr4: + disabled = True + +FILE_DIR = os.path.dirname( + os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + + +def _test_examples(in_filename, out_filename, test_name=""): + + in_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + in_filename) + correct_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + out_filename) + with open(in_file_path) as f: + generated_code = parse_autolev(f, include_numeric=True) + + with open(correct_file_path) as f: + for idx, line1 in enumerate(f): + if line1.startswith("#"): + break + try: + line2 = generated_code.split('\n')[idx] + assert line1.rstrip() == line2.rstrip() + except Exception: + msg = 'mismatch in ' + test_name + ' in line no: {0}' + raise AssertionError(msg.format(idx+1)) + + +def test_rule_tests(): + + l = ["ruletest1", "ruletest2", "ruletest3", "ruletest4", "ruletest5", + "ruletest6", "ruletest7", "ruletest8", "ruletest9", "ruletest10", + "ruletest11", "ruletest12"] + + for i in l: + in_filepath = i + ".al" + out_filepath = i + ".py" + _test_examples(in_filepath, out_filepath, i) + + +def test_pydy_examples(): + + l = ["mass_spring_damper", "chaos_pendulum", "double_pendulum", + "non_min_pendulum"] + + for i in l: + in_filepath = os.path.join("pydy-example-repo", i + ".al") + out_filepath = os.path.join("pydy-example-repo", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_autolev_tutorial(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'autolev-tutorial') + + if os.path.isdir(dir_path): + l = ["tutor1", "tutor2", "tutor3", "tutor4", "tutor5", "tutor6", + "tutor7"] + for i in l: + in_filepath = os.path.join("autolev-tutorial", i + ".al") + out_filepath = os.path.join("autolev-tutorial", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_dynamics_online(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'dynamics-online') + + if os.path.isdir(dir_path): + ch1 = ["1-4", "1-5", "1-6", "1-7", "1-8", "1-9_1", "1-9_2", "1-9_3"] + ch2 = ["2-1", "2-2", "2-3", "2-4", "2-5", "2-6", "2-7", "2-8", "2-9", + "circular"] + ch3 = ["3-1_1", "3-1_2", "3-2_1", "3-2_2", "3-2_3", "3-2_4", "3-2_5", + "3-3"] + ch4 = ["4-1_1", "4-2_1", "4-4_1", "4-4_2", "4-5_1", "4-5_2"] + chapters = [(ch1, "ch1"), (ch2, "ch2"), (ch3, "ch3"), (ch4, "ch4")] + for ch, name in chapters: + for i in ch: + in_filepath = os.path.join("dynamics-online", name, i + ".al") + out_filepath = os.path.join("dynamics-online", name, i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_output_01(): + """Autolev example calculates the position, velocity, and acceleration of a + point and expresses in a single reference frame:: + + (1) FRAMES C,D,F + (2) VARIABLES FD'',DC'' + (3) CONSTANTS R,L + (4) POINTS O,E + (5) SIMPROT(F,D,1,FD) + -> (6) F_D = [1, 0, 0; 0, COS(FD), -SIN(FD); 0, SIN(FD), COS(FD)] + (7) SIMPROT(D,C,2,DC) + -> (8) D_C = [COS(DC), 0, SIN(DC); 0, 1, 0; -SIN(DC), 0, COS(DC)] + (9) W_C_F> = EXPRESS(W_C_F>, F) + -> (10) W_C_F> = FD'*F1> + COS(FD)*DC'*F2> + SIN(FD)*DC'*F3> + (11) P_O_E>=R*D2>-L*C1> + (12) P_O_E>=EXPRESS(P_O_E>, D) + -> (13) P_O_E> = -L*COS(DC)*D1> + R*D2> + L*SIN(DC)*D3> + (14) V_E_F>=EXPRESS(DT(P_O_E>,F),D) + -> (15) V_E_F> = L*SIN(DC)*DC'*D1> - L*SIN(DC)*FD'*D2> + (R*FD'+L*COS(DC)*DC')*D3> + (16) A_E_F>=EXPRESS(DT(V_E_F>,F),D) + -> (17) A_E_F> = L*(COS(DC)*DC'^2+SIN(DC)*DC'')*D1> + (-R*FD'^2-2*L*COS(DC)*DC'*FD'-L*SIN(DC)*FD'')*D2> + (R*FD''+L*COS(DC)*DC''-L*SIN(DC)*DC'^2-L*SIN(DC)*FD'^2)*D3> + + """ + + if not antlr4: + skip('Test skipped: antlr4 is not installed.') + + autolev_input = """\ +FRAMES C,D,F +VARIABLES FD'',DC'' +CONSTANTS R,L +POINTS O,E +SIMPROT(F,D,1,FD) +SIMPROT(D,C,2,DC) +W_C_F>=EXPRESS(W_C_F>,F) +P_O_E>=R*D2>-L*C1> +P_O_E>=EXPRESS(P_O_E>,D) +V_E_F>=EXPRESS(DT(P_O_E>,F),D) +A_E_F>=EXPRESS(DT(V_E_F>,F),D)\ +""" + + sympy_input = parse_autolev(autolev_input) + + g = {} + l = {} + exec(sympy_input, g, l) + + w_c_f = l['frame_c'].ang_vel_in(l['frame_f']) + # P_O_E> means "the position of point E wrt to point O" + p_o_e = l['point_e'].pos_from(l['point_o']) + v_e_f = l['point_e'].vel(l['frame_f']) + a_e_f = l['point_e'].acc(l['frame_f']) + + # NOTE : The Autolev outputs above were manually transformed into + # equivalent SymPy physics vector expressions. Would be nice to automate + # this transformation. + expected_w_c_f = (l['fd'].diff()*l['frame_f'].x + + cos(l['fd'])*l['dc'].diff()*l['frame_f'].y + + sin(l['fd'])*l['dc'].diff()*l['frame_f'].z) + + assert (w_c_f - expected_w_c_f).simplify() == 0 + + expected_p_o_e = (-l['l']*cos(l['dc'])*l['frame_d'].x + + l['r']*l['frame_d'].y + + l['l']*sin(l['dc'])*l['frame_d'].z) + + assert (p_o_e - expected_p_o_e).simplify() == 0 + + expected_v_e_f = (l['l']*sin(l['dc'])*l['dc'].diff()*l['frame_d'].x - + l['l']*sin(l['dc'])*l['fd'].diff()*l['frame_d'].y + + (l['r']*l['fd'].diff() + + l['l']*cos(l['dc'])*l['dc'].diff())*l['frame_d'].z) + assert (v_e_f - expected_v_e_f).simplify() == 0 + + expected_a_e_f = (l['l']*(cos(l['dc'])*l['dc'].diff()**2 + + sin(l['dc'])*l['dc'].diff().diff())*l['frame_d'].x + + (-l['r']*l['fd'].diff()**2 - + 2*l['l']*cos(l['dc'])*l['dc'].diff()*l['fd'].diff() - + l['l']*sin(l['dc'])*l['fd'].diff().diff())*l['frame_d'].y + + (l['r']*l['fd'].diff().diff() + + l['l']*cos(l['dc'])*l['dc'].diff().diff() - + l['l']*sin(l['dc'])*l['dc'].diff()**2 - + l['l']*sin(l['dc'])*l['fd'].diff()**2)*l['frame_d'].z) + assert (a_e_f - expected_a_e_f).simplify() == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_c_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b74622e40030cba180cb4fc354216ccca119baec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_c_parser.py @@ -0,0 +1,5248 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if cin: + from sympy.codegen.ast import (Variable, String, Return, + FunctionDefinition, Integer, Float, Declaration, CodeBlock, + FunctionPrototype, FunctionCall, NoneToken, Assignment, Type, + IntBaseType, SignedIntType, UnsignedIntType, FloatType, + AddAugmentedAssignment, SubAugmentedAssignment, + MulAugmentedAssignment, DivAugmentedAssignment, + ModAugmentedAssignment, While) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import (Add, Mul, Mod, Pow, Rational, + StrictLessThan, LessThan, StrictGreaterThan, GreaterThan, + Equality, Unequality) + from sympy.logic.boolalg import And, Not, Or + from sympy.core.symbol import Symbol + from sympy.logic.boolalg import (false, true) + import os + + def test_variable(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'float a;' + '\n' + + 'float b;' + '\n' + ) + c_src3 = ( + 'int a;' + '\n' + + 'float b;' + '\n' + + 'int c;' + ) + c_src4 = ( + 'int x = 1, y = 6.78;' + '\n' + + 'float p = 2, q = 9.67;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[2] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[3] == Declaration( + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('9.67', precision=53) + ) + ) + + + def test_int(): + c_src1 = 'int a = 1;' + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + ) + c_src3 = 'int a = 2.345, b = 5.67;' + c_src4 = 'int p = 6, q = 23.45;' + c_src5 = "int x = '0', y = 'a';" + c_src6 = "int r = true, s = false;" + + # cin.TypeKind.UCHAR + c_src_type1 = ( + "signed char a = 1, b = 5.1;" + ) + + # cin.TypeKind.SHORT + c_src_type2 = ( + "short a = 1, b = 5.1;" + "signed short c = 1, d = 5.1;" + "short int e = 1, f = 5.1;" + "signed short int g = 1, h = 5.1;" + ) + + # cin.TypeKind.INT + c_src_type3 = ( + "signed int a = 1, b = 5.1;" + "int c = 1, d = 5.1;" + ) + + # cin.TypeKind.LONG + c_src_type4 = ( + "long a = 1, b = 5.1;" + "long int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UCHAR + c_src_type5 = "unsigned char a = 1, b = 5.1;" + + # cin.TypeKind.USHORT + c_src_type6 = ( + "unsigned short a = 1, b = 5.1;" + "unsigned short int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UINT + c_src_type7 = "unsigned int a = 1, b = 5.1;" + + # cin.TypeKind.ULONG + c_src_type8 = ( + "unsigned long a = 1, b = 5.1;" + "unsigned long int c = 1, d = 5.1;" + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + res_type4 = SymPyExpression(c_src_type4, 'c').return_expr() + res_type5 = SymPyExpression(c_src_type5, 'c').return_expr() + res_type6 = SymPyExpression(c_src_type6, 'c').return_expr() + res_type7 = SymPyExpression(c_src_type7, 'c').return_expr() + res_type8 = SymPyExpression(c_src_type8, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('q'), + type=IntBaseType(String('intc')), + value=Integer(23) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(48) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(97) + ) + ) + + assert res6[0] == Declaration( + Variable( + Symbol('r'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res6[1] == Declaration( + Variable( + Symbol('s'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type2[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[2] == Declaration( + Variable(Symbol('c'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[4] == Declaration( + Variable( + Symbol('e'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[5] == Declaration( + Variable( + Symbol('f'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[6] == Declaration( + Variable( + Symbol('g'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[7] == Declaration( + Variable( + Symbol('h'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type4[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type4[2] == Declaration( + Variable( + Symbol('c'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type5[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type5[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type6[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type6[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type7[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(1) + ) + ) + + assert res_type7[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(5) + ) + ) + + assert res_type8[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type8[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + + def test_float(): + c_src1 = 'float a = 1.0;' + c_src2 = ( + 'float a = 1.25;' + '\n' + + 'float b = 2.39;' + '\n' + ) + c_src3 = 'float x = 1, y = 2;' + c_src4 = 'float p = 5, e = 7.89;' + c_src5 = 'float r = true, s = false;' + + # cin.TypeKind.FLOAT + c_src_type1 = 'float x = 1, y = 2.5;' + + # cin.TypeKind.DOUBLE + c_src_type2 = 'double x = 1, y = 2.5;' + + # cin.TypeKind.LONGDOUBLE + c_src_type3 = 'long double x = 1, y = 2.5;' + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3900000000000001', precision=53) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('5.0', precision=53) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('e'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('7.89', precision=53) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('r'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('s'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('0.0', precision=53) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + assert res_type2[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('2.5', precision=53) + ) + ) + + + def test_bool(): + c_src1 = ( + 'bool a = true, b = false;' + ) + + c_src2 = ( + 'bool a = 1, b = 0;' + ) + + c_src3 = ( + 'bool a = 10, b = 20;' + ) + + c_src4 = ( + 'bool a = 19.1, b = 9.0, c = 0.0;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res4[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function(): + c_src1 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2()' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3()' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + c_src4 = ( + 'float fun4()' + '\n' + + '{}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + assert res4[0] == FunctionPrototype( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun4'), + parameters=() + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_parameters(): + c_src1 = ( + 'void fun1( int a)' + '\n' + + '{' + '\n' + + 'int i;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2(float x, float y)' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3(int p, float q, int r)' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')) + ), + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function_call(): + c_src1 = ( + 'int fun1(int x)' + '\n' + + '{' + '\n' + + 'return x;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int x = fun1(2);' + '\n' + + '}' + ) + + c_src2 = ( + 'int fun2(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return a;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int y = fun2(2, 3, 4);' + '\n' + + '}' + ) + + c_src3 = ( + 'int fun3(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return b;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int p;' + '\n' + + 'int q;' + '\n' + + 'int r;' + '\n' + + 'int z = fun3(p, q, r);' + '\n' + + '}' + ) + + c_src4 = ( + 'int fun4(float a, float b, int c)' + '\n' + + '{' + '\n' + + 'return c;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'float x;' + '\n' + + 'float y;' + '\n' + + 'int z;' + '\n' + + 'int i = fun4(x, y, z)' + '\n' + + '}' + ) + + c_src5 = ( + 'int fun()' + '\n' + + '{' + '\n' + + 'return 1;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int a = fun()' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + + assert res1[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun1'), + parameters=(Variable(Symbol('x'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Return('x') + ) + ) + + assert res1[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + value=FunctionCall(String('fun1'), + function_args=( + Integer(2), + ) + ) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('a') + ) + ) + + assert res2[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('y'), + value=FunctionCall( + String('fun2'), + function_args=( + Integer(2), + Integer(3), + Integer(4) + ) + ) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun3'), + parameters=( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('b') + ) + ) + + assert res3[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('p'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('q'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('z'), + value=FunctionCall( + String('fun3'), + function_args=( + Symbol('p'), + Symbol('q'), + Symbol('r') + ) + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun4'), + parameters=(Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('c') + ) + ) + + assert res4[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('z'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('i'), + value=FunctionCall(String('fun4'), + function_args=( + Symbol('x'), + Symbol('y'), + Symbol('z') + ) + ) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun'), + parameters=(), + body=CodeBlock( + Return('') + ) + ) + + assert res5[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + value=FunctionCall(String('fun'), + function_args=() + ) + ) + ) + ) + ) + + + def test_parse(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + + f1 = open('..a.h', 'w') + f2 = open('..b.h', 'w') + + f1.write(c_src1) + f2. write(c_src2) + + f1.close() + f2.close() + + res1 = SymPyExpression('..a.h', 'c').return_expr() + res2 = SymPyExpression('..b.h', 'c').return_expr() + + os.remove('..a.h') + os.remove('..b.h') + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + + def test_binary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1;' + '\n' + + '}' + ) + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 0;' + '\n' + + 'a = a + 1;' + '\n' + + 'a = 3*a - 10;' + '\n' + + '}' + ) + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'a = 1 + a - 3 * 6;' + '\n' + + '}' + ) + c_src4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'a = 100;' + '\n' + + 'b = a*a + a*a + a + 19*a + 1 + 24;' + '\n' + + '}' + ) + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = b;' + '\n' + + 'd = ((a+b)*(a+c))*((c-d)*(a+c));' + '\n' + + '}' + ) + c_src6 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = 3;' + '\n' + + 'd = (a*a*a*a + 3*b*b + b + b + c*d);' + '\n' + + '}' + ) + c_src7 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 1.01;' + '\n' + + '}' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 + 2.5;' + '\n' + + '}' + ) + + c_src9 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 / 2.5;' + '\n' + + '}' + ) + + c_src10 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 100 / 4;' + '\n' + + '}' + ) + + c_src11 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 20 - 100 / 4 * 5 + 10;' + '\n' + + '}' + ) + + c_src12 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (20 - 100) / 4 * (5 + 10);' + '\n' + + '}' + ) + + c_src13 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'float c;' + '\n' + + 'c = b/a;' + '\n' + + '}' + ) + + c_src14 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s;' + '\n' + + 's = (a/2)*(2*a + (n-1)*d);' + '\n' + + '}' + ) + + c_src15 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1 % 2;' + '\n' + + '}' + ) + + c_src16 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int b;' + '\n' + + 'b = a % 3;' + '\n' + + '}' + ) + + c_src17 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c;' + '\n' + + 'c = a % b;' + '\n' + + '}' + ) + + c_src18 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = (a + b * (100/a)) % mod;' + '\n' + + '}' + ) + + c_src19 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = ((a % mod + b % mod) % mod' \ + '* (a % mod - b % mod) % mod) % mod;' + '\n' + + '}' + ) + + c_src20 = ( + 'void func()'+ + '{' + '\n' + + 'bool a' + '\n' + + 'bool b;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 != 2;' + '\n' + + '}' + ) + + c_src21 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool b;' + '\n' + + 'bool c;' + '\n' + + 'bool d;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 <= 2;' + '\n' + + 'c = 1 > 2;' + '\n' + + 'd = 1 >= 2;' + '\n' + + '}' + ) + + c_src22 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + 'bool c7;' + '\n' + + 'bool c8;' + '\n' + + + 'c1 = a == 1;' + '\n' + + 'c2 = b == 2;' + '\n' + + + 'c3 = 1 != a;' + '\n' + + 'c4 = 1 != b;' + '\n' + + + 'c5 = a < 0;' + '\n' + + 'c6 = b <= 10;' + '\n' + + 'c7 = a > 0;' + '\n' + + 'c8 = b >= 11;' + '\n' + + '}' + ) + + c_src23 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 3;' + '\n' + + 'int b = 4;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src24 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a == 1.25;' + '\n' + + 'c2 = b == 2.54;' + '\n' + + + 'c3 = 1.2 != a;' + '\n' + + 'c4 = 1.5 != b;' + '\n' + + '}' + ) + + c_src25 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + '\n' + + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src26 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true == true;' + '\n' + + 'c2 = true == false;' + '\n' + + 'c3 = false == false;' + '\n' + + + 'c4 = true != true;' + '\n' + + 'c5 = true != false;' + '\n' + + 'c6 = false != false;' + '\n' + + '}' + ) + + c_src27 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true && true;' + '\n' + + 'c2 = true && false;' + '\n' + + 'c3 = false && false;' + '\n' + + + 'c4 = true || true;' + '\n' + + 'c5 = true || false;' + '\n' + + 'c6 = false || false;' + '\n' + + '}' + ) + + c_src28 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && true;' + '\n' + + 'c2 = false && a;' + '\n' + + + 'c3 = true || a;' + '\n' + + 'c4 = a || false;' + '\n' + + '}' + ) + + c_src29 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && 1;' + '\n' + + 'c2 = a && 0;' + '\n' + + + 'c3 = a || 1;' + '\n' + + 'c4 = 0 || a;' + '\n' + + '}' + ) + + c_src30 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'bool c;'+ '\n' + + 'bool d;'+ '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a && b;' + '\n' + + 'c2 = a && c;' + '\n' + + 'c3 = c && d;' + '\n' + + + 'c4 = a || b;' + '\n' + + 'c5 = a || c;' + '\n' + + 'c6 = c || d;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -1;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -+1;' + '\n' + + '}' + ) + + c_src_raise3 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 2*-2;' + '\n' + + '}' + ) + + c_src_raise4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (int)2.0;' + '\n' + + '}' + ) + + c_src_raise5 = ( + 'void func()'+ + '{' + '\n' + + 'int a=100;' + '\n' + + 'a = (a==100)?(1):(0);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment(Variable(Symbol('a')), Integer(1)) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(0))), + Assignment( + Variable(Symbol('a')), + Add(Symbol('a'), + Integer(1)) + ), + Assignment(Variable(Symbol('a')), + Add( + Mul( + Integer(3), + Symbol('a')), + Integer(-10) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Assignment( + Variable(Symbol('a')), + Add( + Symbol('a'), + Integer(-17) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(100)), + Assignment( + Variable(Symbol('b')), + Add( + Mul( + Integer(2), + Pow( + Symbol('a'), + Integer(2)) + ), + Mul( + Integer(20), + Symbol('a')), + Integer(25) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1)), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Symbol('b')), + Assignment( + Variable(Symbol('d')), + Mul( + Add( + Symbol('a'), + Symbol('b')), + Pow( + Add( + Symbol('a'), + Symbol('c') + ), + Integer(2) + ), + Add( + Symbol('c'), + Mul( + Integer(-1), + Symbol('d') + ) + ) + ) + ) + ) + ) + + assert res6[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Integer(3) + ), + Assignment( + Variable(Symbol('d')), + Add( + Pow( + Symbol('a'), + Integer(4) + ), + Mul( + Integer(3), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Mul( + Integer(2), + Symbol('b') + ), + Mul( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + ) + + assert res7[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('1.01', precision=53) + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('12.5', precision=53) + ) + ) + ) + + assert res9[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('4.0', precision=53) + ) + ) + ) + + assert res10[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(25) + ) + ) + ) + + assert res11[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-95) + ) + ) + ) + + assert res12[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-300) + ) + ) + ) + + assert res13[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('c')), + Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + ) + + assert res14[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ), + Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('s')), + Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + ) + + assert res15[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ) + ) + ) + + assert res16[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('b')), + Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + ) + + assert res17[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res18[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res19[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Mul( + Add( + Mod( + Symbol('a'), + Symbol('mod') + ), + Mul( + Integer(-1), + Mod( + Symbol('b'), + Symbol('mod') + ) + ) + ), + Mod( + Add( + Symbol('a'), + Symbol('b') + ), + Symbol('mod') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res20[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ) + ) + ) + + assert res21[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ), + Assignment( + Variable(Symbol('c')), + false + ), + Assignment( + Variable(Symbol('d')), + false + ) + ) + ) + + assert res22[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Integer(1) + ) + ), + Assignment( + Variable(Symbol('c2')), + Equality( + Symbol('b'), + Integer(2) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Integer(1), + Symbol('a') + ) + ), + Assignment( + Variable(Symbol('c4')), + Unequality( + Integer(1), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictLessThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c6')), + LessThan( + Symbol('b'), + Integer(10) + ) + ), + Assignment( + Variable(Symbol('c7')), + StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c8')), + GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + ) + + assert res23[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res24[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + ) + + + assert res25[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ), + Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool') + ) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res26[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + false + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false + ) + ) + ) + + assert res27[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + false + ), + Assignment( + Variable(Symbol('c4')), + true + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false) + ) + ) + + assert res28[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res29[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res30[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + And( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + And( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c3')), + And( + Symbol('c'), + Symbol('d') + ) + ), + Assignment( + Variable(Symbol('c4')), + Or( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + Or( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c6')), + Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise3, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise4, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise5, 'c')) + + + @XFAIL + def test_var_decl(): + c_src1 = ( + 'int b = 100;' + '\n' + + 'int a = b;' + '\n' + ) + + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = a + 1;' + '\n' + ) + + c_src3 = ( + 'float a = 10.0 + 2.5;' + '\n' + + 'float b = a * 20.0;' + '\n' + ) + + c_src4 = ( + 'int a = 1 + 100 - 3 * 6;' + '\n' + ) + + c_src5 = ( + 'int a = (((1 + 100) * 12) - 3) * (6 - 10);' + '\n' + ) + + c_src6 = ( + 'int b = 2;' + '\n' + + 'int c = 3;' + '\n' + + 'int a = b + c * 4;' + '\n' + ) + + c_src7 = ( + 'int b = 1;' + '\n' + + 'int c = b + 2;' + '\n' + + 'int a = 10 * b * b * c;' + '\n' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int temp = a;' + '\n' + + 'a = b;' + '\n' + + 'b = temp;' + '\n' + + '}' + ) + + c_src9 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int c = a;' + '\n' + + 'int d = a + b + c;' + '\n' + + 'int e = a*a*a + 3*a*a*b + 3*a*b*b + b*b*b;' + '\n' + 'int f = (a + b + c) * (a + b - c);' + '\n' + + 'int g = (a + b + c + d)*(a + b + c + d)*(a * (b - c));' + + '\n' + ) + + c_src10 = ( + 'float a = 10.0;' + '\n' + + 'float b = 2.5;' + '\n' + + 'float c = a*a + 2*a*b + b*b;' + '\n' + ) + + c_src11 = ( + 'float a = 10.0 / 2.5;' + '\n' + ) + + c_src12 = ( + 'int a = 100 / 4;' + '\n' + ) + + c_src13 = ( + 'int a = 20 - 100 / 4 * 5 + 10;' + '\n' + ) + + c_src14 = ( + 'int a = (20 - 100) / 4 * (5 + 10);' + '\n' + ) + + c_src15 = ( + 'int a = 4;' + '\n' + + 'int b = 2;' + '\n' + + 'float c = b/a;' + '\n' + ) + + c_src16 = ( + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s = (a/2)*(2*a + (n-1)*d);' + '\n' + ) + + c_src17 = ( + 'int a = 1 % 2;' + '\n' + ) + + c_src18 = ( + 'int a = 2;' + '\n' + + 'int b = a % 3;' + '\n' + ) + + c_src19 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c = a % b;' + '\n' + ) + + c_src20 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = (a + b * (100/a)) % mod;' + '\n' + ) + + c_src21 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = ((a % mod + b % mod) % mod *' \ + '(a % mod - b % mod) % mod) % mod;' + '\n' + ) + + c_src22 = ( + 'bool a = 1 == 2, b = 1 != 2;' + ) + + c_src23 = ( + 'bool a = 1 < 2, b = 1 <= 2, c = 1 > 2, d = 1 >= 2;' + ) + + c_src24 = ( + 'int a = 1, b = 2;' + '\n' + + + 'bool c1 = a == 1;' + '\n' + + 'bool c2 = b == 2;' + '\n' + + + 'bool c3 = 1 != a;' + '\n' + + 'bool c4 = 1 != b;' + '\n' + + + 'bool c5 = a < 0;' + '\n' + + 'bool c6 = b <= 10;' + '\n' + + 'bool c7 = a > 0;' + '\n' + + 'bool c8 = b >= 11;' + + ) + + c_src25 = ( + 'int a = 3, b = 4;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src26 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == 1.25;' + '\n' + + 'bool c2 = b == 2.54;' + '\n' + + + 'bool c3 = 1.2 != a;' + '\n' + + 'bool c4 = 1.5 != b;' + ) + + c_src27 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src28 = ( + 'bool c1 = true == true;' + '\n' + + 'bool c2 = true == false;' + '\n' + + 'bool c3 = false == false;' + '\n' + + + 'bool c4 = true != true;' + '\n' + + 'bool c5 = true != false;' + '\n' + + 'bool c6 = false != false;' + ) + + c_src29 = ( + 'bool c1 = true && true;' + '\n' + + 'bool c2 = true && false;' + '\n' + + 'bool c3 = false && false;' + '\n' + + + 'bool c4 = true || true;' + '\n' + + 'bool c5 = true || false;' + '\n' + + 'bool c6 = false || false;' + ) + + c_src30 = ( + 'bool a = false;' + '\n' + + + 'bool c1 = a && true;' + '\n' + + 'bool c2 = false && a;' + '\n' + + + 'bool c3 = true || a;' + '\n' + + 'bool c4 = a || false;' + ) + + c_src31 = ( + 'int a = 1;' + '\n' + + + 'bool c1 = a && 1;' + '\n' + + 'bool c2 = a && 0;' + '\n' + + + 'bool c3 = a || 1;' + '\n' + + 'bool c4 = 0 || a;' + ) + + c_src32 = ( + 'int a = 1, b = 0;' + '\n' + + 'bool c = false, d = true;'+ '\n' + + + 'bool c1 = a && b;' + '\n' + + 'bool c2 = a && c;' + '\n' + + 'bool c3 = c && d;' + '\n' + + + 'bool c4 = a || b;' + '\n' + + 'bool c5 = a || c;' + '\n' + + 'bool c6 = c || d;' + ) + + c_src_raise1 = ( + "char a = 'b';" + ) + + c_src_raise2 = ( + 'int a[] = {10, 20};' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + res31 = SymPyExpression(c_src31, 'c').return_expr() + res32 = SymPyExpression(c_src32, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Symbol('b') + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration(Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('12.5', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Float('20.0', precision=53), + Symbol('a') + ) + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(83) + ) + ) + + assert res5[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-4836) + ) + ) + + assert res6[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res6[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res6[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Mul( + Integer(4), + Symbol('c') + ) + ) + ) + ) + + assert res7[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res7[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res7[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Mul( + Integer(10), + Pow( + Symbol('b'), + Integer(2) + ), + Symbol('c') + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('temp'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ), + Assignment( + Variable(Symbol('a')), + Symbol('b') + ), + Assignment( + Variable(Symbol('b')), + Symbol('temp') + ) + ) + ) + + assert res9[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res9[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res9[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res9[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + + assert res9[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Pow( + Symbol('a'), + Integer(3) + ), + Mul( + Integer(3), + Pow( + Symbol('a'), + Integer(2) + ), + Symbol('b') + ), + Mul( + Integer(3), + Symbol('a'), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Pow( + Symbol('b'), + Integer(3) + ) + ) + ) + ) + + assert res9[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Mul( + Add( + Symbol('a'), + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + ) + + assert res9[6] == Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=Mul( + Symbol('a'), + Add( + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Pow( + Add( + Symbol('a'), + Symbol('b'), + Symbol('c'), + Symbol('d') + ), + Integer(2) + ) + ) + ) + ) + + assert res10[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('10.0', precision=53) + ) + ) + + assert res10[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res10[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Add( + Pow( + Symbol('a'), + Integer(2) + ), + Mul( + Integer(2), + Symbol('a'), + Symbol('b') + ), + Pow( + Symbol('b'), + Integer(2) + ) + ) + ) + ) + + assert res11[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('4.0', precision=53) + ) + ) + + assert res12[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(25) + ) + ) + + assert res13[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-95) + ) + ) + + assert res14[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-300) + ) + ) + + assert res15[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res15[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res15[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + + assert res16[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res16[1] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res16[2] == Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ) + + assert res16[3] == Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')), + value=Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + + assert res17[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res18[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res18[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + + assert res19[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + assert res19[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res19[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res20[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res20[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res20[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res20[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res21[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res21[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res21[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res21[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Mul( + Add( + Symbol('a'), + Mul( + Integer(-1), + Symbol('b') + ) + ), + Add( + Symbol('a'), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res22[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res22[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res23[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=false + ) + ) + + assert res24[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res24[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res24[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res24[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res24[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('a') + ) + ) + ) + + assert res24[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('b') + ) + ) + ) + + assert res24[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictLessThan(Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=LessThan( + Symbol('b'), + Integer(10) + ) + ) + ) + + assert res24[8] == Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[9] == Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + + assert res25[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res25[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res25[2] == Declaration(Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res26[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res26[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res26[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ) + ) + + assert res26[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Float('2.54', precision=53) + ) + ) + ) + + assert res26[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + + assert res26[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Float('1.5', precision=53), + Symbol('b') + ) + ) + ) + + assert res27[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res27[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res27[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res28[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res30[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res30[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res31[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res31[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res31[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res32[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res32[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res32[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res32[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=true + ) + ) + + assert res32[4] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[5] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[6] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=And( + Symbol('c'), + Symbol('d') + ) + ) + ) + + assert res32[7] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[8] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[9] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_paren_expr(): + c_src1 = ( + 'int a = (1);' + 'int b = (1 + 2 * 3);' + ) + + c_src2 = ( + 'int a = 1, b = 2, c = 3;' + 'int d = (a);' + 'int e = (a + 1);' + 'int f = (a + b * c - d / e);' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(7) + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res2[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res2[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res2[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res2[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Mul( + Symbol('b'), + Symbol('c') + ), + Mul( + Integer(-1), + Symbol('d'), + Pow( + Symbol('e'), + Integer(-1) + ) + ) + ) + ) + ) + + + def test_unary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = 20;' + '\n' + + '++a;' + '\n' + + '--b;' + '\n' + + 'a++;' + '\n' + + 'b--;' + '\n' + + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = -100;' + '\n' + + 'int c = +19;' + '\n' + + 'int d = ++a;' + '\n' + + 'int e = --b;' + '\n' + + 'int f = a++;' + '\n' + + 'int g = b--;' + '\n' + + 'bool h = !false;' + '\n' + + 'bool i = !d;' + '\n' + + 'bool j = !0;' + '\n' + + 'bool k = !10.0;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = ~a;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = *&a;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(20) + ) + ), + PreIncrement(Symbol('a')), + PreDecrement(Symbol('b')), + PostIncrement(Symbol('a')), + PostDecrement(Symbol('b')) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(-100) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(19) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=PreIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=PreDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=PostIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=PostDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('h'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('i'), + type=Type(String('bool')), + value=Not(Symbol('d')) + ) + ), + Declaration( + Variable(Symbol('j'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('k'), + type=Type(String('bool')), + value=false + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_compound_assignment_operator(): + c_src = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'a += 10;' + '\n' + + 'a -= 10;' + '\n' + + 'a *= 10;' + '\n' + + 'a /= 10;' + '\n' + + 'a %= 10;' + '\n' + + '}' + ) + + res = SymPyExpression(c_src, 'c').return_expr() + + assert res[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + AddAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + SubAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + MulAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + DivAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + ModAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ) + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_while_stmt(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + '{' + '\n' + + 'i++;' + '\n' + + '}' + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + 'i++;' + '\n' + + '}' + ) + + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 10;' + '\n' + + 'int cnt = 0;' + '\n' + + 'while(i > 0)' + '\n' + + '{' + '\n' + + 'i--;' + '\n' + + 'cnt++;' + '\n' + + '}' + '\n' + + '}' + ) + + c_src4 = ( + 'int digit_sum(int n)'+ + '{' + '\n' + + 'int sum = 0;' + '\n' + + 'while(n > 0)' + '\n' + + '{' + '\n' + + 'sum += (n % 10);' + '\n' + + 'n /= 10;' + '\n' + + '}' + '\n' + + 'return sum;' + '\n' + + '}' + ) + + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'while(1);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictLessThan( + Symbol('i'), + Integer(10) + ), + body=CodeBlock( + PostIncrement( + Symbol('i') + ) + ) + ) + ) + ) + + assert res2[0] == res1[0] + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable( + Symbol('cnt'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('i'), + Integer(0) + ), + body=CodeBlock( + PostDecrement( + Symbol('i') + ), + PostIncrement( + Symbol('cnt') + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('digit_sum'), + parameters=( + Variable( + Symbol('n'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('sum'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('n'), + Integer(0) + ), + body=CodeBlock( + AddAugmentedAssignment( + Variable( + Symbol('sum') + ), + Mod( + Symbol('n'), + Integer(10) + ) + ), + DivAugmentedAssignment( + Variable( + Symbol('n') + ), + Integer(10) + ) + ) + ), + Return('sum') + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + While( + Integer(1), + body=CodeBlock( + NoneToken() + ) + ) + ) + ) + + +else: + def test_raise(): + from sympy.parsing.c.c_parser import CCodeConverter + raises(ImportError, lambda: CCodeConverter()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'c')) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_custom_latex.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_custom_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eff1c9ec79528c7f9e3a06cf9e2f84c86091ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_custom_latex.py @@ -0,0 +1,69 @@ +import os +import tempfile +from pathlib import Path + +import sympy +from sympy.testing.pytest import raises +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark +from sympy.external import import_module + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +grammar_file = os.path.join(os.path.dirname(__file__), "../latex/lark/grammar/latex.lark") + +modification1 = """ +%override DIV_SYMBOL: DIV +%override MUL_SYMBOL: MUL | CMD_TIMES +""" + +modification2 = r""" +%override number: /\d+(,\d*)?/ +""" + +def init_custom_parser(modification, transformer=None): + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + latex_grammar += modification + + with tempfile.NamedTemporaryFile() as f: + f.write(bytes(latex_grammar, encoding="utf8")) + f.flush() + + parser = LarkLaTeXParser(grammar_file=f.name, transformer=transformer) + + return parser + +def test_custom1(): + # Removes the parser's ability to understand \cdot and \div. + + parser = init_custom_parser(modification1) + + with raises(lark.exceptions.UnexpectedCharacters): + parser.doparse(r"a \cdot b") + parser.doparse(r"x \div y") + +class CustomTransformer(TransformToSymPyExpr): + def number(self, tokens): + if "," in tokens[0]: + # The Float constructor expects a dot as the decimal separator + return sympy.core.numbers.Float(tokens[0].replace(",", ".")) + else: + return sympy.core.numbers.Integer(tokens[0]) + +def test_custom2(): + # Makes the parser parse commas as the decimal separator instead of dots + + parser = init_custom_parser(modification2, CustomTransformer) + + with raises(lark.exceptions.UnexpectedCharacters): + # Asserting that the default parser cannot parse numbers which have commas as + # the decimal separator + parse_latex_lark("100,1") + parse_latex_lark("0,009") + + parser.doparse("100,1") + parser.doparse("0,009") + parser.doparse("2,71828") + parser.doparse("3,14159") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_fortran_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd54533ef231dd0a116910453dff0e993bc727 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_fortran_parser.py @@ -0,0 +1,406 @@ +from sympy.testing.pytest import raises +from sympy.parsing.sym_expr import SymPyExpression +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment, + Declaration, CodeBlock) + from sympy.core import Integer, Float, Add + from sympy.core.symbol import Symbol + + + expr1 = SymPyExpression() + expr2 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + + def test_sym_expr(): + src1 = ( + src + + """\ + d = a + b -c + """ + ) + expr3 = SymPyExpression(src,'f') + expr4 = SymPyExpression(src1,'f') + ls1 = expr3.return_expr() + ls2 = expr4.return_expr() + for i in range(0, 7): + assert isinstance(ls1[i], Declaration) + assert isinstance(ls2[i], Declaration) + assert isinstance(ls2[8], Assignment) + assert ls1[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls2[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + + def test_assignment(): + src1 = ( + src + + """\ + a = b + c = d + p = q + r = s + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(0, 12): + if iter < 8: + assert isinstance(ls1[iter], Declaration) + else: + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('a')), + Variable(Symbol('b')) + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Variable(Symbol('d')) + ) + assert ls1[10] == Assignment( + Variable(Symbol('p')), + Variable(Symbol('q')) + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Variable(Symbol('s')) + ) + + + def test_binop_add(): + src1 = ( + src + + """\ + c = a + b + d = a + c + s = p + q + r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') + Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') + Symbol('q') + Symbol('r') + ) + + + def test_binop_sub(): + src1 = ( + src + + """\ + c = a - b + d = a - c + s = p - q - r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') - Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') - Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') - Symbol('q') - Symbol('r') + ) + + + def test_binop_mul(): + src1 = ( + src + + """\ + c = a * b + d = a * c + s = p * q * r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') * Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') * Symbol('r') + ) + + + def test_binop_div(): + src1 = ( + src + + """\ + c = a / b + d = a / c + s = p / q + r = q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') / Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') / Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') / Symbol('q') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('q') / Symbol('p') + ) + + def test_mul_binop(): + src1 = ( + src + + """\ + d = a + b - c + c = a * b + d + s = p * q / r + r = p * s + q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + Symbol('d') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') / Symbol('r') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('p') * Symbol('s') + Symbol('q') / Symbol('p') + ) + + + def test_function(): + src1 = """\ + integer function f(a,b) + integer :: x, y + f = x + y + end function + """ + expr1.convert_to_expr(src1, 'f') + for iter in expr1.return_expr(): + assert isinstance(iter, FunctionDefinition) + assert iter == FunctionDefinition( + IntBaseType(String('integer')), + name=String('f'), + parameters=( + Variable(Symbol('a')), + Variable(Symbol('b')) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('f'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Assignment( + Variable(Symbol('f')), + Add(Symbol('x'), Symbol('y')) + ), + Return(Variable(Symbol('f'))) + ) + ) + + + def test_var(): + expr1.convert_to_expr(src, 'f') + ls = expr1.return_expr() + for iter in expr1.return_expr(): + assert isinstance(iter, Declaration) + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + +else: + def test_raise(): + from sympy.parsing.fortran.fortran_parser import ASR2PyVisitor + raises(ImportError, lambda: ASR2PyVisitor()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'f')) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py new file mode 100644 index 0000000000000000000000000000000000000000..56df361e77b0c0f94bdb53b03e0dc30a8a10899f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py @@ -0,0 +1,195 @@ +import sympy +from sympy.parsing.sympy_parser import ( + parse_expr, + standard_transformations, + convert_xor, + implicit_multiplication_application, + implicit_multiplication, + implicit_application, + function_exponentiation, + split_symbols, + split_symbols_custom, + _token_splittable +) +from sympy.testing.pytest import raises + + +def test_implicit_multiplication(): + cases = { + '5x': '5*x', + 'abc': 'a*b*c', + '3sin(x)': '3*sin(x)', + '(x+1)(x+2)': '(x+1)*(x+2)', + '(5 x**2)sin(x)': '(5*x**2)*sin(x)', + '2 sin(x) cos(x)': '2*sin(x)*cos(x)', + 'pi x': 'pi*x', + 'x pi': 'x*pi', + 'E x': 'E*x', + 'EulerGamma y': 'EulerGamma*y', + 'E pi': 'E*pi', + 'pi (x + 2)': 'pi*(x+2)', + '(x + 2) pi': '(x+2)*pi', + 'pi sin(x)': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (split_symbols, + implicit_multiplication) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + application = ['sin x', 'cos 2*x', 'sin cos x'] + for case in application: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_implicit_application(): + cases = { + 'factorial': 'factorial', + 'sin x': 'sin(x)', + 'tan y**3': 'tan(y**3)', + 'cos 2*x': 'cos(2*x)', + '(cot)': 'cot', + 'sin cos tan x': 'sin(cos(tan(x)))' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal), (implicit, normal) + + multiplication = ['x y', 'x sin x', '2x'] + for case in multiplication: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_function_exponentiation(): + cases = { + 'sin**2(x)': 'sin(x)**2', + 'exp^y(z)': 'exp(z)^y', + 'sin**2(E^(x))': 'sin(E^(x))**2' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (function_exponentiation,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + other_implicit = ['x y', 'x sin x', '2x', 'sin x', + 'cos 2*x', 'sin cos x'] + for case in other_implicit: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + + assert parse_expr('x**2', local_dict={ 'x': sympy.Symbol('x') }, + transformations=transformations2) == parse_expr('x**2') + + +def test_symbol_splitting(): + # By default Greek letter names should not be split (lambda is a keyword + # so skip it) + transformations = standard_transformations + (split_symbols,) + greek_letters = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', + 'eta', 'theta', 'iota', 'kappa', 'mu', 'nu', 'xi', + 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', + 'phi', 'chi', 'psi', 'omega') + + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + # Make sure symbol splitting resolves names + transformations += (implicit_multiplication,) + local_dict = { 'e': sympy.E } + cases = { + 'xe': 'E*x', + 'Iy': 'I*y', + 'ee': 'E*E', + } + for case, expected in cases.items(): + assert(parse_expr(case, local_dict=local_dict, + transformations=transformations) == + parse_expr(expected)) + + # Make sure custom splitting works + def can_split(symbol): + if symbol not in ('unsplittable', 'names'): + return _token_splittable(symbol) + return False + transformations = standard_transformations + transformations += (split_symbols_custom(can_split), + implicit_multiplication) + + assert(parse_expr('unsplittable', transformations=transformations) == + parse_expr('unsplittable')) + assert(parse_expr('names', transformations=transformations) == + parse_expr('names')) + assert(parse_expr('xy', transformations=transformations) == + parse_expr('x*y')) + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + +def test_all_implicit_steps(): + cases = { + '2x': '2*x', # implicit multiplication + 'x y': 'x*y', + 'xy': 'x*y', + 'sin x': 'sin(x)', # add parentheses + '2sin x': '2*sin(x)', + 'x y z': 'x*y*z', + 'sin(2 * 3x)': 'sin(2 * 3 * x)', + 'sin(x) (1 + cos(x))': 'sin(x) * (1 + cos(x))', + '(x + 2) sin(x)': '(x + 2) * sin(x)', + '(x + 2) sin x': '(x + 2) * sin(x)', + 'sin(sin x)': 'sin(sin(x))', + 'sin x!': 'sin(factorial(x))', + 'sin x!!': 'sin(factorial2(x))', + 'factorial': 'factorial', # don't apply a bare function + 'x sin x': 'x * sin(x)', # both application and multiplication + 'xy sin x': 'x * y * sin(x)', + '(x+2)(x+3)': '(x + 2) * (x+3)', + 'x**2 + 2xy + y**2': 'x**2 + 2 * x * y + y**2', # split the xy + 'pi': 'pi', # don't mess with constants + 'None': 'None', + 'ln sin x': 'ln(sin(x))', # multiple implicit function applications + 'sin x**2': 'sin(x**2)', # implicit application to an exponential + 'alpha': 'Symbol("alpha")', # don't split Greek letters/subscripts + 'x_2': 'Symbol("x_2")', + 'sin^2 x**2': 'sin(x**2)**2', # function raised to a power + 'sin**3(x)': 'sin(x)**3', + '(factorial)': 'factorial', + 'tan 3x': 'tan(3*x)', + 'sin^2(3*E^(x))': 'sin(3*E**(x))**2', + 'sin**2(E^(3x))': 'sin(E**(3*x))**2', + 'sin^2 (3x*E^(x))': 'sin(3*x*E^x)**2', + 'pi sin x': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_multiplication_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + +def test_no_methods_implicit_multiplication(): + # Issue 21020 + u = sympy.Symbol('u') + transformations = standard_transformations + \ + (implicit_multiplication,) + expr = parse_expr('x.is_polynomial(x)', transformations=transformations) + assert expr == True + expr = parse_expr('(exp(x) / (1 + exp(2x))).subs(exp(x), u)', + transformations=transformations) + assert expr == u/(u**2 + 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..49a48966eacaa1cd7a242dcd0e7699c992bb1268 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex.py @@ -0,0 +1,358 @@ +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, oo) +from sympy.core.power import Pow +from sympy.core.relational import (GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import (ceiling, floor) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, csc, sec, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum.state import Bra, Ket +from sympy.abc import x, y, z, a, b, c, t, k, n +antlr4 = import_module("antlr4") + +# disable tests if antlr4-python3-runtime is not present +disabled = antlr4 is None + +theta = Symbol('theta') +f = Function('f') + + +# shorthand definitions +def _Add(a, b): + return Add(a, b, evaluate=False) + + +def _Mul(a, b): + return Mul(a, b, evaluate=False) + + +def _Pow(a, b): + return Pow(a, b, evaluate=False) + + +def _Sqrt(a): + return sqrt(a, evaluate=False) + + +def _Conjugate(a): + return conjugate(a, evaluate=False) + + +def _Abs(a): + return Abs(a, evaluate=False) + + +def _factorial(a): + return factorial(a, evaluate=False) + + +def _exp(a): + return exp(a, evaluate=False) + + +def _log(a, b): + return log(a, b, evaluate=False) + + +def _binomial(n, k): + return binomial(n, k, evaluate=False) + + +def test_import(): + from sympy.parsing.latex._build_latex_antlr import ( + build_parser, + check_antlr_version, + dir_latex_antlr + ) + # XXX: It would be better to come up with a test for these... + del build_parser, check_antlr_version, dir_latex_antlr + + +# These LaTeX strings should parse to the corresponding SymPy expression +GOOD_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"x", x), + (r"2x", 2*x), + (r"x^2", x**2), + (r"x^\frac{1}{2}", _Pow(x, _Pow(2, -1))), + (r"x^{3 + 1}", x**_Add(3, 1)), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a+b, -a)), + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))), + (r"y''_1", Symbol("y_{1}''")), + (r"y_1''", Symbol("y_{1}''")), + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left[x + y\right] z", _Mul(_Add(x, y), z)), + (r"\left\{x + y\right\} z", _Mul(_Add(x, y), z)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"1 \times 2 ", _Mul(1, 2)), + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Pow(2, -1)), + (r"\frac12y", _Mul(_Pow(2, -1), y)), + (r"\frac1234", _Mul(_Pow(2, -1), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))), + (r"(\csc x)(\sec y)", csc(x)*sec(y)), + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir='-')), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir='-')), + (r"\infty", oo), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Pow(x, -1), x, oo)), + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x+y)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)), + (r"x \neq y", Unequality(x, y)), + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x)*_Abs(y)), + (r"||x||y||", _Abs(_Abs(x)*_Abs(y))), + (r"\pi^{|xy|}", Symbol('pi')**_Abs(x*y)), + (r"\int x dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x**2 - y, x)), + (r"\int x + a dx", Integral(_Add(x, a), x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int (x+a)", Integral(_Add(x, a), x)), + (r"\int a + b + c dx", Integral(_Add(_Add(a, b), c), x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3*Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(Pow(x, -1), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Add(_Pow(a, -1), Pow(b, -1)), x)), + (r"\int \frac{3 \cdot d\theta}{\theta}", + Integral(3*_Pow(theta, -1), theta)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Add(_Pow(x, -1), 1), x)), + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))), + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))), + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k**2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Pow(_factorial(n), -1), (n, 0, oo))), + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x, E)), + (r"\ln xy", _log(x*y, E)), + (r"\log x", _log(x, E)), + (r"\log xy", _log(x*y, E)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"[x]", x), + (r"[a + b]", _Add(a, b)), + (r"\frac{d}{dx} [ \tan x ]", Derivative(tan(x), x)), + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))), + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)), + (r"\int x \, dx", Integral(x, x)), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))), + (r"3x - 1", _Add(_Mul(3, x), -1)) +] + + +def test_parseable(): + from sympy.parsing.latex import parse_latex + for latex_str, sympy_expr in GOOD_PAIRS: + assert parse_latex(latex_str) == sympy_expr, latex_str + +# These bad LaTeX strings should raise a LaTeXParsingError when parsed +BAD_STRINGS = [ + r"(", + r")", + r"\frac{d}{dx}", + r"(\frac{d}{dx})", + r"\sqrt{}", + r"\sqrt", + r"\overline{}", + r"\overline", + r"{", + r"}", + r"\mathit{x + y}", + r"\mathit{21}", + r"\frac{2}{}", + r"\frac{}{2}", + r"\int", + r"!", + r"!0", + r"_", + r"^", + r"|", + r"||x|", + r"()", + r"((((((((((((((((()))))))))))))))))", + r"-", + r"\frac{d}{dx} + \frac{d}{dt}", + r"f(x,,y)", + r"f(x,y,", + r"\sin^x", + r"\cos^2", + r"@", + r"#", + r"$", + r"%", + r"&", + r"*", + r"" "\\", + r"~", + r"\frac{(2 + x}{1 - x)}", +] + +def test_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# At time of migration from latex2sympy, should fail but doesn't +FAILING_BAD_STRINGS = [ + r"\cos 1 \cos", + r"f(,", + r"f()", + r"a \div \div b", + r"a \cdot \cdot b", + r"a // b", + r"a +", + r"1.1.1", + r"1 +", + r"a / b /", +] + +@XFAIL +def test_failing_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# In strict mode, FAILING_BAD_STRINGS would fail +def test_strict_mode(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str, strict=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_deps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_deps.py new file mode 100644 index 0000000000000000000000000000000000000000..7df44c2b19e34024db6e898f7c4eac962dcaa1c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_deps.py @@ -0,0 +1,16 @@ +from sympy.external import import_module +from sympy.testing.pytest import ignore_warnings, raises + +antlr4 = import_module("antlr4", warn_not_installed=False) + +# disable tests if antlr4-python3-runtime is not present +if antlr4: + disabled = True + + +def test_no_import(): + from sympy.parsing.latex import parse_latex + + with ignore_warnings(UserWarning): + with raises(ImportError): + parse_latex('1 + 1') diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_lark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_lark.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1f72a66c788ac41d923005ea988664d05a16c1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_latex_lark.py @@ -0,0 +1,872 @@ +from sympy.testing.pytest import XFAIL +from sympy.parsing.latex.lark import parse_latex_lark +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import E, oo, Rational +from sympy.core.power import Pow +from sympy.core.parameters import evaluate +from sympy.core.relational import GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions.elementary.complexes import Abs, conjugate +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import root, sqrt, Min, Max +from sympy.functions.elementary.trigonometric import asin, cos, csc, sec, sin, tan +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit +from sympy import Matrix, MatAdd, MatMul, Transpose, Trace +from sympy import I + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum import Bra, Ket, InnerProduct +from sympy.abc import x, y, z, a, b, c, d, t, k, n + +from .test_latex import theta, f, _Add, _Mul, _Pow, _Sqrt, _Conjugate, _Abs, _factorial, _exp, _binomial + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +# shorthand definitions that are only needed for the Lark LaTeX parser +def _Min(*args): + return Min(*args, evaluate=False) + + +def _Max(*args): + return Max(*args, evaluate=False) + + +def _log(a, b=E): + if b == E: + return log(a, evaluate=False) + else: + return log(a, b, evaluate=False) + + +def _MatAdd(a, b): + return MatAdd(a, b, evaluate=False) + + +def _MatMul(a, b): + return MatMul(a, b, evaluate=False) + + +# These LaTeX strings should parse to the corresponding SymPy expression +SYMBOL_EXPRESSION_PAIRS = [ + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"y''_1", Symbol("y''_{1}")), + (r"y_1''", Symbol("y_{1}''")), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"a'", Symbol("a'")), + (r"a''", Symbol("a''")), + (r"\alpha'", Symbol("alpha'")), + (r"\alpha''", Symbol("alpha''")), + (r"a_b", Symbol("a_{b}")), + (r"a_b'", Symbol("a_{b}'")), + (r"a'_b", Symbol("a'_{b}")), + (r"a'_b'", Symbol("a'_{b}'")), + (r"a_{b'}", Symbol("a_{b'}")), + (r"a_{b'}'", Symbol("a_{b'}'")), + (r"a'_{b'}", Symbol("a'_{b'}")), + (r"a'_{b'}'", Symbol("a'_{b'}'")), + (r"\mathit{foo}'", Symbol("foo'")), + (r"\mathit{foo'}", Symbol("foo'")), + (r"\mathit{foo'}'", Symbol("foo''")), + (r"a_b''", Symbol("a_{b}''")), + (r"a''_b", Symbol("a''_{b}")), + (r"a''_b'''", Symbol("a''_{b}'''")), + (r"a_{b''}", Symbol("a_{b''}")), + (r"a_{b''}''", Symbol("a_{b''}''")), + (r"a''_{b''}", Symbol("a''_{b''}")), + (r"a''_{b''}'''", Symbol("a''_{b''}'''")), + (r"\mathit{foo}''", Symbol("foo''")), + (r"\mathit{foo''}", Symbol("foo''")), + (r"\mathit{foo''}'''", Symbol("foo'''''")), + (r"a_\alpha", Symbol("a_{alpha}")), + (r"a_\alpha'", Symbol("a_{alpha}'")), + (r"a'_\alpha", Symbol("a'_{alpha}")), + (r"a'_\alpha'", Symbol("a'_{alpha}'")), + (r"a_{\alpha'}", Symbol("a_{alpha'}")), + (r"a_{\alpha'}'", Symbol("a_{alpha'}'")), + (r"a'_{\alpha'}", Symbol("a'_{alpha'}")), + (r"a'_{\alpha'}'", Symbol("a'_{alpha'}'")), + (r"a_\alpha''", Symbol("a_{alpha}''")), + (r"a''_\alpha", Symbol("a''_{alpha}")), + (r"a''_\alpha'''", Symbol("a''_{alpha}'''")), + (r"a_{\alpha''}", Symbol("a_{alpha''}")), + (r"a_{\alpha''}''", Symbol("a_{alpha''}''")), + (r"a''_{\alpha''}", Symbol("a''_{alpha''}")), + (r"a''_{\alpha''}'''", Symbol("a''_{alpha''}'''")), + (r"\alpha_b", Symbol("alpha_{b}")), + (r"\alpha_b'", Symbol("alpha_{b}'")), + (r"\alpha'_b", Symbol("alpha'_{b}")), + (r"\alpha'_b'", Symbol("alpha'_{b}'")), + (r"\alpha_{b'}", Symbol("alpha_{b'}")), + (r"\alpha_{b'}'", Symbol("alpha_{b'}'")), + (r"\alpha'_{b'}", Symbol("alpha'_{b'}")), + (r"\alpha'_{b'}'", Symbol("alpha'_{b'}'")), + (r"\alpha_b''", Symbol("alpha_{b}''")), + (r"\alpha''_b", Symbol("alpha''_{b}")), + (r"\alpha''_b'''", Symbol("alpha''_{b}'''")), + (r"\alpha_{b''}", Symbol("alpha_{b''}")), + (r"\alpha_{b''}''", Symbol("alpha_{b''}''")), + (r"\alpha''_{b''}", Symbol("alpha''_{b''}")), + (r"\alpha''_{b''}'''", Symbol("alpha''_{b''}'''")), + (r"\alpha_\beta", Symbol("alpha_{beta}")), + (r"\alpha_{\beta}", Symbol("alpha_{beta}")), + (r"\alpha_{\beta'}", Symbol("alpha_{beta'}")), + (r"\alpha_{\beta''}", Symbol("alpha_{beta''}")), + (r"\alpha'_\beta", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta}", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta'}", Symbol("alpha'_{beta'}")), + (r"\alpha'_{\beta''}", Symbol("alpha'_{beta''}")), + (r"\alpha''_\beta", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta}", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta'}", Symbol("alpha''_{beta'}")), + (r"\alpha''_{\beta''}", Symbol("alpha''_{beta''}")), + (r"\alpha_\beta'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta}'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta'}'", Symbol("alpha_{beta'}'")), + (r"\alpha_{\beta''}'", Symbol("alpha_{beta''}'")), + (r"\alpha'_\beta'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta}'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta'}'", Symbol("alpha'_{beta'}'")), + (r"\alpha'_{\beta''}'", Symbol("alpha'_{beta''}'")), + (r"\alpha''_\beta'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta}'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta'}'", Symbol("alpha''_{beta'}'")), + (r"\alpha''_{\beta''}'", Symbol("alpha''_{beta''}'")), + (r"\alpha_\beta''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta}''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta'}''", Symbol("alpha_{beta'}''")), + (r"\alpha_{\beta''}''", Symbol("alpha_{beta''}''")), + (r"\alpha'_\beta''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta}''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta'}''", Symbol("alpha'_{beta'}''")), + (r"\alpha'_{\beta''}''", Symbol("alpha'_{beta''}''")), + (r"\alpha''_\beta''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta}''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta'}''", Symbol("alpha''_{beta'}''")), + (r"\alpha''_{\beta''}''", Symbol("alpha''_{beta''}''")) + +] + +UNEVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"x", x), + (r"2x", 2 * x), + (r"3x - 1", _Add(_Mul(3, x), -1)), + (r"-c", -c), + (r"\infty", oo), + (r"a \cdot b", a * b), + (r"1 \times 2 ", _Mul(1, 2)), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a + b, -a)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))) +] + +EVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"(-7.13)(1.5)", -10.695), + (r"1+1", 2), + (r"0+1", 1), + (r"1*2", 2), + (r"0*1", 0), + (r"2x", 2 * x), + (r"3x - 1", 3 * x - 1), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"1 \times 2 ", 2), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", b), + (r"(x + y) z", (x + y) * z), +] + +UNEVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Mul(1, _Pow(2, -1))), + (r"\frac12y", _Mul(_Mul(1, _Pow(2, -1)), y)), + (r"\frac1234", _Mul(_Mul(1, _Pow(2, -1)), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))) +] + +EVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", Rational(1, 2)), + (r"\frac12y", y / 2), + (r"\frac1234", 17), + (r"\frac2{3}", Rational(2, 3)), + (r"\frac{a + b}{c}", (a + b) / c), + (r"\frac{7}{3}", Rational(7, 3)) +] + +RELATION_EXPRESSION_PAIRS = [ + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"x \neq y", Unequality(x, y)), # same as 2nd one in the list + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)) +] + +UNEVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", _Pow(x, _Mul(1, _Pow(2, -1)))), + (r"x^{3 + 1}", x ** _Add(3, 1)), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))) +] + +EVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", sqrt(x)), + (r"x^{3 + 1}", x ** 4), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", 0) +] + +UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(_Mul(1, x), x)), + (r"\int x \, dx", Integral(_Mul(1, x), x)), + (r"\int x d\theta", Integral(_Mul(1, x), theta)), + (r"\int (x^2 - y)dx", Integral(_Mul(1, x ** 2 - y), x)), + (r"\int x + a dx", Integral(_Mul(1, _Add(x, a)), x)), + (r"\int da", Integral(_Mul(1, 1), a)), + (r"\int_0^7 dx", Integral(_Mul(1, 1), (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(_Mul(1, x), (x, 0, 1))), + (r"\int_a^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^b_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(_Mul(1, _Add(_Add(a, b), c)), x)), + (r"\int \frac{dz}{z}", Integral(_Mul(1, _Mul(1, Pow(z, -1))), z)), + (r"\int \frac{3 dz}{z}", Integral(_Mul(1, _Mul(3, _Pow(z, -1))), z)), + (r"\int \frac{1}{x} dx", Integral(_Mul(1, _Mul(1, Pow(x, -1))), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Mul(1, _Add(_Mul(1, _Pow(a, -1)), _Mul(1, Pow(b, -1)))), x)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Mul(1, _Add(_Mul(1, _Pow(x, -1)), 1)), x)) +] + +EVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(x, x)), + (r"\int x \, dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x ** 2 - y, x)), + (r"\int x + a dx", Integral(x + a, x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(a + b + c, x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3 * Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(1 / x, x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", Integral(1 / a + 1 / b, x)), + (r"\int \frac{1}{a} - \frac{1}{b} dx", Integral(1 / a - 1 / b, x)), + (r"\int \frac{1}{x} + 1 dx", Integral(1 / x + 1, x)) +] + +DERIVATIVE_EXPRESSION_PAIRS = [ + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"\frac{d}{dx} ( \tan x )", Derivative(tan(x), x)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)) +] + +TRIGONOMETRIC_EXPRESSION_PAIRS = [ + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"(\csc x)(\sec y)", csc(x) * sec(y)), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))) +] + +UNEVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Mul(1, _Pow(x, -1)), x, oo)) +] + +EVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(1 / x, x, oo)) +] + +UNEVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", _Pow(sin(x), _Pow(3, -1))), + # the above test needed to be handled differently than the ones below because root + # acts differently if its second argument is a number + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))) +] + +EVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(x + b)), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", sqrt(2)) +] + +UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))) +] + +EVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", factorial(x)), + (r"100!", factorial(100)), + (r"\theta!", factorial(theta)), + (r"(x + 1)!", factorial(x + 1)), + (r"(x!)!", factorial(factorial(x))), + (r"x!!!", factorial(factorial(factorial(x)))), + (r"5!7!", factorial(5) * factorial(7)), + (r"24! \times 24!", factorial(24) * factorial(24)) +] + +UNEVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(_Mul(1, k ** 2), (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Mul(1, _Mul(1, _Pow(_factorial(n), -1))), (n, 0, oo))) +] + +EVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k ** 2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", Sum(1 / factorial(n), (n, 0, oo))) +] + +UNEVALUATED_PRODUCT_EXPRESSION_PAIRS = [ + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))) +] + +APPLIED_FUNCTION_EXPRESSION_PAIRS = [ + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x + y)), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))) +] + +UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x) * _Abs(y)), + (r"||x||y||", _Abs(_Abs(x) * _Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x)), + (r"\ln xy", _log(x * y)), + (r"\log x", _log(x)), + (r"\log xy", _log(x * y)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"\min(a, b)", _Min(a, b)), + (r"\min(a, b, c - d, xy)", _Min(a, b, c - d, x * y)), + (r"\max(a, b)", _Max(a, b)), + (r"\max(a, b, c - d, xy)", _Max(a, b, c - d, x * y)), + # physics things don't have an `evaluate=False` variant + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", Abs(x)), + (r"||x||", Abs(Abs(x))), + (r"|x||y|", Abs(x) * Abs(y)), + (r"||x||y||", Abs(Abs(x) * Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", exp(x)), + (r"\exp(x)", exp(x)), + (r"\lg x", log(x, 10)), + (r"\ln x", log(x)), + (r"\ln xy", log(x * y)), + (r"\log x", log(x)), + (r"\log xy", log(x * y)), + (r"\log_{2} x", log(x, 2)), + (r"\log_{a} x", log(x, a)), + (r"\log_{11} x", log(x, 11)), + (r"\log_{a^2} x", log(x, _Pow(a, 2))), + (r"\log_2 x", log(x, 2)), + (r"\log_a x", log(x, a)), + (r"\overline{z}", conjugate(z)), + (r"\overline{\overline{z}}", conjugate(conjugate(z))), + (r"\overline{x + y}", conjugate(x + y)), + (r"\overline{x} + \overline{y}", conjugate(x) + conjugate(y)), + (r"\min(a, b)", Min(a, b)), + (r"\min(a, b, c - d, xy)", Min(a, b, c - d, x * y)), + (r"\max(a, b)", Max(a, b)), + (r"\max(a, b, c - d, xy)", Max(a, b, c - d, x * y)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +SPACING_RELATED_EXPRESSION_PAIRS = [ + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)) +] + +UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))) +] + +EVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", binomial(n, k)), + (r"\tbinom{n}{k}", binomial(n, k)), + (r"\dbinom{n}{k}", binomial(n, k)), + (r"\binom{n}{0}", binomial(n, 0)), + (r"x^\binom{n}{k}", x ** binomial(n, k)) +] + +MISCELLANEOUS_EXPRESSION_PAIRS = [ + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), +] + +UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS = [ + (r"\imaginaryunit^2", _Pow(I, 2)), + (r"|\imaginaryunit|", _Abs(I)), + (r"\overline{\imaginaryunit}", _Conjugate(I)), + (r"\imaginaryunit+\imaginaryunit", _Add(I, I)), + (r"\imaginaryunit-\imaginaryunit", _Add(I, -I)), + (r"\imaginaryunit*\imaginaryunit", _Mul(I, I)), + (r"\imaginaryunit/\imaginaryunit", _Mul(I, _Pow(I, -1))), + (r"(1+\imaginaryunit)/|1+\imaginaryunit|", _Mul(_Add(1, I), _Pow(_Abs(_Add(1, I)), -1))) +] + +UNEVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{pmatrix}a & b \\x & y\\\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{bmatrix}a & b \\x & y\end{bmatrix}", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{matrix}a & b \\x & y\end{matrix}\right)", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{matrix}a & b \\x & y\end{matrix}\right]", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{array}{cc}a & b \\x & y\end{array}\right]", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{array}{cc}a & b \\x & y\end{array}\right)", + Matrix([[a, b], [x, y]])), + (r"\left( { \begin{array}{cc}a & b \\x & y\end{array} } \right)", + Matrix([[a, b], [x, y]])), + (r"+\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}+" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), Matrix([[a, b], [x, y]]))), + (r"-\begin{pmatrix}a & b \\x & y\end{pmatrix}", + _MatMul(-1, Matrix([[a, b], [x, y]]))), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}-" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), _MatMul(-1, Matrix([[a, b], [x, y]])))), + ((r"\begin{pmatrix}a & b & c \\x & y & z \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}x & y & z \\a & b & c \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}a & b & c \\x & y & z \\x & y & z \end{pmatrix}"), + _MatMul(_MatMul(Matrix([[a, b, c], [x, y, z], [a, b, c]]), + Matrix([[x, y, z], [a, b, c], [a, b, c]])), + Matrix([[a, b, c], [x, y, z], [x, y, z]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/2", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(2, -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^2", + _Pow(Matrix([[a, b], [x, y]]), 2)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{-1}", + _Pow(Matrix([[a, b], [x, y]]), -1)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^T", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^\mathit{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T", + Transpose(Matrix([[1, 2], [3, 4]]))), + ((r"(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}+" + r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T)*" + r"\begin{bmatrix}1\\0\end{bmatrix}"), + _MatMul(_MatAdd(Matrix([[1, 2], [3, 4]]), + Transpose(Matrix([[1, 2], [3, 4]]))), + Matrix([[1], [0]]))), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^2"), + _Pow(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])), 2)), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^T"), + Transpose(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])))), + (r"\overline{\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}}", + _Conjugate(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))) +] + +EVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\det\left(\left[ { \begin{array}{cc}a&b\\x&y\end{array} } \right]\right)", + Matrix([[a, b], [x, y]]).det()), + (r"\det \begin{pmatrix}1&2\\3&4\end{pmatrix}", -2), + (r"\det{\begin{pmatrix}1&2\\3&4\end{pmatrix}}", -2), + (r"\det(\begin{pmatrix}1&2\\3&4\end{pmatrix})", -2), + (r"\det\left(\begin{pmatrix}1&2\\3&4\end{pmatrix}\right)", -2), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/\begin{vmatrix}a & b \\x & y\end{vmatrix}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/|\begin{matrix}a & b \\x & y\end{matrix}|", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\frac{\begin{pmatrix}a & b \\x & y\end{pmatrix}}{| { \begin{matrix}a & b \\x & y\end{matrix} } |}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\overline{\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}}", + Matrix([[-I, 1-I], [I, 4]])), + (r"\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}^H", + Matrix([[-I, I], [1-I, 4]])), + (r"\trace(\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix})", + Trace(Matrix([[I, 1+I], [-I, 4]]))), + (r"\adjugate(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix})", + Matrix([[4, -2], [-3, 1]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\ast", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{*}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{**}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{***}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\prime", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{''}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'''}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})''", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'''", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\det(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))).det()), + (r"\trace(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + Trace(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\adjugate(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (Matrix([[8, -4], [-6, 2*I]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^T", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^H", + (Matrix([[-2*I, 6], [4, 8]]))) +] + + +def test_symbol_expressions(): + expected_failures = {6, 7} + for i, (latex_str, sympy_expr) in enumerate(SYMBOL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_simple_expressions(): + expected_failures = {20} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_fraction_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FRACTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FRACTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_relation_expressions(): + for latex_str, sympy_expr in RELATION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +def test_power_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_integral_expressions(): + expected_failures = {14} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, i + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_derivative_expressions(): + expected_failures = {3, 4} + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_trigonometric_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(TRIGONOMETRIC_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_limit_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LIMIT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_square_root_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SQRT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SQRT_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_factorial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FACTORIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_sum_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SUM_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SUM_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_product_expressions(): + for latex_str, sympy_expr in UNEVALUATED_PRODUCT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +@XFAIL +def test_applied_function_expressions(): + expected_failures = {0, 3, 4} # 0 is ambiguous, and the others require not-yet-added features + # not sure why 1, and 2 are failing + for i, (latex_str, sympy_expr) in enumerate(APPLIED_FUNCTION_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_common_function_expressions(): + for latex_str, sympy_expr in UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +# unhandled bug causing these to fail +@XFAIL +def test_spacing(): + for latex_str, sympy_expr in SPACING_RELATED_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_binomial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_BINOMIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_miscellaneous_expressions(): + for latex_str, sympy_expr in MISCELLANEOUS_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_literal_complex_number_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_matrix_expressions(): + for latex_str, sympy_expr in UNEVALUATED_MATRIX_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_MATRIX_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_mathematica.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..df193b6d61f9c82778d8e0a40b893cbe6cb8f06a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_mathematica.py @@ -0,0 +1,280 @@ +from sympy import sin, Function, symbols, Dummy, Lambda, cos +from sympy.parsing.mathematica import parse_mathematica, MathematicaParser +from sympy.core.sympify import sympify +from sympy.abc import n, w, x, y, z +from sympy.testing.pytest import raises + + +def test_mathematica(): + d = { + '- 6x': '-6*x', + 'Sin[x]^2': 'sin(x)**2', + '2(x-1)': '2*(x-1)', + '3y+8': '3*y+8', + 'ArcSin[2x+9(4-x)^2]/x': 'asin(2*x+9*(4-x)**2)/x', + 'x+y': 'x+y', + '355/113': '355/113', + '2.718281828': '2.718281828', + 'Cos(1/2 * π)': 'Cos(π/2)', + 'Sin[12]': 'sin(12)', + 'Exp[Log[4]]': 'exp(log(4))', + '(x+1)(x+3)': '(x+1)*(x+3)', + 'Cos[ArcCos[3.6]]': 'cos(acos(3.6))', + 'Cos[x]==Sin[y]': 'Eq(cos(x), sin(y))', + '2*Sin[x+y]': '2*sin(x+y)', + 'Sin[x]+Cos[y]': 'sin(x)+cos(y)', + 'Sin[Cos[x]]': 'sin(cos(x))', + '2*Sqrt[x+y]': '2*sqrt(x+y)', # Test case from the issue 4259 + '+Sqrt[2]': 'sqrt(2)', + '-Sqrt[2]': '-sqrt(2)', + '-1/Sqrt[2]': '-1/sqrt(2)', + '-(1/Sqrt[3])': '-(1/sqrt(3))', + '1/(2*Sqrt[5])': '1/(2*sqrt(5))', + 'Mod[5,3]': 'Mod(5,3)', + '-Mod[5,3]': '-Mod(5,3)', + '(x+1)y': '(x+1)*y', + 'x(y+1)': 'x*(y+1)', + 'Sin[x]Cos[y]': 'sin(x)*cos(y)', + 'Sin[x]^2Cos[y]^2': 'sin(x)**2*cos(y)**2', + 'Cos[x]^2(1 - Cos[y]^2)': 'cos(x)**2*(1-cos(y)**2)', + 'x y': 'x*y', + 'x y': 'x*y', + '2 x': '2*x', + 'x 8': 'x*8', + '2 8': '2*8', + '4.x': '4.*x', + '4. 3': '4.*3', + '4. 3.': '4.*3.', + '1 2 3': '1*2*3', + ' - 2 * Sqrt[ 2 3 * ( 1 + 5 ) ] ': '-2*sqrt(2*3*(1+5))', + 'Log[2,4]': 'log(4,2)', + 'Log[Log[2,4],4]': 'log(4,log(4,2))', + 'Exp[Sqrt[2]^2Log[2, 8]]': 'exp(sqrt(2)**2*log(8,2))', + 'ArcSin[Cos[0]]': 'asin(cos(0))', + 'Log2[16]': 'log(16,2)', + 'Max[1,-2,3,-4]': 'Max(1,-2,3,-4)', + 'Min[1,-2,3]': 'Min(1,-2,3)', + 'Exp[I Pi/2]': 'exp(I*pi/2)', + 'ArcTan[x,y]': 'atan2(y,x)', + 'Pochhammer[x,y]': 'rf(x,y)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[5]': 'airyaiprime(5)', + 'AiryBi[x]': 'airybi(x)', + 'AiryBiPrime[7]': 'airybiprime(7)', + 'LogIntegral[4]': ' li(4)', + 'PrimePi[7]': 'primepi(7)', + 'Prime[5]': 'prime(5)', + 'PrimeQ[5]': 'isprime(5)', + 'Rational[2,19]': 'Rational(2,19)', # test case for issue 25716 + } + + for e in d: + assert parse_mathematica(e) == sympify(d[e]) + + # The parsed form of this expression should not evaluate the Lambda object: + assert parse_mathematica("Sin[#]^2 + Cos[#]^2 &[x]") == sin(x)**2 + cos(x)**2 + + d1, d2, d3 = symbols("d1:4", cls=Dummy) + assert parse_mathematica("Sin[#] + Cos[#3] &").dummy_eq(Lambda((d1, d2, d3), sin(d1) + cos(d3))) + assert parse_mathematica("Sin[#^2] &").dummy_eq(Lambda(d1, sin(d1**2))) + assert parse_mathematica("Function[x, x^3]") == Lambda(x, x**3) + assert parse_mathematica("Function[{x, y}, x^2 + y^2]") == Lambda((x, y), x**2 + y**2) + + +def test_parser_mathematica_tokenizer(): + parser = MathematicaParser() + + chain = lambda expr: parser._from_tokens_to_fullformlist(parser._from_mathematica_to_tokens(expr)) + + # Basic patterns + assert chain("x") == "x" + assert chain("42") == "42" + assert chain(".2") == ".2" + assert chain("+x") == "x" + assert chain("-1") == "-1" + assert chain("- 3") == "-3" + assert chain("α") == "α" + assert chain("+Sin[x]") == ["Sin", "x"] + assert chain("-Sin[x]") == ["Times", "-1", ["Sin", "x"]] + assert chain("x(a+1)") == ["Times", "x", ["Plus", "a", "1"]] + assert chain("(x)") == "x" + assert chain("(+x)") == "x" + assert chain("-a") == ["Times", "-1", "a"] + assert chain("(-x)") == ["Times", "-1", "x"] + assert chain("(x + y)") == ["Plus", "x", "y"] + assert chain("3 + 4") == ["Plus", "3", "4"] + assert chain("a - 3") == ["Plus", "a", "-3"] + assert chain("a - b") == ["Plus", "a", ["Times", "-1", "b"]] + assert chain("7 * 8") == ["Times", "7", "8"] + assert chain("a + b*c") == ["Plus", "a", ["Times", "b", "c"]] + assert chain("a + b* c* d + 2 * e") == ["Plus", "a", ["Times", "b", "c", "d"], ["Times", "2", "e"]] + assert chain("a / b") == ["Times", "a", ["Power", "b", "-1"]] + + # Missing asterisk (*) patterns: + assert chain("x y") == ["Times", "x", "y"] + assert chain("3 4") == ["Times", "3", "4"] + assert chain("a[b] c") == ["Times", ["a", "b"], "c"] + assert chain("(x) (y)") == ["Times", "x", "y"] + assert chain("3 (a)") == ["Times", "3", "a"] + assert chain("(a) b") == ["Times", "a", "b"] + assert chain("4.2") == "4.2" + assert chain("4 2") == ["Times", "4", "2"] + assert chain("4 2") == ["Times", "4", "2"] + assert chain("3 . 4") == ["Dot", "3", "4"] + assert chain("4. 2") == ["Times", "4.", "2"] + assert chain("x.y") == ["Dot", "x", "y"] + assert chain("4.y") == ["Times", "4.", "y"] + assert chain("4 .y") == ["Dot", "4", "y"] + assert chain("x.4") == ["Times", "x", ".4"] + assert chain("x0.3") == ["Times", "x0", ".3"] + assert chain("x. 4") == ["Dot", "x", "4"] + + # Comments + assert chain("a (* +b *) + c") == ["Plus", "a", "c"] + assert chain("a (* + b *) + (**)c (* +d *) + e") == ["Plus", "a", "c", "e"] + assert chain("""a + (* + + b + *) c + (* d + *) e + """) == ["Plus", "a", "c", "e"] + + # Operators couples + and -, * and / are mutually associative: + # (i.e. expression gets flattened when mixing these operators) + assert chain("a*b/c") == ["Times", "a", "b", ["Power", "c", "-1"]] + assert chain("a/b*c") == ["Times", "a", ["Power", "b", "-1"], "c"] + assert chain("a+b-c") == ["Plus", "a", "b", ["Times", "-1", "c"]] + assert chain("a-b+c") == ["Plus", "a", ["Times", "-1", "b"], "c"] + assert chain("-a + b -c ") == ["Plus", ["Times", "-1", "a"], "b", ["Times", "-1", "c"]] + assert chain("a/b/c*d") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"], "d"] + assert chain("a/b/c") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"]] + assert chain("a-b-c") == ["Plus", "a", ["Times", "-1", "b"], ["Times", "-1", "c"]] + assert chain("1/a") == ["Times", "1", ["Power", "a", "-1"]] + assert chain("1/a/b") == ["Times", "1", ["Power", "a", "-1"], ["Power", "b", "-1"]] + assert chain("-1/a*b") == ["Times", "-1", ["Power", "a", "-1"], "b"] + + # Enclosures of various kinds, i.e. ( ) [ ] [[ ]] { } + assert chain("(a + b) + c") == ["Plus", ["Plus", "a", "b"], "c"] + assert chain(" a + (b + c) + d ") == ["Plus", "a", ["Plus", "b", "c"], "d"] + assert chain("a * (b + c)") == ["Times", "a", ["Plus", "b", "c"]] + assert chain("a b (c d)") == ["Times", "a", "b", ["Times", "c", "d"]] + assert chain("{a, b, 2, c}") == ["List", "a", "b", "2", "c"] + assert chain("{a, {b, c}}") == ["List", "a", ["List", "b", "c"]] + assert chain("{{a}}") == ["List", ["List", "a"]] + assert chain("a[b, c]") == ["a", "b", "c"] + assert chain("a[[b, c]]") == ["Part", "a", "b", "c"] + assert chain("a[b[c]]") == ["a", ["b", "c"]] + assert chain("a[[b, c[[d, {e,f}]]]]") == ["Part", "a", "b", ["Part", "c", "d", ["List", "e", "f"]]] + assert chain("a[b[[c,d]]]") == ["a", ["Part", "b", "c", "d"]] + assert chain("a[[b[c]]]") == ["Part", "a", ["b", "c"]] + assert chain("a[[b[[c]]]]") == ["Part", "a", ["Part", "b", "c"]] + assert chain("a[[b[c[[d]]]]]") == ["Part", "a", ["b", ["Part", "c", "d"]]] + assert chain("a[b[[c[d]]]]") == ["a", ["Part", "b", ["c", "d"]]] + assert chain("x[[a+1, b+2, c+3]]") == ["Part", "x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("x[a+1, b+2, c+3]") == ["x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("{a+1, b+2, c+3}") == ["List", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + + # Flat operator: + assert chain("a*b*c*d*e") == ["Times", "a", "b", "c", "d", "e"] + assert chain("a +b + c+ d+e") == ["Plus", "a", "b", "c", "d", "e"] + + # Right priority operator: + assert chain("a^b") == ["Power", "a", "b"] + assert chain("a^b^c") == ["Power", "a", ["Power", "b", "c"]] + assert chain("a^b^c^d") == ["Power", "a", ["Power", "b", ["Power", "c", "d"]]] + + # Left priority operator: + assert chain("a/.b") == ["ReplaceAll", "a", "b"] + assert chain("a/.b/.c/.d") == ["ReplaceAll", ["ReplaceAll", ["ReplaceAll", "a", "b"], "c"], "d"] + + assert chain("a//b") == ["a", "b"] + assert chain("a//b//c") == [["a", "b"], "c"] + assert chain("a//b//c//d") == [[["a", "b"], "c"], "d"] + + # Compound expressions + assert chain("a;b") == ["CompoundExpression", "a", "b"] + assert chain("a;") == ["CompoundExpression", "a", "Null"] + assert chain("a;b;") == ["CompoundExpression", "a", "b", "Null"] + assert chain("a[b;c]") == ["a", ["CompoundExpression", "b", "c"]] + assert chain("a[b,c;d,e]") == ["a", "b", ["CompoundExpression", "c", "d"], "e"] + assert chain("a[b,c;,d]") == ["a", "b", ["CompoundExpression", "c", "Null"], "d"] + + # New lines + assert chain("a\nb\n") == ["CompoundExpression", "a", "b"] + assert chain("a\n\nb\n (c \nd) \n") == ["CompoundExpression", "a", "b", ["Times", "c", "d"]] + assert chain("\na; b\nc") == ["CompoundExpression", "a", "b", "c"] + assert chain("a + \nb\n") == ["Plus", "a", "b"] + assert chain("a\nb; c; d\n e; (f \n g); h + \n i") == ["CompoundExpression", "a", "b", "c", "d", "e", ["Times", "f", "g"], ["Plus", "h", "i"]] + assert chain("\n{\na\nb; c; d\n e (f \n g); h + \n i\n\n}\n") == ["List", ["CompoundExpression", ["Times", "a", "b"], "c", ["Times", "d", "e", ["Times", "f", "g"]], ["Plus", "h", "i"]]] + + # Patterns + assert chain("y_") == ["Pattern", "y", ["Blank"]] + assert chain("y_.") == ["Optional", ["Pattern", "y", ["Blank"]]] + assert chain("y__") == ["Pattern", "y", ["BlankSequence"]] + assert chain("y___") == ["Pattern", "y", ["BlankNullSequence"]] + assert chain("a[b_.,c_]") == ["a", ["Optional", ["Pattern", "b", ["Blank"]]], ["Pattern", "c", ["Blank"]]] + assert chain("b_. c") == ["Times", ["Optional", ["Pattern", "b", ["Blank"]]], "c"] + + # Slots for lambda functions + assert chain("#") == ["Slot", "1"] + assert chain("#3") == ["Slot", "3"] + assert chain("#n") == ["Slot", "n"] + assert chain("##") == ["SlotSequence", "1"] + assert chain("##a") == ["SlotSequence", "a"] + + # Lambda functions + assert chain("x&") == ["Function", "x"] + assert chain("#&") == ["Function", ["Slot", "1"]] + assert chain("#+3&") == ["Function", ["Plus", ["Slot", "1"], "3"]] + assert chain("#1 + #2&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]] + assert chain("# + #&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "1"]]] + assert chain("#&[x]") == [["Function", ["Slot", "1"]], "x"] + assert chain("#1 + #2 & [x, y]") == [["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]], "x", "y"] + assert chain("#1^2#2^3&") == ["Function", ["Times", ["Power", ["Slot", "1"], "2"], ["Power", ["Slot", "2"], "3"]]] + + # Strings inside Mathematica expressions: + assert chain('"abc"') == ["_Str", "abc"] + assert chain('"a\\"b"') == ["_Str", 'a"b'] + # This expression does not make sense mathematically, it's just testing the parser: + assert chain('x + "abc" ^ 3') == ["Plus", "x", ["Power", ["_Str", "abc"], "3"]] + assert chain('"a (* b *) c"') == ["_Str", "a (* b *) c"] + assert chain('"a" (* b *) ') == ["_Str", "a"] + assert chain('"a [ b] "') == ["_Str", "a [ b] "] + raises(SyntaxError, lambda: chain('"')) + raises(SyntaxError, lambda: chain('"\\"')) + raises(SyntaxError, lambda: chain('"abc')) + raises(SyntaxError, lambda: chain('"abc\\"def')) + + # Invalid expressions: + raises(SyntaxError, lambda: chain("(,")) + raises(SyntaxError, lambda: chain("()")) + raises(SyntaxError, lambda: chain("a (* b")) + + +def test_parser_mathematica_exp_alt(): + parser = MathematicaParser() + + convert_chain2 = lambda expr: parser._from_fullformlist_to_fullformsympy(parser._from_fullform_to_fullformlist(expr)) + convert_chain3 = lambda expr: parser._from_fullformsympy_to_sympy(convert_chain2(expr)) + + Sin, Times, Plus, Power = symbols("Sin Times Plus Power", cls=Function) + + full_form1 = "Sin[Times[x, y]]" + full_form2 = "Plus[Times[x, y], z]" + full_form3 = "Sin[Times[x, Plus[y, z], Power[w, n]]]]" + full_form4 = "Rational[Rational[x, y], z]" + + assert parser._from_fullform_to_fullformlist(full_form1) == ["Sin", ["Times", "x", "y"]] + assert parser._from_fullform_to_fullformlist(full_form2) == ["Plus", ["Times", "x", "y"], "z"] + assert parser._from_fullform_to_fullformlist(full_form3) == ["Sin", ["Times", "x", ["Plus", "y", "z"], ["Power", "w", "n"]]] + assert parser._from_fullform_to_fullformlist(full_form4) == ["Rational", ["Rational", "x", "y"], "z"] + + assert convert_chain2(full_form1) == Sin(Times(x, y)) + assert convert_chain2(full_form2) == Plus(Times(x, y), z) + assert convert_chain2(full_form3) == Sin(Times(x, Plus(y, z), Power(w, n))) + + assert convert_chain3(full_form1) == sin(x*y) + assert convert_chain3(full_form2) == x*y + z + assert convert_chain3(full_form3) == sin(x*(y + z)*w**n) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_maxima.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc1db8f1385ed52e8c677a1bcc759f5118d01e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_maxima.py @@ -0,0 +1,50 @@ +from sympy.parsing.maxima import parse_maxima +from sympy.core.numbers import (E, Rational, oo) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x + +n = Symbol('n', integer=True) + + +def test_parser(): + assert Abs(parse_maxima('float(1/3)') - 0.333333333) < 10**(-5) + assert parse_maxima('13^26') == 91733330193268616658399616009 + assert parse_maxima('sin(%pi/2) + cos(%pi/3)') == Rational(3, 2) + assert parse_maxima('log(%e)') == 1 + + +def test_injection(): + parse_maxima('c: x+1', globals=globals()) + # c created by parse_maxima + assert c == x + 1 # noqa:F821 + + parse_maxima('g: sqrt(81)', globals=globals()) + # g created by parse_maxima + assert g == 9 # noqa:F821 + + +def test_maxima_functions(): + assert parse_maxima('expand( (x+1)^2)') == x**2 + 2*x + 1 + assert parse_maxima('factor( x**2 + 2*x + 1)') == (x + 1)**2 + assert parse_maxima('2*cos(x)^2 + sin(x)^2') == 2*cos(x)**2 + sin(x)**2 + assert parse_maxima('trigexpand(sin(2*x)+cos(2*x))') == \ + -1 + 2*cos(x)**2 + 2*cos(x)*sin(x) + assert parse_maxima('solve(x^2-4,x)') == [-2, 2] + assert parse_maxima('limit((1+1/x)^x,x,inf)') == E + assert parse_maxima('limit(sqrt(-x)/x,x,0,minus)') is -oo + assert parse_maxima('diff(x^x, x)') == x**x*(1 + log(x)) + assert parse_maxima('sum(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == (n**2 + n)/2 + assert parse_maxima('product(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == factorial(n) + assert parse_maxima('ratsimp((x^2-1)/(x+1))') == x - 1 + assert Abs( parse_maxima( + 'float(sec(%pi/3) + csc(%pi/3))') - 3.154700538379252) < 10**(-5) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sym_expr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..99912805db381b96e7f41a348fe6f90d71adf781 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sym_expr.py @@ -0,0 +1,209 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises +from sympy.external import import_module + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran and cin: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Declaration, FloatType) + from sympy.core import Integer, Float + from sympy.core.symbol import Symbol + + expr1 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + def test_c_parse(): + src1 = """\ + int a, b = 4; + float c, d = 2.4; + """ + expr1.convert_to_expr(src1, 'c') + ls = expr1.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3999999999999999', precision=53) + ) + ) + + + def test_fortran_parse(): + expr = SymPyExpression(src, 'f') + ls = expr.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + + + def test_convert_py(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_py = expr1.convert_to_python() + assert exp_py == [ + 'a = 0', + 'b = 0', + 'c = 0', + 'd = 0', + 'p = 0.0', + 'q = 0.0', + 'r = 0.0', + 's = 0.0', + 'a = b + c', + 's = p*q/r' + ] + + + def test_convert_fort(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_fort = expr1.convert_to_fortran() + assert exp_fort == [ + ' integer*4 a', + ' integer*4 b', + ' integer*4 c', + ' integer*4 d', + ' real*8 p', + ' real*8 q', + ' real*8 r', + ' real*8 s', + ' a = b + c', + ' s = p*q/r' + ] + + + def test_convert_c(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_c = expr1.convert_to_c() + assert exp_c == [ + 'int a = 0', + 'int b = 0', + 'int c = 0', + 'int d = 0', + 'double p = 0.0', + 'double q = 0.0', + 'double r = 0.0', + 'double s = 0.0', + 'a = b + c;', + 's = p*q/r;' + ] + + + def test_exceptions(): + src = 'int a;' + raises(ValueError, lambda: SymPyExpression(src)) + raises(ValueError, lambda: SymPyExpression(mode = 'c')) + raises(NotImplementedError, lambda: SymPyExpression(src, mode = 'd')) + +elif not lfortran and not cin: + def test_raise(): + raises(ImportError, lambda: SymPyExpression('int a;', 'c')) + raises(ImportError, lambda: SymPyExpression('integer :: a', 'f')) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sympy_parser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..43ecccbe262ffb4093248d891aa7423c8f62c628 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/parsing/tests/test_sympy_parser.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- + + +import builtins +import types + +from sympy.assumptions import Q +from sympy.core import Symbol, Function, Float, Rational, Integer, I, Mul, Pow, Eq, Lt, Le, Gt, Ge, Ne +from sympy.functions import exp, factorial, factorial2, sin, Min, Max +from sympy.logic import And +from sympy.series import Limit +from sympy.testing.pytest import raises + +from sympy.parsing.sympy_parser import ( + parse_expr, standard_transformations, rationalize, TokenError, + split_symbols, implicit_multiplication, convert_equals_signs, + convert_xor, function_exponentiation, lambda_notation, auto_symbol, + repeated_decimals, implicit_multiplication_application, + auto_number, factorial_notation, implicit_application, + _transformation, T + ) + + +def test_sympy_parser(): + x = Symbol('x') + inputs = { + '2*x': 2 * x, + '3.00': Float(3), + '22/7': Rational(22, 7), + '2+3j': 2 + 3*I, + 'exp(x)': exp(x), + 'x!': factorial(x), + 'x!!': factorial2(x), + '(x + 1)! - 1': factorial(x + 1) - 1, + '3.[3]': Rational(10, 3), + '.0[3]': Rational(1, 30), + '3.2[3]': Rational(97, 30), + '1.3[12]': Rational(433, 330), + '1 + 3.[3]': Rational(13, 3), + '1 + .0[3]': Rational(31, 30), + '1 + 3.2[3]': Rational(127, 30), + '.[0011]': Rational(1, 909), + '0.1[00102] + 1': Rational(366697, 333330), + '1.[0191]': Rational(10190, 9999), + '10!': 3628800, + '-(2)': -Integer(2), + '[-1, -2, 3]': [Integer(-1), Integer(-2), Integer(3)], + 'Symbol("x").free_symbols': x.free_symbols, + "S('S(3).n(n=3)')": Float(3, 3), + 'factorint(12, visual=True)': Mul( + Pow(2, 2, evaluate=False), + Pow(3, 1, evaluate=False), + evaluate=False), + 'Limit(sin(x), x, 0, dir="-")': Limit(sin(x), x, 0, dir='-'), + 'Q.even(x)': Q.even(x), + + + } + for text, result in inputs.items(): + assert parse_expr(text) == result + + raises(TypeError, lambda: + parse_expr('x', standard_transformations)) + raises(TypeError, lambda: + parse_expr('x', transformations=lambda x,y: 1)) + raises(TypeError, lambda: + parse_expr('x', transformations=(lambda x,y: 1,))) + raises(TypeError, lambda: parse_expr('x', transformations=((),))) + raises(TypeError, lambda: parse_expr('x', {}, [], [])) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + + +def test_rationalize(): + inputs = { + '0.123': Rational(123, 1000) + } + transformations = standard_transformations + (rationalize,) + for text, result in inputs.items(): + assert parse_expr(text, transformations=transformations) == result + + +def test_factorial_fail(): + inputs = ['x!!!', 'x!!!!', '(!)'] + + + for text in inputs: + try: + parse_expr(text) + assert False + except TokenError: + assert True + + +def test_repeated_fail(): + inputs = ['1[1]', '.1e1[1]', '0x1[1]', '1.1j[1]', '1.1[1 + 1]', + '0.1[[1]]', '0x1.1[1]'] + + + # All are valid Python, so only raise TypeError for invalid indexing + for text in inputs: + raises(TypeError, lambda: parse_expr(text)) + + + inputs = ['0.1[', '0.1[1', '0.1[]'] + for text in inputs: + raises((TokenError, SyntaxError), lambda: parse_expr(text)) + + +def test_repeated_dot_only(): + assert parse_expr('.[1]') == Rational(1, 9) + assert parse_expr('1 + .[1]') == Rational(10, 9) + + +def test_local_dict(): + local_dict = { + 'my_function': lambda x: x + 2 + } + inputs = { + 'my_function(2)': Integer(4) + } + for text, result in inputs.items(): + assert parse_expr(text, local_dict=local_dict) == result + + +def test_local_dict_split_implmult(): + t = standard_transformations + (split_symbols, implicit_multiplication,) + w = Symbol('w', real=True) + y = Symbol('y') + assert parse_expr('yx', local_dict={'x':w}, transformations=t) == y*w + + +def test_local_dict_symbol_to_fcn(): + x = Symbol('x') + d = {'foo': Function('bar')} + assert parse_expr('foo(x)', local_dict=d) == d['foo'](x) + d = {'foo': Symbol('baz')} + raises(TypeError, lambda: parse_expr('foo(x)', local_dict=d)) + + +def test_global_dict(): + global_dict = { + 'Symbol': Symbol + } + inputs = { + 'Q & S': And(Symbol('Q'), Symbol('S')) + } + for text, result in inputs.items(): + assert parse_expr(text, global_dict=global_dict) == result + + +def test_no_globals(): + + # Replicate creating the default global_dict: + default_globals = {} + exec('from sympy import *', default_globals) + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + default_globals[name] = obj + default_globals['max'] = Max + default_globals['min'] = Min + + # Need to include Symbol or parse_expr will not work: + default_globals.pop('Symbol') + global_dict = {'Symbol':Symbol} + + for name in default_globals: + obj = parse_expr(name, global_dict=global_dict) + assert obj == Symbol(name) + + +def test_issue_2515(): + raises(TokenError, lambda: parse_expr('(()')) + raises(TokenError, lambda: parse_expr('"""')) + + +def test_issue_7663(): + x = Symbol('x') + e = '2*(x+1)' + assert parse_expr(e, evaluate=False) == parse_expr(e, evaluate=False) + assert parse_expr(e, evaluate=False).equals(2*(x+1)) + +def test_recursive_evaluate_false_10560(): + inputs = { + '4*-3' : '4*-3', + '-4*3' : '(-4)*3', + "-2*x*y": '(-2)*x*y', + "x*-4*x": "x*(-4)*x" + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_function_evaluate_false(): + inputs = [ + 'Abs(0)', 'im(0)', 're(0)', 'sign(0)', 'arg(0)', 'conjugate(0)', + 'acos(0)', 'acot(0)', 'acsc(0)', 'asec(0)', 'asin(0)', 'atan(0)', + 'acosh(0)', 'acoth(0)', 'acsch(0)', 'asech(0)', 'asinh(0)', 'atanh(0)', + 'cos(0)', 'cot(0)', 'csc(0)', 'sec(0)', 'sin(0)', 'tan(0)', + 'cosh(0)', 'coth(0)', 'csch(0)', 'sech(0)', 'sinh(0)', 'tanh(0)', + 'exp(0)', 'log(0)', 'sqrt(0)', + ] + for case in inputs: + expr = parse_expr(case, evaluate=False) + assert case == str(expr) != str(expr.doit()) + assert str(parse_expr('ln(0)', evaluate=False)) == 'log(0)' + assert str(parse_expr('cbrt(0)', evaluate=False)) == '0**(1/3)' + + +def test_issue_10773(): + inputs = { + '-10/5': '(-10)/5', + '-10/-5' : '(-10)/(-5)', + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_split_symbols(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + xy = Symbol('xy') + + + assert parse_expr("xy") == xy + assert parse_expr("xy", transformations=transformations) == x*y + + +def test_split_symbols_function(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + f = Function('f') + + + assert parse_expr("ay(x+1)", transformations=transformations) == a*y*(x+1) + assert parse_expr("af(x+1)", transformations=transformations, + local_dict={'f':f}) == a*f(x+1) + + +def test_functional_exponent(): + t = standard_transformations + (convert_xor, function_exponentiation) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + yfcn = Function('y') + assert parse_expr("sin^2(x)", transformations=t) == (sin(x))**2 + assert parse_expr("sin^y(x)", transformations=t) == (sin(x))**y + assert parse_expr("exp^y(x)", transformations=t) == (exp(x))**y + assert parse_expr("E^y(x)", transformations=t) == exp(yfcn(x)) + assert parse_expr("a^y(x)", transformations=t) == a**(yfcn(x)) + + +def test_match_parentheses_implicit_multiplication(): + transformations = standard_transformations + \ + (implicit_multiplication,) + raises(TokenError, lambda: parse_expr('(1,2),(3,4]',transformations=transformations)) + + +def test_convert_equals_signs(): + transformations = standard_transformations + \ + (convert_equals_signs, ) + x = Symbol('x') + y = Symbol('y') + assert parse_expr("1*2=x", transformations=transformations) == Eq(2, x) + assert parse_expr("y = x", transformations=transformations) == Eq(y, x) + assert parse_expr("(2*y = x) = False", + transformations=transformations) == Eq(Eq(2*y, x), False) + + +def test_parse_function_issue_3539(): + x = Symbol('x') + f = Function('f') + assert parse_expr('f(x)') == f(x) + +def test_issue_24288(): + assert parse_expr("1 < 2", evaluate=False) == Lt(1, 2, evaluate=False) + assert parse_expr("1 <= 2", evaluate=False) == Le(1, 2, evaluate=False) + assert parse_expr("1 > 2", evaluate=False) == Gt(1, 2, evaluate=False) + assert parse_expr("1 >= 2", evaluate=False) == Ge(1, 2, evaluate=False) + assert parse_expr("1 != 2", evaluate=False) == Ne(1, 2, evaluate=False) + assert parse_expr("1 == 2", evaluate=False) == Eq(1, 2, evaluate=False) + assert parse_expr("1 < 2 < 3", evaluate=False) == And(Lt(1, 2, evaluate=False), Lt(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 <= 2 <= 3", evaluate=False) == And(Le(1, 2, evaluate=False), Le(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 < 2 <= 3 < 4", evaluate=False) == \ + And(Lt(1, 2, evaluate=False), Le(2, 3, evaluate=False), Lt(3, 4, evaluate=False), evaluate=False) + # Valid Python relational operators that SymPy does not decide how to handle them yet + raises(ValueError, lambda: parse_expr("1 in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 not in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is not 2", evaluate=False)) + +def test_split_symbols_numeric(): + transformations = ( + standard_transformations + + (implicit_multiplication_application,)) + + n = Symbol('n') + expr1 = parse_expr('2**n * 3**n') + expr2 = parse_expr('2**n3**n', transformations=transformations) + assert expr1 == expr2 == 2**n*3**n + + expr1 = parse_expr('n12n34', transformations=transformations) + assert expr1 == n*12*n*34 + + +def test_unicode_names(): + assert parse_expr('α') == Symbol('α') + + +def test_python3_features(): + assert parse_expr("123_456") == 123456 + assert parse_expr("1.2[3_4]") == parse_expr("1.2[34]") == Rational(611, 495) + assert parse_expr("1.2[012_012]") == parse_expr("1.2[012012]") == Rational(400, 333) + assert parse_expr('.[3_4]') == parse_expr('.[34]') == Rational(34, 99) + assert parse_expr('.1[3_4]') == parse_expr('.1[34]') == Rational(133, 990) + assert parse_expr('123_123.123_123[3_4]') == parse_expr('123123.123123[34]') == Rational(12189189189211, 99000000) + + +def test_issue_19501(): + x = Symbol('x') + eq = parse_expr('E**x(1+x)', local_dict={'x': x}, transformations=( + standard_transformations + + (implicit_multiplication_application,))) + assert eq.free_symbols == {x} + + +def test_parsing_definitions(): + from sympy.abc import x + assert len(_transformation) == 12 # if this changes, extend below + assert _transformation[0] == lambda_notation + assert _transformation[1] == auto_symbol + assert _transformation[2] == repeated_decimals + assert _transformation[3] == auto_number + assert _transformation[4] == factorial_notation + assert _transformation[5] == implicit_multiplication_application + assert _transformation[6] == convert_xor + assert _transformation[7] == implicit_application + assert _transformation[8] == implicit_multiplication + assert _transformation[9] == convert_equals_signs + assert _transformation[10] == function_exponentiation + assert _transformation[11] == rationalize + assert T[:5] == T[0,1,2,3,4] == standard_transformations + t = _transformation + assert T[-1, 0] == (t[len(t) - 1], t[0]) + assert T[:5, 8] == standard_transformations + (t[8],) + assert parse_expr('0.3x^2', transformations='all') == 3*x**2/10 + assert parse_expr('sin 3x', transformations='implicit') == sin(3*x) + + +def test_builtins(): + cases = [ + ('abs(x)', 'Abs(x)'), + ('max(x, y)', 'Max(x, y)'), + ('min(x, y)', 'Min(x, y)'), + ('pow(x, y)', 'Pow(x, y)'), + ] + for built_in_func_call, sympy_func_call in cases: + assert parse_expr(built_in_func_call) == parse_expr(sympy_func_call) + assert str(parse_expr('pow(38, -1, 97)')) == '23' + + +def test_issue_22822(): + raises(ValueError, lambda: parse_expr('x', {'': 1})) + data = {'some_parameter': None} + assert parse_expr('some_parameter is None', data) is True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4d4f8ba4a844eabf592683578ae2b87505f4a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/aesaracode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/aesaracode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c9f22e4c541219c553ba55cda86f49015e6e6c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/aesaracode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/c.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/c.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6988d6df76ff0e4470c62582c8bf6294c9e1a279 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/c.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/codeprinter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/codeprinter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb66ec1c7ff51b3aca824b6c97815476e89153b1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/codeprinter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/conventions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/conventions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e4a76a8fce93003e988311298a027bdc9077608 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/conventions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/cxx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/cxx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad05c281967d1fe4b6143fa809ae8efa18c444c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/cxx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/defaults.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/defaults.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60b0572763425d4e20caa7939cc1a7bfa3c53091 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/defaults.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/dot.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/dot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cce6d3ef248f59138b6405b27c0b3493db8276b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/dot.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/fortran.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/fortran.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..373cec9bd3b5c8fe9e3d43b296d59bb17689b598 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/fortran.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/glsl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/glsl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9adab434d72b1751df56dd9213df1c84c1be93d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/glsl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/gtk.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/gtk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4965471903a98667f9b6ef57d0f6116aef93475c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/gtk.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/jscode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/jscode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad95c3be46a79187e94abf665cb3f452ca7b8630 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/jscode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/julia.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/julia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f2ce646ea6964ec3815ed1a251ae12740b249ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/julia.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ddd80c41d39276240392dc2853fd0627d62004c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb276f1c23c68a22cb4b0c800bff5ecfdc03d91 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/maple.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/maple.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b88cb29ed3bd43ed4215abc953993801d2d454 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/maple.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/mathematica.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/mathematica.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..629c9d8abb6d43b08df048bbce3d376e730fd5eb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/mathematica.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/numpy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/numpy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ba3af47bf04a5b608ea5a43983500d5f16cd9e8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/numpy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/octave.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/octave.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4350af580965ec3f35b2fb62b51ee0a09767d135 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/octave.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/precedence.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/precedence.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de6ed0a5c698e82d752b2345262242e3cc6358d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/precedence.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/preview.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/preview.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2a25a7b4c120f6d9bebdf8deda9232d16235a27 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/preview.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/printer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/printer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b2908c01cc4630d822e57900a8238c6436da24 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/printer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pycode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pycode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52eabe20fa9f3ffa04a4c1065c878b4475c3b62d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pycode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/python.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39e4d5a4feac0ae76a7a1a8c74d33f26b21aeb88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/python.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pytorch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pytorch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79a2b678f8e3b7ef83c5a81468c4bb0e3c51ab7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/pytorch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rcode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rcode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0677e9b031438761ac65ab2dce128a42e7401ec1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rcode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/repr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/repr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba70ea10b3a96f24dbf00b2b43d98379b241a11 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/repr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rust.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rust.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e32c3b57a50efa747f93d539966922a4ce131447 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/rust.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/smtlib.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/smtlib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b66e39a6d8a74235f8144866faf537863e7ac94 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/smtlib.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/str.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/str.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48f8413ef62337cebc49df77a3c7e6a90947272b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/str.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tableform.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tableform.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56094f523cd0cbbe0b39382f446c1c9893f9af32 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tableform.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tensorflow.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tensorflow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0806e055e5cdec01856d65df8f12513ce18d3e04 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tensorflow.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/theanocode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/theanocode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ff06c99703156b5c7317353766550dcef48c789 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/theanocode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ead8144df043771c7cf26af3ad6e25900fd1e646 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/__pycache__/tree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/dot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/dot.py new file mode 100644 index 0000000000000000000000000000000000000000..c968fee389c16108b757b8fcad531ac6fa4ddb2f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/dot.py @@ -0,0 +1,294 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer, Rational, Float +from sympy.printing.repr import srepr + +__all__ = ['dotprint'] + +default_styles = ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) +) + +slotClasses = (Symbol, Integer, Rational, Float) +def purestr(x, with_args=False): + """A string that follows ```obj = type(obj)(*obj.args)``` exactly. + + Parameters + ========== + + with_args : boolean, optional + If ``True``, there will be a second argument for the return + value, which is a tuple containing ``purestr`` applied to each + of the subnodes. + + If ``False``, there will not be a second argument for the + return. + + Default is ``False`` + + Examples + ======== + + >>> from sympy import Float, Symbol, MatrixSymbol + >>> from sympy import Integer # noqa: F401 + >>> from sympy.core.symbol import Str # noqa: F401 + >>> from sympy.printing.dot import purestr + + Applying ``purestr`` for basic symbolic object: + >>> code = purestr(Symbol('x')) + >>> code + "Symbol('x')" + >>> eval(code) == Symbol('x') + True + + For basic numeric object: + >>> purestr(Float(2)) + "Float('2.0', precision=53)" + + For matrix symbol: + >>> code = purestr(MatrixSymbol('x', 2, 2)) + >>> code + "MatrixSymbol(Str('x'), Integer(2), Integer(2))" + >>> eval(code) == MatrixSymbol('x', 2, 2) + True + + With ``with_args=True``: + >>> purestr(Float(2), with_args=True) + ("Float('2.0', precision=53)", ()) + >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) + ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", + ("Str('x')", 'Integer(2)', 'Integer(2)')) + """ + sargs = () + if not isinstance(x, Basic): + rv = str(x) + elif not x.args: + rv = srepr(x) + else: + args = x.args + sargs = tuple(map(purestr, args)) + rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) + if with_args: + rv = rv, sargs + return rv + + +def styleof(expr, styles=default_styles): + """ Merge style dictionaries in order + + Examples + ======== + + >>> from sympy import Symbol, Basic, Expr, S + >>> from sympy.printing.dot import styleof + >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + ... (Expr, {'color': 'black'})] + + >>> styleof(Basic(S(1)), styles) + {'color': 'blue', 'shape': 'ellipse'} + + >>> x = Symbol('x') + >>> styleof(x + 1, styles) # this is an Expr + {'color': 'black', 'shape': 'ellipse'} + """ + style = {} + for typ, sty in styles: + if isinstance(expr, typ): + style.update(sty) + return style + + +def attrprint(d, delimiter=', '): + """ Print a dictionary of attributes + + Examples + ======== + + >>> from sympy.printing.dot import attrprint + >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) + "color"="blue", "shape"="ellipse" + """ + return delimiter.join('"%s"="%s"'%item for item in sorted(d.items())) + + +def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): + """ String defining a node + + Examples + ======== + + >>> from sympy.printing.dot import dotnode + >>> from sympy.abc import x + >>> print(dotnode(x)) + "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; + """ + style = styleof(expr, styles) + + if isinstance(expr, Basic) and not expr.is_Atom: + label = str(expr.__class__.__name__) + else: + label = labelfunc(expr) + style['label'] = label + expr_str = purestr(expr) + if repeat: + expr_str += '_%s' % str(pos) + return '"%s" [%s];' % (expr_str, attrprint(style)) + + +def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): + """ List of strings for all expr->expr.arg pairs + + See the docstring of dotprint for explanations of the options. + + Examples + ======== + + >>> from sympy.printing.dot import dotedges + >>> from sympy.abc import x + >>> for e in dotedges(x+2): + ... print(e) + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + """ + if atom(expr): + return [] + else: + expr_str, arg_strs = purestr(expr, with_args=True) + if repeat: + expr_str += '_%s' % str(pos) + arg_strs = ['%s_%s' % (a, str(pos + (i,))) + for i, a in enumerate(arg_strs)] + return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs] + +template = \ +"""digraph{ + +# Graph style +%(graphstyle)s + +######### +# Nodes # +######### + +%(nodes)s + +######### +# Edges # +######### + +%(edges)s +}""" + +_graphstyle = {'rankdir': 'TD', 'ordering': 'out'} + +def dotprint(expr, + styles=default_styles, atom=lambda x: not isinstance(x, Basic), + maxdepth=None, repeat=True, labelfunc=str, **kwargs): + """DOT description of a SymPy expression tree + + Parameters + ========== + + styles : list of lists composed of (Class, mapping), optional + Styles for different classes. + + The default is + + .. code-block:: python + + ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) + ) + + atom : function, optional + Function used to determine if an arg is an atom. + + A good choice is ``lambda x: not x.args``. + + The default is ``lambda x: not isinstance(x, Basic)``. + + maxdepth : integer, optional + The maximum depth. + + The default is ``None``, meaning no limit. + + repeat : boolean, optional + Whether to use different nodes for common subexpressions. + + The default is ``True``. + + For example, for ``x + x*y`` with ``repeat=True``, it will have + two nodes for ``x``; with ``repeat=False``, it will have one + node. + + .. warning:: + Even if a node appears twice in the same object like ``x`` in + ``Pow(x, x)``, it will still only appear once. + Hence, with ``repeat=False``, the number of arrows out of an + object might not equal the number of args it has. + + labelfunc : function, optional + A function to create a label for a given leaf node. + + The default is ``str``. + + Another good option is ``srepr``. + + For example with ``str``, the leaf nodes of ``x + 1`` are labeled, + ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` + and ``Integer(1)``. + + **kwargs : optional + Additional keyword arguments are included as styles for the graph. + + Examples + ======== + + >>> from sympy import dotprint + >>> from sympy.abc import x + >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE + digraph{ + + # Graph style + "ordering"="out" + "rankdir"="TD" + + ######### + # Nodes # + ######### + + "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; + "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; + "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; + + ######### + # Edges # + ######### + + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + } + + """ + # repeat works by adding a signature tuple to the end of each node for its + # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the + # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. + graphstyle = _graphstyle.copy() + graphstyle.update(kwargs) + + nodes = [] + edges = [] + def traverse(e, depth, pos=()): + nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) + if maxdepth and depth >= maxdepth: + return + edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) + [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] + traverse(expr, 0) + + return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), + 'nodes': '\n'.join(nodes), + 'edges': '\n'.join(edges)} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/lambdarepr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..87fa0988d138d54d68ab8aef1bbc0f27b243b472 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/lambdarepr.py @@ -0,0 +1,251 @@ +from .pycode import ( + PythonCodePrinter, + MpmathPrinter, +) +from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility +from sympy.core.sorting import default_sort_key + + +__all__ = [ + 'PythonCodePrinter', + 'MpmathPrinter', # MpmathPrinter is published for backward compatibility + 'NumPyPrinter', + 'LambdaPrinter', + 'NumPyPrinter', + 'IntervalPrinter', + 'lambdarepr', +] + + +class LambdaPrinter(PythonCodePrinter): + """ + This printer converts expressions into strings that can be used by + lambdify. + """ + printmethod = "_lambdacode" + + + def _print_And(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' and ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Or(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' or ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Not(self, expr): + result = ['(', 'not (', self._print(expr.args[0]), '))'] + return ''.join(result) + + def _print_BooleanTrue(self, expr): + return "True" + + def _print_BooleanFalse(self, expr): + return "False" + + def _print_ITE(self, expr): + result = [ + '((', self._print(expr.args[1]), + ') if (', self._print(expr.args[0]), + ') else (', self._print(expr.args[2]), '))' + ] + return ''.join(result) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Pow(self, expr, **kwargs): + # XXX Temporary workaround. Should Python math printer be + # isolated from PythonCodePrinter? + return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs) + + +# numexpr works by altering the string passed to numexpr.evaluate +# rather than by populating a namespace. Thus a special printer... +class NumExprPrinter(LambdaPrinter): + # key, value pairs correspond to SymPy name and numexpr name + # functions not appearing in this dict will raise a TypeError + printmethod = "_numexprcode" + + _numexpr_functions = { + 'sin' : 'sin', + 'cos' : 'cos', + 'tan' : 'tan', + 'asin': 'arcsin', + 'acos': 'arccos', + 'atan': 'arctan', + 'atan2' : 'arctan2', + 'sinh' : 'sinh', + 'cosh' : 'cosh', + 'tanh' : 'tanh', + 'asinh': 'arcsinh', + 'acosh': 'arccosh', + 'atanh': 'arctanh', + 'ln' : 'log', + 'log': 'log', + 'exp': 'exp', + 'sqrt' : 'sqrt', + 'Abs' : 'abs', + 'conjugate' : 'conj', + 'im' : 'imag', + 're' : 'real', + 'where' : 'where', + 'complex' : 'complex', + 'contains' : 'contains', + } + + module = 'numexpr' + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_seq(self, seq, delimiter=', '): + # simplified _print_seq taken from pretty.py + s = [self._print(item) for item in seq] + if s: + return delimiter.join(s) + else: + return "" + + def _print_Function(self, e): + func_name = e.func.__name__ + + nstr = self._numexpr_functions.get(func_name, None) + if nstr is None: + # check for implemented_function + if hasattr(e, '_imp_'): + return "(%s)" % self._print(e._imp_(*e.args)) + else: + raise TypeError("numexpr does not support function '%s'" % + func_name) + return "%s(%s)" % (nstr, self._print_seq(e.args)) + + def _print_Piecewise(self, expr): + "Piecewise function printer" + exprs = [self._print(arg.expr) for arg in expr.args] + conds = [self._print(arg.cond) for arg in expr.args] + # If [default_value, True] is a (expr, cond) sequence in a Piecewise object + # it will behave the same as passing the 'default' kwarg to select() + # *as long as* it is the last element in expr.args. + # If this is not the case, it may be triggered prematurely. + ans = [] + parenthesis_count = 0 + is_last_cond_True = False + for cond, expr in zip(conds, exprs): + if cond == 'True': + ans.append(expr) + is_last_cond_True = True + break + else: + ans.append('where(%s, %s, ' % (cond, expr)) + parenthesis_count += 1 + if not is_last_cond_True: + # See https://github.com/pydata/numexpr/issues/298 + # + # simplest way to put a nan but raises + # 'RuntimeWarning: invalid value encountered in log' + # + # There are other ways to do this such as + # + # >>> import numexpr as ne + # >>> nan = float('nan') + # >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan}) + # array([-1., nan, nan]) + # + # That needs to be handled in the lambdified function though rather + # than here in the printer. + ans.append('log(-1)') + return ''.join(ans) + ')' * parenthesis_count + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def blacklisted(self, expr): + raise TypeError("numexpr cannot be used with %s" % + expr.__class__.__name__) + + # blacklist all Matrix printing + _print_SparseRepMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + blacklisted + # blacklist some Python expressions + _print_list = \ + _print_tuple = \ + _print_Tuple = \ + _print_dict = \ + _print_Dict = \ + blacklisted + + def _print_NumExprEvaluate(self, expr): + evaluate = self._module_format(self.module +".evaluate") + return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr)) + + def doprint(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + if not isinstance(expr, CodegenAST): + expr = NumExprEvaluate(expr) + return super().doprint(expr) + + def _print_Return(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + r, = expr.args + if not isinstance(r, NumExprEvaluate): + expr = expr.func(NumExprEvaluate(r)) + return super()._print_Return(expr) + + def _print_Assignment(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + lhs, rhs, *args = expr.args + if not isinstance(rhs, NumExprEvaluate): + expr = expr.func(lhs, NumExprEvaluate(rhs), *args) + return super()._print_Assignment(expr) + + def _print_CodeBlock(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ] + return super()._print_CodeBlock(self, expr.func(*args)) + + +class IntervalPrinter(MpmathPrinter, LambdaPrinter): + """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """ + + def _print_Integer(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr) + + def _print_Rational(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Half(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Pow(self, expr): + return super(MpmathPrinter, self)._print_Pow(expr, rational=True) + + +for k in NumExprPrinter._numexpr_functions: + setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function) + +def lambdarepr(expr, **settings): + """ + Returns a string usable for lambdifying. + """ + return LambdaPrinter(settings).doprint(expr) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbabc649152a3c353a37225d342064634fbb5805 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__init__.py @@ -0,0 +1,12 @@ +"""ASCII-ART 2D pretty-printer""" + +from .pretty import (pretty, pretty_print, pprint, pprint_use_unicode, + pprint_try_use_unicode, pager_print) + +# if unicode output is available -- let's use it +pprint_try_use_unicode() + +__all__ = [ + 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode', + 'pprint_try_use_unicode', 'pager_print', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1ced20137fac96e6df6a487e101b127b733993a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd7d5ebe83ae765e7f903aba2eae1218b4609225 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8228e0eae9f56d327aa261e3ba150e30f6161b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..b945f009119b24fc95e8452d91359957baba26a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty.py @@ -0,0 +1,2937 @@ +import itertools + +from sympy.core import S +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import Number, Rational +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import SympifyError +from sympy.printing.conventions import requires_partial +from sympy.printing.precedence import PRECEDENCE, precedence, precedence_traditional +from sympy.printing.printer import Printer, print_function +from sympy.printing.str import sstr +from sympy.utilities.iterables import has_variety +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import hobj, vobj, xobj, \ + xsym, pretty_symbol, pretty_atom, pretty_use_unicode, greek_unicode, U, \ + pretty_try_use_unicode, annotated, is_subscriptable_in_unicode, center_pad, root as nth_root + +# rename for usage from outside +pprint_use_unicode = pretty_use_unicode +pprint_try_use_unicode = pretty_try_use_unicode + + +class PrettyPrinter(Printer): + """Printer, which converts an expression into 2D ASCII-art figure.""" + printmethod = "_pretty" + + _default_settings = { + "order": None, + "full_prec": "auto", + "use_unicode": None, + "wrap_line": True, + "num_columns": None, + "use_unicode_sqrt_char": True, + "root_notation": True, + "mat_symbol_style": "plain", + "imaginary_unit": "i", + "perm_cyclic": True + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + + if not isinstance(self._settings['imaginary_unit'], str): + raise TypeError("'imaginary_unit' must a string, not {}".format(self._settings['imaginary_unit'])) + elif self._settings['imaginary_unit'] not in ("i", "j"): + raise ValueError("'imaginary_unit' must be either 'i' or 'j', not '{}'".format(self._settings['imaginary_unit'])) + + def emptyPrinter(self, expr): + return prettyForm(str(expr)) + + @property + def _use_unicode(self): + if self._settings['use_unicode']: + return True + else: + return pretty_use_unicode() + + def doprint(self, expr): + return self._print(expr).render(**self._settings) + + # empty op so _print(stringPict) returns the same + def _print_stringPict(self, e): + return e + + def _print_basestring(self, e): + return prettyForm(e) + + def _print_atan2(self, e): + pform = prettyForm(*self._print_seq(e.args).parens()) + pform = prettyForm(*pform.left('atan2')) + return pform + + def _print_Symbol(self, e, bold_name=False): + symb = pretty_symbol(e.name, bold_name) + return prettyForm(symb) + _print_RandomSymbol = _print_Symbol + def _print_MatrixSymbol(self, e): + return self._print_Symbol(e, self._settings['mat_symbol_style'] == "bold") + + def _print_Float(self, e): + # we will use StrPrinter's Float printer, but we need to handle the + # full_prec ourselves, according to the self._print_level + full_prec = self._settings["full_prec"] + if full_prec == "auto": + full_prec = self._print_level == 1 + return prettyForm(sstr(e, full_prec=full_prec)) + + def _print_Cross(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Curl(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Divergence(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Dot(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Gradient(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Laplacian(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('INCREMENT')))) + return pform + + def _print_Atom(self, e): + try: + # print atoms like Exp1 or Pi + return prettyForm(pretty_atom(e.__class__.__name__, printer=self)) + except KeyError: + return self.emptyPrinter(e) + + # Infinity inherits from Number, so we have to override _print_XXX order + _print_Infinity = _print_Atom + _print_NegativeInfinity = _print_Atom + _print_EmptySet = _print_Atom + _print_Naturals = _print_Atom + _print_Naturals0 = _print_Atom + _print_Integers = _print_Atom + _print_Rationals = _print_Atom + _print_Complexes = _print_Atom + + _print_EmptySequence = _print_Atom + + def _print_Reals(self, e): + if self._use_unicode: + return self._print_Atom(e) + else: + inf_list = ['-oo', 'oo'] + return self._print_seq(inf_list, '(', ')') + + def _print_subfactorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('!')) + return pform + + def _print_factorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!')) + return pform + + def _print_factorial2(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!!')) + return pform + + def _print_binomial(self, e): + n, k = e.args + + n_pform = self._print(n) + k_pform = self._print(k) + + bar = ' '*max(n_pform.width(), k_pform.width()) + + pform = prettyForm(*k_pform.above(bar)) + pform = prettyForm(*pform.above(n_pform)) + pform = prettyForm(*pform.parens('(', ')')) + + pform.baseline = (pform.baseline + 1)//2 + + return pform + + def _print_Relational(self, e): + op = prettyForm(' ' + xsym(e.rel_op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r), binding=prettyForm.OPEN) + return pform + + def _print_Not(self, e): + from sympy.logic.boolalg import (Equivalent, Implies) + if self._use_unicode: + arg = e.args[0] + pform = self._print(arg) + if isinstance(arg, Equivalent): + return self._print_Equivalent(arg, altchar=pretty_atom('NotEquiv')) + if isinstance(arg, Implies): + return self._print_Implies(arg, altchar=pretty_atom('NotArrow')) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + return prettyForm(*pform.left(pretty_atom('Not'))) + else: + return self._print_Function(e) + + def __print_Boolean(self, e, char, sort=True): + args = e.args + if sort: + args = sorted(e.args, key=default_sort_key) + arg = args[0] + pform = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + for arg in args[1:]: + pform_arg = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform_arg = prettyForm(*pform_arg.parens()) + + pform = prettyForm(*pform.right(' %s ' % char)) + pform = prettyForm(*pform.right(pform_arg)) + + return pform + + def _print_And(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('And')) + else: + return self._print_Function(e, sort=True) + + def _print_Or(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Or')) + else: + return self._print_Function(e, sort=True) + + def _print_Xor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom("Xor")) + else: + return self._print_Function(e, sort=True) + + def _print_Nand(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nand')) + else: + return self._print_Function(e, sort=True) + + def _print_Nor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nor')) + else: + return self._print_Function(e, sort=True) + + def _print_Implies(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Arrow'), sort=False) + else: + return self._print_Function(e) + + def _print_Equivalent(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Equiv')) + else: + return self._print_Function(e, sort=True) + + def _print_conjugate(self, e): + pform = self._print(e.args[0]) + return prettyForm( *pform.above( hobj('_', pform.width())) ) + + def _print_Abs(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('|', '|')) + return pform + + def _print_floor(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lfloor', 'rfloor')) + return pform + else: + return self._print_Function(e) + + def _print_ceiling(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lceil', 'rceil')) + return pform + else: + return self._print_Function(e) + + def _print_Derivative(self, deriv): + if requires_partial(deriv.expr) and self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + count_total_deriv = 0 + + for sym, num in reversed(deriv.variable_count): + s = self._print(sym) + ds = prettyForm(*s.left(deriv_symbol)) + count_total_deriv += num + + if (not num.is_Integer) or (num > 1): + ds = ds**prettyForm(str(num)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if (count_total_deriv > 1) != False: + pform = pform**prettyForm(str(count_total_deriv)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Cycle(self, dc): + from sympy.combinatorics.permutations import Permutation, Cycle + # for Empty Cycle + if dc == Cycle(): + cyc = stringPict('') + return prettyForm(*cyc.parens()) + + dc_list = Permutation(dc.list()).cyclic_form + # for Identity Cycle + if dc_list == []: + cyc = self._print(dc.size - 1) + return prettyForm(*cyc.parens()) + + cyc = stringPict('') + for i in dc_list: + l = self._print(str(tuple(i)).replace(',', '')) + cyc = prettyForm(*cyc.right(l)) + return cyc + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + return self._print_Cycle(Cycle(expr)) + + lower = expr.array_form + upper = list(range(len(lower))) + + result = stringPict('') + first = True + for u, l in zip(upper, lower): + s1 = self._print(u) + s2 = self._print(l) + col = prettyForm(*s1.below(s2)) + if first: + first = False + else: + col = prettyForm(*col.left(" ")) + result = prettyForm(*result.right(col)) + return prettyForm(*result.parens()) + + + def _print_Integral(self, integral): + f = integral.function + + # Add parentheses if arg involves addition of terms and + # create a pretty form for the argument + prettyF = self._print(f) + # XXX generalize parens + if f.is_Add: + prettyF = prettyForm(*prettyF.parens()) + + # dx dy dz ... + arg = prettyF + for x in integral.limits: + prettyArg = self._print(x[0]) + # XXX qparens (parens if needs-parens) + if prettyArg.width() > 1: + prettyArg = prettyForm(*prettyArg.parens()) + + arg = prettyForm(*arg.right(' d', prettyArg)) + + # \int \int \int ... + firstterm = True + s = None + for lim in integral.limits: + # Create bar based on the height of the argument + h = arg.height() + H = h + 2 + + # XXX hack! + ascii_mode = not self._use_unicode + if ascii_mode: + H += 2 + + vint = vobj('int', H) + + # Construct the pretty form with the integral sign and the argument + pform = prettyForm(vint) + pform.baseline = arg.baseline + ( + H - h)//2 # covering the whole argument + + if len(lim) > 1: + # Create pretty forms for endpoints, if definite integral. + # Do not print empty endpoints. + if len(lim) == 2: + prettyA = prettyForm("") + prettyB = self._print(lim[1]) + if len(lim) == 3: + prettyA = self._print(lim[1]) + prettyB = self._print(lim[2]) + + if ascii_mode: # XXX hack + # Add spacing so that endpoint can more easily be + # identified with the correct integral sign + spc = max(1, 3 - prettyB.width()) + prettyB = prettyForm(*prettyB.left(' ' * spc)) + + spc = max(1, 4 - prettyA.width()) + prettyA = prettyForm(*prettyA.right(' ' * spc)) + + pform = prettyForm(*pform.above(prettyB)) + pform = prettyForm(*pform.below(prettyA)) + + if not ascii_mode: # XXX hack + pform = prettyForm(*pform.right(' ')) + + if firstterm: + s = pform # first term + firstterm = False + else: + s = prettyForm(*s.left(pform)) + + pform = prettyForm(*arg.left(s)) + pform.binding = prettyForm.MUL + return pform + + def _print_Product(self, expr): + func = expr.term + pretty_func = self._print(func) + + horizontal_chr = xobj('_', 1) + corner_chr = xobj('_', 1) + vertical_chr = xobj('|', 1) + + if self._use_unicode: + # use unicode corners + horizontal_chr = xobj('-', 1) + corner_chr = xobj('UpTack', 1) + + func_height = pretty_func.height() + + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + pretty_lower, pretty_upper = self.__print_SumProduct_Limits(lim) + + width = (func_height + 2) * 5 // 3 - 2 + sign_lines = [horizontal_chr + corner_chr + (horizontal_chr * (width-2)) + corner_chr + horizontal_chr] + for _ in range(func_height + 1): + sign_lines.append(' ' + vertical_chr + (' ' * (width-2)) + vertical_chr + ' ') + + pretty_sign = stringPict('') + pretty_sign = prettyForm(*pretty_sign.stack(*sign_lines)) + + + max_upper = max(max_upper, pretty_upper.height()) + + if first: + sign_height = pretty_sign.height() + + pretty_sign = prettyForm(*pretty_sign.above(pretty_upper)) + pretty_sign = prettyForm(*pretty_sign.below(pretty_lower)) + + if first: + pretty_func.baseline = 0 + first = False + + height = pretty_sign.height() + padding = stringPict('') + padding = prettyForm(*padding.stack(*[' ']*(height - 1))) + pretty_sign = prettyForm(*pretty_sign.right(padding)) + + pretty_func = prettyForm(*pretty_sign.right(pretty_func)) + + pretty_func.baseline = max_upper + sign_height//2 + pretty_func.binding = prettyForm.MUL + return pretty_func + + def __print_SumProduct_Limits(self, lim): + def print_start(lhs, rhs): + op = prettyForm(' ' + xsym("==") + ' ') + l = self._print(lhs) + r = self._print(rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + prettyUpper = self._print(lim[2]) + prettyLower = print_start(lim[0], lim[1]) + return prettyLower, prettyUpper + + def _print_Sum(self, expr): + ascii_mode = not self._use_unicode + + def asum(hrequired, lower, upper, use_ascii): + def adjust(s, wid=None, how='<^>'): + if not wid or len(s) > wid: + return s + need = wid - len(s) + if how in ('<^>', "<") or how not in list('<^>'): + return s + ' '*need + half = need//2 + lead = ' '*half + if how == ">": + return " "*need + s + return lead + s + ' '*(need - len(lead)) + + h = max(hrequired, 2) + d = h//2 + w = d + 1 + more = hrequired % 2 + + lines = [] + if use_ascii: + lines.append("_"*(w) + ' ') + lines.append(r"\%s`" % (' '*(w - 1))) + for i in range(1, d): + lines.append('%s\\%s' % (' '*i, ' '*(w - i))) + if more: + lines.append('%s)%s' % (' '*(d), ' '*(w - d))) + for i in reversed(range(1, d)): + lines.append('%s/%s' % (' '*i, ' '*(w - i))) + lines.append("/" + "_"*(w - 1) + ',') + return d, h + more, lines, more + else: + w = w + more + d = d + more + vsum = vobj('sum', 4) + lines.append("_"*(w)) + for i in range(0, d): + lines.append('%s%s%s' % (' '*i, vsum[2], ' '*(w - i - 1))) + for i in reversed(range(0, d)): + lines.append('%s%s%s' % (' '*i, vsum[4], ' '*(w - i - 1))) + lines.append(vsum[8]*(w)) + return d, h + 2*more, lines, more + + f = expr.function + + prettyF = self._print(f) + + if f.is_Add: # add parens + prettyF = prettyForm(*prettyF.parens()) + + H = prettyF.height() + 2 + + # \sum \sum \sum ... + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + prettyLower, prettyUpper = self.__print_SumProduct_Limits(lim) + + max_upper = max(max_upper, prettyUpper.height()) + + # Create sum sign based on the height of the argument + d, h, slines, adjustment = asum( + H, prettyLower.width(), prettyUpper.width(), ascii_mode) + prettySign = stringPict('') + prettySign = prettyForm(*prettySign.stack(*slines)) + + if first: + sign_height = prettySign.height() + + prettySign = prettyForm(*prettySign.above(prettyUpper)) + prettySign = prettyForm(*prettySign.below(prettyLower)) + + if first: + # change F baseline so it centers on the sign + prettyF.baseline -= d - (prettyF.height()//2 - + prettyF.baseline) + first = False + + # put padding to the right + pad = stringPict('') + pad = prettyForm(*pad.stack(*[' ']*h)) + prettySign = prettyForm(*prettySign.right(pad)) + # put the present prettyF to the right + prettyF = prettyForm(*prettySign.right(prettyF)) + + # adjust baseline of ascii mode sigma with an odd height so that it is + # exactly through the center + ascii_adjustment = ascii_mode if not adjustment else 0 + prettyF.baseline = max_upper + sign_height//2 + ascii_adjustment + + prettyF.binding = prettyForm.MUL + return prettyF + + def _print_Limit(self, l): + e, z, z0, dir = l.args + + E = self._print(e) + if precedence(e) <= PRECEDENCE["Mul"]: + E = prettyForm(*E.parens('(', ')')) + Lim = prettyForm('lim') + + LimArg = self._print(z) + if self._use_unicode: + LimArg = prettyForm(*LimArg.right(f"{xobj('-', 1)}{pretty_atom('Arrow')}")) + else: + LimArg = prettyForm(*LimArg.right('->')) + LimArg = prettyForm(*LimArg.right(self._print(z0))) + + if str(dir) == '+-' or z0 in (S.Infinity, S.NegativeInfinity): + dir = "" + else: + if self._use_unicode: + dir = pretty_atom('SuperscriptPlus') if str(dir) == "+" else pretty_atom('SuperscriptMinus') + + LimArg = prettyForm(*LimArg.right(self._print(dir))) + + Lim = prettyForm(*Lim.below(LimArg)) + Lim = prettyForm(*Lim.right(E), binding=prettyForm.MUL) + + return Lim + + def _print_matrix_contents(self, e): + """ + This method factors out what is essentially grid printing. + """ + M = e # matrix + Ms = {} # i,j -> pretty(M[i,j]) + for i in range(M.rows): + for j in range(M.cols): + Ms[i, j] = self._print(M[i, j]) + + # h- and v- spacers + hsep = 2 + vsep = 1 + + # max width for columns + maxw = [-1] * M.cols + + for j in range(M.cols): + maxw[j] = max([Ms[i, j].width() for i in range(M.rows)] or [0]) + + # drawing result + D = None + + for i in range(M.rows): + + D_row = None + for j in range(M.cols): + s = Ms[i, j] + + # reshape s to maxw + # XXX this should be generalized, and go to stringPict.reshape ? + assert s.width() <= maxw[j] + + # hcenter it, +0.5 to the right 2 + # ( it's better to align formula starts for say 0 and r ) + # XXX this is not good in all cases -- maybe introduce vbaseline? + left, right = center_pad(s.width(), maxw[j]) + + s = prettyForm(*s.right(right)) + s = prettyForm(*s.left(left)) + + # we don't need vcenter cells -- this is automatically done in + # a pretty way because when their baselines are taking into + # account in .right() + + if D_row is None: + D_row = s # first box in a row + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(s)) + + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + if D is None: + D = prettyForm('') # Empty Matrix + + return D + + def _print_MatrixBase(self, e, lparens='[', rparens=']'): + D = self._print_matrix_contents(e) + D.baseline = D.height()//2 + D = prettyForm(*D.parens(lparens, rparens)) + return D + + def _print_Determinant(self, e): + mat = e.arg + if mat.is_MatrixExpr: + from sympy.matrices.expressions.blockmatrix import BlockMatrix + if isinstance(mat, BlockMatrix): + return self._print_MatrixBase(mat.blocks, lparens='|', rparens='|') + D = self._print(mat) + D.baseline = D.height()//2 + return prettyForm(*D.parens('|', '|')) + else: + return self._print_MatrixBase(mat, lparens='|', rparens='|') + + def _print_TensorProduct(self, expr): + # This should somehow share the code with _print_WedgeProduct: + if self._use_unicode: + circled_times = "\u2297" + else: + circled_times = ".*" + return self._print_seq(expr.args, None, None, circled_times, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_WedgeProduct(self, expr): + # This should somehow share the code with _print_TensorProduct: + if self._use_unicode: + wedge_symbol = "\u2227" + else: + wedge_symbol = '/\\' + return self._print_seq(expr.args, None, None, wedge_symbol, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_Trace(self, e): + D = self._print(e.arg) + D = prettyForm(*D.parens('(',')')) + D.baseline = D.height()//2 + D = prettyForm(*D.left('\n'*(0) + 'tr')) + return D + + + def _print_MatrixElement(self, expr): + from sympy.matrices import MatrixSymbol + if (isinstance(expr.parent, MatrixSymbol) + and expr.i.is_number and expr.j.is_number): + return self._print( + Symbol(expr.parent.name + '_%d%d' % (expr.i, expr.j))) + else: + prettyFunc = self._print(expr.parent) + prettyFunc = prettyForm(*prettyFunc.parens()) + prettyIndices = self._print_seq((expr.i, expr.j), delimiter=', ' + ).parens(left='[', right=']')[0] + pform = prettyForm(binding=prettyForm.FUNC, + *stringPict.next(prettyFunc, prettyIndices)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyIndices + + return pform + + + def _print_MatrixSlice(self, m): + # XXX works only for applied functions + from sympy.matrices import MatrixSymbol + prettyFunc = self._print(m.parent) + if not isinstance(m.parent, MatrixSymbol): + prettyFunc = prettyForm(*prettyFunc.parens()) + def ppslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = '' + if x[1] == dim: + x[1] = '' + return prettyForm(*self._print_seq(x, delimiter=':')) + prettyArgs = self._print_seq((ppslice(m.rowslice, m.parent.rows), + ppslice(m.colslice, m.parent.cols)), delimiter=', ').parens(left='[', right=']')[0] + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_Transpose(self, expr): + mat = expr.arg + pform = self._print(mat) + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**(prettyForm('T')) + return pform + + def _print_Adjoint(self, expr): + mat = expr.arg + pform = self._print(mat) + if self._use_unicode: + dag = prettyForm(pretty_atom('Dagger')) + else: + dag = prettyForm('+') + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**dag + return pform + + def _print_BlockMatrix(self, B): + if B.blocks.shape == (1, 1): + return self._print(B.blocks[0, 0]) + return self._print(B.blocks) + + def _print_MatAdd(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + coeff = item.as_coeff_mmul()[0] + if S(coeff).could_extract_minus_sign(): + s = prettyForm(*stringPict.next(s, ' ')) + pform = self._print(item) + else: + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + + return s + + def _print_MatMul(self, expr): + args = list(expr.args) + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.matadd import MatAdd + for i, a in enumerate(args): + if (isinstance(a, (Add, MatAdd, HadamardProduct, KroneckerProduct)) + and len(expr.args) > 1): + args[i] = prettyForm(*self._print(a).parens()) + else: + args[i] = self._print(a) + + return prettyForm.__mul__(*args) + + def _print_Identity(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('IdentityMatrix')) + else: + return prettyForm('I') + + def _print_ZeroMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('ZeroMatrix')) + else: + return prettyForm('0') + + def _print_OneMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom("OneMatrix")) + else: + return prettyForm('1') + + def _print_DotProduct(self, expr): + args = list(expr.args) + + for i, a in enumerate(args): + args[i] = self._print(a) + return prettyForm.__mul__(*args) + + def _print_MatPow(self, expr): + pform = self._print(expr.base) + from sympy.matrices import MatrixSymbol + if not isinstance(expr.base, MatrixSymbol) and expr.base.is_MatrixExpr: + pform = prettyForm(*pform.parens()) + pform = pform**(self._print(expr.exp)) + return pform + + def _print_HadamardProduct(self, expr): + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = pretty_atom('Ring') + else: + delim = '.*' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul, HadamardProduct))) + + def _print_HadamardPower(self, expr): + # from sympy import MatAdd, MatMul + if self._use_unicode: + circ = pretty_atom('Ring') + else: + circ = self._print('.') + pretty_base = self._print(expr.base) + pretty_exp = self._print(expr.exp) + if precedence(expr.exp) < PRECEDENCE["Mul"]: + pretty_exp = prettyForm(*pretty_exp.parens()) + pretty_circ_exp = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(circ, pretty_exp) + ) + return pretty_base**pretty_circ_exp + + def _print_KroneckerProduct(self, expr): + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = f" {pretty_atom('TensorProduct')} " + else: + delim = ' x ' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul))) + + def _print_FunctionMatrix(self, X): + D = self._print(X.lamda.expr) + D = prettyForm(*D.parens('[', ']')) + return D + + def _print_TransferFunction(self, expr): + if not expr.num == 1: + num, den = expr.num, expr.den + res = Mul(num, Pow(den, -1, evaluate=False), evaluate=False) + return self._print_Mul(res) + else: + return self._print(1)/self._print(expr.den) + + def _print_Series(self, expr): + args = list(expr.args) + for i, a in enumerate(expr.args): + args[i] = prettyForm(*self._print(a).parens()) + return prettyForm.__mul__(*args) + + def _print_MIMOSeries(self, expr): + from sympy.physics.control.lti import MIMOParallel + args = list(expr.args) + pretty_args = [] + for a in reversed(args): + if (isinstance(a, MIMOParallel) and len(expr.args) > 1): + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(prettyForm(*expression.parens())) + else: + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(expression) + return prettyForm.__mul__(*pretty_args) + + def _print_Parallel(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + return s + + def _print_MIMOParallel(self, expr): + from sympy.physics.control.lti import TransferFunctionMatrix + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + if isinstance(item, TransferFunctionMatrix): + s.baseline = s.height() - 1 + s = prettyForm(*stringPict.next(s, pform)) + # s.baseline = s.height()//2 + return s + + def _print_Feedback(self, expr): + from sympy.physics.control import TransferFunction, Series + + num, tf = expr.sys1, TransferFunction(1, 1, expr.var) + num_arg_list = list(num.args) if isinstance(num, Series) else [num] + den_arg_list = list(expr.sys2.args) if \ + isinstance(expr.sys2, Series) else [expr.sys2] + + if isinstance(num, Series) and isinstance(expr.sys2, Series): + den = Series(*num_arg_list, *den_arg_list) + elif isinstance(num, Series) and isinstance(expr.sys2, TransferFunction): + if expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, expr.sys2) + elif isinstance(num, TransferFunction) and isinstance(expr.sys2, Series): + if num == tf: + den = Series(*den_arg_list) + else: + den = Series(num, *den_arg_list) + else: + if num == tf: + den = Series(*den_arg_list) + elif expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, *den_arg_list) + + denom = prettyForm(*stringPict.next(self._print(tf))) + denom.baseline = denom.height()//2 + denom = prettyForm(*stringPict.next(denom, ' + ')) if expr.sign == -1 \ + else prettyForm(*stringPict.next(denom, ' - ')) + denom = prettyForm(*stringPict.next(denom, self._print(den))) + + return self._print(num)/denom + + def _print_MIMOFeedback(self, expr): + from sympy.physics.control import MIMOSeries, TransferFunctionMatrix + + inv_mat = self._print(MIMOSeries(expr.sys2, expr.sys1)) + plant = self._print(expr.sys1) + _feedback = prettyForm(*stringPict.next(inv_mat)) + _feedback = prettyForm(*stringPict.right("I + ", _feedback)) if expr.sign == -1 \ + else prettyForm(*stringPict.right("I - ", _feedback)) + _feedback = prettyForm(*stringPict.parens(_feedback)) + _feedback.baseline = 0 + _feedback = prettyForm(*stringPict.right(_feedback, '-1 ')) + _feedback.baseline = _feedback.height()//2 + _feedback = prettyForm.__mul__(_feedback, prettyForm(" ")) + if isinstance(expr.sys1, TransferFunctionMatrix): + _feedback.baseline = _feedback.height() - 1 + _feedback = prettyForm(*stringPict.next(_feedback, plant)) + return _feedback + + def _print_TransferFunctionMatrix(self, expr): + mat = self._print(expr._expr_mat) + mat.baseline = mat.height() - 1 + subscript = greek_unicode['tau'] if self._use_unicode else r'{t}' + mat = prettyForm(*mat.right(subscript)) + return mat + + def _print_StateSpace(self, expr): + from sympy.matrices.expressions.blockmatrix import BlockMatrix + A = expr._A + B = expr._B + C = expr._C + D = expr._D + mat = BlockMatrix([[A, B], [C, D]]) + return self._print(mat.blocks) + + def _print_BasisDependent(self, expr): + from sympy.vector import Vector + + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of BasisDependent is not implemented") + + if expr == expr.zero: + return prettyForm(expr.zero._pretty_form) + o1 = [] + vectstrs = [] + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key = lambda x: x[0].__str__()) + for k, v in inneritems: + #if the coef of the basis vector is 1 + #we skip the 1 + if v == 1: + o1.append("" + + k._pretty_form) + #Same for -1 + elif v == -1: + o1.append("(-1) " + + k._pretty_form) + #For a general expr + else: + #We always wrap the measure numbers in + #parentheses + arg_str = self._print( + v).parens()[0] + + o1.append(arg_str + ' ' + k._pretty_form) + vectstrs.append(k._pretty_form) + + #outstr = u("").join(o1) + if o1[0].startswith(" + "): + o1[0] = o1[0][3:] + elif o1[0].startswith(" "): + o1[0] = o1[0][1:] + #Fixing the newlines + lengths = [] + strs = [''] + flag = [] + for i, partstr in enumerate(o1): + flag.append(0) + # XXX: What is this hack? + if '\n' in partstr: + tempstr = partstr + tempstr = tempstr.replace(vectstrs[i], '') + if xobj(')_ext', 1) in tempstr: # If scalar is a fraction + for paren in range(len(tempstr)): + flag[i] = 1 + if tempstr[paren] == xobj(')_ext', 1) and tempstr[paren + 1] == '\n': + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string + tempstr = tempstr[:paren] + xobj(')_ext', 1)\ + + ' ' + vectstrs[i] + tempstr[paren + 1:] + break + elif xobj(')_lower_hook', 1) in tempstr: + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string. For this reason, + # we insert the vector string at the rightmost index. + index = tempstr.rfind(xobj(')_lower_hook', 1)) + if index != -1: # then this character was found in this string + flag[i] = 1 + tempstr = tempstr[:index] + xobj(')_lower_hook', 1)\ + + ' ' + vectstrs[i] + tempstr[index + 1:] + o1[i] = tempstr + + o1 = [x.split('\n') for x in o1] + n_newlines = max(len(x) for x in o1) # Width of part in its pretty form + + if 1 in flag: # If there was a fractional scalar + for i, parts in enumerate(o1): + if len(parts) == 1: # If part has no newline + parts.insert(0, ' ' * (len(parts[0]))) + flag[i] = 1 + + for i, parts in enumerate(o1): + lengths.append(len(parts[flag[i]])) + for j in range(n_newlines): + if j+1 <= len(parts): + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + if j == flag[i]: + strs[flag[i]] += parts[flag[i]] + ' + ' + else: + strs[j] += parts[j] + ' '*(lengths[-1] - + len(parts[j])+ + 3) + else: + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + strs[j] += ' '*(lengths[-1]+3) + + return prettyForm('\n'.join([s[:-3] for s in strs])) + + def _print_NDimArray(self, expr): + from sympy.matrices.immutable import ImmutableMatrix + + if expr.rank() == 0: + return self._print(expr[()]) + + level_str = [[]] + [[] for i in range(expr.rank())] + shape_ranges = [list(range(i)) for i in expr.shape] + # leave eventual matrix elements unflattened + mat = lambda x: ImmutableMatrix(x, evaluate=False) + for outer_i in itertools.product(*shape_ranges): + level_str[-1].append(expr[outer_i]) + even = True + for back_outer_i in range(expr.rank()-1, -1, -1): + if len(level_str[back_outer_i+1]) < expr.shape[back_outer_i]: + break + if even: + level_str[back_outer_i].append(level_str[back_outer_i+1]) + else: + level_str[back_outer_i].append(mat( + level_str[back_outer_i+1])) + if len(level_str[back_outer_i + 1]) == 1: + level_str[back_outer_i][-1] = mat( + [[level_str[back_outer_i][-1]]]) + even = not even + level_str[back_outer_i+1] = [] + + out_expr = level_str[0][0] + if expr.rank() % 2 == 1: + out_expr = mat([out_expr]) + + return self._print(out_expr) + + def _printer_tensor_indices(self, name, indices, index_map={}): + center = stringPict(name) + top = stringPict(" "*center.width()) + bot = stringPict(" "*center.width()) + + last_valence = None + prev_map = None + + for index in indices: + indpic = self._print(index.args[0]) + if ((index in index_map) or prev_map) and last_valence == index.is_up: + if index.is_up: + top = prettyForm(*stringPict.next(top, ",")) + else: + bot = prettyForm(*stringPict.next(bot, ",")) + if index in index_map: + indpic = prettyForm(*stringPict.next(indpic, "=")) + indpic = prettyForm(*stringPict.next(indpic, self._print(index_map[index]))) + prev_map = True + else: + prev_map = False + if index.is_up: + top = stringPict(*top.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + bot = stringPict(*bot.right(" "*indpic.width())) + else: + bot = stringPict(*bot.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + top = stringPict(*top.right(" "*indpic.width())) + last_valence = index.is_up + + pict = prettyForm(*center.above(top)) + pict = prettyForm(*pict.below(bot)) + return pict + + def _print_Tensor(self, expr): + name = expr.args[0].name + indices = expr.get_indices() + return self._printer_tensor_indices(name, indices) + + def _print_TensorElement(self, expr): + name = expr.expr.args[0].name + indices = expr.expr.get_indices() + index_map = expr.index_map + return self._printer_tensor_indices(name, indices, index_map) + + def _print_TensMul(self, expr): + sign, args = expr._get_args_for_traditional_printer() + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in args + ] + pform = prettyForm.__mul__(*args) + if sign: + return prettyForm(*pform.left(sign)) + else: + return pform + + def _print_TensAdd(self, expr): + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in expr.args + ] + return prettyForm.__add__(*args) + + def _print_TensorIndex(self, expr): + sym = expr.args[0] + if not expr.is_up: + sym = -sym + return self._print(sym) + + def _print_PartialDerivative(self, deriv): + if self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + + for variable in reversed(deriv.variables): + s = self._print(variable) + ds = prettyForm(*s.left(deriv_symbol)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if len(deriv.variables) > 1: + pform = pform**self._print(len(deriv.variables)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Piecewise(self, pexpr): + + P = {} + for n, ec in enumerate(pexpr.args): + P[n, 0] = self._print(ec.expr) + if ec.cond == True: + P[n, 1] = prettyForm('otherwise') + else: + P[n, 1] = prettyForm( + *prettyForm('for ').right(self._print(ec.cond))) + hsep = 2 + vsep = 1 + len_args = len(pexpr.args) + + # max widths + maxw = [max(P[i, j].width() for i in range(len_args)) + for j in range(2)] + + # FIXME: Refactor this code and matrix into some tabular environment. + # drawing result + D = None + + for i in range(len_args): + D_row = None + for j in range(2): + p = P[i, j] + assert p.width() <= maxw[j] + + wdelta = maxw[j] - p.width() + wleft = wdelta // 2 + wright = wdelta - wleft + + p = prettyForm(*p.right(' '*wright)) + p = prettyForm(*p.left(' '*wleft)) + + if D_row is None: + D_row = p + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(p)) + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + D = prettyForm(*D.parens('{', '')) + D.baseline = D.height()//2 + D.binding = prettyForm.OPEN + return D + + def _print_ITE(self, ite): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(ite.rewrite(Piecewise)) + + def _hprint_vec(self, v): + D = None + + for a in v: + p = a + if D is None: + D = p + else: + D = prettyForm(*D.right(', ')) + D = prettyForm(*D.right(p)) + if D is None: + D = stringPict(' ') + + return D + + def _hprint_vseparator(self, p1, p2, left=None, right=None, delimiter='', ifascii_nougly=False): + if ifascii_nougly and not self._use_unicode: + return self._print_seq((p1, '|', p2), left=left, right=right, + delimiter=delimiter, ifascii_nougly=True) + tmp = self._print_seq((p1, p2,), left=left, right=right, delimiter=delimiter) + sep = stringPict(vobj('|', tmp.height()), baseline=tmp.baseline) + return self._print_seq((p1, sep, p2), left=left, right=right, + delimiter=delimiter) + + def _print_hyper(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + ap = [self._print(a) for a in e.ap] + bq = [self._print(b) for b in e.bq] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + # Drawing result - first create the ap, bq vectors + D = None + for v in [ap, bq]: + D_row = self._hprint_vec(v) + if D is None: + D = D_row # first row in a picture + else: + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the F symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('F') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + add = (sz + 1)//2 + + F = prettyForm(*F.left(self._print(len(e.ap)))) + F = prettyForm(*F.right(self._print(len(e.bq)))) + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_meijerg(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + + v = {} + v[(0, 0)] = [self._print(a) for a in e.an] + v[(0, 1)] = [self._print(a) for a in e.aother] + v[(1, 0)] = [self._print(b) for b in e.bm] + v[(1, 1)] = [self._print(b) for b in e.bother] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + vp = {} + for idx in v: + vp[idx] = self._hprint_vec(v[idx]) + + for i in range(2): + maxw = max(vp[(0, i)].width(), vp[(1, i)].width()) + for j in range(2): + s = vp[(j, i)] + left = (maxw - s.width()) // 2 + right = maxw - left - s.width() + s = prettyForm(*s.left(' ' * left)) + s = prettyForm(*s.right(' ' * right)) + vp[(j, i)] = s + + D1 = prettyForm(*vp[(0, 0)].right(' ', vp[(0, 1)])) + D1 = prettyForm(*D1.below(' ')) + D2 = prettyForm(*vp[(1, 0)].right(' ', vp[(1, 1)])) + D = prettyForm(*D1.below(D2)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the G symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('G') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + + pp = self._print(len(e.ap)) + pq = self._print(len(e.bq)) + pm = self._print(len(e.bm)) + pn = self._print(len(e.an)) + + def adjust(p1, p2): + diff = p1.width() - p2.width() + if diff == 0: + return p1, p2 + elif diff > 0: + return p1, prettyForm(*p2.left(' '*diff)) + else: + return prettyForm(*p1.left(' '*-diff)), p2 + pp, pm = adjust(pp, pm) + pq, pn = adjust(pq, pn) + pu = prettyForm(*pm.right(', ', pn)) + pl = prettyForm(*pp.right(', ', pq)) + + ht = F.baseline - above - 2 + if ht > 0: + pu = prettyForm(*pu.below('\n'*ht)) + p = prettyForm(*pu.below(pl)) + + F.baseline = above + F = prettyForm(*F.right(p)) + + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_ExpBase(self, e): + # TODO should exp_polar be printed differently? + # what about exp_polar(0), exp_polar(1)? + base = prettyForm(pretty_atom('Exp1', 'e')) + return base ** self._print(e.args[0]) + + def _print_Exp1(self, e): + return prettyForm(pretty_atom('Exp1', 'e')) + + def _print_Function(self, e, sort=False, func_name=None, left='(', + right=')'): + # optional argument func_name for supplying custom names + # XXX works only for applied functions + return self._helper_print_function(e.func, e.args, sort=sort, func_name=func_name, left=left, right=right) + + def _print_mathieuc(self, e): + return self._print_Function(e, func_name='C') + + def _print_mathieus(self, e): + return self._print_Function(e, func_name='S') + + def _print_mathieucprime(self, e): + return self._print_Function(e, func_name="C'") + + def _print_mathieusprime(self, e): + return self._print_Function(e, func_name="S'") + + def _helper_print_function(self, func, args, sort=False, func_name=None, + delimiter=', ', elementwise=False, left='(', + right=')'): + if sort: + args = sorted(args, key=default_sort_key) + + if not func_name and hasattr(func, "__name__"): + func_name = func.__name__ + + if func_name: + prettyFunc = self._print(Symbol(func_name)) + else: + prettyFunc = prettyForm(*self._print(func).parens()) + + if elementwise: + if self._use_unicode: + circ = pretty_atom('Modifier Letter Low Ring') + else: + circ = '.' + circ = self._print(circ) + prettyFunc = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(prettyFunc, circ) + ) + + prettyArgs = prettyForm(*self._print_seq(args, delimiter=delimiter).parens( + left=left, right=right)) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_ElementwiseApplyFunction(self, e): + func = e.function + arg = e.expr + args = [arg] + return self._helper_print_function(func, args, delimiter="", elementwise=True) + + @property + def _special_function_classes(self): + from sympy.functions.special.tensor_functions import KroneckerDelta + from sympy.functions.special.gamma_functions import gamma, lowergamma + from sympy.functions.special.zeta_functions import lerchphi + from sympy.functions.special.beta_functions import beta + from sympy.functions.special.delta_functions import DiracDelta + from sympy.functions.special.error_functions import Chi + return {KroneckerDelta: [greek_unicode['delta'], 'delta'], + gamma: [greek_unicode['Gamma'], 'Gamma'], + lerchphi: [greek_unicode['Phi'], 'lerchphi'], + lowergamma: [greek_unicode['gamma'], 'gamma'], + beta: [greek_unicode['Beta'], 'B'], + DiracDelta: [greek_unicode['delta'], 'delta'], + Chi: ['Chi', 'Chi']} + + def _print_FunctionClass(self, expr): + for cls in self._special_function_classes: + if issubclass(expr, cls) and expr.__name__ == cls.__name__: + if self._use_unicode: + return prettyForm(self._special_function_classes[cls][0]) + else: + return prettyForm(self._special_function_classes[cls][1]) + func_name = expr.__name__ + return prettyForm(pretty_symbol(func_name)) + + def _print_GeometryEntity(self, expr): + # GeometryEntity is based on Tuple but should not print like a Tuple + return self.emptyPrinter(expr) + + def _print_polylog(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('Li_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_lerchphi(self, e): + func_name = greek_unicode['Phi'] if self._use_unicode else 'lerchphi' + return self._print_Function(e, func_name=func_name) + + def _print_dirichlet_eta(self, e): + func_name = greek_unicode['eta'] if self._use_unicode else 'dirichlet_eta' + return self._print_Function(e, func_name=func_name) + + def _print_Heaviside(self, e): + func_name = greek_unicode['theta'] if self._use_unicode else 'Heaviside' + if e.args[1] is S.Half: + pform = prettyForm(*self._print(e.args[0]).parens()) + pform = prettyForm(*pform.left(func_name)) + return pform + else: + return self._print_Function(e, func_name=func_name) + + def _print_fresnels(self, e): + return self._print_Function(e, func_name="S") + + def _print_fresnelc(self, e): + return self._print_Function(e, func_name="C") + + def _print_airyai(self, e): + return self._print_Function(e, func_name="Ai") + + def _print_airybi(self, e): + return self._print_Function(e, func_name="Bi") + + def _print_airyaiprime(self, e): + return self._print_Function(e, func_name="Ai'") + + def _print_airybiprime(self, e): + return self._print_Function(e, func_name="Bi'") + + def _print_LambertW(self, e): + return self._print_Function(e, func_name="W") + + def _print_Covariance(self, e): + return self._print_Function(e, func_name="Cov") + + def _print_Variance(self, e): + return self._print_Function(e, func_name="Var") + + def _print_Probability(self, e): + return self._print_Function(e, func_name="P") + + def _print_Expectation(self, e): + return self._print_Function(e, func_name="E", left='[', right=']') + + def _print_Lambda(self, e): + expr = e.expr + sig = e.signature + if self._use_unicode: + arrow = f" {pretty_atom('ArrowFromBar')} " + else: + arrow = " -> " + if len(sig) == 1 and sig[0].is_symbol: + sig = sig[0] + var_form = self._print(sig) + + return prettyForm(*stringPict.next(var_form, arrow, self._print(expr)), binding=8) + + def _print_Order(self, expr): + pform = self._print(expr.expr) + if (expr.point and any(p != S.Zero for p in expr.point)) or \ + len(expr.variables) > 1: + pform = prettyForm(*pform.right("; ")) + if len(expr.variables) > 1: + pform = prettyForm(*pform.right(self._print(expr.variables))) + elif len(expr.variables): + pform = prettyForm(*pform.right(self._print(expr.variables[0]))) + if self._use_unicode: + pform = prettyForm(*pform.right(f" {pretty_atom('Arrow')} ")) + else: + pform = prettyForm(*pform.right(" -> ")) + if len(expr.point) > 1: + pform = prettyForm(*pform.right(self._print(expr.point))) + else: + pform = prettyForm(*pform.right(self._print(expr.point[0]))) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left("O")) + return pform + + def _print_SingularityFunction(self, e): + if self._use_unicode: + shift = self._print(e.args[0]-e.args[1]) + n = self._print(e.args[2]) + base = prettyForm("<") + base = prettyForm(*base.right(shift)) + base = prettyForm(*base.right(">")) + pform = base**n + return pform + else: + n = self._print(e.args[2]) + shift = self._print(e.args[0]-e.args[1]) + base = self._print_seq(shift, "<", ">", ' ') + return base**n + + def _print_beta(self, e): + func_name = greek_unicode['Beta'] if self._use_unicode else 'B' + return self._print_Function(e, func_name=func_name) + + def _print_betainc(self, e): + func_name = "B'" + return self._print_Function(e, func_name=func_name) + + def _print_betainc_regularized(self, e): + func_name = 'I' + return self._print_Function(e, func_name=func_name) + + def _print_gamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_uppergamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_lowergamma(self, e): + func_name = greek_unicode['gamma'] if self._use_unicode else 'lowergamma' + return self._print_Function(e, func_name=func_name) + + def _print_DiracDelta(self, e): + if self._use_unicode: + if len(e.args) == 2: + a = prettyForm(greek_unicode['delta']) + b = self._print(e.args[1]) + b = prettyForm(*b.parens()) + c = self._print(e.args[0]) + c = prettyForm(*c.parens()) + pform = a**b + pform = prettyForm(*pform.right(' ')) + pform = prettyForm(*pform.right(c)) + return pform + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(greek_unicode['delta'])) + return pform + else: + return self._print_Function(e) + + def _print_expint(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('E_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_Chi(self, e): + # This needs a special case since otherwise it comes out as greek + # letter chi... + prettyFunc = prettyForm("Chi") + prettyArgs = prettyForm(*self._print_seq(e.args).parens()) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_elliptic_e(self, e): + pforma0 = self._print(e.args[0]) + if len(e.args) == 1: + pform = pforma0 + else: + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('E')) + return pform + + def _print_elliptic_k(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('K')) + return pform + + def _print_elliptic_f(self, e): + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('F')) + return pform + + def _print_elliptic_pi(self, e): + name = greek_unicode['Pi'] if self._use_unicode else 'Pi' + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + if len(e.args) == 2: + pform = self._hprint_vseparator(pforma0, pforma1) + else: + pforma2 = self._print(e.args[2]) + pforma = self._hprint_vseparator(pforma1, pforma2, ifascii_nougly=False) + pforma = prettyForm(*pforma.left('; ')) + pform = prettyForm(*pforma.left(pforma0)) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(name)) + return pform + + def _print_GoldenRatio(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('phi')) + return self._print(Symbol("GoldenRatio")) + + def _print_EulerGamma(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('gamma')) + return self._print(Symbol("EulerGamma")) + + def _print_Catalan(self, expr): + return self._print(Symbol("G")) + + def _print_Mod(self, expr): + pform = self._print(expr.args[0]) + if pform.binding > prettyForm.MUL: + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right(' mod ')) + pform = prettyForm(*pform.right(self._print(expr.args[1]))) + pform.binding = prettyForm.OPEN + return pform + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + pforms, indices = [], [] + + def pretty_negative(pform, index): + """Prepend a minus sign to a pretty form. """ + #TODO: Move this code to prettyForm + if index == 0: + if pform.height() > 1: + pform_neg = '- ' + else: + pform_neg = '-' + else: + pform_neg = ' - ' + + if (pform.binding > prettyForm.NEG + or pform.binding == prettyForm.ADD): + p = stringPict(*pform.parens()) + else: + p = pform + p = stringPict.next(pform_neg, p) + # Lower the binding to NEG, even if it was higher. Otherwise, it + # will print as a + ( - (b)), instead of a - (b). + return prettyForm(binding=prettyForm.NEG, *p) + + for i, term in enumerate(terms): + if term.is_Mul and term.could_extract_minus_sign(): + coeff, other = term.as_coeff_mul(rational=False) + if coeff == -1: + negterm = Mul(*other, evaluate=False) + else: + negterm = Mul(-coeff, *other, evaluate=False) + pform = self._print(negterm) + pforms.append(pretty_negative(pform, i)) + elif term.is_Rational and term.q > 1: + pforms.append(None) + indices.append(i) + elif term.is_Number and term < 0: + pform = self._print(-term) + pforms.append(pretty_negative(pform, i)) + elif term.is_Relational: + pforms.append(prettyForm(*self._print(term).parens())) + else: + pforms.append(self._print(term)) + + if indices: + large = True + + for pform in pforms: + if pform is not None and pform.height() > 1: + break + else: + large = False + + for i in indices: + term, negative = terms[i], False + + if term < 0: + term, negative = -term, True + + if large: + pform = prettyForm(str(term.p))/prettyForm(str(term.q)) + else: + pform = self._print(term) + + if negative: + pform = pretty_negative(pform, i) + + pforms[i] = pform + + return prettyForm.__add__(*pforms) + + def _print_Mul(self, product): + from sympy.physics.units import Quantity + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + args = product.args + if args[0] is S.One or any(isinstance(arg, Number) for arg in args[1:]): + strargs = list(map(self._print, args)) + # XXX: This is a hack to work around the fact that + # prettyForm.__mul__ absorbs a leading -1 in the args. Probably it + # would be better to fix this in prettyForm.__mul__ instead. + negone = strargs[0] == '-1' + if negone: + strargs[0] = prettyForm('1', 0, 0) + obj = prettyForm.__mul__(*strargs) + if negone: + obj = prettyForm('-' + obj.s, obj.baseline, obj.binding) + return obj + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + if self.order not in ('old', 'none'): + args = product.as_ordered_factors() + else: + args = list(product.args) + + # If quantities are present append them at the back + args = sorted(args, key=lambda x: isinstance(x, Quantity) or + (isinstance(x, Pow) and isinstance(x.base, Quantity))) + + # Gather terms for numerator/denominator + for item in args: + if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append( Rational(item.p) ) + if item.q != 1: + b.append( Rational(item.q) ) + else: + a.append(item) + + # Convert to pretty forms. Parentheses are added by `__mul__`. + a = [self._print(ai) for ai in a] + b = [self._print(bi) for bi in b] + + # Construct a pretty form + if len(b) == 0: + return prettyForm.__mul__(*a) + else: + if len(a) == 0: + a.append( self._print(S.One) ) + return prettyForm.__mul__(*a)/prettyForm.__mul__(*b) + + # A helper function for _print_Pow to print x**(1/n) + def _print_nth_root(self, base, root): + bpretty = self._print(base) + + # In very simple cases, use a single-char root sign + if (self._settings['use_unicode_sqrt_char'] and self._use_unicode + and root == 2 and bpretty.height() == 1 + and (bpretty.width() == 1 + or (base.is_Integer and base.is_nonnegative))): + return prettyForm(*bpretty.left(nth_root[2])) + + # Construct root sign, start with the \/ shape + _zZ = xobj('/', 1) + rootsign = xobj('\\', 1) + _zZ + # Constructing the number to put on root + rpretty = self._print(root) + # roots look bad if they are not a single line + if rpretty.height() != 1: + return self._print(base)**self._print(1/root) + # If power is half, no number should appear on top of root sign + exp = '' if root == 2 else str(rpretty).ljust(2) + if len(exp) > 2: + rootsign = ' '*(len(exp) - 2) + rootsign + # Stack the exponent + rootsign = stringPict(exp + '\n' + rootsign) + rootsign.baseline = 0 + # Diagonal: length is one less than height of base + linelength = bpretty.height() - 1 + diagonal = stringPict('\n'.join( + ' '*(linelength - i - 1) + _zZ + ' '*i + for i in range(linelength) + )) + # Put baseline just below lowest line: next to exp + diagonal.baseline = linelength - 1 + # Make the root symbol + rootsign = prettyForm(*rootsign.right(diagonal)) + # Det the baseline to match contents to fix the height + # but if the height of bpretty is one, the rootsign must be one higher + rootsign.baseline = max(1, bpretty.baseline) + #build result + s = prettyForm(hobj('_', 2 + bpretty.width())) + s = prettyForm(*bpretty.above(s)) + s = prettyForm(*s.left(rootsign)) + return s + + def _print_Pow(self, power): + from sympy.simplify.simplify import fraction + b, e = power.as_base_exp() + if power.is_commutative: + if e is S.NegativeOne: + return prettyForm("1")/self._print(b) + n, d = fraction(e) + if n is S.One and d.is_Atom and not e.is_Integer and (e.is_Rational or d.is_Symbol) \ + and self._settings['root_notation']: + return self._print_nth_root(b, d) + if e.is_Rational and e < 0: + return prettyForm("1")/self._print(Pow(b, -e, evaluate=False)) + + if b.is_Relational: + return prettyForm(*self._print(b).parens()).__pow__(self._print(e)) + + return self._print(b)**self._print(e) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def __print_numer_denom(self, p, q): + if q == 1: + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG) + else: + return prettyForm(str(p)) + elif abs(p) >= 10 and abs(q) >= 10: + # If more than one digit in numer and denom, print larger fraction + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG)/prettyForm(str(q)) + # Old printing method: + #pform = prettyForm(str(-p))/prettyForm(str(q)) + #return prettyForm(binding=prettyForm.NEG, *pform.left('- ')) + else: + return prettyForm(str(p))/prettyForm(str(q)) + else: + return None + + def _print_Rational(self, expr): + result = self.__print_numer_denom(expr.p, expr.q) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_Fraction(self, expr): + result = self.__print_numer_denom(expr.numerator, expr.denominator) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_ProductSet(self, p): + if len(p.sets) >= 1 and not has_variety(p.sets): + return self._print(p.sets[0]) ** self._print(len(p.sets)) + else: + prod_char = pretty_atom('Multiplication') if self._use_unicode else 'x' + return self._print_seq(p.sets, None, None, ' %s ' % prod_char, + parenthesize=lambda set: set.is_Union or + set.is_Intersection or set.is_ProductSet) + + def _print_FiniteSet(self, s): + items = sorted(s.args, key=default_sort_key) + return self._print_seq(items, '{', '}', ', ' ) + + def _print_Range(self, s): + + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif len(s) > 4: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + printset = tuple(s) + + return self._print_seq(printset, '{', '}', ', ' ) + + def _print_Interval(self, i): + if i.start == i.end: + return self._print_seq(i.args[:1], '{', '}') + + else: + if i.left_open: + left = '(' + else: + left = '[' + + if i.right_open: + right = ')' + else: + right = ']' + + return self._print_seq(i.args[:2], left, right) + + def _print_AccumulationBounds(self, i): + left = '<' + right = '>' + + return self._print_seq(i.args[:2], left, right) + + def _print_Intersection(self, u): + + delimiter = ' %s ' % pretty_atom('Intersection', 'n') + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Union or set.is_Complement) + + def _print_Union(self, u): + + union_delimiter = ' %s ' % pretty_atom('Union', 'U') + + return self._print_seq(u.args, None, None, union_delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Intersection or set.is_Complement) + + def _print_SymmetricDifference(self, u): + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of SymmetricDifference is not implemented") + + sym_delimeter = ' %s ' % pretty_atom('SymmetricDifference') + + return self._print_seq(u.args, None, None, sym_delimeter) + + def _print_Complement(self, u): + + delimiter = r' \ ' + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or set.is_Intersection + or set.is_Union) + + def _print_ImageSet(self, ts): + if self._use_unicode: + inn = pretty_atom("SmallElementOf") + else: + inn = 'in' + fun = ts.lamda + sets = ts.base_sets + signature = fun.signature + expr = self._print(fun.expr) + + # TODO: the stuff to the left of the | and the stuff to the right of + # the | should have independent baselines, that way something like + # ImageSet(Lambda(x, 1/x**2), S.Naturals) prints the "x in N" part + # centered on the right instead of aligned with the fraction bar on + # the left. The same also applies to ConditionSet and ComplexRegion + if len(signature) == 1: + S = self._print_seq((signature[0], inn, sets[0]), + delimiter=' ') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + else: + pargs = tuple(j for var, setv in zip(signature, sets) for j in + (var, ' ', inn, ' ', setv, ", ")) + S = self._print_seq(pargs[:-1], delimiter='') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + + def _print_ConditionSet(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + # using _and because and is a keyword and it is bad practice to + # overwrite them + _and = pretty_atom('And') + else: + inn = 'in' + _and = 'and' + + variables = self._print_seq(Tuple(ts.sym)) + as_expr = getattr(ts.condition, 'as_expr', None) + if as_expr is not None: + cond = self._print(ts.condition.as_expr()) + else: + cond = self._print(ts.condition) + if self._use_unicode: + cond = self._print(cond) + cond = prettyForm(*cond.parens()) + + if ts.base_set is S.UniversalSet: + return self._hprint_vseparator(variables, cond, left="{", + right="}", ifascii_nougly=True, + delimiter=' ') + + base = self._print(ts.base_set) + C = self._print_seq((variables, inn, base, _and, cond), + delimiter=' ') + return self._hprint_vseparator(variables, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_ComplexRegion(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + else: + inn = 'in' + variables = self._print_seq(ts.variables) + expr = self._print(ts.expr) + prodsets = self._print(ts.sets) + + C = self._print_seq((variables, inn, prodsets), + delimiter=' ') + return self._hprint_vseparator(expr, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_Contains(self, e): + var, set = e.args + if self._use_unicode: + el = f" {pretty_atom('ElementOf')} " + return prettyForm(*stringPict.next(self._print(var), + el, self._print(set)), binding=8) + else: + return prettyForm(sstr(e)) + + def _print_FourierSeries(self, s): + if s.an.formula is S.Zero and s.bn.formula is S.Zero: + return self._print(s.a0) + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + return self._print_Add(s.truncate()) + self._print(dots) + + def _print_FormalPowerSeries(self, s): + return self._print_Add(s.infinite) + + def _print_SetExpr(self, se): + pretty_set = prettyForm(*self._print(se.set).parens()) + pretty_name = self._print(Symbol("SetExpr")) + return prettyForm(*pretty_name.right(pretty_set)) + + def _print_SeqFormula(self, s): + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if len(s.start.free_symbols) > 0 or len(s.stop.free_symbols) > 0: + raise NotImplementedError("Pretty printing of sequences with symbolic bound not implemented") + + if s.start is S.NegativeInfinity: + stop = s.stop + printset = (dots, s.coeff(stop - 3), s.coeff(stop - 2), + s.coeff(stop - 1), s.coeff(stop)) + elif s.stop is S.Infinity or s.length > 4: + printset = s[:4] + printset.append(dots) + printset = tuple(printset) + else: + printset = tuple(s) + return self._print_list(printset) + + _print_SeqPer = _print_SeqFormula + _print_SeqAdd = _print_SeqFormula + _print_SeqMul = _print_SeqFormula + + def _print_seq(self, seq, left=None, right=None, delimiter=', ', + parenthesize=lambda x: False, ifascii_nougly=True): + + pforms = [] + for item in seq: + pform = self._print(item) + if parenthesize(item): + pform = prettyForm(*pform.parens()) + if pforms: + pforms.append(delimiter) + pforms.append(pform) + + if not pforms: + s = stringPict('') + else: + s = prettyForm(*stringPict.next(*pforms)) + + s = prettyForm(*s.parens(left, right, ifascii_nougly=ifascii_nougly)) + return s + + def join(self, delimiter, args): + pform = None + + for arg in args: + if pform is None: + pform = arg + else: + pform = prettyForm(*pform.right(delimiter)) + pform = prettyForm(*pform.right(arg)) + + if pform is None: + return prettyForm("") + else: + return pform + + def _print_list(self, l): + return self._print_seq(l, '[', ']') + + def _print_tuple(self, t): + if len(t) == 1: + ptuple = prettyForm(*stringPict.next(self._print(t[0]), ',')) + return prettyForm(*ptuple.parens('(', ')', ifascii_nougly=True)) + else: + return self._print_seq(t, '(', ')') + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for k in keys: + K = self._print(k) + V = self._print(d[k]) + s = prettyForm(*stringPict.next(K, ': ', V)) + + items.append(s) + + return self._print_seq(items, '{', '}') + + def _print_Dict(self, d): + return self._print_dict(d) + + def _print_set(self, s): + if not s: + return prettyForm('set()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + return pretty + + def _print_frozenset(self, s): + if not s: + return prettyForm('frozenset()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + pretty = prettyForm(*pretty.parens('(', ')', ifascii_nougly=True)) + pretty = prettyForm(*stringPict.next(type(s).__name__, pretty)) + return pretty + + def _print_UniversalSet(self, s): + if self._use_unicode: + return prettyForm(pretty_atom('Universe')) + else: + return prettyForm('UniversalSet') + + def _print_PolyRing(self, ring): + return prettyForm(sstr(ring)) + + def _print_FracField(self, field): + return prettyForm(sstr(field)) + + def _print_FreeGroupElement(self, elm): + return prettyForm(str(elm)) + + def _print_PolyElement(self, poly): + return prettyForm(sstr(poly)) + + def _print_FracElement(self, frac): + return prettyForm(sstr(frac)) + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_ComplexRootOf(self, expr): + args = [self._print_Add(expr.expr, order='lex'), expr.index] + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('CRootOf')) + return pform + + def _print_RootSum(self, expr): + args = [self._print_Add(expr.expr, order='lex')] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('RootSum')) + + return pform + + def _print_FiniteField(self, expr): + if self._use_unicode: + form = f"{pretty_atom('Integers')}_%d" + else: + form = 'GF(%d)' + + return prettyForm(pretty_symbol(form % expr.mod)) + + def _print_IntegerRing(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Integers')) + else: + return prettyForm('ZZ') + + def _print_RationalField(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Rationals')) + else: + return prettyForm('QQ') + + def _print_RealField(self, domain): + if self._use_unicode: + prefix = pretty_atom("Reals") + else: + prefix = 'RR' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_ComplexField(self, domain): + if self._use_unicode: + prefix = pretty_atom('Complexes') + else: + prefix = 'CC' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_PolynomialRing(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_FractionField(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '(', ')') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_PolynomialRingBase(self, expr): + g = expr.symbols + if str(expr.order) != str(expr.default_order): + g = g + ("order=" + str(expr.order),) + pform = self._print_seq(g, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_GroebnerBasis(self, basis): + exprs = [ self._print_Add(arg, order=basis.order) + for arg in basis.exprs ] + exprs = prettyForm(*self.join(", ", exprs).parens(left="[", right="]")) + + gens = [ self._print(gen) for gen in basis.gens ] + + domain = prettyForm( + *prettyForm("domain=").right(self._print(basis.domain))) + order = prettyForm( + *prettyForm("order=").right(self._print(basis.order))) + + pform = self.join(", ", [exprs] + gens + [domain, order]) + + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(basis.__class__.__name__)) + + return pform + + def _print_Subs(self, e): + pform = self._print(e.expr) + pform = prettyForm(*pform.parens()) + + h = pform.height() if pform.height() > 1 else 2 + rvert = stringPict(vobj('|', h), baseline=pform.baseline) + pform = prettyForm(*pform.right(rvert)) + + b = pform.baseline + pform.baseline = pform.height() - 1 + pform = prettyForm(*pform.right(self._print_seq([ + self._print_seq((self._print(v[0]), xsym('=='), self._print(v[1])), + delimiter='') for v in zip(e.variables, e.point) ]))) + + pform.baseline = b + return pform + + def _print_number_function(self, e, name): + # Print name_arg[0] for one argument or name_arg[0](arg[1]) + # for more than one argument + pform = prettyForm(name) + arg = self._print(e.args[0]) + pform_arg = prettyForm(" "*arg.width()) + pform_arg = prettyForm(*pform_arg.below(arg)) + pform = prettyForm(*pform.right(pform_arg)) + if len(e.args) == 1: + return pform + m, x = e.args + # TODO: copy-pasted from _print_Function: can we do better? + prettyFunc = pform + prettyArgs = prettyForm(*self._print_seq([x]).parens()) + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + return pform + + def _print_euler(self, e): + return self._print_number_function(e, "E") + + def _print_catalan(self, e): + return self._print_number_function(e, "C") + + def _print_bernoulli(self, e): + return self._print_number_function(e, "B") + + _print_bell = _print_bernoulli + + def _print_lucas(self, e): + return self._print_number_function(e, "L") + + def _print_fibonacci(self, e): + return self._print_number_function(e, "F") + + def _print_tribonacci(self, e): + return self._print_number_function(e, "T") + + def _print_stieltjes(self, e): + if self._use_unicode: + return self._print_number_function(e, greek_unicode['gamma']) + else: + return self._print_number_function(e, "stieltjes") + + def _print_KroneckerDelta(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(self._print(e.args[1]))) + if self._use_unicode: + a = stringPict(pretty_symbol('delta')) + else: + a = stringPict('d') + b = pform + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.as_boolean()))) + return pform + elif hasattr(d, 'set'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + pform = prettyForm(*pform.right(self._print(' in '))) + pform = prettyForm(*pform.right(self._print(d.set))) + return pform + elif hasattr(d, 'symbols'): + pform = self._print('Domain on ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + return pform + else: + return self._print(None) + + def _print_DMP(self, p): + try: + if p.ring is not None: + # TODO incorporate order + return self._print(p.ring.to_sympy(p)) + except SympifyError: + pass + return self._print(repr(p)) + + def _print_DMF(self, p): + return self._print_DMP(p) + + def _print_Object(self, object): + return self._print(pretty_symbol(object.name)) + + def _print_Morphism(self, morphism): + arrow = xsym("-->") + + domain = self._print(morphism.domain) + codomain = self._print(morphism.codomain) + tail = domain.right(arrow, codomain)[0] + + return prettyForm(tail) + + def _print_NamedMorphism(self, morphism): + pretty_name = self._print(pretty_symbol(morphism.name)) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(":", pretty_morphism)[0]) + + def _print_IdentityMorphism(self, morphism): + from sympy.categories import NamedMorphism + return self._print_NamedMorphism( + NamedMorphism(morphism.domain, morphism.codomain, "id")) + + def _print_CompositeMorphism(self, morphism): + + circle = xsym(".") + + # All components of the morphism have names and it is thus + # possible to build the name of the composite. + component_names_list = [pretty_symbol(component.name) for + component in morphism.components] + component_names_list.reverse() + component_names = circle.join(component_names_list) + ":" + + pretty_name = self._print(component_names) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(pretty_morphism)[0]) + + def _print_Category(self, category): + return self._print(pretty_symbol(category.name)) + + def _print_Diagram(self, diagram): + if not diagram.premises: + # This is an empty diagram. + return self._print(S.EmptySet) + + pretty_result = self._print(diagram.premises) + if diagram.conclusions: + results_arrow = " %s " % xsym("==>") + + pretty_conclusions = self._print(diagram.conclusions)[0] + pretty_result = pretty_result.right( + results_arrow, pretty_conclusions) + + return prettyForm(pretty_result[0]) + + def _print_DiagramGrid(self, grid): + from sympy.matrices import Matrix + matrix = Matrix([[grid[i, j] if grid[i, j] else Symbol(" ") + for j in range(grid.width)] + for i in range(grid.height)]) + return self._print_matrix_contents(matrix) + + def _print_FreeModuleElement(self, m): + # Print as row vector for convenience, for now. + return self._print_seq(m, '[', ']') + + def _print_SubModule(self, M): + gens = [[M.ring.to_sympy(g) for g in gen] for gen in M.gens] + return self._print_seq(gens, '<', '>') + + def _print_FreeModule(self, M): + return self._print(M.ring)**self._print(M.rank) + + def _print_ModuleImplementedIdeal(self, M): + sym = M.ring.to_sympy + return self._print_seq([sym(x) for [x] in M._module.gens], '<', '>') + + def _print_QuotientRing(self, R): + return self._print(R.ring) / self._print(R.base_ideal) + + def _print_QuotientRingElement(self, R): + return self._print(R.ring.to_sympy(R)) + self._print(R.ring.base_ideal) + + def _print_QuotientModuleElement(self, m): + return self._print(m.data) + self._print(m.module.killed_module) + + def _print_QuotientModule(self, M): + return self._print(M.base) / self._print(M.killed_module) + + def _print_MatrixHomomorphism(self, h): + matrix = self._print(h._sympy_matrix()) + matrix.baseline = matrix.height() // 2 + pform = prettyForm(*matrix.right(' : ', self._print(h.domain), + ' %s> ' % hobj('-', 2), self._print(h.codomain))) + return pform + + def _print_Manifold(self, manifold): + return self._print(manifold.name) + + def _print_Patch(self, patch): + return self._print(patch.name) + + def _print_CoordSystem(self, coords): + return self._print(coords.name) + + def _print_BaseScalarField(self, field): + string = field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(string)) + + def _print_BaseVectorField(self, field): + s = U('PARTIAL DIFFERENTIAL') + '_' + field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(s)) + + def _print_Differential(self, diff): + if self._use_unicode: + d = pretty_atom('Differential') + else: + d = 'd' + field = diff._form_field + if hasattr(field, '_coord_sys'): + string = field._coord_sys.symbols[field._index].name + return self._print(d + ' ' + pretty_symbol(string)) + else: + pform = self._print(field) + pform = prettyForm(*pform.parens()) + return prettyForm(*pform.left(d)) + + def _print_Tr(self, p): + #TODO: Handle indices + pform = self._print(p.args[0]) + pform = prettyForm(*pform.left('%s(' % (p.__class__.__name__))) + pform = prettyForm(*pform.right(')')) + return pform + + def _print_primenu(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['nu'])) + else: + pform = prettyForm(*pform.left('nu')) + return pform + + def _print_primeomega(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['Omega'])) + else: + pform = prettyForm(*pform.left('Omega')) + return pform + + def _print_Quantity(self, e): + if e.name.name == 'degree': + if self._use_unicode: + pform = self._print(pretty_atom('Degree')) + else: + pform = self._print(chr(176)) + return pform + else: + return self.emptyPrinter(e) + + def _print_AssignmentBase(self, e): + + op = prettyForm(' ' + xsym(e.op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + def _print_Str(self, s): + return self._print(s.name) + + +@print_function(PrettyPrinter) +def pretty(expr, **settings): + """Returns a string containing the prettified form of expr. + + For information on keyword arguments see pretty_print function. + + """ + pp = PrettyPrinter(settings) + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def pretty_print(expr, **kwargs): + """Prints expr in pretty form. + + pprint is just a shortcut for this function. + + Parameters + ========== + + expr : expression + The expression to print. + + wrap_line : bool, optional (default=True) + Line wrapping enabled/disabled. + + num_columns : int or None, optional (default=None) + Number of columns before line breaking (default to None which reads + the terminal width), useful when using SymPy without terminal. + + use_unicode : bool or None, optional (default=None) + Use unicode characters, such as the Greek letter pi instead of + the string pi. + + full_prec : bool or string, optional (default="auto") + Use full precision. + + order : bool or string, optional (default=None) + Set to 'none' for long expressions if slow; default is None. + + use_unicode_sqrt_char : bool, optional (default=True) + Use compact single-character square root symbol (when unambiguous). + + root_notation : bool, optional (default=True) + Set to 'False' for printing exponents of the form 1/n in fractional form. + By default exponent is printed in root form. + + mat_symbol_style : string, optional (default="plain") + Set to "bold" for printing MatrixSymbols using a bold mathematical symbol face. + By default the standard face is used. + + imaginary_unit : string, optional (default="i") + Letter to use for imaginary unit when use_unicode is True. + Can be "i" (default) or "j". + """ + print(pretty(expr, **kwargs)) + +pprint = pretty_print + + +def pager_print(expr, **settings): + """Prints expr using the pager, in pretty form. + + This invokes a pager command using pydoc. Lines are not wrapped + automatically. This routine is meant to be used with a pager that allows + sideways scrolling, like ``less -S``. + + Parameters are the same as for ``pretty_print``. If you wish to wrap lines, + pass ``num_columns=None`` to auto-detect the width of the terminal. + + """ + from pydoc import pager + from locale import getpreferredencoding + if 'num_columns' not in settings: + settings['num_columns'] = 500000 # disable line wrap + pager(pretty(expr, **settings).encode(getpreferredencoding())) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty_symbology.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty_symbology.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb6ec556c6ed7b15dfcddcfc3da189102d5395b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/pretty_symbology.py @@ -0,0 +1,731 @@ +"""Symbolic primitives + unicode/ASCII abstraction for pretty.py""" + +import sys +import warnings +from string import ascii_lowercase, ascii_uppercase +import unicodedata + +unicode_warnings = '' + +def U(name): + """ + Get a unicode character by name or, None if not found. + + This exists because older versions of Python use older unicode databases. + """ + try: + return unicodedata.lookup(name) + except KeyError: + global unicode_warnings + unicode_warnings += 'No \'%s\' in unicodedata\n' % name + return None + +from sympy.printing.conventions import split_super_sub +from sympy.core.alphabets import greeks +from sympy.utilities.exceptions import sympy_deprecation_warning + +# prefix conventions when constructing tables +# L - LATIN i +# G - GREEK beta +# D - DIGIT 0 +# S - SYMBOL + + + +__all__ = ['greek_unicode', 'sub', 'sup', 'xsym', 'vobj', 'hobj', 'pretty_symbol', + 'annotated', 'center_pad', 'center'] + + +_use_unicode = False + + +def pretty_use_unicode(flag=None): + """Set whether pretty-printer should use unicode by default""" + global _use_unicode, unicode_warnings + if flag is None: + return _use_unicode + + if flag and unicode_warnings: + # print warnings (if any) on first unicode usage + warnings.warn(unicode_warnings) + unicode_warnings = '' + + use_unicode_prev = _use_unicode + _use_unicode = flag + return use_unicode_prev + + +def pretty_try_use_unicode(): + """See if unicode output is available and leverage it if possible""" + + encoding = getattr(sys.stdout, 'encoding', None) + + # this happens when e.g. stdout is redirected through a pipe, or is + # e.g. a cStringIO.StringO + if encoding is None: + return # sys.stdout has no encoding + + symbols = [] + + # see if we can represent greek alphabet + symbols += greek_unicode.values() + + # and atoms + symbols += atoms_table.values() + + for s in symbols: + if s is None: + return # common symbols not present! + + try: + s.encode(encoding) + except UnicodeEncodeError: + return + + # all the characters were present and encodable + pretty_use_unicode(True) + + +def xstr(*args): + sympy_deprecation_warning( + """ + The sympy.printing.pretty.pretty_symbology.xstr() function is + deprecated. Use str() instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions" + ) + return str(*args) + +# GREEK +g = lambda l: U('GREEK SMALL LETTER %s' % l.upper()) +G = lambda l: U('GREEK CAPITAL LETTER %s' % l.upper()) + +greek_letters = list(greeks) # make a copy +# deal with Unicode's funny spelling of lambda +greek_letters[greek_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_unicode = {L: g(L) for L in greek_letters} +greek_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_letters) + +# aliases +greek_unicode['lambda'] = greek_unicode['lamda'] +greek_unicode['Lambda'] = greek_unicode['Lamda'] +greek_unicode['varsigma'] = '\N{GREEK SMALL LETTER FINAL SIGMA}' + +# BOLD +b = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +B = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +bold_unicode = {l: b(l) for l in ascii_lowercase} +bold_unicode.update((L, B(L)) for L in ascii_uppercase) + +# GREEK BOLD +gb = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +GB = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +greek_bold_letters = list(greeks) # make a copy, not strictly required here +# deal with Unicode's funny spelling of lambda +greek_bold_letters[greek_bold_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_bold_unicode = {L: g(L) for L in greek_bold_letters} +greek_bold_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_bold_letters) +greek_bold_unicode['lambda'] = greek_unicode['lamda'] +greek_bold_unicode['Lambda'] = greek_unicode['Lamda'] +greek_bold_unicode['varsigma'] = '\N{MATHEMATICAL BOLD SMALL FINAL SIGMA}' + +digit_2txt = { + '0': 'ZERO', + '1': 'ONE', + '2': 'TWO', + '3': 'THREE', + '4': 'FOUR', + '5': 'FIVE', + '6': 'SIX', + '7': 'SEVEN', + '8': 'EIGHT', + '9': 'NINE', +} + +symb_2txt = { + '+': 'PLUS SIGN', + '-': 'MINUS', + '=': 'EQUALS SIGN', + '(': 'LEFT PARENTHESIS', + ')': 'RIGHT PARENTHESIS', + '[': 'LEFT SQUARE BRACKET', + ']': 'RIGHT SQUARE BRACKET', + '{': 'LEFT CURLY BRACKET', + '}': 'RIGHT CURLY BRACKET', + + # non-std + '{}': 'CURLY BRACKET', + 'sum': 'SUMMATION', + 'int': 'INTEGRAL', +} + +# SUBSCRIPT & SUPERSCRIPT +LSUB = lambda letter: U('LATIN SUBSCRIPT SMALL LETTER %s' % letter.upper()) +GSUB = lambda letter: U('GREEK SUBSCRIPT SMALL LETTER %s' % letter.upper()) +DSUB = lambda digit: U('SUBSCRIPT %s' % digit_2txt[digit]) +SSUB = lambda symb: U('SUBSCRIPT %s' % symb_2txt[symb]) + +LSUP = lambda letter: U('SUPERSCRIPT LATIN SMALL LETTER %s' % letter.upper()) +DSUP = lambda digit: U('SUPERSCRIPT %s' % digit_2txt[digit]) +SSUP = lambda symb: U('SUPERSCRIPT %s' % symb_2txt[symb]) + +sub = {} # symb -> subscript symbol +sup = {} # symb -> superscript symbol + +# latin subscripts +for l in 'aeioruvxhklmnpst': + sub[l] = LSUB(l) + +for l in 'in': + sup[l] = LSUP(l) + +for gl in ['beta', 'gamma', 'rho', 'phi', 'chi']: + sub[gl] = GSUB(gl) + +for d in [str(i) for i in range(10)]: + sub[d] = DSUB(d) + sup[d] = DSUP(d) + +for s in '+-=()': + sub[s] = SSUB(s) + sup[s] = SSUP(s) + +# Variable modifiers +# TODO: Make brackets adjust to height of contents +modifier_dict = { + # Accents + 'mathring': lambda s: center_accent(s, '\N{COMBINING RING ABOVE}'), + 'ddddot': lambda s: center_accent(s, '\N{COMBINING FOUR DOTS ABOVE}'), + 'dddot': lambda s: center_accent(s, '\N{COMBINING THREE DOTS ABOVE}'), + 'ddot': lambda s: center_accent(s, '\N{COMBINING DIAERESIS}'), + 'dot': lambda s: center_accent(s, '\N{COMBINING DOT ABOVE}'), + 'check': lambda s: center_accent(s, '\N{COMBINING CARON}'), + 'breve': lambda s: center_accent(s, '\N{COMBINING BREVE}'), + 'acute': lambda s: center_accent(s, '\N{COMBINING ACUTE ACCENT}'), + 'grave': lambda s: center_accent(s, '\N{COMBINING GRAVE ACCENT}'), + 'tilde': lambda s: center_accent(s, '\N{COMBINING TILDE}'), + 'hat': lambda s: center_accent(s, '\N{COMBINING CIRCUMFLEX ACCENT}'), + 'bar': lambda s: center_accent(s, '\N{COMBINING OVERLINE}'), + 'vec': lambda s: center_accent(s, '\N{COMBINING RIGHT ARROW ABOVE}'), + 'prime': lambda s: s+'\N{PRIME}', + 'prm': lambda s: s+'\N{PRIME}', + # # Faces -- these are here for some compatibility with latex printing + # 'bold': lambda s: s, + # 'bm': lambda s: s, + # 'cal': lambda s: s, + # 'scr': lambda s: s, + # 'frak': lambda s: s, + # Brackets + 'norm': lambda s: '\N{DOUBLE VERTICAL LINE}'+s+'\N{DOUBLE VERTICAL LINE}', + 'avg': lambda s: '\N{MATHEMATICAL LEFT ANGLE BRACKET}'+s+'\N{MATHEMATICAL RIGHT ANGLE BRACKET}', + 'abs': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', + 'mag': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', +} + +# VERTICAL OBJECTS +HUP = lambda symb: U('%s UPPER HOOK' % symb_2txt[symb]) +CUP = lambda symb: U('%s UPPER CORNER' % symb_2txt[symb]) +MID = lambda symb: U('%s MIDDLE PIECE' % symb_2txt[symb]) +EXT = lambda symb: U('%s EXTENSION' % symb_2txt[symb]) +HLO = lambda symb: U('%s LOWER HOOK' % symb_2txt[symb]) +CLO = lambda symb: U('%s LOWER CORNER' % symb_2txt[symb]) +TOP = lambda symb: U('%s TOP' % symb_2txt[symb]) +BOT = lambda symb: U('%s BOTTOM' % symb_2txt[symb]) + +# {} '(' -> (extension, start, end, middle) 1-character +_xobj_unicode = { + + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( EXT('('), HUP('('), HLO('(') ), '('), + ')': (( EXT(')'), HUP(')'), HLO(')') ), ')'), + '[': (( EXT('['), CUP('['), CLO('[') ), '['), + ']': (( EXT(']'), CUP(']'), CLO(']') ), ']'), + '{': (( EXT('{}'), HUP('{'), HLO('{'), MID('{') ), '{'), + '}': (( EXT('{}'), HUP('}'), HLO('}'), MID('}') ), '}'), + '|': U('BOX DRAWINGS LIGHT VERTICAL'), + 'Tee': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'UpTack': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL'), + 'corner_up_centre' + '(_ext': U('LEFT PARENTHESIS EXTENSION'), + ')_ext': U('RIGHT PARENTHESIS EXTENSION'), + '(_lower_hook': U('LEFT PARENTHESIS LOWER HOOK'), + ')_lower_hook': U('RIGHT PARENTHESIS LOWER HOOK'), + '(_upper_hook': U('LEFT PARENTHESIS UPPER HOOK'), + ')_upper_hook': U('RIGHT PARENTHESIS UPPER HOOK'), + '<': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT')), '<'), + + '>': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), '>'), + + 'lfloor': (( EXT('['), EXT('['), CLO('[') ), U('LEFT FLOOR')), + 'rfloor': (( EXT(']'), EXT(']'), CLO(']') ), U('RIGHT FLOOR')), + 'lceil': (( EXT('['), CUP('['), EXT('[') ), U('LEFT CEILING')), + 'rceil': (( EXT(']'), CUP(']'), EXT(']') ), U('RIGHT CEILING')), + + 'int': (( EXT('int'), U('TOP HALF INTEGRAL'), U('BOTTOM HALF INTEGRAL') ), U('INTEGRAL')), + 'sum': (( U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), '_', U('OVERLINE'), U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), U('N-ARY SUMMATION')), + + # horizontal objects + #'-': '-', + '-': U('BOX DRAWINGS LIGHT HORIZONTAL'), + '_': U('LOW LINE'), + # We used to use this, but LOW LINE looks better for roots, as it's a + # little lower (i.e., it lines up with the / perfectly. But perhaps this + # one would still be wanted for some cases? + # '_': U('HORIZONTAL SCAN LINE-9'), + + # diagonal objects '\' & '/' ? + '/': U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + '\\': U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), +} + +_xobj_ascii = { + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( '|', '/', '\\' ), '('), + ')': (( '|', '\\', '/' ), ')'), + +# XXX this looks ugly +# '[': (( '|', '-', '-' ), '['), +# ']': (( '|', '-', '-' ), ']'), +# XXX not so ugly :( + '[': (( '[', '[', '[' ), '['), + ']': (( ']', ']', ']' ), ']'), + + '{': (( '|', '/', '\\', '<' ), '{'), + '}': (( '|', '\\', '/', '>' ), '}'), + '|': '|', + + '<': (( '|', '/', '\\' ), '<'), + '>': (( '|', '\\', '/' ), '>'), + + 'int': ( ' | ', ' /', '/ ' ), + + # horizontal objects + '-': '-', + '_': '_', + + # diagonal objects '\' & '/' ? + '/': '/', + '\\': '\\', +} + + +def xobj(symb, length): + """Construct spatial object of given length. + + return: [] of equal-length strings + """ + + if length <= 0: + raise ValueError("Length should be greater than 0") + + # TODO robustify when no unicodedat available + if _use_unicode: + _xobj = _xobj_unicode + else: + _xobj = _xobj_ascii + + vinfo = _xobj[symb] + + c1 = top = bot = mid = None + + if not isinstance(vinfo, tuple): # 1 entry + ext = vinfo + else: + if isinstance(vinfo[0], tuple): # (vlong), c1 + vlong = vinfo[0] + c1 = vinfo[1] + else: # (vlong), c1 + vlong = vinfo + + ext = vlong[0] + + try: + top = vlong[1] + bot = vlong[2] + mid = vlong[3] + except IndexError: + pass + + if c1 is None: + c1 = ext + if top is None: + top = ext + if bot is None: + bot = ext + if mid is not None: + if (length % 2) == 0: + # even height, but we have to print it somehow anyway... + # XXX is it ok? + length += 1 + + else: + mid = ext + + if length == 1: + return c1 + + res = [] + next = (length - 2)//2 + nmid = (length - 2) - next*2 + + res += [top] + res += [ext]*next + res += [mid]*nmid + res += [ext]*next + res += [bot] + + return res + + +def vobj(symb, height): + """Construct vertical object of a given height + + see: xobj + """ + return '\n'.join( xobj(symb, height) ) + + +def hobj(symb, width): + """Construct horizontal object of a given width + + see: xobj + """ + return ''.join( xobj(symb, width) ) + +# RADICAL +# n -> symbol +root = { + 2: U('SQUARE ROOT'), # U('RADICAL SYMBOL BOTTOM') + 3: U('CUBE ROOT'), + 4: U('FOURTH ROOT'), +} + + +# RATIONAL +VF = lambda txt: U('VULGAR FRACTION %s' % txt) + +# (p,q) -> symbol +frac = { + (1, 2): VF('ONE HALF'), + (1, 3): VF('ONE THIRD'), + (2, 3): VF('TWO THIRDS'), + (1, 4): VF('ONE QUARTER'), + (3, 4): VF('THREE QUARTERS'), + (1, 5): VF('ONE FIFTH'), + (2, 5): VF('TWO FIFTHS'), + (3, 5): VF('THREE FIFTHS'), + (4, 5): VF('FOUR FIFTHS'), + (1, 6): VF('ONE SIXTH'), + (5, 6): VF('FIVE SIXTHS'), + (1, 8): VF('ONE EIGHTH'), + (3, 8): VF('THREE EIGHTHS'), + (5, 8): VF('FIVE EIGHTHS'), + (7, 8): VF('SEVEN EIGHTHS'), +} + + +# atom symbols +_xsym = { + '==': ('=', '='), + '<': ('<', '<'), + '>': ('>', '>'), + '<=': ('<=', U('LESS-THAN OR EQUAL TO')), + '>=': ('>=', U('GREATER-THAN OR EQUAL TO')), + '!=': ('!=', U('NOT EQUAL TO')), + ':=': (':=', ':='), + '+=': ('+=', '+='), + '-=': ('-=', '-='), + '*=': ('*=', '*='), + '/=': ('/=', '/='), + '%=': ('%=', '%='), + '*': ('*', U('DOT OPERATOR')), + '-->': ('-->', U('EM DASH') + U('EM DASH') + + U('BLACK RIGHT-POINTING TRIANGLE') if U('EM DASH') + and U('BLACK RIGHT-POINTING TRIANGLE') else None), + '==>': ('==>', U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BLACK RIGHT-POINTING TRIANGLE') if + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BLACK RIGHT-POINTING TRIANGLE') else None), + '.': ('*', U('RING OPERATOR')), +} + + +def xsym(sym): + """get symbology for a 'character'""" + op = _xsym[sym] + + if _use_unicode: + return op[1] + else: + return op[0] + + +# SYMBOLS + +atoms_table = { + # class how-to-display + 'Exp1': U('SCRIPT SMALL E'), + 'Pi': U('GREEK SMALL LETTER PI'), + 'Infinity': U('INFINITY'), + 'NegativeInfinity': U('INFINITY') and ('-' + U('INFINITY')), # XXX what to do here + #'ImaginaryUnit': U('GREEK SMALL LETTER IOTA'), + #'ImaginaryUnit': U('MATHEMATICAL ITALIC SMALL I'), + 'ImaginaryUnit': U('DOUBLE-STRUCK ITALIC SMALL I'), + 'EmptySet': U('EMPTY SET'), + 'Naturals': U('DOUBLE-STRUCK CAPITAL N'), + 'Naturals0': (U('DOUBLE-STRUCK CAPITAL N') and + (U('DOUBLE-STRUCK CAPITAL N') + + U('SUBSCRIPT ZERO'))), + 'Integers': U('DOUBLE-STRUCK CAPITAL Z'), + 'Rationals': U('DOUBLE-STRUCK CAPITAL Q'), + 'Reals': U('DOUBLE-STRUCK CAPITAL R'), + 'Complexes': U('DOUBLE-STRUCK CAPITAL C'), + 'Universe': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL U'), + 'IdentityMatrix': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL I'), + 'ZeroMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO'), + 'OneMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ONE'), + 'Differential': U('DOUBLE-STRUCK ITALIC SMALL D'), + 'Union': U('UNION'), + 'ElementOf': U('ELEMENT OF'), + 'SmallElementOf': U('SMALL ELEMENT OF'), + 'SymmetricDifference': U('INCREMENT'), + 'Intersection': U('INTERSECTION'), + 'Ring': U('RING OPERATOR'), + 'Multiplication': U('MULTIPLICATION SIGN'), + 'TensorProduct': U('N-ARY CIRCLED TIMES OPERATOR'), + 'Dots': U('HORIZONTAL ELLIPSIS'), + 'Modifier Letter Low Ring':U('Modifier Letter Low Ring'), + 'EmptySequence': 'EmptySequence', + 'SuperscriptPlus': U('SUPERSCRIPT PLUS SIGN'), + 'SuperscriptMinus': U('SUPERSCRIPT MINUS'), + 'Dagger': U('DAGGER'), + 'Degree': U('DEGREE SIGN'), + #Logic Symbols + 'And': U('LOGICAL AND'), + 'Or': U('LOGICAL OR'), + 'Not': U('NOT SIGN'), + 'Nor': U('NOR'), + 'Nand': U('NAND'), + 'Xor': U('XOR'), + 'Equiv': U('LEFT RIGHT DOUBLE ARROW'), + 'NotEquiv': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Implies': U('LEFT RIGHT DOUBLE ARROW'), + 'NotImplies': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Arrow': U('RIGHTWARDS ARROW'), + 'ArrowFromBar': U('RIGHTWARDS ARROW FROM BAR'), + 'NotArrow': U('RIGHTWARDS ARROW WITH STROKE'), + 'Tautology': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'Contradiction': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL') +} + + +def pretty_atom(atom_name, default=None, printer=None): + """return pretty representation of an atom""" + if _use_unicode: + if printer is not None and atom_name == 'ImaginaryUnit' and printer._settings['imaginary_unit'] == 'j': + return U('DOUBLE-STRUCK ITALIC SMALL J') + else: + return atoms_table[atom_name] + else: + if default is not None: + return default + + raise KeyError('only unicode') # send it default printer + + +def pretty_symbol(symb_name, bold_name=False): + """return pretty representation of a symbol""" + # let's split symb_name into symbol + index + # UC: beta1 + # UC: f_beta + + if not _use_unicode: + return symb_name + + name, sups, subs = split_super_sub(symb_name) + + def translate(s, bold_name) : + if bold_name: + gG = greek_bold_unicode.get(s) + else: + gG = greek_unicode.get(s) + if gG is not None: + return gG + for key in sorted(modifier_dict.keys(), key=lambda k:len(k), reverse=True) : + if s.lower().endswith(key) and len(s)>len(key): + return modifier_dict[key](translate(s[:-len(key)], bold_name)) + if bold_name: + return ''.join([bold_unicode[c] for c in s]) + return s + + name = translate(name, bold_name) + + # Let's prettify sups/subs. If it fails at one of them, pretty sups/subs are + # not used at all. + def pretty_list(l, mapping): + result = [] + for s in l: + pretty = mapping.get(s) + if pretty is None: + try: # match by separate characters + pretty = ''.join([mapping[c] for c in s]) + except (TypeError, KeyError): + return None + result.append(pretty) + return result + + pretty_sups = pretty_list(sups, sup) + if pretty_sups is not None: + pretty_subs = pretty_list(subs, sub) + else: + pretty_subs = None + + # glue the results into one string + if pretty_subs is None: # nice formatting of sups/subs did not work + if subs: + name += '_'+'_'.join([translate(s, bold_name) for s in subs]) + if sups: + name += '__'+'__'.join([translate(s, bold_name) for s in sups]) + return name + else: + sups_result = ' '.join(pretty_sups) + subs_result = ' '.join(pretty_subs) + + return ''.join([name, sups_result, subs_result]) + + +def annotated(letter): + """ + Return a stylised drawing of the letter ``letter``, together with + information on how to put annotations (super- and subscripts to the + left and to the right) on it. + + See pretty.py functions _print_meijerg, _print_hyper on how to use this + information. + """ + ucode_pics = { + 'F': (2, 0, 2, 0, '\N{BOX DRAWINGS LIGHT DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT UP}'), + 'G': (3, 0, 3, 1, '\N{BOX DRAWINGS LIGHT ARC DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL}\N{BOX DRAWINGS LIGHT RIGHT}\N{BOX DRAWINGS LIGHT DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT ARC UP AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC UP AND LEFT}') + } + ascii_pics = { + 'F': (3, 0, 3, 0, ' _\n|_\n|\n'), + 'G': (3, 0, 3, 1, ' __\n/__\n\\_|') + } + + if _use_unicode: + return ucode_pics[letter] + else: + return ascii_pics[letter] + +_remove_combining = dict.fromkeys(list(range(ord('\N{COMBINING GRAVE ACCENT}'), ord('\N{COMBINING LATIN SMALL LETTER X}'))) + + list(range(ord('\N{COMBINING LEFT HARPOON ABOVE}'), ord('\N{COMBINING ASTERISK ABOVE}')))) + +def is_combining(sym): + """Check whether symbol is a unicode modifier. """ + + return ord(sym) in _remove_combining + + +def center_accent(string, accent): + """ + Returns a string with accent inserted on the middle character. Useful to + put combining accents on symbol names, including multi-character names. + + Parameters + ========== + + string : string + The string to place the accent in. + accent : string + The combining accent to insert + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Combining_character + .. [2] https://en.wikipedia.org/wiki/Combining_Diacritical_Marks + + """ + + # Accent is placed on the previous character, although it may not always look + # like that depending on console + midpoint = len(string) // 2 + 1 + firstpart = string[:midpoint] + secondpart = string[midpoint:] + return firstpart + accent + secondpart + + +def line_width(line): + """Unicode combining symbols (modifiers) are not ever displayed as + separate symbols and thus should not be counted + """ + return len(line.translate(_remove_combining)) + + +def is_subscriptable_in_unicode(subscript): + """ + Checks whether a string is subscriptable in unicode or not. + + Parameters + ========== + + subscript: the string which needs to be checked + + Examples + ======== + + >>> from sympy.printing.pretty.pretty_symbology import is_subscriptable_in_unicode + >>> is_subscriptable_in_unicode('abc') + False + >>> is_subscriptable_in_unicode('123') + True + + """ + return all(character in sub for character in subscript) + + +def center_pad(wstring, wtarget, fillchar=' '): + """ + Return the padding strings necessary to center a string of + wstring characters wide in a wtarget wide space. + + The line_width wstring should always be less or equal to wtarget + or else a ValueError will be raised. + """ + if wstring > wtarget: + raise ValueError('not enough space for string') + wdelta = wtarget - wstring + + wleft = wdelta // 2 # favor left '1 ' + wright = wdelta - wleft + + left = fillchar * wleft + right = fillchar * wright + + return left, right + + +def center(string, width, fillchar=' '): + """Return a centered string of length determined by `line_width` + that uses `fillchar` for padding. + """ + left, right = center_pad(line_width(string), width, fillchar) + return ''.join([left, string, right]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/stringpict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/stringpict.py new file mode 100644 index 0000000000000000000000000000000000000000..b6055f09c83b2abbe0c492991aaee4dff5b34f49 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/stringpict.py @@ -0,0 +1,537 @@ +"""Prettyprinter by Jurjen Bos. +(I hate spammers: mail me at pietjepuk314 at the reverse of ku.oc.oohay). +All objects have a method that create a "stringPict", +that can be used in the str method for pretty printing. + +Updates by Jason Gedge (email at cs mun ca) + - terminal_string() method + - minor fixes and changes (mostly to prettyForm) + +TODO: + - Allow left/center/right alignment options for above/below and + top/center/bottom alignment options for left/right +""" + +import shutil + +from .pretty_symbology import hobj, vobj, xsym, xobj, pretty_use_unicode, line_width, center +from sympy.utilities.exceptions import sympy_deprecation_warning + +_GLOBAL_WRAP_LINE = None + +class stringPict: + """An ASCII picture. + The pictures are represented as a list of equal length strings. + """ + #special value for stringPict.below + LINE = 'line' + + def __init__(self, s, baseline=0): + """Initialize from string. + Multiline strings are centered. + """ + self.s = s + #picture is a string that just can be printed + self.picture = stringPict.equalLengths(s.splitlines()) + #baseline is the line number of the "base line" + self.baseline = baseline + self.binding = None + + @staticmethod + def equalLengths(lines): + # empty lines + if not lines: + return [''] + + width = max(line_width(line) for line in lines) + return [center(line, width) for line in lines] + + def height(self): + """The height of the picture in characters.""" + return len(self.picture) + + def width(self): + """The width of the picture in characters.""" + return line_width(self.picture[0]) + + @staticmethod + def next(*args): + """Put a string of stringPicts next to each other. + Returns string, baseline arguments for stringPict. + """ + #convert everything to stringPicts + objects = [] + for arg in args: + if isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #make a list of pictures, with equal height and baseline + newBaseline = max(obj.baseline for obj in objects) + newHeightBelowBaseline = max( + obj.height() - obj.baseline + for obj in objects) + newHeight = newBaseline + newHeightBelowBaseline + + pictures = [] + for obj in objects: + oneEmptyLine = [' '*obj.width()] + basePadding = newBaseline - obj.baseline + totalPadding = newHeight - obj.height() + pictures.append( + oneEmptyLine * basePadding + + obj.picture + + oneEmptyLine * (totalPadding - basePadding)) + + result = [''.join(lines) for lines in zip(*pictures)] + return '\n'.join(result), newBaseline + + def right(self, *args): + r"""Put pictures next to this one. + Returns string, baseline arguments for stringPict. + (Multiline) strings are allowed, and are given a baseline of 0. + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("10").right(" + ",stringPict("1\r-\r2",1))[0]) + 1 + 10 + - + 2 + + """ + return stringPict.next(self, *args) + + def left(self, *args): + """Put pictures (left to right) at left. + Returns string, baseline arguments for stringPict. + """ + return stringPict.next(*(args + (self,))) + + @staticmethod + def stack(*args): + """Put pictures on top of each other, + from top to bottom. + Returns string, baseline arguments for stringPict. + The baseline is the baseline of the second picture. + Everything is centered. + Baseline is the baseline of the second picture. + Strings are allowed. + The special value stringPict.LINE is a row of '-' extended to the width. + """ + #convert everything to stringPicts; keep LINE + objects = [] + for arg in args: + if arg is not stringPict.LINE and isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #compute new width + newWidth = max( + obj.width() + for obj in objects + if obj is not stringPict.LINE) + + lineObj = stringPict(hobj('-', newWidth)) + + #replace LINE with proper lines + for i, obj in enumerate(objects): + if obj is stringPict.LINE: + objects[i] = lineObj + + #stack the pictures, and center the result + newPicture = [center(line, newWidth) for obj in objects for line in obj.picture] + newBaseline = objects[0].height() + objects[1].baseline + return '\n'.join(newPicture), newBaseline + + def below(self, *args): + """Put pictures under this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of top picture + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("x+3").below( + ... stringPict.LINE, '3')[0]) #doctest: +NORMALIZE_WHITESPACE + x+3 + --- + 3 + + """ + s, baseline = stringPict.stack(self, *args) + return s, self.baseline + + def above(self, *args): + """Put pictures above this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of bottom picture. + """ + string, baseline = stringPict.stack(*(args + (self,))) + baseline = len(string.splitlines()) - self.height() + self.baseline + return string, baseline + + def parens(self, left='(', right=')', ifascii_nougly=False): + """Put parentheses around self. + Returns string, baseline arguments for stringPict. + + left or right can be None or empty string which means 'no paren from + that side' + """ + h = self.height() + b = self.baseline + + # XXX this is a hack -- ascii parens are ugly! + if ifascii_nougly and not pretty_use_unicode(): + h = 1 + b = 0 + + res = self + + if left: + lparen = stringPict(vobj(left, h), baseline=b) + res = stringPict(*lparen.right(self)) + if right: + rparen = stringPict(vobj(right, h), baseline=b) + res = stringPict(*res.right(rparen)) + + return ('\n'.join(res.picture), res.baseline) + + def leftslash(self): + """Precede object by a slash of the proper size. + """ + # XXX not used anywhere ? + height = max( + self.baseline, + self.height() - 1 - self.baseline)*2 + 1 + slash = '\n'.join( + ' '*(height - i - 1) + xobj('/', 1) + ' '*i + for i in range(height) + ) + return self.left(stringPict(slash, height//2)) + + def root(self, n=None): + """Produce a nice root symbol. + Produces ugly results for big n inserts. + """ + # XXX not used anywhere + # XXX duplicate of root drawing in pretty.py + #put line over expression + result = self.above('_'*self.width()) + #construct right half of root symbol + height = self.height() + slash = '\n'.join( + ' ' * (height - i - 1) + '/' + ' ' * i + for i in range(height) + ) + slash = stringPict(slash, height - 1) + #left half of root symbol + if height > 2: + downline = stringPict('\\ \n \\', 1) + else: + downline = stringPict('\\') + #put n on top, as low as possible + if n is not None and n.width() > downline.width(): + downline = downline.left(' '*(n.width() - downline.width())) + downline = downline.above(n) + #build root symbol + root = downline.right(slash) + #glue it on at the proper height + #normally, the root symbel is as high as self + #which is one less than result + #this moves the root symbol one down + #if the root became higher, the baseline has to grow too + root.baseline = result.baseline - result.height() + root.height() + return result.left(root) + + def render(self, * args, **kwargs): + """Return the string form of self. + + Unless the argument line_break is set to False, it will + break the expression in a form that can be printed + on the terminal without being broken up. + """ + if _GLOBAL_WRAP_LINE is not None: + kwargs["wrap_line"] = _GLOBAL_WRAP_LINE + + if kwargs["wrap_line"] is False: + return "\n".join(self.picture) + + if kwargs["num_columns"] is not None: + # Read the argument num_columns if it is not None + ncols = kwargs["num_columns"] + else: + # Attempt to get a terminal width + ncols = self.terminal_width() + + if ncols <= 0: + ncols = 80 + + # If smaller than the terminal width, no need to correct + if self.width() <= ncols: + return type(self.picture[0])(self) + + """ + Break long-lines in a visually pleasing format. + without overflow indicators | with overflow indicators + | 2 2 3 | | 2 2 3 ↪| + |6*x *y + 4*x*y + | |6*x *y + 4*x*y + ↪| + | | | | + | 3 4 4 | |↪ 3 4 4 | + |4*y*x + x + y | |↪ 4*y*x + x + y | + |a*c*e + a*c*f + a*d | |a*c*e + a*c*f + a*d ↪| + |*e + a*d*f + b*c*e | | | + |+ b*c*f + b*d*e + b | |↪ *e + a*d*f + b*c* ↪| + |*d*f | | | + | | |↪ e + b*c*f + b*d*e ↪| + | | | | + | | |↪ + b*d*f | + """ + + overflow_first = "" + if kwargs["use_unicode"] or pretty_use_unicode(): + overflow_start = "\N{RIGHTWARDS ARROW WITH HOOK} " + overflow_end = " \N{RIGHTWARDS ARROW WITH HOOK}" + else: + overflow_start = "> " + overflow_end = " >" + + def chunks(line): + """Yields consecutive chunks of line_width ncols""" + prefix = overflow_first + width, start = line_width(prefix + overflow_end), 0 + for i, x in enumerate(line): + wx = line_width(x) + # Only flush the screen when the current character overflows. + # This way, combining marks can be appended even when width == ncols. + if width + wx > ncols: + yield prefix + line[start:i] + overflow_end + prefix = overflow_start + width, start = line_width(prefix + overflow_end), i + width += wx + yield prefix + line[start:] + + # Concurrently assemble chunks of all lines into individual screens + pictures = zip(*map(chunks, self.picture)) + + # Join lines of each screen into sub-pictures + pictures = ["\n".join(picture) for picture in pictures] + + # Add spacers between sub-pictures + return "\n\n".join(pictures) + + def terminal_width(self): + """Return the terminal width if possible, otherwise return 0. + """ + size = shutil.get_terminal_size(fallback=(0, 0)) + return size.columns + + def __eq__(self, o): + if isinstance(o, str): + return '\n'.join(self.picture) == o + elif isinstance(o, stringPict): + return o.picture == self.picture + return False + + def __hash__(self): + return super().__hash__() + + def __str__(self): + return '\n'.join(self.picture) + + def __repr__(self): + return "stringPict(%r,%d)" % ('\n'.join(self.picture), self.baseline) + + def __getitem__(self, index): + return self.picture[index] + + def __len__(self): + return len(self.s) + + +class prettyForm(stringPict): + """ + Extension of the stringPict class that knows about basic math applications, + optimizing double minus signs. + + "Binding" is interpreted as follows:: + + ATOM this is an atom: never needs to be parenthesized + FUNC this is a function application: parenthesize if added (?) + DIV this is a division: make wider division if divided + POW this is a power: only parenthesize if exponent + MUL this is a multiplication: parenthesize if powered + ADD this is an addition: parenthesize if multiplied or powered + NEG this is a negative number: optimize if added, parenthesize if + multiplied or powered + OPEN this is an open object: parenthesize if added, multiplied, or + powered (example: Piecewise) + """ + ATOM, FUNC, DIV, POW, MUL, ADD, NEG, OPEN = range(8) + + def __init__(self, s, baseline=0, binding=0, unicode=None): + """Initialize from stringPict and binding power.""" + stringPict.__init__(self, s, baseline) + self.binding = binding + if unicode is not None: + sympy_deprecation_warning( + """ + The unicode argument to prettyForm is deprecated. Only the s + argument (the first positional argument) should be passed. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + self._unicode = unicode or s + + @property + def unicode(self): + sympy_deprecation_warning( + """ + The prettyForm.unicode attribute is deprecated. Use the + prettyForm.s attribute instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + return self._unicode + + # Note: code to handle subtraction is in _print_Add + + def __add__(self, *others): + """Make a pretty addition. + Addition of negative numbers is simplified. + """ + arg = self + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + #add parentheses for weak binders + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + #use existing minus sign if available + if arg.binding != prettyForm.NEG: + result.append(' + ') + result.append(arg) + return prettyForm(binding=prettyForm.ADD, *stringPict.next(*result)) + + def __truediv__(self, den, slashed=False): + """Make a pretty division; stacked or slashed. + """ + if slashed: + raise NotImplementedError("Can't do slashed fraction yet") + num = self + if num.binding == prettyForm.DIV: + num = stringPict(*num.parens()) + if den.binding == prettyForm.DIV: + den = stringPict(*den.parens()) + + if num.binding==prettyForm.NEG: + num = num.right(" ")[0] + + return prettyForm(binding=prettyForm.DIV, *stringPict.stack( + num, + stringPict.LINE, + den)) + + def __mul__(self, *others): + """Make a pretty multiplication. + Parentheses are needed around +, - and neg. + """ + quantity = { + 'degree': "\N{DEGREE SIGN}" + } + + if len(others) == 0: + return self # We aren't actually multiplying... So nothing to do here. + + # add parens on args that need them + arg = self + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + if arg.picture[0] not in quantity.values(): + result.append(xsym('*')) + #add parentheses for weak binders + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result.append(arg) + + len_res = len(result) + for i in range(len_res): + if i < len_res - 1 and result[i] == '-1' and result[i + 1] == xsym('*'): + # substitute -1 by -, like in -1*x -> -x + result.pop(i) + result.pop(i) + result.insert(i, '-') + if result[0][0] == '-': + # if there is a - sign in front of all + # This test was failing to catch a prettyForm.__mul__(prettyForm("-1", 0, 6)) being negative + bin = prettyForm.NEG + if result[0] == '-': + right = result[1] + if right.picture[right.baseline][0] == '-': + result[0] = '- ' + else: + bin = prettyForm.MUL + return prettyForm(binding=bin, *stringPict.next(*result)) + + def __repr__(self): + return "prettyForm(%r,%d,%d)" % ( + '\n'.join(self.picture), + self.baseline, + self.binding) + + def __pow__(self, b): + """Make a pretty power. + """ + a = self + use_inline_func_form = False + if b.binding == prettyForm.POW: + b = stringPict(*b.parens()) + if a.binding > prettyForm.FUNC: + a = stringPict(*a.parens()) + elif a.binding == prettyForm.FUNC: + # heuristic for when to use inline power + if b.height() > 1: + a = stringPict(*a.parens()) + else: + use_inline_func_form = True + + if use_inline_func_form: + # 2 + # sin + + (x) + b.baseline = a.prettyFunc.baseline + b.height() + func = stringPict(*a.prettyFunc.right(b)) + return prettyForm(*func.right(a.prettyArgs)) + else: + # 2 <-- top + # (x+y) <-- bot + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + + return prettyForm(binding=prettyForm.POW, *bot.above(top)) + + simpleFunctions = ["sin", "cos", "tan"] + + @staticmethod + def apply(function, *args): + """Functions of one or more variables. + """ + if function in prettyForm.simpleFunctions: + #simple function: use only space if possible + assert len( + args) == 1, "Simple function %s must have 1 argument" % function + arg = args[0].__pretty__() + if arg.binding <= prettyForm.DIV: + #optimization: no parentheses necessary + return prettyForm(binding=prettyForm.FUNC, *arg.left(function + ' ')) + argumentList = [] + for arg in args: + argumentList.append(',') + argumentList.append(arg.__pretty__()) + argumentList = stringPict(*stringPict.next(*argumentList[1:])) + argumentList = stringPict(*argumentList.parens()) + return prettyForm(binding=prettyForm.ATOM, *argumentList.left(function)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2ee4d1c7e7390e45b616f7bee2b31d7cfb8a01e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/test_pretty.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/test_pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..1cca79bd1dc5c3ba81483c8fe2e87c35926d1b94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/pretty/tests/test_pretty.py @@ -0,0 +1,7972 @@ +# -*- coding: utf-8 -*- +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.function import (Derivative, Function, Lambda, Subs) +from sympy.core.mul import Mul +from sympy.core import (EulerGamma, GoldenRatio, Catalan) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import LambertW +from sympy.functions.special.bessel import (airyai, airyaiprime, airybi, airybiprime) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.error_functions import (fresnelc, fresnels) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import dirichlet_eta +from sympy.geometry.line import (Ray, Segment) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Equivalent, ITE, Implies, Nand, Nor, Not, Or, Xor) +from sympy.matrices.dense import (Matrix, diag) +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.trace import Trace +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.realfield import RR +from sympy.polys.orderings import (grlex, ilex) +from sympy.polys.polytools import groebner +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Complement, FiniteSet, Intersection, Interval, Union) +from sympy.codegen.ast import (Assignment, AddAugmentedAssignment, + SubAugmentedAssignment, MulAugmentedAssignment, DivAugmentedAssignment, ModAugmentedAssignment) +from sympy.core.expr import UnevaluatedExpr +from sympy.physics.quantum.trace import Tr + +from sympy.functions import (Abs, Chi, Ci, Ei, KroneckerDelta, + Piecewise, Shi, Si, atan2, beta, binomial, catalan, ceiling, cos, + euler, exp, expint, factorial, factorial2, floor, gamma, hyper, log, + meijerg, sin, sqrt, subfactorial, tan, uppergamma, lerchphi, polylog, + elliptic_k, elliptic_f, elliptic_e, elliptic_pi, DiracDelta, bell, + bernoulli, fibonacci, tribonacci, lucas, stieltjes, mathieuc, mathieus, + mathieusprime, mathieucprime) + +from sympy.matrices import (Adjoint, Inverse, MatrixSymbol, Transpose, + KroneckerProduct, BlockMatrix, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions import hadamard_power + +from sympy.physics import mechanics +from sympy.physics.control.lti import (TransferFunction, Feedback, TransferFunctionMatrix, + Series, Parallel, MIMOSeries, MIMOParallel, MIMOFeedback, StateSpace) +from sympy.physics.units import joule, degree +from sympy.printing.pretty import pprint, pretty as xpretty +from sympy.printing.pretty.pretty_symbology import center_accent, is_combining, center +from sympy.sets.conditionset import ConditionSet + +from sympy.sets import ImageSet, ProductSet +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray, tensorproduct) +from sympy.tensor.functions import TensorProduct +from sympy.tensor.tensor import (TensorIndexType, tensor_indices, TensorHead, + TensorElement, tensor_heads) + +from sympy.testing.pytest import raises, _both_exp_pow, warns_deprecated_sympy + +from sympy.vector import CoordSys3D, Gradient, Curl, Divergence, Dot, Cross, Laplacian + + + +import sympy as sym +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + +a, b, c, d, x, y, z, k, n, s, p = symbols('a,b,c,d,x,y,z,k,n,s,p') +f = Function("f") +th = Symbol('theta') +ph = Symbol('phi') + +""" +Expressions whose pretty-printing is tested here: +(A '#' to the right of an expression indicates that its various acceptable +orderings are accounted for by the tests.) + + +BASIC EXPRESSIONS: + +oo +(x**2) +1/x +y*x**-2 +x**Rational(-5,2) +(-2)**x +Pow(3, 1, evaluate=False) +(x**2 + x + 1) # +1-x # +1-2*x # +x/y +-x/y +(x+2)/y # +(1+x)*y #3 +-5*x/(x+10) # correct placement of negative sign +1 - Rational(3,2)*(x+1) +-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5) # issue 5524 + + +ORDERING: + +x**2 + x + 1 +1 - x +1 - 2*x +2*x**4 + y**2 - x**2 + y**3 + + +RELATIONAL: + +Eq(x, y) +Lt(x, y) +Gt(x, y) +Le(x, y) +Ge(x, y) +Ne(x/(y+1), y**2) # + + +RATIONAL NUMBERS: + +y*x**-2 +y**Rational(3,2) * x**Rational(-5,2) +sin(x)**3/tan(x)**2 + + +FUNCTIONS (ABS, CONJ, EXP, FUNCTION BRACES, FACTORIAL, FLOOR, CEILING): + +(2*x + exp(x)) # +Abs(x) +Abs(x/(x**2+1)) # +Abs(1 / (y - Abs(x))) +factorial(n) +factorial(2*n) +subfactorial(n) +subfactorial(2*n) +factorial(factorial(factorial(n))) +factorial(n+1) # +conjugate(x) +conjugate(f(x+1)) # +f(x) +f(x, y) +f(x/(y+1), y) # +f(x**x**x**x**x**x) +sin(x)**2 +conjugate(a+b*I) +conjugate(exp(a+b*I)) +conjugate( f(1 + conjugate(f(x))) ) # +f(x/(y+1), y) # denom of first arg +floor(1 / (y - floor(x))) +ceiling(1 / (y - ceiling(x))) + + +SQRT: + +sqrt(2) +2**Rational(1,3) +2**Rational(1,1000) +sqrt(x**2 + 1) +(1 + sqrt(5))**Rational(1,3) +2**(1/x) +sqrt(2+pi) +(2+(1+x**2)/(2+x))**Rational(1,4)+(1+x**Rational(1,1000))/sqrt(3+x**2) + + +DERIVATIVES: + +Derivative(log(x), x, evaluate=False) +Derivative(log(x), x, evaluate=False) + x # +Derivative(log(x) + x**2, x, y, evaluate=False) +Derivative(2*x*y, y, x, evaluate=False) + x**2 # +beta(alpha).diff(alpha) + + +INTEGRALS: + +Integral(log(x), x) +Integral(x**2, x) +Integral((sin(x))**2 / (tan(x))**2) +Integral(x**(2**x), x) +Integral(x**2, (x,1,2)) +Integral(x**2, (x,Rational(1,2),10)) +Integral(x**2*y**2, x,y) +Integral(x**2, (x, None, 1)) +Integral(x**2, (x, 1, None)) +Integral(sin(th)/cos(ph), (th,0,pi), (ph, 0, 2*pi)) + + +MATRICES: + +Matrix([[x**2+1, 1], [y, x+y]]) # +Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + + +PIECEWISE: + +Piecewise((x,x<1),(x**2,True)) + +ITE: + +ITE(x, y, z) + +SEQUENCES (TUPLES, LISTS, DICTIONARIES): + +() +[] +{} +(1/x,) +[x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] +(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) +{x: sin(x)} +{1/x: 1/y, x: sin(x)**2} # +[x**2] +(x**2,) +{x**2: 1} + + +LIMITS: + +Limit(x, x, oo) +Limit(x**2, x, 0) +Limit(1/x, x, 0) +Limit(sin(x)/x, x, 0) + + +UNITS: + +joule => kg*m**2/s + + +SUBS: + +Subs(f(x), x, ph**2) +Subs(f(x).diff(x), x, 0) +Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + + +ORDER: + +O(1) +O(1/x) +O(x**2 + y**2) + +""" + + +def pretty(expr, order=None): + """ASCII pretty-printing""" + return xpretty(expr, order=order, use_unicode=False, wrap_line=False) + + +def upretty(expr, order=None): + """Unicode pretty-printing""" + return xpretty(expr, order=order, use_unicode=True, wrap_line=False) + + +def test_pretty_ascii_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( "xxx" ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_pretty_unicode_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_upretty_greek(): + assert upretty( oo ) == '∞' + assert upretty( Symbol('alpha^+_1') ) == 'α⁺₁' + assert upretty( Symbol('beta') ) == 'β' + assert upretty(Symbol('lambda')) == 'λ' + + +def test_upretty_multiindex(): + assert upretty( Symbol('beta12') ) == 'β₁₂' + assert upretty( Symbol('Y00') ) == 'Y₀₀' + assert upretty( Symbol('Y_00') ) == 'Y₀₀' + assert upretty( Symbol('F^+-') ) == 'F⁺⁻' + + +def test_upretty_sub_super(): + assert upretty( Symbol('beta_1_2') ) == 'β₁ ₂' + assert upretty( Symbol('beta^1^2') ) == 'β¹ ²' + assert upretty( Symbol('beta_1^2') ) == 'β²₁' + assert upretty( Symbol('beta_10_20') ) == 'β₁₀ ₂₀' + assert upretty( Symbol('beta_ax_gamma^i') ) == 'βⁱₐₓ ᵧ' + assert upretty( Symbol("F^1^2_3_4") ) == 'F¹ ²₃ ₄' + assert upretty( Symbol("F_1_2^3^4") ) == 'F³ ⁴₁ ₂' + assert upretty( Symbol("F_1_2_3_4") ) == 'F₁ ₂ ₃ ₄' + assert upretty( Symbol("F^1^2^3^4") ) == 'F¹ ² ³ ⁴' + + +def test_upretty_subs_missing_in_24(): + assert upretty( Symbol('F_beta') ) == 'Fᵦ' + assert upretty( Symbol('F_gamma') ) == 'Fᵧ' + assert upretty( Symbol('F_rho') ) == 'Fᵨ' + assert upretty( Symbol('F_phi') ) == 'Fᵩ' + assert upretty( Symbol('F_chi') ) == 'Fᵪ' + + assert upretty( Symbol('F_a') ) == 'Fₐ' + assert upretty( Symbol('F_e') ) == 'Fₑ' + assert upretty( Symbol('F_i') ) == 'Fᵢ' + assert upretty( Symbol('F_o') ) == 'Fₒ' + assert upretty( Symbol('F_u') ) == 'Fᵤ' + assert upretty( Symbol('F_r') ) == 'Fᵣ' + assert upretty( Symbol('F_v') ) == 'Fᵥ' + assert upretty( Symbol('F_x') ) == 'Fₓ' + + +def test_missing_in_2X_issue_9047(): + assert upretty( Symbol('F_h') ) == 'Fₕ' + assert upretty( Symbol('F_k') ) == 'Fₖ' + assert upretty( Symbol('F_l') ) == 'Fₗ' + assert upretty( Symbol('F_m') ) == 'Fₘ' + assert upretty( Symbol('F_n') ) == 'Fₙ' + assert upretty( Symbol('F_p') ) == 'Fₚ' + assert upretty( Symbol('F_s') ) == 'Fₛ' + assert upretty( Symbol('F_t') ) == 'Fₜ' + + +def test_upretty_modifiers(): + # Accents + assert upretty( Symbol('Fmathring') ) == 'F̊' + assert upretty( Symbol('Fddddot') ) == 'F⃜' + assert upretty( Symbol('Fdddot') ) == 'F⃛' + assert upretty( Symbol('Fddot') ) == 'F̈' + assert upretty( Symbol('Fdot') ) == 'Ḟ' + assert upretty( Symbol('Fcheck') ) == 'F̌' + assert upretty( Symbol('Fbreve') ) == 'F̆' + assert upretty( Symbol('Facute') ) == 'F́' + assert upretty( Symbol('Fgrave') ) == 'F̀' + assert upretty( Symbol('Ftilde') ) == 'F̃' + assert upretty( Symbol('Fhat') ) == 'F̂' + assert upretty( Symbol('Fbar') ) == 'F̅' + assert upretty( Symbol('Fvec') ) == 'F⃗' + assert upretty( Symbol('Fprime') ) == 'F′' + assert upretty( Symbol('Fprm') ) == 'F′' + # No faces are actually implemented, but test to make sure the modifiers are stripped + assert upretty( Symbol('Fbold') ) == 'Fbold' + assert upretty( Symbol('Fbm') ) == 'Fbm' + assert upretty( Symbol('Fcal') ) == 'Fcal' + assert upretty( Symbol('Fscr') ) == 'Fscr' + assert upretty( Symbol('Ffrak') ) == 'Ffrak' + # Brackets + assert upretty( Symbol('Fnorm') ) == '‖F‖' + assert upretty( Symbol('Favg') ) == '⟨F⟩' + assert upretty( Symbol('Fabs') ) == '|F|' + assert upretty( Symbol('Fmag') ) == '|F|' + # Combinations + assert upretty( Symbol('xvecdot') ) == 'x⃗̇' + assert upretty( Symbol('xDotVec') ) == 'ẋ⃗' + assert upretty( Symbol('xHATNorm') ) == '‖x̂‖' + assert upretty( Symbol('xMathring_yCheckPRM__zbreveAbs') ) == 'x̊_y̌′__|z̆|' + assert upretty( Symbol('alphadothat_nVECDOT__tTildePrime') ) == 'α̇̂_n⃗̇__t̃′' + assert upretty( Symbol('x_dot') ) == 'x_dot' + assert upretty( Symbol('x__dot') ) == 'x__dot' + + +def test_pretty_Cycle(): + from sympy.combinatorics.permutations import Cycle + assert pretty(Cycle(1, 2)) == '(1 2)' + assert pretty(Cycle(2)) == '(2)' + assert pretty(Cycle(1, 3)(4, 5)) == '(1 3)(4 5)' + assert pretty(Cycle()) == '()' + + +def test_pretty_Permutation(): + from sympy.combinatorics.permutations import Permutation + p1 = Permutation(1, 2)(3, 4) + assert xpretty(p1, perm_cyclic=True, use_unicode=True) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=True, use_unicode=False) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=False, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, perm_cyclic=False, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert xpretty(p1, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + Permutation.print_cyclic = old_print_cyclic + + +def test_pretty_basic(): + assert pretty( -Rational(1)/2 ) == '-1/2' + assert pretty( -Rational(13)/22 ) == \ +"""\ +-13 \n\ +----\n\ + 22 \ +""" + expr = oo + ascii_str = \ +"""\ +oo\ +""" + ucode_str = \ +"""\ +∞\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2) + ascii_str = \ +"""\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 1/x + ascii_str = \ +"""\ +1\n\ +-\n\ +x\ +""" + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # not the same as 1/x + expr = x**-1.0 + ascii_str = \ +"""\ + -1.0\n\ +x \ +""" + ucode_str = \ +"""\ + -1.0\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # see issue #2860 + expr = Pow(S(2), -1.0, evaluate=False) + ascii_str = \ +"""\ + -1.0\n\ +2 \ +""" + ucode_str = \ +"""\ + -1.0\n\ +2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + #see issue #14033 + expr = x**Rational(1, 3) + ascii_str = \ +"""\ + 1/3\n\ +x \ +""" + ucode_str = \ +"""\ + 1/3\n\ +x \ +""" + assert xpretty(expr, use_unicode=False, wrap_line=False,\ + root_notation = False) == ascii_str + assert xpretty(expr, use_unicode=True, wrap_line=False,\ + root_notation = False) == ucode_str + + expr = x**Rational(-5, 2) + ascii_str = \ +"""\ + 1 \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 1 \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (-2)**x + ascii_str = \ +"""\ + x\n\ +(-2) \ +""" + ucode_str = \ +"""\ + x\n\ +(-2) \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # See issue 4923 + expr = Pow(3, 1, evaluate=False) + ascii_str = \ +"""\ + 1\n\ +3 \ +""" + ucode_str = \ +"""\ + 1\n\ +3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2 + x + 1) + ascii_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ascii_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + ucode_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ucode_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = 1 - x + ascii_str_1 = \ +"""\ +1 - x\ +""" + ascii_str_2 = \ +"""\ +-x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - x\ +""" + ucode_str_2 = \ +"""\ +-x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 1 - 2*x + ascii_str_1 = \ +"""\ +1 - 2*x\ +""" + ascii_str_2 = \ +"""\ +-2*x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - 2⋅x\ +""" + ucode_str_2 = \ +"""\ +-2⋅x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = x/y + ascii_str = \ +"""\ +x\n\ +-\n\ +y\ +""" + ucode_str = \ +"""\ +x\n\ +─\n\ +y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -x/y + ascii_str = \ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str = \ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x + 2)/y + ascii_str_1 = \ +"""\ +2 + x\n\ +-----\n\ + y \ +""" + ascii_str_2 = \ +"""\ +x + 2\n\ +-----\n\ + y \ +""" + ucode_str_1 = \ +"""\ +2 + x\n\ +─────\n\ + y \ +""" + ucode_str_2 = \ +"""\ +x + 2\n\ +─────\n\ + y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = (1 + x)*y + ascii_str_1 = \ +"""\ +y*(1 + x)\ +""" + ascii_str_2 = \ +"""\ +(1 + x)*y\ +""" + ascii_str_3 = \ +"""\ +y*(x + 1)\ +""" + ucode_str_1 = \ +"""\ +y⋅(1 + x)\ +""" + ucode_str_2 = \ +"""\ +(1 + x)⋅y\ +""" + ucode_str_3 = \ +"""\ +y⋅(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + # Test for correct placement of the negative sign + expr = -5*x/(x + 10) + ascii_str_1 = \ +"""\ +-5*x \n\ +------\n\ +10 + x\ +""" + ascii_str_2 = \ +"""\ +-5*x \n\ +------\n\ +x + 10\ +""" + ucode_str_1 = \ +"""\ +-5⋅x \n\ +──────\n\ +10 + x\ +""" + ucode_str_2 = \ +"""\ +-5⋅x \n\ +──────\n\ +x + 10\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = -S.Half - 3*x + ascii_str = \ +"""\ +-3*x - 1/2\ +""" + ucode_str = \ +"""\ +-3⋅x - 1/2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x + ascii_str = \ +"""\ +1/2 - 3*x\ +""" + ucode_str = \ +"""\ +1/2 - 3⋅x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -S.Half - 3*x/2 + ascii_str = \ +"""\ + 3*x 1\n\ +- --- - -\n\ + 2 2\ +""" + ucode_str = \ +"""\ + 3⋅x 1\n\ +- ─── - ─\n\ + 2 2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x/2 + ascii_str = \ +"""\ +1 3*x\n\ +- - ---\n\ +2 2 \ +""" + ucode_str = \ +"""\ +1 3⋅x\n\ +─ - ───\n\ +2 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_negative_fractions(): + expr = -x/y + ascii_str =\ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x*z/y + ascii_str =\ +"""\ +-x*z \n\ +-----\n\ + y \ +""" + ucode_str =\ +"""\ +-x⋅z \n\ +─────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = x**2/y + ascii_str =\ +"""\ + 2\n\ +x \n\ +--\n\ +y \ +""" + ucode_str =\ +"""\ + 2\n\ +x \n\ +──\n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x**2/y + ascii_str =\ +"""\ + 2 \n\ +-x \n\ +----\n\ + y \ +""" + ucode_str =\ +"""\ + 2 \n\ +-x \n\ +────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x/(y*z) + ascii_str =\ +"""\ +-x \n\ +---\n\ +y*z\ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ +y⋅z\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -a/y**2 + ascii_str =\ +"""\ +-a \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-a \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = y**(-a/b) + ascii_str =\ +"""\ + -a \n\ + ---\n\ + b \n\ +y \ +""" + ucode_str =\ +"""\ + -a \n\ + ───\n\ + b \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -1/y**2 + ascii_str =\ +"""\ +-1 \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-1 \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -10/b**2 + ascii_str =\ +"""\ +-10 \n\ +----\n\ + 2 \n\ + b \ +""" + ucode_str =\ +"""\ +-10 \n\ +────\n\ + 2 \n\ + b \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = Rational(-200, 37) + ascii_str =\ +"""\ +-200 \n\ +-----\n\ + 37 \ +""" + ucode_str =\ +"""\ +-200 \n\ +─────\n\ + 37 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_Mul(): + expr = Mul(0, 1, evaluate=False) + assert pretty(expr) == "0*1" + assert upretty(expr) == "0⋅1" + expr = Mul(1, 0, evaluate=False) + assert pretty(expr) == "1*0" + assert upretty(expr) == "1⋅0" + expr = Mul(1, 1, evaluate=False) + assert pretty(expr) == "1*1" + assert upretty(expr) == "1⋅1" + expr = Mul(1, 1, 1, evaluate=False) + assert pretty(expr) == "1*1*1" + assert upretty(expr) == "1⋅1⋅1" + expr = Mul(1, 2, evaluate=False) + assert pretty(expr) == "1*2" + assert upretty(expr) == "1⋅2" + expr = Add(0, 1, evaluate=False) + assert pretty(expr) == "0 + 1" + assert upretty(expr) == "0 + 1" + expr = Mul(1, 1, 2, evaluate=False) + assert pretty(expr) == "1*1*2" + assert upretty(expr) == "1⋅1⋅2" + expr = Add(0, 0, 1, evaluate=False) + assert pretty(expr) == "0 + 0 + 1" + assert upretty(expr) == "0 + 0 + 1" + expr = Mul(1, -1, evaluate=False) + assert pretty(expr) == "1*-1" + assert upretty(expr) == "1⋅-1" + expr = Mul(1.0, x, evaluate=False) + assert pretty(expr) == "1.0*x" + assert upretty(expr) == "1.0⋅x" + expr = Mul(1, 1, 2, 3, x, evaluate=False) + assert pretty(expr) == "1*1*2*3*x" + assert upretty(expr) == "1⋅1⋅2⋅3⋅x" + expr = Mul(-1, 1, evaluate=False) + assert pretty(expr) == "-1*1" + assert upretty(expr) == "-1⋅1" + expr = Mul(4, 3, 2, 1, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*1*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅1⋅0⋅y⋅x" + expr = Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*(z + 1)*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅(z + 1)⋅0⋅y⋅x" + expr = Mul(Rational(2, 3), Rational(5, 7), evaluate=False) + assert pretty(expr) == "2/3*5/7" + assert upretty(expr) == "2/3⋅5/7" + expr = Mul(x + y, Rational(1, 2), evaluate=False) + assert pretty(expr) == "(x + y)*1/2" + assert upretty(expr) == "(x + y)⋅1/2" + expr = Mul(Rational(1, 2), x + y, evaluate=False) + assert pretty(expr) == "x + y\n-----\n 2 " + assert upretty(expr) == "x + y\n─────\n 2 " + expr = Mul(S.One, x + y, evaluate=False) + assert pretty(expr) == "1*(x + y)" + assert upretty(expr) == "1⋅(x + y)" + expr = Mul(x - y, S.One, evaluate=False) + assert pretty(expr) == "(x - y)*1" + assert upretty(expr) == "(x - y)⋅1" + expr = Mul(Rational(1, 2), x - y, S.One, x + y, evaluate=False) + assert pretty(expr) == "1/2*(x - y)*1*(x + y)" + assert upretty(expr) == "1/2⋅(x - y)⋅1⋅(x + y)" + expr = Mul(x + y, Rational(3, 4), S.One, y - z, evaluate=False) + assert pretty(expr) == "(x + y)*3/4*1*(y - z)" + assert upretty(expr) == "(x + y)⋅3/4⋅1⋅(y - z)" + expr = Mul(x + y, Rational(1, 1), Rational(3, 4), Rational(5, 6),evaluate=False) + assert pretty(expr) == "(x + y)*1*3/4*5/6" + assert upretty(expr) == "(x + y)⋅1⋅3/4⋅5/6" + expr = Mul(Rational(3, 4), x + y, S.One, y - z, evaluate=False) + assert pretty(expr) == "3/4*(x + y)*1*(y - z)" + assert upretty(expr) == "3/4⋅(x + y)⋅1⋅(y - z)" + + +def test_issue_5524(): + assert pretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 / ___ \\\n\ +- (5 - y) + (x - 5)*\\-x - 2*\\/ 2 + 5/\ +""" + + assert upretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 \n\ +- (5 - y) + (x - 5)⋅(-x - 2⋅√2 + 5)\ +""" + + +def test_pretty_ordering(): + assert pretty(x**2 + x + 1, order='lex') == \ +"""\ + 2 \n\ +x + x + 1\ +""" + assert pretty(x**2 + x + 1, order='rev-lex') == \ +"""\ + 2\n\ +1 + x + x \ +""" + assert pretty(1 - x, order='lex') == '-x + 1' + assert pretty(1 - x, order='rev-lex') == '1 - x' + + assert pretty(1 - 2*x, order='lex') == '-2*x + 1' + assert pretty(1 - 2*x, order='rev-lex') == '1 - 2*x' + + f = 2*x**4 + y**2 - x**2 + y**3 + assert pretty(f, order=None) == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='lex') == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='rev-lex') == \ +"""\ + 2 3 2 4\n\ +y + y - x + 2*x \ +""" + + expr = x - x**3/6 + x**5/120 + O(x**6) + ascii_str = \ +"""\ + 3 5 \n\ + x x / 6\\\n\ +x - -- + --- + O\\x /\n\ + 6 120 \ +""" + ucode_str = \ +"""\ + 3 5 \n\ + x x ⎛ 6⎞\n\ +x - ── + ─── + O⎝x ⎠\n\ + 6 120 \ +""" + assert pretty(expr, order=None) == ascii_str + assert upretty(expr, order=None) == ucode_str + + assert pretty(expr, order='lex') == ascii_str + assert upretty(expr, order='lex') == ucode_str + + assert pretty(expr, order='rev-lex') == ascii_str + assert upretty(expr, order='rev-lex') == ucode_str + + +def test_EulerGamma(): + assert pretty(EulerGamma) == str(EulerGamma) == "EulerGamma" + assert upretty(EulerGamma) == "γ" + + +def test_GoldenRatio(): + assert pretty(GoldenRatio) == str(GoldenRatio) == "GoldenRatio" + assert upretty(GoldenRatio) == "φ" + + +def test_Catalan(): + assert pretty(Catalan) == upretty(Catalan) == "G" + + +def test_pretty_relational(): + expr = Eq(x, y) + ascii_str = \ +"""\ +x = y\ +""" + ucode_str = \ +"""\ +x = y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lt(x, y) + ascii_str = \ +"""\ +x < y\ +""" + ucode_str = \ +"""\ +x < y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Gt(x, y) + ascii_str = \ +"""\ +x > y\ +""" + ucode_str = \ +"""\ +x > y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Le(x, y) + ascii_str = \ +"""\ +x <= y\ +""" + ucode_str = \ +"""\ +x ≤ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ge(x, y) + ascii_str = \ +"""\ +x >= y\ +""" + ucode_str = \ +"""\ +x ≥ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ne(x/(y + 1), y**2) + ascii_str_1 = \ +"""\ + x 2\n\ +----- != y \n\ +1 + y \ +""" + ascii_str_2 = \ +"""\ + x 2\n\ +----- != y \n\ +y + 1 \ +""" + ucode_str_1 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +1 + y \ +""" + ucode_str_2 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +y + 1 \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + +def test_Assignment(): + expr = Assignment(x, y) + ascii_str = \ +"""\ +x := y\ +""" + ucode_str = \ +"""\ +x := y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_AugmentedAssignment(): + expr = AddAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x += y\ +""" + ucode_str = \ +"""\ +x += y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = SubAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x -= y\ +""" + ucode_str = \ +"""\ +x -= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = MulAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x *= y\ +""" + ucode_str = \ +"""\ +x *= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = DivAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x /= y\ +""" + ucode_str = \ +"""\ +x /= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ModAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x %= y\ +""" + ucode_str = \ +"""\ +x %= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_rational(): + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y**Rational(3, 2) * x**Rational(-5, 2) + ascii_str = \ +"""\ + 3/2\n\ +y \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 3/2\n\ +y \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**3/tan(x)**2 + ascii_str = \ +"""\ + 3 \n\ +sin (x)\n\ +-------\n\ + 2 \n\ +tan (x)\ +""" + ucode_str = \ +"""\ + 3 \n\ +sin (x)\n\ +───────\n\ + 2 \n\ +tan (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +@_both_exp_pow +def test_pretty_functions(): + """Tests for Abs, conjugate, exp, function braces, and factorial.""" + expr = (2*x + exp(x)) + ascii_str_1 = \ +"""\ + x\n\ +2*x + e \ +""" + ascii_str_2 = \ +"""\ + x \n\ +e + 2*x\ +""" + ucode_str_1 = \ +"""\ + x\n\ +2⋅x + ℯ \ +""" + ucode_str_2 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + ucode_str_3 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Abs(x) + ascii_str = \ +"""\ +|x|\ +""" + ucode_str = \ +"""\ +│x│\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Abs(x/(x**2 + 1)) + ascii_str_1 = \ +"""\ +| x |\n\ +|------|\n\ +| 2|\n\ +|1 + x |\ +""" + ascii_str_2 = \ +"""\ +| x |\n\ +|------|\n\ +| 2 |\n\ +|x + 1|\ +""" + ucode_str_1 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2│\n\ +│1 + x │\ +""" + ucode_str_2 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2 │\n\ +│x + 1│\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Abs(1 / (y - Abs(x))) + ascii_str = \ +"""\ + 1 \n\ +---------\n\ +|y - |x||\ +""" + ucode_str = \ +"""\ + 1 \n\ +─────────\n\ +│y - │x││\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial(n) + ascii_str = \ +"""\ +n!\ +""" + ucode_str = \ +"""\ +n!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(2*n) + ascii_str = \ +"""\ +(2*n)!\ +""" + ucode_str = \ +"""\ +(2⋅n)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(factorial(factorial(n))) + ascii_str = \ +"""\ +((n!)!)!\ +""" + ucode_str = \ +"""\ +((n!)!)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = subfactorial(n) + ascii_str = \ +"""\ +!n\ +""" + ucode_str = \ +"""\ +!n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = subfactorial(2*n) + ascii_str = \ +"""\ +!(2*n)\ +""" + ucode_str = \ +"""\ +!(2⋅n)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial2(n) + ascii_str = \ +"""\ +n!!\ +""" + ucode_str = \ +"""\ +n!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(2*n) + ascii_str = \ +"""\ +(2*n)!!\ +""" + ucode_str = \ +"""\ +(2⋅n)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(factorial2(factorial2(n))) + ascii_str = \ +"""\ +((n!!)!!)!!\ +""" + ucode_str = \ +"""\ +((n!!)!!)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 2*binomial(n, k) + ascii_str = \ +"""\ + /n\\\n\ +2*| |\n\ + \\k/\ +""" + ucode_str = \ +"""\ + ⎛n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝k⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(2*n, k) + ascii_str = \ +"""\ + /2*n\\\n\ +2*| |\n\ + \\ k /\ +""" + ucode_str = \ +"""\ + ⎛2⋅n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝ k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(n**2, k) + ascii_str = \ +"""\ + / 2\\\n\ + |n |\n\ +2*| |\n\ + \\k /\ +""" + ucode_str = \ +"""\ + ⎛ 2⎞\n\ + ⎜n ⎟\n\ +2⋅⎜ ⎟\n\ + ⎝k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bell(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n, x) + ascii_str = \ +"""\ +B (x)\n\ + n \ +""" + ucode_str = \ +"""\ +B (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = fibonacci(n) + ascii_str = \ +"""\ +F \n\ + n\ +""" + ucode_str = \ +"""\ +F \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = lucas(n) + ascii_str = \ +"""\ +L \n\ + n\ +""" + ucode_str = \ +"""\ +L \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = tribonacci(n) + ascii_str = \ +"""\ +T \n\ + n\ +""" + ucode_str = \ +"""\ +T \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n) + ascii_str = \ +"""\ +stieltjes \n\ + n\ +""" + ucode_str = \ +"""\ +γ \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n, x) + ascii_str = \ +"""\ +stieltjes (x)\n\ + n \ +""" + ucode_str = \ +"""\ +γ (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieuc(x, y, z) + ascii_str = 'C(x, y, z)' + ucode_str = 'C(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieus(x, y, z) + ascii_str = 'S(x, y, z)' + ucode_str = 'S(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieucprime(x, y, z) + ascii_str = "C'(x, y, z)" + ucode_str = "C'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieusprime(x, y, z) + ascii_str = "S'(x, y, z)" + ucode_str = "S'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(x) + ascii_str = \ +"""\ +_\n\ +x\ +""" + ucode_str = \ +"""\ +_\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + f = Function('f') + expr = conjugate(f(x + 1)) + ascii_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ascii_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + ucode_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ucode_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x) + ascii_str = \ +"""\ +f(x)\ +""" + ucode_str = \ +"""\ +f(x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x, y) + ascii_str = \ +"""\ +f(x, y)\ +""" + ucode_str = \ +"""\ +f(x, y)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x**x**x**x**x**x) + ascii_str = \ +"""\ + / / / / / x\\\\\\\\\\ + | | | | \\x /|||| + | | | \\x /||| + | | \\x /|| + | \\x /| +f\\x /\ +""" + ucode_str = \ +"""\ + ⎛ ⎛ ⎛ ⎛ ⎛ x⎞⎞⎞⎞⎞ + ⎜ ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟⎟ + ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟ + ⎜ ⎜ ⎝x ⎠⎟⎟ + ⎜ ⎝x ⎠⎟ +f⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**2 + ascii_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + ucode_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(a + b*I) + ascii_str = \ +"""\ +_ _\n\ +a - I*b\ +""" + ucode_str = \ +"""\ +_ _\n\ +a - ⅈ⋅b\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(exp(a + b*I)) + ascii_str = \ +"""\ + _ _\n\ + a - I*b\n\ +e \ +""" + ucode_str = \ +"""\ + _ _\n\ + a - ⅈ⋅b\n\ +ℯ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate( f(1 + conjugate(f(x))) ) + ascii_str_1 = \ +"""\ +___________\n\ + / ____\\\n\ +f\\1 + f(x)/\ +""" + ascii_str_2 = \ +"""\ +___________\n\ + /____ \\\n\ +f\\f(x) + 1/\ +""" + ucode_str_1 = \ +"""\ +___________\n\ + ⎛ ____⎞\n\ +f⎝1 + f(x)⎠\ +""" + ucode_str_2 = \ +"""\ +___________\n\ + ⎛____ ⎞\n\ +f⎝f(x) + 1⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = floor(1 / (y - floor(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +floor|------------|\n\ + \\y - floor(x)/\ +""" + ucode_str = \ +"""\ +⎢ 1 ⎥\n\ +⎢───────⎥\n\ +⎣y - ⌊x⌋⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ceiling(1 / (y - ceiling(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +ceiling|--------------|\n\ + \\y - ceiling(x)/\ +""" + ucode_str = \ +"""\ +⎡ 1 ⎤\n\ +⎢───────⎥\n\ +⎢y - ⌈x⌉⎥\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n) + ascii_str = \ +"""\ +E \n\ + n\ +""" + ucode_str = \ +"""\ +E \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(1/(1 + 1/(1 + 1/n))) + ascii_str = \ +"""\ +E \n\ + 1 \n\ + ---------\n\ + 1 \n\ + 1 + -----\n\ + 1\n\ + 1 + -\n\ + n\ +""" + + ucode_str = \ +"""\ +E \n\ + 1 \n\ + ─────────\n\ + 1 \n\ + 1 + ─────\n\ + 1\n\ + 1 + ─\n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x) + ascii_str = \ +"""\ +E (x)\n\ + n \ +""" + ucode_str = \ +"""\ +E (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x/2) + ascii_str = \ +"""\ + /x\\\n\ +E |-|\n\ + n\\2/\ +""" + ucode_str = \ +"""\ + ⎛x⎞\n\ +E ⎜─⎟\n\ + n⎝2⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt(): + expr = sqrt(2) + ascii_str = \ +"""\ + ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"√2" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 3) + ascii_str = \ +"""\ +3 ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +3 ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 1000) + ascii_str = \ +"""\ +1000___\n\ + \\/ 2 \ +""" + ucode_str = \ +"""\ +1000___\n\ + ╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(x**2 + 1) + ascii_str = \ +"""\ + ________\n\ + / 2 \n\ +\\/ x + 1 \ +""" + ucode_str = \ +"""\ + ________\n\ + ╱ 2 \n\ +╲╱ x + 1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1 + sqrt(5))**Rational(1, 3) + ascii_str = \ +"""\ + ___________\n\ +3 / ___ \n\ +\\/ 1 + \\/ 5 \ +""" + ucode_str = \ +"""\ +3 ________\n\ +╲╱ 1 + √5 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**(1/x) + ascii_str = \ +"""\ +x ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +x ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(2 + pi) + ascii_str = \ +"""\ + ________\n\ +\\/ 2 + pi \ +""" + ucode_str = \ +"""\ + _______\n\ +╲╱ 2 + π \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (2 + ( + 1 + x**2)/(2 + x))**Rational(1, 4) + (1 + x**Rational(1, 1000))/sqrt(3 + x**2) + ascii_str = \ +"""\ + ____________ \n\ + / 2 1000___ \n\ + / x + 1 \\/ x + 1\n\ +4 / 2 + ------ + -----------\n\ +\\/ x + 2 ________\n\ + / 2 \n\ + \\/ x + 3 \ +""" + ucode_str = \ +"""\ + ____________ \n\ + ╱ 2 1000___ \n\ + ╱ x + 1 ╲╱ x + 1\n\ +4 ╱ 2 + ────── + ───────────\n\ +╲╱ x + 2 ________\n\ + ╱ 2 \n\ + ╲╱ x + 3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt_char_knob(): + # See PR #9234. + expr = sqrt(2) + ucode_str1 = \ +"""\ + ___\n\ +╲╱ 2 \ +""" + ucode_str2 = \ +"√2" + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=False) == ucode_str1 + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=True) == ucode_str2 + + +def test_pretty_sqrt_longsymbol_no_sqrt_char(): + # Do not use unicode sqrt char for long symbols (see PR #9234). + expr = sqrt(Symbol('C1')) + ucode_str = \ +"""\ + ____\n\ +╲╱ C₁ \ +""" + assert upretty(expr) == ucode_str + + +def test_pretty_KroneckerDelta(): + x, y = symbols("x, y") + expr = KroneckerDelta(x, y) + ascii_str = \ +"""\ +d \n\ + x,y\ +""" + ucode_str = \ +"""\ +δ \n\ + x,y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_product(): + n, m, k, l = symbols('n m k l') + f = symbols('f', cls=Function) + expr = Product(f((n/3)**2), (n, k**2, l)) + + unicode_str = \ +"""\ + l \n\ +─┬──────┬─ \n\ + │ │ ⎛ 2⎞\n\ + │ │ ⎜n ⎟\n\ + │ │ f⎜──⎟\n\ + │ │ ⎝9 ⎠\n\ + │ │ \n\ + 2 \n\ + n = k """ + ascii_str = \ +"""\ + l \n\ +__________ \n\ + | | / 2\\\n\ + | | |n |\n\ + | | f|--|\n\ + | | \\9 /\n\ + | | \n\ + 2 \n\ + n = k """ + + expr = Product(f((n/3)**2), (n, k**2, l), (l, 1, m)) + + unicode_str = \ +"""\ + m l \n\ +─┬──────┬─ ─┬──────┬─ \n\ + │ │ │ │ ⎛ 2⎞\n\ + │ │ │ │ ⎜n ⎟\n\ + │ │ │ │ f⎜──⎟\n\ + │ │ │ │ ⎝9 ⎠\n\ + │ │ │ │ \n\ + l = 1 2 \n\ + n = k """ + ascii_str = \ +"""\ + m l \n\ +__________ __________ \n\ + | | | | / 2\\\n\ + | | | | |n |\n\ + | | | | f|--|\n\ + | | | | \\9 /\n\ + | | | | \n\ + l = 1 2 \n\ + n = k """ + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_pretty_Lambda(): + # S.IdentityFunction is a special case + expr = Lambda(y, y) + assert pretty(expr) == "x -> x" + assert upretty(expr) == "x ↦ x" + + expr = Lambda(x, x+1) + assert pretty(expr) == "x -> x + 1" + assert upretty(expr) == "x ↦ x + 1" + + expr = Lambda(x, x**2) + ascii_str = \ +"""\ + 2\n\ +x -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +x ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(x, x**2)**2 + ascii_str = \ +"""\ + 2 +/ 2\\ \n\ +\\x -> x / \ +""" + ucode_str = \ +"""\ + 2 +⎛ 2⎞ \n\ +⎝x ↦ x ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x) + ascii_str = "(x, y) -> x" + ucode_str = "(x, y) ↦ x" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x**2) + ascii_str = \ +"""\ + 2\n\ +(x, y) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +(x, y) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(((x, y),), x**2) + ascii_str = \ +"""\ + 2\n\ +((x, y),) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +((x, y),) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_TransferFunction(): + tf1 = TransferFunction(s - 1, s + 1, s) + assert upretty(tf1) == "s - 1\n─────\ns + 1" + tf2 = TransferFunction(2*s + 1, 3 - p, s) + assert upretty(tf2) == "2⋅s + 1\n───────\n 3 - p " + tf3 = TransferFunction(p, p + 1, p) + assert upretty(tf3) == " p \n─────\np + 1" + + +def test_pretty_Series(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(2, 3, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm2 = TransferFunctionMatrix([[tf3], [-tf4]]) + tfm3 = TransferFunctionMatrix([[tf1, -tf2, -tf3], [tf3, -tf4, tf2]]) + tfm4 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm5 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + + expected1 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝-x + y⎠\ +""" + expected2 = \ +"""\ +⎛-x + y⎞ ⎛-x - y ⎞\n\ +⎜──────⎟⋅⎜───────⎟\n\ +⎝x + y ⎠ ⎝x - 2⋅y⎠\ +""" + expected3 = \ +"""\ +⎛ 2 ⎞ \n\ +⎜x + y⎟ ⎛ x + y ⎞ ⎛-x - y x - y⎞\n\ +⎜──────⎟⋅⎜───────⎟⋅⎜─────── + ─────⎟\n\ +⎝-x + y⎠ ⎝x - 2⋅y⎠ ⎝x - 2⋅y x + y⎠\ +""" + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y x - y⎞ ⎜x - y x + y⎟\n\ +⎜─────── + ─────⎟⋅⎜───── + ──────⎟\n\ +⎝x - 2⋅y x + y⎠ ⎝x + y -x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ \n\ +⎢─────── ─────⎥ ⎢x + y⎥ \n\ +⎢x - 2⋅y x + y⎥ ⎢──────⎥ \n\ +⎢ ⎥ ⎢-x + y⎥ \n\ +⎢ 2 ⎥ ⋅⎢ ⎥ \n\ +⎢x + y 2 ⎥ ⎢ -2 ⎥ \n\ +⎢────── ─ ⎥ ⎢ ─── ⎥ \n\ +⎣-x + y 3 ⎦τ ⎣ 3 ⎦τ\ +""" + expected6 = \ +"""\ + ⎛⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞\n\ + ⎜⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟\n\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ ⎜⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟\n\ +⎢─────── ─────⎥ ⎢ x + y -x + y - x - y⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎢x - 2⋅y x + y⎥ ⎢─────── ────── ────────⎥ ⎜⎢ 2 ⎥ ⎢ 2 ⎥ ⎟\n\ +⎢ ⎥ ⎢x - 2⋅y x + y -x + y ⎥ ⎜⎢x + y -2 ⎥ ⎢ -2 x + y ⎥ ⎟\n\ +⎢ 2 ⎥ ⋅⎢ ⎥ ⋅⎜⎢────── ─── ⎥ + ⎢ ─── ────── ⎥ ⎟\n\ +⎢x + y 2 ⎥ ⎢ 2 ⎥ ⎜⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎟\n\ +⎢────── ─ ⎥ ⎢x + y -2 x - y ⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎣-x + y 3 ⎦τ ⎢────── ─── ───── ⎥ ⎜⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎟\n\ + ⎣-x + y 3 x + y ⎦τ ⎜⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎟\n\ + ⎝⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠\ +""" + + assert upretty(Series(tf1, tf3)) == expected1 + assert upretty(Series(-tf2, -tf1)) == expected2 + assert upretty(Series(tf3, tf1, Parallel(-tf1, tf2))) == expected3 + assert upretty(Series(Parallel(tf1, tf2), Parallel(tf2, tf3))) == expected4 + assert upretty(MIMOSeries(tfm2, tfm1)) == expected5 + assert upretty(MIMOSeries(MIMOParallel(tfm4, -tfm5), tfm3, tfm1)) == expected6 + + +def test_pretty_Parallel(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(y**2 - x, x**3 + x, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm2 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + tfm3 = TransferFunctionMatrix([[-tf1, tf2], [-tf3, tf4], [tf2, tf1]]) + tfm4 = TransferFunctionMatrix([[-tf1, -tf2], [-tf3, -tf4]]) + + expected1 = \ +"""\ + x + y x - y\n\ +─────── + ─────\n\ +x - 2⋅y x + y\ +""" + expected2 = \ +"""\ +-x + y -x - y \n\ +────── + ─────── +x + y x - 2⋅y\ +""" + expected3 = \ +"""\ + 2 \n\ +x + y x + y ⎛-x - y ⎞ ⎛x - y⎞ +────── + ─────── + ⎜───────⎟⋅⎜─────⎟ +-x + y x - 2⋅y ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎛x - y⎞ ⎛x - y⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜─────⎟ + ⎜─────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x + y⎠ ⎝-x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y -x + y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x - y ⎤ \n\ +⎢─────── ────── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢ 2 2 ⎥ ⎢ 2 2 ⎥ \n\ +⎢x + y x - y ⎥ ⎢x - y x + y ⎥ ⎢x + y x - y ⎥ \n\ +⎢────── ────── ⎥ + ⎢────── ────── ⎥ + ⎢────── ────── ⎥ \n\ +⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎢-x + y 3 ⎥ \n\ +⎢ x + x ⎥ ⎢x + x ⎥ ⎢ x + x ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎢-x + y -x - y ⎥ \n\ +⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎢────── ───────⎥ \n\ +⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣x + y x - 2⋅y⎦τ\ +""" + expected6 = \ +"""\ +⎡ x - y x + y ⎤ ⎡-x + y -x - y ⎤ \n\ +⎢ ───── ───────⎥ ⎢────── ─────── ⎥ \n\ +⎢ x + y x - 2⋅y⎥ ⎡-x - y -x + y⎤ ⎢x + y x - 2⋅y ⎥ \n\ +⎢ ⎥ ⎢─────── ──────⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢x - 2⋅y x + y ⎥ ⎢ 2 2 ⎥ \n\ +⎢x - y x + y ⎥ ⎢ ⎥ ⎢-x + y - x - y⎥ \n\ +⎢────── ────── ⎥ ⋅⎢ 2 2⎥ + ⎢─────── ────────⎥ \n\ +⎢ 3 -x + y ⎥ ⎢- x - y x - y ⎥ ⎢ 3 -x + y ⎥ \n\ +⎢x + x ⎥ ⎢──────── ──────⎥ ⎢x + x ⎥ \n\ +⎢ ⎥ ⎢ -x + y 3 ⎥ ⎢ ⎥ \n\ +⎢-x - y -x + y ⎥ ⎣ x + x⎦τ ⎢ x + y x - y ⎥ \n\ +⎢─────── ────── ⎥ ⎢─────── ───── ⎥ \n\ +⎣x - 2⋅y x + y ⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + assert upretty(Parallel(tf1, tf2)) == expected1 + assert upretty(Parallel(-tf2, -tf1)) == expected2 + assert upretty(Parallel(tf3, tf1, Series(-tf1, tf2))) == expected3 + assert upretty(Parallel(Series(tf1, tf2), Series(tf2, tf3))) == expected4 + assert upretty(MIMOParallel(-tfm3, -tfm2, tfm1)) == expected5 + assert upretty(MIMOParallel(MIMOSeries(tfm4, -tfm2), tfm2)) == expected6 + + +def test_pretty_Feedback(): + tf = TransferFunction(1, 1, y) + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(x - 2*y**3, x + y, x) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, x) + expected1 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +─────────────\n\ +1 ⎛ x + y ⎞\n\ +─ + ⎜───────⎟\n\ +1 ⎝x - 2⋅y⎠\ +""" + expected2 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +────────────────────────────────────\n\ + ⎛ 2 ⎞\n\ +1 ⎛x - y⎞ ⎛ x + y ⎞ ⎜y - 2⋅y + 1⎟\n\ +─ + ⎜─────⎟⋅⎜───────⎟⋅⎜────────────⎟\n\ +1 ⎝x + y⎠ ⎝x - 2⋅y⎠ ⎝ y + 5 ⎠\ +""" + expected3 = \ +"""\ + ⎛ x + y ⎞ \n\ + ⎜───────⎟ \n\ + ⎝x - 2⋅y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜────────────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝ y + 5 ⎠ ⎝x - y⎠\ +""" + expected4 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + expected5 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected6 = \ +"""\ + ⎛ 2 ⎞ \n\ + ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ \n\ + ⎜────────────⎟⋅⎜─────⎟ \n\ + ⎝ y + 5 ⎠ ⎝x - y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ ⎛x - y⎞ ⎛ x + y ⎞\n\ +─ + ⎜────────────⎟⋅⎜─────⎟⋅⎜─────⎟⋅⎜───────⎟\n\ +1 ⎝ y + 5 ⎠ ⎝x - y⎠ ⎝x + y⎠ ⎝x - 2⋅y⎠\ +""" + expected7 = \ +"""\ + ⎛ 3⎞ \n\ + ⎜x - 2⋅y ⎟ \n\ + ⎜────────⎟ \n\ + ⎝ x + y ⎠ \n\ +──────────────────\n\ + ⎛ 3⎞ \n\ +1 ⎜x - 2⋅y ⎟ ⎛2⎞\n\ +─ + ⎜────────⎟⋅⎜─⎟\n\ +1 ⎝ x + y ⎠ ⎝2⎠\ +""" + expected8 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ + ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + expected9 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ - ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected10 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ - ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + assert upretty(Feedback(tf, tf1)) == expected1 + assert upretty(Feedback(tf, tf2*tf1*tf3)) == expected2 + assert upretty(Feedback(tf1, tf2*tf3*tf5)) == expected3 + assert upretty(Feedback(tf1*tf2, tf)) == expected4 + assert upretty(Feedback(tf1*tf2, tf5)) == expected5 + assert upretty(Feedback(tf3*tf5, tf2*tf1)) == expected6 + assert upretty(Feedback(tf4, tf6)) == expected7 + assert upretty(Feedback(tf5, tf)) == expected8 + + assert upretty(Feedback(tf1*tf2, tf5, 1)) == expected9 + assert upretty(Feedback(tf5, tf, 1)) == expected10 + + +def test_pretty_MIMOFeedback(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_3 = TransferFunctionMatrix([[tf1, tf1], [tf2, tf2]]) + + expected1 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟ ⎢─────── ───── ⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ \n\ +⎜I - ⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎟ ⎢ ───── ───────⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ\ +""" + expected2 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───────⎥ ⎟ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ \n\ +⎜I + ⎢ ⎥ ⋅⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ ⋅⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎢ x - y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎢ ───── ───── ⎥ ⎟ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣ x + y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + + assert upretty(MIMOFeedback(tfm_1, tfm_2, 1)) == \ + expected1 # Positive MIMOFeedback + assert upretty(MIMOFeedback(tfm_1*tfm_2, tfm_3)) == \ + expected2 # Negative MIMOFeedback (Default) + + +def test_pretty_TransferFunctionMatrix(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(y, x**2 + x + 1, y) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, y) + expected1 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢───────⎥ \n\ +⎢x - 2⋅y⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎣ x + y ⎦τ\ +""" + expected2 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢ ─────── ⎥ \n\ +⎢ x - 2⋅y ⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎢ x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1⎥ \n\ +⎢──────────────⎥ \n\ +⎣ y + 5 ⎦τ\ +""" + expected3 = \ +"""\ +⎡ x + y x - y ⎤ \n\ +⎢ ─────── ───── ⎥ \n\ +⎢ x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢y - 2⋅y + 1 y ⎥ \n\ +⎢──────────── ──────────⎥ \n\ +⎢ y + 5 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 ⎥ \n\ +⎢ ───── ─ ⎥ \n\ +⎣ x - y 2 ⎦τ\ +""" + expected4 = \ +"""\ +⎡ x - y x + y y ⎤ \n\ +⎢ ───── ─────── ──────────⎥ \n\ +⎢ x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1 x - 1 -2 ⎥ \n\ +⎢────────────── ───── ─── ⎥ \n\ +⎣ y + 5 x - y 2 ⎦τ\ +""" + expected5 = \ +"""\ +⎡ x + y x - y x + y y ⎤ \n\ +⎢───────⋅───── ─────── ──────────⎥ \n\ +⎢x - 2⋅y x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 x + y -2 ⎥ \n\ +⎢ ───── + ─ ─────── ─── ⎥ \n\ +⎣ x - y 2 x - 2⋅y 2 ⎦τ\ +""" + + assert upretty(TransferFunctionMatrix([[tf1], [tf2]])) == expected1 + assert upretty(TransferFunctionMatrix([[tf1], [tf2], [-tf3]])) == expected2 + assert upretty(TransferFunctionMatrix([[tf1, tf2], [tf3, tf4], [tf5, tf6]])) == expected3 + assert upretty(TransferFunctionMatrix([[tf2, tf1, tf4], [-tf3, -tf5, -tf6]])) == expected4 + assert upretty(TransferFunctionMatrix([[Series(tf2, tf1), tf1, tf4], [Parallel(tf6, tf5), tf1, -tf6]])) == \ + expected5 + + +def test_pretty_StateSpace(): + ss1 = StateSpace(Matrix([a]), Matrix([b]), Matrix([c]), Matrix([d])) + A = Matrix([[0, 1], [1, 0]]) + B = Matrix([1, 0]) + C = Matrix([[0, 1]]) + D = Matrix([0]) + ss2 = StateSpace(A, B, C, D) + ss3 = StateSpace(Matrix([[-1.5, -2], [1, 0]]), + Matrix([[0.5, 0], [0, 1]]), + Matrix([[0, 1], [0, 2]]), + Matrix([[2, 2], [1, 1]])) + + expected1 = \ +"""\ +⎡[a] [b]⎤\n\ +⎢ ⎥\n\ +⎣[c] [d]⎦\ +""" + expected2 = \ +"""\ +⎡⎡0 1⎤ ⎡1⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣1 0⎦ ⎣0⎦⎥\n\ +⎢ ⎥\n\ +⎣[0 1] [0]⎦\ +""" + expected3 = \ +"""\ +⎡⎡-1.5 -2⎤ ⎡0.5 0⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣ 1 0 ⎦ ⎣ 0 1⎦⎥\n\ +⎢ ⎥\n\ +⎢ ⎡0 1⎤ ⎡2 2⎤ ⎥\n\ +⎢ ⎢ ⎥ ⎢ ⎥ ⎥\n\ +⎣ ⎣0 2⎦ ⎣1 1⎦ ⎦\ +""" + + assert upretty(ss1) == expected1 + assert upretty(ss2) == expected2 + assert upretty(ss3) == expected3 + +def test_pretty_order(): + expr = O(1) + ascii_str = \ +"""\ +O(1)\ +""" + ucode_str = \ +"""\ +O(1)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x) + ascii_str = \ +"""\ + /1\\\n\ +O|-|\n\ + \\x/\ +""" + ucode_str = \ +"""\ + ⎛1⎞\n\ +O⎜─⎟\n\ + ⎝x⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (0, 0)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (0, 0)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1, (x, oo)) + ascii_str = \ +"""\ +O(1; x -> oo)\ +""" + ucode_str = \ +"""\ +O(1; x → ∞)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x, (x, oo)) + ascii_str = \ +"""\ + /1 \\\n\ +O|-; x -> oo|\n\ + \\x /\ +""" + ucode_str = \ +"""\ + ⎛1 ⎞\n\ +O⎜─; x → ∞⎟\n\ + ⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2, (x, oo), (y, oo)) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (oo, oo)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (∞, ∞)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_derivatives(): + # Simple + expr = Derivative(log(x), x, evaluate=False) + ascii_str = \ +"""\ +d \n\ +--(log(x))\n\ +dx \ +""" + ucode_str = \ +"""\ +d \n\ +──(log(x))\n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(log(x), x, evaluate=False) + x + ascii_str_1 = \ +"""\ + d \n\ +x + --(log(x))\n\ + dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(log(x)) + x\n\ +dx \ +""" + ucode_str_1 = \ +"""\ + d \n\ +x + ──(log(x))\n\ + dx \ +""" + ucode_str_2 = \ +"""\ +d \n\ +──(log(x)) + x\n\ +dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + # basic partial derivatives + expr = Derivative(log(x + y) + x, x) + ascii_str_1 = \ +"""\ +d \n\ +--(log(x + y) + x)\n\ +dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(x + log(x + y))\n\ +dx \ +""" + ucode_str_1 = \ +"""\ +∂ \n\ +──(log(x + y) + x)\n\ +∂x \ +""" + ucode_str_2 = \ +"""\ +∂ \n\ +──(x + log(x + y))\n\ +∂x \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2], upretty(expr) + + # Multiple symbols + expr = Derivative(log(x) + x**2, x, y) + ascii_str_1 = \ +"""\ + 2 \n\ + d / 2\\\n\ +-----\\log(x) + x /\n\ +dy dx \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + d ⎛ 2⎞\n\ +─────⎝log(x) + x ⎠\n\ +dy dx \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, y, x) + x**2 + ascii_str_1 = \ +"""\ + 2 \n\ + d 2\n\ +-----(2*x*y) + x \n\ +dx dy \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + ∂ 2\n\ +─────(2⋅x⋅y) + x \n\ +∂x ∂y \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, x, x) + ascii_str = \ +"""\ + 2 \n\ +d \n\ +---(2*x*y)\n\ + 2 \n\ +dx \ +""" + ucode_str = \ +"""\ + 2 \n\ +∂ \n\ +───(2⋅x⋅y)\n\ + 2 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, 17) + ascii_str = \ +"""\ + 17 \n\ +d \n\ +----(2*x*y)\n\ + 17 \n\ +dx \ +""" + ucode_str = \ +"""\ + 17 \n\ +∂ \n\ +────(2⋅x⋅y)\n\ + 17 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, x, y) + ascii_str = \ +"""\ + 3 \n\ + d \n\ +------(2*x*y)\n\ + 2 \n\ +dy dx \ +""" + ucode_str = \ +"""\ + 3 \n\ + ∂ \n\ +──────(2⋅x⋅y)\n\ + 2 \n\ +∂y ∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # Greek letters + alpha = Symbol('alpha') + beta = Function('beta') + expr = beta(alpha).diff(alpha) + ascii_str = \ +"""\ + d \n\ +------(beta(alpha))\n\ +dalpha \ +""" + ucode_str = \ +"""\ +d \n\ +──(β(α))\n\ +dα \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(f(x), (x, n)) + + ascii_str = \ +"""\ + n \n\ +d \n\ +---(f(x))\n\ + n \n\ +dx \ +""" + ucode_str = \ +"""\ + n \n\ +d \n\ +───(f(x))\n\ + n \n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_integrals(): + expr = Integral(log(x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | log(x) dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ log(x) dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral((sin(x))**2 / (tan(x))**2) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | sin (x) \n\ + | ------- dx\n\ + | 2 \n\ + | tan (x) \n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ sin (x) \n\ +⎮ ─────── dx\n\ +⎮ 2 \n\ +⎮ tan (x) \n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**(2**x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | / x\\ \n\ + | \\2 / \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ x⎞ \n\ +⎮ ⎝2 ⎠ \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, 1, 2)) + ascii_str = \ +"""\ + 2 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1 \ +""" + ucode_str = \ +"""\ +2 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, Rational(1, 2), 10)) + ascii_str = \ +"""\ + 10 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1/2 \ +""" + ucode_str = \ +"""\ +10 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1/2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2*y**2, x, y) + ascii_str = \ +"""\ + / / \n\ + | | \n\ + | | 2 2 \n\ + | | x *y dx dy\n\ + | | \n\ +/ / \ +""" + ucode_str = \ +"""\ +⌠ ⌠ \n\ +⎮ ⎮ 2 2 \n\ +⎮ ⎮ x ⋅y dx dy\n\ +⌡ ⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(sin(th)/cos(ph), (th, 0, pi), (ph, 0, 2*pi)) + ascii_str = \ +"""\ + 2*pi pi \n\ + / / \n\ + | | \n\ + | | sin(theta) \n\ + | | ---------- d(theta) d(phi)\n\ + | | cos(phi) \n\ + | | \n\ + / / \n\ +0 0 \ +""" + ucode_str = \ +"""\ +2⋅π π \n\ + ⌠ ⌠ \n\ + ⎮ ⎮ sin(θ) \n\ + ⎮ ⎮ ────── dθ dφ\n\ + ⎮ ⎮ cos(φ) \n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_matrix(): + # Empty Matrix + expr = Matrix() + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(2, 0, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(0, 2, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix([[x**2 + 1, 1], [y, x + y]]) + ascii_str_1 = \ +"""\ +[ 2 ] +[1 + x 1 ] +[ ] +[ y x + y]\ +""" + ascii_str_2 = \ +"""\ +[ 2 ] +[x + 1 1 ] +[ ] +[ y x + y]\ +""" + ucode_str_1 = \ +"""\ +⎡ 2 ⎤ +⎢1 + x 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + ucode_str_2 = \ +"""\ +⎡ 2 ⎤ +⎢x + 1 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + ascii_str = \ +"""\ +[x ] +[- y theta] +[y ] +[ ] +[ I*k*phi ] +[0 e 1 ]\ +""" + ucode_str = \ +"""\ +⎡x ⎤ +⎢─ y θ⎥ +⎢y ⎥ +⎢ ⎥ +⎢ ⅈ⋅k⋅φ ⎥ +⎣0 ℯ 1⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + unicode_str = \ +"""\ +⎡v̇_msc_00 0 0 ⎤ +⎢ ⎥ +⎢ 0 v̇_msc_01 0 ⎥ +⎢ ⎥ +⎣ 0 0 v̇_msc_02⎦\ +""" + + expr = diag(*MatrixSymbol('vdot_msc',1,3)) + assert upretty(expr) == unicode_str + + +def test_pretty_ndim_arrays(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert pretty(M) == "x" + assert upretty(M) == "x" + + M = ArrayType([[1/x, y], [z, w]]) + M1 = ArrayType([1/x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + ascii_str = \ +"""\ +[1 ]\n\ +[- y]\n\ +[x ]\n\ +[ ]\n\ +[z w]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y⎥\n\ +⎢x ⎥\n\ +⎢ ⎥\n\ +⎣z w⎦\ +""" + assert pretty(M) == ascii_str + assert upretty(M) == ucode_str + + ascii_str = \ +"""\ +[1 ]\n\ +[- y z]\n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y z⎥\n\ +⎣x ⎦\ +""" + assert pretty(M1) == ascii_str + assert upretty(M1) == ucode_str + + ascii_str = \ +"""\ +[[1 y] ]\n\ +[[-- -] [z ]]\n\ +[[ 2 x] [ y 2 ] [- y*z]]\n\ +[[x ] [ - y ] [x ]]\n\ +[[ ] [ x ] [ ]]\n\ +[[z w] [ ] [ 2 ]]\n\ +[[- -] [y*z w*y] [z w*z]]\n\ +[[x x] ]\ +""" + ucode_str = \ +"""\ +⎡⎡1 y⎤ ⎤\n\ +⎢⎢── ─⎥ ⎡z ⎤⎥\n\ +⎢⎢ 2 x⎥ ⎡ y 2 ⎤ ⎢─ y⋅z⎥⎥\n\ +⎢⎢x ⎥ ⎢ ─ y ⎥ ⎢x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ x ⎥ ⎢ ⎥⎥\n\ +⎢⎢z w⎥ ⎢ ⎥ ⎢ 2 ⎥⎥\n\ +⎢⎢─ ─⎥ ⎣y⋅z w⋅y⎦ ⎣z w⋅z⎦⎥\n\ +⎣⎣x x⎦ ⎦\ +""" + assert pretty(M2) == ascii_str + assert upretty(M2) == ucode_str + + ascii_str = \ +"""\ +[ [1 y] ]\n\ +[ [-- -] ]\n\ +[ [ 2 x] [ y 2 ]]\n\ +[ [x ] [ - y ]]\n\ +[ [ ] [ x ]]\n\ +[ [z w] [ ]]\n\ +[ [- -] [y*z w*y]]\n\ +[ [x x] ]\n\ +[ ]\n\ +[[z ] [ w ]]\n\ +[[- y*z] [ - w*y]]\n\ +[[x ] [ x ]]\n\ +[[ ] [ ]]\n\ +[[ 2 ] [ 2 ]]\n\ +[[z w*z] [w*z w ]]\ +""" + ucode_str = \ +"""\ +⎡ ⎡1 y⎤ ⎤\n\ +⎢ ⎢── ─⎥ ⎥\n\ +⎢ ⎢ 2 x⎥ ⎡ y 2 ⎤⎥\n\ +⎢ ⎢x ⎥ ⎢ ─ y ⎥⎥\n\ +⎢ ⎢ ⎥ ⎢ x ⎥⎥\n\ +⎢ ⎢z w⎥ ⎢ ⎥⎥\n\ +⎢ ⎢─ ─⎥ ⎣y⋅z w⋅y⎦⎥\n\ +⎢ ⎣x x⎦ ⎥\n\ +⎢ ⎥\n\ +⎢⎡z ⎤ ⎡ w ⎤⎥\n\ +⎢⎢─ y⋅z⎥ ⎢ ─ w⋅y⎥⎥\n\ +⎢⎢x ⎥ ⎢ x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎢ 2 ⎥ ⎢ 2 ⎥⎥\n\ +⎣⎣z w⋅z⎦ ⎣w⋅z w ⎦⎦\ +""" + assert pretty(M3) == ascii_str + assert upretty(M3) == ucode_str + + Mrow = ArrayType([[x, y, 1 / z]]) + Mcolumn = ArrayType([[x], [y], [1 / z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + ascii_str = \ +"""\ +[[ 1]]\n\ +[[x y -]]\n\ +[[ z]]\ +""" + ucode_str = \ +"""\ +⎡⎡ 1⎤⎤\n\ +⎢⎢x y ─⎥⎥\n\ +⎣⎣ z⎦⎦\ +""" + assert pretty(Mrow) == ascii_str + assert upretty(Mrow) == ucode_str + + ascii_str = \ +"""\ +[x]\n\ +[ ]\n\ +[y]\n\ +[ ]\n\ +[1]\n\ +[-]\n\ +[z]\ +""" + ucode_str = \ +"""\ +⎡x⎤\n\ +⎢ ⎥\n\ +⎢y⎥\n\ +⎢ ⎥\n\ +⎢1⎥\n\ +⎢─⎥\n\ +⎣z⎦\ +""" + assert pretty(Mcolumn) == ascii_str + assert upretty(Mcolumn) == ucode_str + + ascii_str = \ +"""\ +[[x]]\n\ +[[ ]]\n\ +[[y]]\n\ +[[ ]]\n\ +[[1]]\n\ +[[-]]\n\ +[[z]]\ +""" + ucode_str = \ +"""\ +⎡⎡x⎤⎤\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢y⎥⎥\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢1⎥⎥\n\ +⎢⎢─⎥⎥\n\ +⎣⎣z⎦⎦\ +""" + assert pretty(Mcol2) == ascii_str + assert upretty(Mcol2) == ucode_str + + +def test_tensor_TensorProduct(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert upretty(TensorProduct(A, B)) == "A\u2297B" + assert upretty(TensorProduct(A, B, A)) == "A\u2297B\u2297A" + + +def test_diffgeom_print_WedgeProduct(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert upretty(wp) == "ⅆ x∧ⅆ y" + assert pretty(wp) == r"d x/\d y" + + +def test_Adjoint(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Adjoint(X)) == " +\nX " + assert pretty(Adjoint(X + Y)) == " +\n(X + Y) " + assert pretty(Adjoint(X) + Adjoint(Y)) == " + +\nX + Y " + assert pretty(Adjoint(X*Y)) == " +\n(X*Y) " + assert pretty(Adjoint(Y)*Adjoint(X)) == " + +\nY *X " + assert pretty(Adjoint(X**2)) == " +\n/ 2\\ \n\\X / " + assert pretty(Adjoint(X)**2) == " 2\n/ +\\ \n\\X / " + assert pretty(Adjoint(Inverse(X))) == " +\n/ -1\\ \n\\X / " + assert pretty(Inverse(Adjoint(X))) == " -1\n/ +\\ \n\\X / " + assert pretty(Adjoint(Transpose(X))) == " +\n/ T\\ \n\\X / " + assert pretty(Transpose(Adjoint(X))) == " T\n/ +\\ \n\\X / " + assert upretty(Adjoint(X)) == " †\nX " + assert upretty(Adjoint(X + Y)) == " †\n(X + Y) " + assert upretty(Adjoint(X) + Adjoint(Y)) == " † †\nX + Y " + assert upretty(Adjoint(X*Y)) == " †\n(X⋅Y) " + assert upretty(Adjoint(Y)*Adjoint(X)) == " † †\nY ⋅X " + assert upretty(Adjoint(X**2)) == \ + " †\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Adjoint(X)**2) == \ + " 2\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Inverse(X))) == \ + " †\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Adjoint(X))) == \ + " -1\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Transpose(X))) == \ + " †\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Adjoint(X))) == \ + " T\n⎛ †⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Adjoint(m)) == \ + ' †\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Adjoint(m+X)) == \ + ' †\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' †\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_Transpose(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Transpose(X)) == " T\nX " + assert pretty(Transpose(X + Y)) == " T\n(X + Y) " + assert pretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert pretty(Transpose(X*Y)) == " T\n(X*Y) " + assert pretty(Transpose(Y)*Transpose(X)) == " T T\nY *X " + assert pretty(Transpose(X**2)) == " T\n/ 2\\ \n\\X / " + assert pretty(Transpose(X)**2) == " 2\n/ T\\ \n\\X / " + assert pretty(Transpose(Inverse(X))) == " T\n/ -1\\ \n\\X / " + assert pretty(Inverse(Transpose(X))) == " -1\n/ T\\ \n\\X / " + assert upretty(Transpose(X)) == " T\nX " + assert upretty(Transpose(X + Y)) == " T\n(X + Y) " + assert upretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert upretty(Transpose(X*Y)) == " T\n(X⋅Y) " + assert upretty(Transpose(Y)*Transpose(X)) == " T T\nY ⋅X " + assert upretty(Transpose(X**2)) == \ + " T\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Transpose(X)**2) == \ + " 2\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Inverse(X))) == \ + " T\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Transpose(X))) == \ + " -1\n⎛ T⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Transpose(m)) == \ + ' T\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Transpose(m+X)) == \ + ' T\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' T\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_pretty_Trace_issue_9044(): + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[2, 4], [6, 8]]) + ascii_str_1 = \ +"""\ + /[1 2]\\ +tr|[ ]| + \\[3 4]/\ +""" + ucode_str_1 = \ +"""\ + ⎛⎡1 2⎤⎞ +tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠\ +""" + ascii_str_2 = \ +"""\ + /[1 2]\\ /[2 4]\\ +tr|[ ]| + tr|[ ]| + \\[3 4]/ \\[6 8]/\ +""" + ucode_str_2 = \ +"""\ + ⎛⎡1 2⎤⎞ ⎛⎡2 4⎤⎞ +tr⎜⎢ ⎥⎟ + tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠ ⎝⎣6 8⎦⎠\ +""" + assert pretty(Trace(X)) == ascii_str_1 + assert upretty(Trace(X)) == ucode_str_1 + + assert pretty(Trace(X) + Trace(Y)) == ascii_str_2 + assert upretty(Trace(X) + Trace(Y)) == ucode_str_2 + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + expr = MatrixSlice(X, (None, None, None), (None, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = X[x:x + 1, y:y + 1] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1, y:y + 1]' + expr = X[x:x + 1:2, y:y + 1:2] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1:2, y:y + 1:2]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[x:, :y] + assert pretty(expr) == upretty(expr) == 'X[x:, :y]' + expr = X[x:y, z:w] + assert pretty(expr) == upretty(expr) == 'X[x:y, z:w]' + expr = X[x:y:t, w:t:x] + assert pretty(expr) == upretty(expr) == 'X[x:y:t, w:t:x]' + expr = X[x::y, t::w] + assert pretty(expr) == upretty(expr) == 'X[x::y, t::w]' + expr = X[:x:y, :t:w] + assert pretty(expr) == upretty(expr) == 'X[:x:y, :t:w]' + expr = X[::x, ::y] + assert pretty(expr) == upretty(expr) == 'X[::x, ::y]' + expr = MatrixSlice(X, (0, None, None), (0, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (None, n, None), (None, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, None), (0, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, 2), (0, n, 2)) + assert pretty(expr) == upretty(expr) == 'X[::2, ::2]' + expr = X[1:2:3, 4:5:6] + assert pretty(expr) == upretty(expr) == 'X[1:2:3, 4:5:6]' + expr = X[1:3:5, 4:6:8] + assert pretty(expr) == upretty(expr) == 'X[1:3:5, 4:6:8]' + expr = X[1:10:2] + assert pretty(expr) == upretty(expr) == 'X[1:10:2, :]' + expr = Y[:5, 1:9:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1:9:2]' + expr = Y[:5, 1:10:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1::2]' + expr = Y[5, :5:2] + assert pretty(expr) == upretty(expr) == 'Y[5:6, :5:2]' + expr = X[0:1, 0:1] + assert pretty(expr) == upretty(expr) == 'X[:1, :1]' + expr = X[0:1:2, 0:1:2] + assert pretty(expr) == upretty(expr) == 'X[:1:2, :1:2]' + expr = (Y + Z)[2:, 2:] + assert pretty(expr) == upretty(expr) == '(Y + Z)[2:, 2:]' + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert pretty(X) == upretty(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + + ascii_str = """\ + / T \\\n\ +(d -> sin(d)).\\X *X/\ +""" + ucode_str = """\ + ⎛ T ⎞\n\ +(d ↦ sin(d))˳⎝X ⋅X⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + ascii_str = """\ +/ 1\\ \n\ +|x -> -|.(n*X)\n\ +\\ x/ \ +""" + ucode_str = """\ +⎛ 1⎞ \n\ +⎜x ↦ ─⎟˳(n⋅X)\n\ +⎝ x⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_dotproduct(): + from sympy.matrices.expressions.dotproduct import DotProduct + n = symbols("n", integer=True) + A = MatrixSymbol('A', n, 1) + B = MatrixSymbol('B', n, 1) + C = Matrix(1, 3, [1, 2, 3]) + D = Matrix(1, 3, [1, 3, 4]) + + assert pretty(DotProduct(A, B)) == "A*B" + assert pretty(DotProduct(C, D)) == "[1 2 3]*[1 3 4]" + assert upretty(DotProduct(A, B)) == "A⋅B" + assert upretty(DotProduct(C, D)) == "[1 2 3]⋅[1 3 4]" + + +def test_pretty_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert upretty(Determinant(m)) == '│1 2│\n│ │\n│3 4│' + assert upretty(Determinant(Inverse(m))) == \ + '│ -1│\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ │\n'\ + '│⎣3 4⎦ │' + X = MatrixSymbol('X', 2, 2) + assert upretty(Determinant(X)) == '│X│' + assert upretty(Determinant(X + m)) == \ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ + X│\n'\ + '│⎣3 4⎦ │' + assert upretty(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '│ 𝟙 X│\n'\ + '│ │\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ 𝟘│\n'\ + '│⎣3 4⎦ │' + + +def test_pretty_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ +/x for x < 1\n\ +| \n\ +< 2 \n\ +|x otherwise\n\ +\\ \ +""" + ucode_str = \ +"""\ +⎧x for x < 1\n\ +⎪ \n\ +⎨ 2 \n\ +⎪x otherwise\n\ +⎩ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ + //x for x < 1\\\n\ + || |\n\ +-|< 2 |\n\ + ||x otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x < 1⎞\n\ + ⎜⎪ ⎟\n\ +-⎜⎨ 2 ⎟\n\ + ⎜⎪x otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x + Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x + |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x + ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x - Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x - |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x - ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x*Piecewise((x, x > 0), (y, True)) + ascii_str = \ +"""\ + //x for x > 0\\\n\ +x*|< |\n\ + \\\\y otherwise/\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x > 0⎞\n\ +x⋅⎜⎨ ⎟\n\ + ⎝⎩y otherwise⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x > + 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ +//x for x > 0\\ || |\n\ +|< |*|< 2 |\n\ +\\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ +⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ +⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x + > 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ + //x for x > 0\\ || |\n\ +-|< |*|< 2 |\n\ + \\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +-⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((0, Abs(1/y) < 1), (1, Abs(y) < 1), (y*meijerg(((2, 1), + ()), ((), (1, 0)), 1/y), True)) + ascii_str = \ +"""\ +/ 1 \n\ +| 0 for --- < 1\n\ +| |y| \n\ +| \n\ +< 1 for |y| < 1\n\ +| \n\ +| __0, 2 /1, 2 | 1\\ \n\ +|y*/__ | | -| otherwise \n\ +\\ \\_|2, 2 \\ 0, 1 | y/ \ +""" + ucode_str = \ +"""\ +⎧ 1 \n\ +⎪ 0 for ─── < 1\n\ +⎪ │y│ \n\ +⎪ \n\ +⎨ 1 for │y│ < 1\n\ +⎪ \n\ +⎪ ╭─╮0, 2 ⎛1, 2 │ 1⎞ \n\ +⎪y⋅│╶┐ ⎜ │ ─⎟ otherwise \n\ +⎩ ╰─╯2, 2 ⎝ 0, 1 │ y⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # XXX: We have to use evaluate=False here because Piecewise._eval_power + # denests the power. + expr = Pow(Piecewise((x, x > 0), (y, True)), 2, evaluate=False) + ascii_str = \ +"""\ + 2\n\ +//x for x > 0\\ \n\ +|< | \n\ +\\\\y otherwise/ \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛⎧x for x > 0⎞ \n\ +⎜⎨ ⎟ \n\ +⎝⎩y otherwise⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ITE(): + expr = ITE(x, y, z) + assert pretty(expr) == ( + '/y for x \n' + '< \n' + '\\z otherwise' + ) + assert upretty(expr) == """\ +⎧y for x \n\ +⎨ \n\ +⎩z otherwise\ +""" + + +def test_pretty_seq(): + expr = () + ascii_str = \ +"""\ +()\ +""" + ucode_str = \ +"""\ +()\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [] + ascii_str = \ +"""\ +[]\ +""" + ucode_str = \ +"""\ +[]\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {} + expr_2 = {} + ascii_str = \ +"""\ +{}\ +""" + ucode_str = \ +"""\ +{}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = (1/x,) + ascii_str = \ +"""\ + 1 \n\ +(-,)\n\ + x \ +""" + ucode_str = \ +"""\ +⎛1 ⎞\n\ +⎜─,⎟\n\ +⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +[x , -, x, y, -----------]\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤\n\ +⎢ 2 1 sin (θ)⎥\n\ +⎢x , ─, x, y, ───────⎥\n\ +⎢ x 2 ⎥\n\ +⎣ cos (φ)⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x: sin(x)} + expr_2 = Dict({x: sin(x)}) + ascii_str = \ +"""\ +{x: sin(x)}\ +""" + ucode_str = \ +"""\ +{x: sin(x)}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = {1/x: 1/y, x: sin(x)**2} + expr_2 = Dict({1/x: 1/y, x: sin(x)**2}) + ascii_str = \ +"""\ + 1 1 2 \n\ +{-: -, x: sin (x)}\n\ + x y \ +""" + ucode_str = \ +"""\ +⎧1 1 2 ⎫\n\ +⎨─: ─, x: sin (x)⎬\n\ +⎩x y ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + # There used to be a bug with pretty-printing sequences of even height. + expr = [x**2] + ascii_str = \ +"""\ + 2 \n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡ 2⎤\n\ +⎣x ⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2,) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x**2: 1} + expr_2 = Dict({x**2: 1}) + ascii_str = \ +"""\ + 2 \n\ +{x : 1}\ +""" + ucode_str = \ +"""\ +⎧ 2 ⎫\n\ +⎨x : 1⎬\n\ +⎩ ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + +def test_any_object_in_sequence(): + # Cf. issue 5306 + b1 = Basic() + b2 = Basic(Basic()) + + expr = [b2, b1] + assert pretty(expr) == "[Basic(Basic()), Basic()]" + assert upretty(expr) == "[Basic(Basic()), Basic()]" + + expr = {b2, b1} + assert pretty(expr) == "{Basic(), Basic(Basic())}" + assert upretty(expr) == "{Basic(), Basic(Basic())}" + + expr = {b2: b1, b1: b2} + expr2 = Dict({b2: b1, b1: b2}) + assert pretty(expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert pretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + + +def test_print_builtin_set(): + assert pretty(set()) == 'set()' + assert upretty(set()) == 'set()' + + assert pretty(frozenset()) == 'frozenset()' + assert upretty(frozenset()) == 'frozenset()' + + s1 = {1/x, x} + s2 = frozenset(s1) + + assert pretty(s1) == \ +"""\ + 1 \n\ +{-, x} + x \ +""" + assert upretty(s1) == \ +"""\ +⎧1 ⎫ +⎨─, x⎬ +⎩x ⎭\ +""" + + assert pretty(s2) == \ +"""\ + 1 \n\ +frozenset({-, x}) + x \ +""" + assert upretty(s2) == \ +"""\ + ⎛⎧1 ⎫⎞ +frozenset⎜⎨─, x⎬⎟ + ⎝⎩x ⎭⎠\ +""" + + +def test_pretty_sets(): + s = FiniteSet + assert pretty(s(*[x*y, x**2])) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(s(*range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(s(*range(1, 13))) == "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty({x*y, x**2}) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(set(range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(set(range(1, 13))) == \ + "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty(frozenset([x*y, x**2])) == \ +"""\ + 2 \n\ +frozenset({x , x*y})\ +""" + assert pretty(frozenset(range(1, 6))) == "frozenset({1, 2, 3, 4, 5})" + assert pretty(frozenset(range(1, 13))) == \ + "frozenset({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})" + + assert pretty(Range(0, 3, 1)) == '{0, 1, 2}' + + ascii_str = '{0, 1, ..., 29}' + ucode_str = '{0, 1, …, 29}' + assert pretty(Range(0, 30, 1)) == ascii_str + assert upretty(Range(0, 30, 1)) == ucode_str + + ascii_str = '{30, 29, ..., 2}' + ucode_str = '{30, 29, …, 2}' + assert pretty(Range(30, 1, -1)) == ascii_str + assert upretty(Range(30, 1, -1)) == ucode_str + + ascii_str = '{0, 2, ...}' + ucode_str = '{0, 2, …}' + assert pretty(Range(0, oo, 2)) == ascii_str + assert upretty(Range(0, oo, 2)) == ucode_str + + ascii_str = '{..., 2, 0}' + ucode_str = '{…, 2, 0}' + assert pretty(Range(oo, -2, -2)) == ascii_str + assert upretty(Range(oo, -2, -2)) == ucode_str + + ascii_str = '{-2, -3, ...}' + ucode_str = '{-2, -3, …}' + assert pretty(Range(-2, -oo, -1)) == ascii_str + assert upretty(Range(-2, -oo, -1)) == ucode_str + + +def test_pretty_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + ascii_str = "SetExpr([1, 3])" + ucode_str = "SetExpr([1, 3])" + assert pretty(se) == ascii_str + assert upretty(se) == ucode_str + + +def test_pretty_ImageSet(): + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + ascii_str = '{x + y | x in {1, 2, 3}, y in {3, 4}}' + ucode_str = '{x + y │ x ∊ {1, 2, 3}, y ∊ {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + ascii_str = '{x + y | (x, y) in {1, 2, 3} x {3, 4}}' + ucode_str = '{x + y │ (x, y) ∊ {1, 2, 3} × {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(x, x**2), S.Naturals) + ascii_str = '''\ + 2 \n\ +{x | x in Naturals}''' + ucode_str = '''\ +⎧ 2 │ ⎫\n\ +⎨x │ x ∊ ℕ⎬\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # TODO: The "x in N" parts below should be centered independently of the + # 1/x**2 fraction + imgset = ImageSet(Lambda(x, 1/x**2), S.Naturals) + ascii_str = '''\ + 1 \n\ +{-- | x in Naturals} + 2 \n\ + x ''' + ucode_str = '''\ +⎧1 │ ⎫\n\ +⎪── │ x ∊ ℕ⎪\n\ +⎨ 2 │ ⎬\n\ +⎪x │ ⎪\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda((x, y), 1/(x + y)**2), S.Naturals, S.Naturals) + ascii_str = '''\ + 1 \n\ +{-------- | x in Naturals, y in Naturals} + 2 \n\ + (x + y) ''' + ucode_str = '''\ +⎧ 1 │ ⎫ +⎪──────── │ x ∊ ℕ, y ∊ ℕ⎪ +⎨ 2 │ ⎬ +⎪(x + y) │ ⎪ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # issue 23449 centering issue + assert upretty([Symbol("ihat") / (Symbol("i") + 1)]) == '''\ +⎡ î ⎤ +⎢─────⎥ +⎣i + 1⎦\ +''' + assert upretty(Matrix([Symbol("ihat"), Symbol("i") + 1])) == '''\ +⎡ î ⎤ +⎢ ⎥ +⎣i + 1⎦\ +''' + + +def test_pretty_ConditionSet(): + ascii_str = '{x | x in (-oo, oo) and sin(x) = 0}' + ucode_str = '{x │ x ∊ ℝ ∧ (sin(x) = 0)}' + assert pretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ascii_str + assert upretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ucode_str + + assert pretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + assert upretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + + assert pretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "EmptySet" + assert upretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "∅" + + assert pretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + assert upretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + + condset = ConditionSet(x, 1/x**2 > 0) + ascii_str = '''\ + 1 \n\ +{x | -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + condset = ConditionSet(x, 1/x**2 > 0, S.Reals) + ascii_str = '''\ + 1 \n\ +{x | x in (-oo, oo) and -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ x ∊ ℝ ∧ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + +def test_pretty_ComplexRegion(): + from sympy.sets.fancysets import ComplexRegion + cregion = ComplexRegion(Interval(3, 5)*Interval(4, 6)) + ascii_str = '{x + y*I | x, y in [3, 5] x [4, 6]}' + ucode_str = '{x + y⋅ⅈ │ x, y ∊ [3, 5] × [4, 6]}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True) + ascii_str = '{r*(I*sin(theta) + cos(theta)) | r, theta in [0, 1] x [0, 2*pi)}' + ucode_str = '{r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ [0, 1] × [0, 2⋅π)}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(3, 1/a**2)*Interval(4, 6)) + ascii_str = '''\ + 1 \n\ +{x + y*I | x, y in [3, --] x [4, 6]} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪x + y⋅ⅈ │ x, y ∊ ⎢3, ──⎥ × [4, 6]⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1/a**2)*Interval(0, 2*pi), polar=True) + ascii_str = '''\ + 1 \n\ +{r*(I*sin(theta) + cos(theta)) | r, theta in [0, --] x [0, 2*pi)} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ ⎢0, ──⎥ × [0, 2⋅π)⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + +def test_pretty_Union_issue_10414(): + a, b = Interval(2, 3), Interval(4, 7) + ucode_str = '[2, 3] ∪ [4, 7]' + ascii_str = '[2, 3] U [4, 7]' + assert upretty(Union(a, b)) == ucode_str + assert pretty(Union(a, b)) == ascii_str + + +def test_pretty_Intersection_issue_10414(): + x, y, z, w = symbols('x, y, z, w') + a, b = Interval(x, y), Interval(z, w) + ucode_str = '[x, y] ∩ [z, w]' + ascii_str = '[x, y] n [z, w]' + assert upretty(Intersection(a, b)) == ucode_str + assert pretty(Intersection(a, b)) == ascii_str + + +def test_ProductSet_exponent(): + ucode_str = ' 1\n[0, 1] ' + assert upretty(Interval(0, 1)**1) == ucode_str + ucode_str = ' 2\n[0, 1] ' + assert upretty(Interval(0, 1)**2) == ucode_str + + +def test_ProductSet_parenthesis(): + ucode_str = '([4, 7] × {1, 2}) ∪ ([2, 3] × [4, 7])' + + a, b = Interval(2, 3), Interval(4, 7) + assert upretty(Union(a*b, b*FiniteSet(1, 2))) == ucode_str + + +def test_ProductSet_prod_char_issue_10413(): + ascii_str = '[2, 3] x [4, 7]' + ucode_str = '[2, 3] × [4, 7]' + + a, b = Interval(2, 3), Interval(4, 7) + assert pretty(a*b) == ascii_str + assert upretty(a*b) == ucode_str + + +def test_pretty_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + ascii_str = '[0, 1, 4, 9, ...]' + ucode_str = '[0, 1, 4, 9, …]' + + assert pretty(s1) == ascii_str + assert upretty(s1) == ucode_str + + ascii_str = '[1, 2, 1, 2, ...]' + ucode_str = '[1, 2, 1, 2, …]' + assert pretty(s2) == ascii_str + assert upretty(s2) == ucode_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + ascii_str = '[0, 1, 4]' + ucode_str = '[0, 1, 4]' + + assert pretty(s3) == ascii_str + assert upretty(s3) == ucode_str + + ascii_str = '[1, 2, 1]' + ucode_str = '[1, 2, 1]' + assert pretty(s4) == ascii_str + assert upretty(s4) == ucode_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + ascii_str = '[..., 9, 4, 1, 0]' + ucode_str = '[…, 9, 4, 1, 0]' + + assert pretty(s5) == ascii_str + assert upretty(s5) == ucode_str + + ascii_str = '[..., 2, 1, 2, 1]' + ucode_str = '[…, 2, 1, 2, 1]' + assert pretty(s6) == ascii_str + assert upretty(s6) == ucode_str + + ascii_str = '[1, 3, 5, 11, ...]' + ucode_str = '[1, 3, 5, 11, …]' + + assert pretty(SeqAdd(s1, s2)) == ascii_str + assert upretty(SeqAdd(s1, s2)) == ucode_str + + ascii_str = '[1, 3, 5]' + ucode_str = '[1, 3, 5]' + + assert pretty(SeqAdd(s3, s4)) == ascii_str + assert upretty(SeqAdd(s3, s4)) == ucode_str + + ascii_str = '[..., 11, 5, 3, 1]' + ucode_str = '[…, 11, 5, 3, 1]' + + assert pretty(SeqAdd(s5, s6)) == ascii_str + assert upretty(SeqAdd(s5, s6)) == ucode_str + + ascii_str = '[0, 2, 4, 18, ...]' + ucode_str = '[0, 2, 4, 18, …]' + + assert pretty(SeqMul(s1, s2)) == ascii_str + assert upretty(SeqMul(s1, s2)) == ucode_str + + ascii_str = '[0, 2, 4]' + ucode_str = '[0, 2, 4]' + + assert pretty(SeqMul(s3, s4)) == ascii_str + assert upretty(SeqMul(s3, s4)) == ucode_str + + ascii_str = '[..., 18, 4, 2, 0]' + ucode_str = '[…, 18, 4, 2, 0]' + + assert pretty(SeqMul(s5, s6)) == ascii_str + assert upretty(SeqMul(s5, s6)) == ucode_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + raises(NotImplementedError, lambda: pretty(s7)) + raises(NotImplementedError, lambda: upretty(s7)) + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + ascii_str = '[0, b, 4*b]' + ucode_str = '[0, b, 4⋅b]' + assert pretty(s8) == ascii_str + assert upretty(s8) == ucode_str + + +def test_pretty_FourierSeries(): + f = fourier_series(x, (x, -pi, pi)) + + ascii_str = \ +"""\ + 2*sin(3*x) \n\ +2*sin(x) - sin(2*x) + ---------- + ...\n\ + 3 \ +""" + + ucode_str = \ +"""\ + 2⋅sin(3⋅x) \n\ +2⋅sin(x) - sin(2⋅x) + ────────── + …\n\ + 3 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_FormalPowerSeries(): + f = fps(log(1 + x)) + + + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -k k \n\ + \\ -(-1) *x \n\ + / -----------\n\ + / k \n\ +/___, \n\ +k = 1 \ +""" + + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -k k \n\ + ╲ -(-1) ⋅x \n\ + ╱ ───────────\n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾ \n\ +k = 1 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_limits(): + expr = Limit(x, x, oo) + ascii_str = \ +"""\ + lim x\n\ +x->oo \ +""" + ucode_str = \ +"""\ +lim x\n\ +x─→∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x**2, x, 0) + ascii_str = \ +"""\ + 2\n\ + lim x \n\ +x->0+ \ +""" + ucode_str = \ +"""\ + 2\n\ + lim x \n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(1/x, x, 0) + ascii_str = \ +"""\ + 1\n\ + lim -\n\ +x->0+x\ +""" + ucode_str = \ +"""\ + 1\n\ + lim ─\n\ +x─→0⁺x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0) + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0+\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁺⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0, "-") + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0-\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁻⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x + sin(x), x, 0) + ascii_str = \ +"""\ + lim (x + sin(x))\n\ +x->0+ \ +""" + ucode_str = \ +"""\ + lim (x + sin(x))\n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x, x, 0)**2 + ascii_str = \ +"""\ + 2\n\ +/ lim x\\ \n\ +\\x->0+ / \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛ lim x⎞ \n\ +⎝x─→0⁺ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ + lim |x* lim |-||\n\ +x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ + lim ⎜x⋅ lim ⎜─⎟⎟\n\ +x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ +2* lim |x* lim |-||\n\ + x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ +2⋅ lim ⎜x⋅ lim ⎜─⎟⎟\n\ + x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x), x, 0, dir='+-') + ascii_str = \ +"""\ +lim sin(x)\n\ +x->0 \ +""" + ucode_str = \ +"""\ +lim sin(x)\n\ +x─→0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ComplexRootOf(): + expr = rootof(x**5 + 11*x - 2, 0) + ascii_str = \ +"""\ + / 5 \\\n\ +CRootOf\\x + 11*x - 2, 0/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +CRootOf⎝x + 11⋅x - 2, 0⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_RootSum(): + expr = RootSum(x**5 + 11*x - 2, auto=False) + ascii_str = \ +"""\ + / 5 \\\n\ +RootSum\\x + 11*x - 2/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +RootSum⎝x + 11⋅x - 2⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = RootSum(x**5 + 11*x - 2, Lambda(z, exp(z))) + ascii_str = \ +"""\ + / 5 z\\\n\ +RootSum\\x + 11*x - 2, z -> e /\ +""" + ucode_str = \ +"""\ + ⎛ 5 z⎞\n\ +RootSum⎝x + 11⋅x - 2, z ↦ ℯ ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_GroebnerBasis(): + expr = groebner([], x, y) + + ascii_str = \ +"""\ +GroebnerBasis([], x, y, domain=ZZ, order=lex)\ +""" + ucode_str = \ +"""\ +GroebnerBasis([], x, y, domain=ℤ, order=lex)\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + expr = groebner(F, x, y, order='grlex') + + ascii_str = \ +"""\ + /[ 2 2 ] \\\n\ +GroebnerBasis\\[x - x - 3*y + 1, y - 2*x + y - 1], x, y, domain=ZZ, order=grlex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣x - x - 3⋅y + 1, y - 2⋅x + y - 1⎦, x, y, domain=ℤ, order=grlex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = expr.fglm('lex') + + ascii_str = \ +"""\ + /[ 2 4 3 2 ] \\\n\ +GroebnerBasis\\[2*x - y - y + 1, y + 2*y - 3*y - 16*y + 7], x, y, domain=ZZ, order=lex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 4 3 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣2⋅x - y - y + 1, y + 2⋅y - 3⋅y - 16⋅y + 7⎦, x, y, domain=ℤ, order=lex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_UniversalSet(): + assert pretty(S.UniversalSet) == "UniversalSet" + assert upretty(S.UniversalSet) == '𝕌' + + +def test_pretty_Boolean(): + expr = Not(x, evaluate=False) + + assert pretty(expr) == "Not(x)" + assert upretty(expr) == "¬x" + + expr = And(x, y) + + assert pretty(expr) == "And(x, y)" + assert upretty(expr) == "x ∧ y" + + expr = Or(x, y) + + assert pretty(expr) == "Or(x, y)" + assert upretty(expr) == "x ∨ y" + + syms = symbols('a:f') + expr = And(*syms) + + assert pretty(expr) == "And(a, b, c, d, e, f)" + assert upretty(expr) == "a ∧ b ∧ c ∧ d ∧ e ∧ f" + + expr = Or(*syms) + + assert pretty(expr) == "Or(a, b, c, d, e, f)" + assert upretty(expr) == "a ∨ b ∨ c ∨ d ∨ e ∨ f" + + expr = Xor(x, y, evaluate=False) + + assert pretty(expr) == "Xor(x, y)" + assert upretty(expr) == "x ⊻ y" + + expr = Nand(x, y, evaluate=False) + + assert pretty(expr) == "Nand(x, y)" + assert upretty(expr) == "x ⊼ y" + + expr = Nor(x, y, evaluate=False) + + assert pretty(expr) == "Nor(x, y)" + assert upretty(expr) == "x ⊽ y" + + expr = Implies(x, y, evaluate=False) + + assert pretty(expr) == "Implies(x, y)" + assert upretty(expr) == "x → y" + + # don't sort args + expr = Implies(y, x, evaluate=False) + + assert pretty(expr) == "Implies(y, x)" + assert upretty(expr) == "y → x" + + expr = Equivalent(x, y, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + expr = Equivalent(y, x, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + +def test_pretty_Domain(): + expr = FF(23) + + assert pretty(expr) == "GF(23)" + assert upretty(expr) == "ℤ₂₃" + + expr = ZZ + + assert pretty(expr) == "ZZ" + assert upretty(expr) == "ℤ" + + expr = QQ + + assert pretty(expr) == "QQ" + assert upretty(expr) == "ℚ" + + expr = RR + + assert pretty(expr) == "RR" + assert upretty(expr) == "ℝ" + + expr = QQ[x] + + assert pretty(expr) == "QQ[x]" + assert upretty(expr) == "ℚ[x]" + + expr = QQ[x, y] + + assert pretty(expr) == "QQ[x, y]" + assert upretty(expr) == "ℚ[x, y]" + + expr = ZZ.frac_field(x) + + assert pretty(expr) == "ZZ(x)" + assert upretty(expr) == "ℤ(x)" + + expr = ZZ.frac_field(x, y) + + assert pretty(expr) == "ZZ(x, y)" + assert upretty(expr) == "ℤ(x, y)" + + expr = QQ.poly_ring(x, y, order=grlex) + + assert pretty(expr) == "QQ[x, y, order=grlex]" + assert upretty(expr) == "ℚ[x, y, order=grlex]" + + expr = QQ.poly_ring(x, y, order=ilex) + + assert pretty(expr) == "QQ[x, y, order=ilex]" + assert upretty(expr) == "ℚ[x, y, order=ilex]" + + +def test_pretty_prec(): + assert xpretty(S("0.3"), full_prec=True, wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec="auto", wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec=False, wrap_line=False) == "0.3" + assert xpretty(S("0.3")*x, full_prec=True, use_unicode=False, wrap_line=False) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert xpretty(S("0.3")*x, full_prec="auto", use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + assert xpretty(S("0.3")*x, full_prec=False, use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_pprint(): + import sys + from io import StringIO + fd = StringIO() + sso = sys.stdout + sys.stdout = fd + try: + pprint(pi, use_unicode=False, wrap_line=False) + finally: + sys.stdout = sso + assert fd.getvalue() == 'pi\n' + + +def test_pretty_class(): + """Test that the printer dispatcher correctly handles classes.""" + class C: + pass # C has no .__class__ and this was causing problems + + class D: + pass + + assert pretty( C ) == str( C ) + assert pretty( D ) == str( D ) + + +def test_pretty_no_wrap_line(): + huge_expr = 0 + for i in range(20): + huge_expr += i*sin(i + x) + assert xpretty(huge_expr ).find('\n') != -1 + assert xpretty(huge_expr, wrap_line=False).find('\n') == -1 + + +def test_settings(): + raises(TypeError, lambda: pretty(S(4), method="garbage")) + + +def test_pretty_sum(): + from sympy.abc import x, a, b, k, m, n + + expr = Sum(k**k, (k, 0, n)) + ascii_str = \ +"""\ + n \n\ +___ \n\ +\\ ` \n\ + \\ k\n\ + / k \n\ +/__, \n\ +k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**k, (k, oo, n)) + ascii_str = \ +"""\ + n \n\ + ___ \n\ + \\ ` \n\ + \\ k\n\ + / k \n\ + /__, \n\ +k = oo \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = ∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), (k, 0, n**n)) + ascii_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +\\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ +/_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ +/ \n\ +-oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ +∞ \n\ +⌠ \n\ +⎮ x \n\ +⎮ x dx \n\ +⌡ \n\ +-∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), ( + k, x + n + x**2 + n**2 + (x/n) + (1/x), Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ + / \n\ + -oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + 2 2 1 x \n\ +k = n + n + x + x + - + - \n\ + x n \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ⌠ \n\ + ⎮ x \n\ + ⎮ x dx \n\ + ⌡ \n\ + -∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + 2 2 1 x \n\ +k = n + n + x + x + ─ + ─ \n\ + x n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, x + n + x**2 + n**2 + (x/n) + (1/x))) + ascii_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + - + - \n\ + x n \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + ─ + ─ \n\ + x n \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ + __ \n\ + \\ ` \n\ + ) x\n\ + /_, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ x\n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ 2\n\ + / x \n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ 2\n\ + ╱ x \n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ x\n\ + ) -\n\ + / 2\n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ \n\ + ╲ x\n\ + ╱ ─\n\ + ╱ 2\n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**3/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 3\n\ + \\ x \n\ + / --\n\ + / 2 \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 3\n\ + ╲ x \n\ + ╱ ──\n\ + ╱ 2 \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum((x**3*y**(x/2))**n, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ n\n\ + \\ / x\\ \n\ + ) | -| \n\ + / | 3 2| \n\ + / \\x *y / \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +_____ \n\ +╲ \n\ + ╲ \n\ + ╲ n\n\ + ╲ ⎛ x⎞ \n\ + ╱ ⎜ ─⎟ \n\ + ╱ ⎜ 3 2⎟ \n\ + ╱ ⎝x ⋅y ⎠ \n\ +╱ \n\ +‾‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 1 \n\ + \\ --\n\ + / 2\n\ + / x \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 1 \n\ + ╲ ──\n\ + ╱ 2\n\ + ╱ x \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -a \n\ + \\ ---\n\ + / b \n\ + / y \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -a \n\ + ╲ ───\n\ + ╱ b \n\ + ╱ y \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo), (y, 1, 2)) + ascii_str = \ +"""\ + 2 oo \n\ +____ ____ \n\ +\\ ` \\ ` \n\ + \\ \\ -a\n\ + \\ \\ --\n\ + / / b \n\ + / / y \n\ +/___, /___, \n\ +y = 1 x = 0 \ +""" + ucode_str = \ +"""\ + 2 ∞ \n\ +____ ____ \n\ +╲ ╲ \n\ + ╲ ╲ -a\n\ + ╲ ╲ ──\n\ + ╱ ╱ b \n\ + ╱ ╱ y \n\ +╱ ╱ \n\ +‾‾‾‾ ‾‾‾‾ \n\ +y = 1 x = 0 \ +""" + expr = Sum(1/(1 + 1/( + 1 + 1/k)) + 1, (k, 111, 1 + 1/n), (k, 1/(1 + m), oo)) + 1/(1 + 1/k) + ascii_str = \ +"""\ + 1 \n\ + 1 + - \n\ + oo n \n\ + _____ _____ \n\ + \\ ` \\ ` \n\ + \\ \\ / 1 \\ \n\ + \\ \\ |1 + ---------| \n\ + \\ \\ | 1 | 1 \n\ + ) ) | 1 + -----| + -----\n\ + / / | 1| 1\n\ + / / | 1 + -| 1 + -\n\ + / / \\ k/ k\n\ + /____, /____, \n\ + 1 k = 111 \n\ +k = ----- \n\ + m + 1 \ +""" + ucode_str = \ +"""\ + 1 \n\ + 1 + ─ \n\ + ∞ n \n\ + ______ ______ \n\ + ╲ ╲ \n\ + ╲ ╲ \n\ + ╲ ╲ ⎛ 1 ⎞ \n\ + ╲ ╲ ⎜1 + ─────────⎟ \n\ + ╲ ╲ ⎜ 1 ⎟ 1 \n\ + ╱ ╱ ⎜ 1 + ─────⎟ + ─────\n\ + ╱ ╱ ⎜ 1⎟ 1\n\ + ╱ ╱ ⎜ 1 + ─⎟ 1 + ─\n\ + ╱ ╱ ⎝ k⎠ k\n\ + ╱ ╱ \n\ + ‾‾‾‾‾‾ ‾‾‾‾‾‾ \n\ + 1 k = 111 \n\ +k = ───── \n\ + m + 1 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_units(): + expr = joule + ascii_str1 = \ +"""\ + 2\n\ +kilogram*meter \n\ +---------------\n\ + 2 \n\ + second \ +""" + unicode_str1 = \ +"""\ + 2\n\ +kilogram⋅meter \n\ +───────────────\n\ + 2 \n\ + second \ +""" + + ascii_str2 = \ +"""\ + 2\n\ +3*x*y*kilogram*meter \n\ +---------------------\n\ + 2 \n\ + second \ +""" + unicode_str2 = \ +"""\ + 2\n\ +3⋅x⋅y⋅kilogram⋅meter \n\ +─────────────────────\n\ + 2 \n\ + second \ +""" + + from sympy.physics.units import kg, m, s + assert upretty(expr) == "joule" + assert pretty(expr) == "joule" + assert upretty(expr.convert_to(kg*m**2/s**2)) == unicode_str1 + assert pretty(expr.convert_to(kg*m**2/s**2)) == ascii_str1 + assert upretty(3*kg*x*m**2*y/s**2) == unicode_str2 + assert pretty(3*kg*x*m**2*y/s**2) == ascii_str2 + + +def test_pretty_Subs(): + f = Function('f') + expr = Subs(f(x), x, ph**2) + ascii_str = \ +"""\ +(f(x))| 2\n\ + |x=phi \ +""" + unicode_str = \ +"""\ +(f(x))│ 2\n\ + │x=φ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x), x, 0) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +\\dx /|x=0\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎝dx ⎠│x=0\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +|dx || \n\ +|--------|| \n\ +\\ y /|x=0, y=1/2\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎜dx ⎟│ \n\ +⎜────────⎟│ \n\ +⎝ y ⎠│x=0, y=1/2\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_gammas(): + assert upretty(lowergamma(x, y)) == "γ(x, y)" + assert upretty(uppergamma(x, y)) == "Γ(x, y)" + assert xpretty(gamma(x), use_unicode=True) == 'Γ(x)' + assert xpretty(gamma, use_unicode=True) == 'Γ' + assert xpretty(symbols('gamma', cls=Function)(x), use_unicode=True) == 'γ(x)' + assert xpretty(symbols('gamma', cls=Function), use_unicode=True) == 'γ' + + +def test_beta(): + assert xpretty(beta(x,y), use_unicode=True) == 'Β(x, y)' + assert xpretty(beta(x,y), use_unicode=False) == 'B(x, y)' + assert xpretty(beta, use_unicode=True) == 'Β' + assert xpretty(beta, use_unicode=False) == 'B' + mybeta = Function('beta') + assert xpretty(mybeta(x), use_unicode=True) == 'β(x)' + assert xpretty(mybeta(x, y, z), use_unicode=False) == 'beta(x, y, z)' + assert xpretty(mybeta, use_unicode=True) == 'β' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert xpretty(mygamma, use_unicode=True) == r"mygamma" + assert xpretty(mygamma(x), use_unicode=True) == r"mygamma(x)" + + +def test_SingularityFunction(): + assert xpretty(SingularityFunction(x, 0, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=True) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 0, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=False) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + + +def test_deltas(): + assert xpretty(DiracDelta(x), use_unicode=True) == 'δ(x)' + assert xpretty(DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +δ (x)\ +""" + assert xpretty(x*DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +x⋅δ (x)\ +""" + + +def test_hyper(): + expr = hyper((), (), z) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ z⎟\n\ +0╵ 0 ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | z|\n\ +0 0 \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((), (1,), x) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +0╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | x|\n\ +0 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([2], [1], x) + ucode_str = \ +"""\ + ┌─ ⎛2 │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +1╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2 | \\\n\ + | | | x|\n\ +1 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi/3, -2*k), (3, 4, 5, -3), x) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ + ┌─ ⎜ ─, -2⋅k │ ⎟\n\ + ├─ ⎜ 3 │ x⎟\n\ +2╵ 4 ⎜ │ ⎟\n\ + ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + \n\ + _ / pi | \\\n\ + |_ | --, -2*k | |\n\ + | | 3 | x|\n\ +2 4 | | |\n\ + \\-3, 3, 4, 5 | /\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi, S('2/3'), -2*k), (3, 4, 5, -3), x**2) + ucode_str = \ +"""\ + ┌─ ⎛2/3, π, -2⋅k │ 2⎞\n\ + ├─ ⎜ │ x ⎟\n\ +3╵ 4 ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2/3, pi, -2*k | 2\\ + | | | x | +3 4 \\ -3, 3, 4, 5 | /""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([1, 2], [3, 4], 1/(1/(1/(1/x + 1) + 1) + 1)) + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ + ┌─ ⎜1, 2 │ 1 + ─────────⎟\n\ + ├─ ⎜ │ 1 ⎟\n\ +2╵ 2 ⎜3, 4 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + \n\ + / | 1 \\\n\ + | | -------------|\n\ + _ | | 1 |\n\ + |_ |1, 2 | 1 + ---------|\n\ + | | | 1 |\n\ +2 2 |3, 4 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_meijerg(): + expr = meijerg([pi, pi, x], [1], [0, 1], [1, 2, 3], z) + ucode_str = \ +"""\ +╭─╮2, 3 ⎛π, π, x 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯4, 5 ⎝ 0, 1 1, 2, 3 │ ⎠\ +""" + ascii_str = \ +"""\ + __2, 3 /pi, pi, x 1 | \\\n\ +/__ | | z|\n\ +\\_|4, 5 \\ 0, 1 1, 2, 3 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, pi/7], [2, pi, 5], [], [], z**2) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ +╭─╮0, 2 ⎜1, ─ 2, 5, π │ 2⎟\n\ +│╶┐ ⎜ 7 │ z ⎟\n\ +╰─╯5, 0 ⎜ │ ⎟\n\ + ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + / pi | \\\n\ + __0, 2 |1, -- 2, 5, pi | 2|\n\ +/__ | 7 | z |\n\ +\\_|5, 0 | | |\n\ + \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ucode_str = \ +"""\ +╭─╮ 1, 10 ⎛1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯11, 2 ⎝ 1 1 │ ⎠\ +""" + ascii_str = \ +"""\ + __ 1, 10 /1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 | \\\n\ +/__ | | z|\n\ +\\_|11, 2 \\ 1 1 | /\ +""" + + expr = meijerg([1]*10, [1], [1], [1], z) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, 2, ], [4, 3], [3], [4, 5], 1/(1/(1/(1/x + 1) + 1) + 1)) + + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ +╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟\n\ +│╶┐ ⎜ │ 1 ⎟\n\ +╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + / | 1 \\\n\ + | | -------------|\n\ + | | 1 |\n\ + __1, 2 |1, 2 3, 4 | 1 + ---------|\n\ +/__ | | 1 |\n\ +\\_|4, 3 | 3 4, 5 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(expr, x) + + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ │ 1 ⎞ \n\ +⎮ ⎜ │ ─────────────⎟ \n\ +⎮ ⎜ │ 1 ⎟ \n\ +⎮ ╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟ \n\ +⎮ │╶┐ ⎜ │ 1 ⎟ dx\n\ +⎮ ╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟ \n\ +⎮ ⎜ │ 1⎟ \n\ +⎮ ⎜ │ 1 + ─⎟ \n\ +⎮ ⎝ │ x⎠ \n\ +⌡ \ +""" + + ascii_str = \ +"""\ + / \n\ + | \n\ + | / | 1 \\ \n\ + | | | -------------| \n\ + | | | 1 | \n\ + | __1, 2 |1, 2 3, 4 | 1 + ---------| \n\ + | /__ | | 1 | dx\n\ + | \\_|4, 3 | 3 4, 5 | 1 + -----| \n\ + | | | 1| \n\ + | | | 1 + -| \n\ + | \\ | x/ \n\ + | \n\ +/ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + expr = A*B*C**-1 + ascii_str = \ +"""\ + -1\n\ +A*B*C \ +""" + ucode_str = \ +"""\ + -1\n\ +A⋅B⋅C \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = C**-1*A*B + ascii_str = \ +"""\ + -1 \n\ +C *A*B\ +""" + ucode_str = \ +"""\ + -1 \n\ +C ⋅A⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B + ascii_str = \ +"""\ + -1 \n\ +A*C *B\ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B/x + ascii_str = \ +"""\ + -1 \n\ +A*C *B\n\ +-------\n\ + x \ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\n\ +───────\n\ + x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_special_functions(): + x, y = symbols("x y") + + # atan2 + expr = atan2(y/sqrt(200), sqrt(x)) + ascii_str = \ +"""\ + / ___ \\\n\ + |\\/ 2 *y ___|\n\ +atan2|-------, \\/ x |\n\ + \\ 20 /\ +""" + ucode_str = \ +"""\ + ⎛√2⋅y ⎞\n\ +atan2⎜────, √x⎟\n\ + ⎝ 20 ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_geometry(): + e = Segment((0, 1), (0, 2)) + assert pretty(e) == 'Segment2D(Point2D(0, 1), Point2D(0, 2))' + e = Ray((1, 1), angle=4.02*pi) + assert pretty(e) == 'Ray2D(Point2D(1, 1), Point2D(2, tan(pi/50) + 1))' + + +def test_expint(): + expr = Ei(x) + string = 'Ei(x)' + assert pretty(expr) == string + assert upretty(expr) == string + + expr = expint(1, z) + ucode_str = "E₁(z)" + ascii_str = "expint(1, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + assert pretty(Shi(x)) == 'Shi(x)' + assert pretty(Si(x)) == 'Si(x)' + assert pretty(Ci(x)) == 'Ci(x)' + assert pretty(Chi(x)) == 'Chi(x)' + assert upretty(Shi(x)) == 'Shi(x)' + assert upretty(Si(x)) == 'Si(x)' + assert upretty(Ci(x)) == 'Ci(x)' + assert upretty(Chi(x)) == 'Chi(x)' + + +def test_elliptic_functions(): + ascii_str = \ +"""\ + / 1 \\\n\ +K|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +K⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_k(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +F|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +F⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_f(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 1 \\\n\ +E|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +E⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_e(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +E|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +E⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_e(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / |4\\\n\ +Pi|3|-|\n\ + \\ |x/\ +""" + ucode_str = \ +"""\ + ⎛ │4⎞\n\ +Π⎜3│─⎟\n\ + ⎝ │x⎠\ +""" + expr = elliptic_pi(3, 4/x) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 4| \\\n\ +Pi|3; -|6|\n\ + \\ x| /\ +""" + ucode_str = \ +"""\ + ⎛ 4│ ⎞\n\ +Π⎜3; ─│6⎟\n\ + ⎝ x│ ⎠\ +""" + expr = elliptic_pi(3, 4/x, 6) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert upretty(where(X > 0)) == "Domain: 0 < x₁ ∧ x₁ < ∞" + + D = Die('d1', 6) + assert upretty(where(D > 4)) == 'Domain: d₁ = 5 ∨ d₁ = 6' + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert upretty(pspace(Tuple(A, B)).domain) == \ + 'Domain: 0 ≤ a ∧ 0 ≤ b ∧ a < ∞ ∧ b < ∞' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ.poly_ring(x, y) + + expr = F.convert(x/(x + y)) + assert pretty(expr) == "x/(x + y)" + assert upretty(expr) == "x/(x + y)" + + expr = R.convert(x + y) + assert pretty(expr) == "x + y" + assert upretty(expr) == "x + y" + + +def test_issue_6285(): + assert pretty(Pow(2, -5, evaluate=False)) == '1 \n--\n 5\n2 ' + assert pretty(Pow(x, (1/pi))) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'x ' + + +def test_issue_6359(): + assert pretty(Integral(x**2, x)**2) == \ +"""\ + 2 +/ / \\ \n\ +| | | \n\ +| | 2 | \n\ +| | x dx| \n\ +| | | \n\ +\\/ / \ +""" + assert upretty(Integral(x**2, x)**2) == \ +"""\ + 2 +⎛⌠ ⎞ \n\ +⎜⎮ 2 ⎟ \n\ +⎜⎮ x dx⎟ \n\ +⎝⌡ ⎠ \ +""" + + assert pretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2\n\ +/ 1 \\ \n\ +|___ | \n\ +|\\ ` | \n\ +| \\ 2| \n\ +| / x | \n\ +|/__, | \n\ +\\x = 0 / \ +""" + assert upretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2 +⎛ 1 ⎞ \n\ +⎜ ___ ⎟ \n\ +⎜ ╲ ⎟ \n\ +⎜ ╲ 2⎟ \n\ +⎜ ╱ x ⎟ \n\ +⎜ ╱ ⎟ \n\ +⎜ ‾‾‾ ⎟ \n\ +⎝x = 0 ⎠ \ +""" + + assert pretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +/ 2 \\ \n\ +|______ | \n\ +| | | 2| \n\ +| | | x | \n\ +| | | | \n\ +\\x = 1 / \ +""" + assert upretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +⎛ 2 ⎞ \n\ +⎜─┬──┬─ ⎟ \n\ +⎜ │ │ 2⎟ \n\ +⎜ │ │ x ⎟ \n\ +⎜ │ │ ⎟ \n\ +⎝x = 1 ⎠ \ +""" + + f = Function('f') + assert pretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +/d \\ \n\ +|--(f(x))| \n\ +\\dx / \ +""" + assert upretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +⎛d ⎞ \n\ +⎜──(f(x))⎟ \n\ +⎝dx ⎠ \ +""" + + +def test_issue_6739(): + ascii_str = \ +"""\ + 1 \n\ +-----\n\ + ___\n\ +\\/ x \ +""" + ucode_str = \ +"""\ +1 \n\ +──\n\ +√x\ +""" + assert pretty(1/sqrt(x)) == ascii_str + assert upretty(1/sqrt(x)) == ucode_str + + +def test_complicated_symbol_unchanged(): + for symb_name in ["dexpr2_d1tau", "dexpr2^d1tau"]: + assert pretty(Symbol(symb_name)) == symb_name + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert pretty(A1) == "A1" + assert upretty(A1) == "A₁" + + assert pretty(f1) == "f1:A1-->A2" + assert upretty(f1) == "f₁:A₁——▶A₂" + assert pretty(id_A1) == "id:A1-->A1" + assert upretty(id_A1) == "id:A₁——▶A₁" + + assert pretty(f2*f1) == "f2*f1:A1-->A3" + assert upretty(f2*f1) == "f₂∘f₁:A₁——▶A₃" + + assert pretty(K1) == "K1" + assert upretty(K1) == "K₁" + + # Test how diagrams are printed. + d = Diagram() + assert pretty(d) == "EmptySet" + assert upretty(d) == "∅" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" + + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, " \ + "id:A₂——▶A₂: ∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" \ + " ==> {f2*f1:A1-->A3: {unique}}" + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, id:A₂——▶A₂: " \ + "∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" \ + " ══▶ {f₂∘f₁:A₁——▶A₃: {unique}}" + + grid = DiagramGrid(d) + assert pretty(grid) == "A1 A2\n \nA3 " + assert upretty(grid) == "A₁ A₂\n \nA₃ " + + +def test_PrettyModules(): + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + ucode_str = \ +"""\ + 2\n\ +ℚ[x, y] \ +""" + ascii_str = \ +"""\ + 2\n\ +QQ[x, y] \ +""" + + assert upretty(F) == ucode_str + assert pretty(F) == ascii_str + + ucode_str = \ +"""\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + ascii_str = \ +"""\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(M) == ucode_str + assert pretty(M) == ascii_str + + I = R.ideal(x**2, y) + + ucode_str = \ +"""\ +╱ 2 ╲\n\ +╲x , y╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +\ +""" + + assert upretty(I) == ucode_str + assert pretty(I) == ascii_str + + Q = F / M + + ucode_str = \ +"""\ + 2 \n\ + ℚ[x, y] \n\ +─────────────────\n\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ + QQ[x, y] \n\ +-----------------\n\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(Q) == ucode_str + assert pretty(Q) == ascii_str + + ucode_str = \ +"""\ +╱⎡ 3⎤ ╲\n\ +│⎢ x ⎥ ╱ ⎡ 2⎤╲ ╱ ⎡ 2⎤╲│\n\ +│⎢1, ──⎥ + ╲[x, y], ⎣1, x ⎦╱, [2, y] + ╲[x, y], ⎣1, x ⎦╱│\n\ +╲⎣ 2 ⎦ ╱\ +""" + + ascii_str = \ +"""\ + 3 \n\ + x 2 2 \n\ +<[1, --] + <[x, y], [1, x ]>, [2, y] + <[x, y], [1, x ]>>\n\ + 2 \ +""" + + +def test_QuotientRing(): + R = QQ.old_poly_ring(x)/[x**2 + 1] + + ucode_str = \ +"""\ + ℚ[x] \n\ +────────\n\ +╱ 2 ╲\n\ +╲x + 1╱\ +""" + + ascii_str = \ +"""\ + QQ[x] \n\ +--------\n\ + 2 \n\ +\ +""" + + assert upretty(R) == ucode_str + assert pretty(R) == ascii_str + + ucode_str = \ +"""\ + ╱ 2 ╲\n\ +1 + ╲x + 1╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +1 + \ +""" + + assert upretty(R.one) == ucode_str + assert pretty(R.one) == ascii_str + + +def test_Homomorphism(): + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x) + + expr = homomorphism(R.free_module(1), R.free_module(1), [0]) + + ucode_str = \ +"""\ + 1 1\n\ +[0] : ℚ[x] ──> ℚ[x] \ +""" + + ascii_str = \ +"""\ + 1 1\n\ +[0] : QQ[x] --> QQ[x] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(2), R.free_module(2), [0, 0]) + + ucode_str = \ +"""\ +⎡0 0⎤ 2 2\n\ +⎢ ⎥ : ℚ[x] ──> ℚ[x] \n\ +⎣0 0⎦ \ +""" + + ascii_str = \ +"""\ +[0 0] 2 2\n\ +[ ] : QQ[x] --> QQ[x] \n\ +[0 0] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(1), R.free_module(1) / [[x]], [0]) + + ucode_str = \ +"""\ + 1\n\ + 1 ℚ[x] \n\ +[0] : ℚ[x] ──> ─────\n\ + <[x]>\ +""" + + ascii_str = \ +"""\ + 1\n\ + 1 QQ[x] \n\ +[0] : QQ[x] --> ------\n\ + <[x]> \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert pretty(t) == r'Tr(A*B)' + assert upretty(t) == 'Tr(A⋅B)' + + +def test_pretty_Add(): + eq = Mul(-2, x - 2, evaluate=False) + 5 + assert pretty(eq) == '5 - 2*(x - 2)' + + +def test_issue_7179(): + assert upretty(Not(Equivalent(x, y))) == 'x ⇎ y' + assert upretty(Not(Implies(x, y))) == 'x ↛ y' + + +def test_issue_7180(): + assert upretty(Equivalent(x, y)) == 'x ⇔ y' + + +def test_pretty_Complement(): + assert pretty(S.Reals - S.Naturals) == '(-oo, oo) \\ Naturals' + assert upretty(S.Reals - S.Naturals) == 'ℝ \\ ℕ' + assert pretty(S.Reals - S.Naturals0) == '(-oo, oo) \\ Naturals0' + assert upretty(S.Reals - S.Naturals0) == 'ℝ \\ ℕ₀' + + +def test_pretty_SymmetricDifference(): + from sympy.sets.sets import SymmetricDifference + assert upretty(SymmetricDifference(Interval(2,3), Interval(3,5), \ + evaluate = False)) == '[2, 3] ∆ [3, 5]' + with raises(NotImplementedError): + pretty(SymmetricDifference(Interval(2,3), Interval(3,5), evaluate = False)) + + +def test_pretty_Contains(): + assert pretty(Contains(x, S.Integers)) == 'Contains(x, Integers)' + assert upretty(Contains(x, S.Integers)) == 'x ∈ ℤ' + + +def test_issue_8292(): + from sympy.core import sympify + e = sympify('((x+x**4)/(x-1))-(2*(x-1)**4/(x-1)**4)', evaluate=False) + ucode_str = \ +"""\ + 4 4 \n\ + 2⋅(x - 1) x + x\n\ +- ────────── + ──────\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + ascii_str = \ +"""\ + 4 4 \n\ + 2*(x - 1) x + x\n\ +- ---------- + ------\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + assert pretty(e) == ascii_str + assert upretty(e) == ucode_str + + +def test_issue_4335(): + y = Function('y') + expr = -y(x).diff(x) + ucode_str = \ +"""\ + d \n\ +-──(y(x))\n\ + dx \ +""" + ascii_str = \ +"""\ + d \n\ +- --(y(x))\n\ + dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_8344(): + from sympy.core import sympify + e = sympify('2*x*y**2/1**2 + 1', evaluate=False) + ucode_str = \ +"""\ + 2 \n\ +2⋅x⋅y \n\ +────── + 1\n\ + 2 \n\ + 1 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6324(): + x = Pow(2, 3, evaluate=False) + y = Pow(10, -2, evaluate=False) + e = Mul(x, y, evaluate=False) + ucode_str = \ +"""\ + 3 \n\ +2 \n\ +───\n\ + 2\n\ +10 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_7927(): + e = sin(x/2)**cos(x/2) + ucode_str = \ +"""\ + ⎛x⎞\n\ + cos⎜─⎟\n\ + ⎝2⎠\n\ +⎛ ⎛x⎞⎞ \n\ +⎜sin⎜─⎟⎟ \n\ +⎝ ⎝2⎠⎠ \ +""" + assert upretty(e) == ucode_str + e = sin(x)**(S(11)/13) + ucode_str = \ +"""\ + 11\n\ + ──\n\ + 13\n\ +(sin(x)) \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6134(): + from sympy.abc import lamda, t + phi = Function('phi') + + e = lamda*x*Integral(phi(t)*pi*sin(pi*t), (t, 0, 1)) + lamda*x**2*Integral(phi(t)*2*pi*sin(2*pi*t), (t, 0, 1)) + ucode_str = \ +"""\ + 1 1 \n\ + 2 ⌠ ⌠ \n\ +λ⋅x ⋅⎮ 2⋅π⋅φ(t)⋅sin(2⋅π⋅t) dt + λ⋅x⋅⎮ π⋅φ(t)⋅sin(π⋅t) dt\n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_9877(): + ucode_str1 = '(2, 3) ∪ ([1, 2] \\ {x})' + a, b, c = Interval(2, 3, True, True), Interval(1, 2), FiniteSet(x) + assert upretty(Union(a, Complement(b, c))) == ucode_str1 + + ucode_str2 = '{x} ∩ {y} ∩ ({z} \\ [1, 2])' + d, e, f, g = FiniteSet(x), FiniteSet(y), FiniteSet(z), Interval(1, 2) + assert upretty(Intersection(d, e, Complement(f, g))) == ucode_str2 + + +def test_issue_13651(): + expr1 = c + Mul(-1, a + b, evaluate=False) + assert pretty(expr1) == 'c - (a + b)' + expr2 = c + Mul(-1, a - b + d, evaluate=False) + assert pretty(expr2) == 'c - (a - b + d)' + + +def test_pretty_primenu(): + from sympy.functions.combinatorial.numbers import primenu + + ascii_str1 = "nu(n)" + ucode_str1 = "ν(n)" + + n = symbols('n', integer=True) + assert pretty(primenu(n)) == ascii_str1 + assert upretty(primenu(n)) == ucode_str1 + + +def test_pretty_primeomega(): + from sympy.functions.combinatorial.numbers import primeomega + + ascii_str1 = "Omega(n)" + ucode_str1 = "Ω(n)" + + n = symbols('n', integer=True) + assert pretty(primeomega(n)) == ascii_str1 + assert upretty(primeomega(n)) == ucode_str1 + + +def test_pretty_Mod(): + from sympy.core import Mod + + ascii_str1 = "x mod 7" + ucode_str1 = "x mod 7" + + ascii_str2 = "(x + 1) mod 7" + ucode_str2 = "(x + 1) mod 7" + + ascii_str3 = "2*x mod 7" + ucode_str3 = "2⋅x mod 7" + + ascii_str4 = "(x mod 7) + 1" + ucode_str4 = "(x mod 7) + 1" + + ascii_str5 = "2*(x mod 7)" + ucode_str5 = "2⋅(x mod 7)" + + x = symbols('x', integer=True) + assert pretty(Mod(x, 7)) == ascii_str1 + assert upretty(Mod(x, 7)) == ucode_str1 + assert pretty(Mod(x + 1, 7)) == ascii_str2 + assert upretty(Mod(x + 1, 7)) == ucode_str2 + assert pretty(Mod(2 * x, 7)) == ascii_str3 + assert upretty(Mod(2 * x, 7)) == ucode_str3 + assert pretty(Mod(x, 7) + 1) == ascii_str4 + assert upretty(Mod(x, 7) + 1) == ucode_str4 + assert pretty(2 * Mod(x, 7)) == ascii_str5 + assert upretty(2 * Mod(x, 7)) == ucode_str5 + + +def test_issue_11801(): + assert pretty(Symbol("")) == "" + assert upretty(Symbol("")) == "" + + +def test_pretty_UnevaluatedExpr(): + x = symbols('x') + he = UnevaluatedExpr(1/x) + + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + + assert upretty(he) == ucode_str + + ucode_str = \ +"""\ + 2\n\ +⎛1⎞ \n\ +⎜─⎟ \n\ +⎝x⎠ \ +""" + + assert upretty(he**2) == ucode_str + + ucode_str = \ +"""\ + 1\n\ +1 + ─\n\ + x\ +""" + + assert upretty(he + 1) == ucode_str + + ucode_str = \ +('''\ + 1\n\ +x⋅─\n\ + x\ +''') + assert upretty(x*he) == ucode_str + + +def test_issue_10472(): + M = (Matrix([[0, 0], [0, 0]]), Matrix([0, 0])) + + ucode_str = \ +"""\ +⎛⎡0 0⎤ ⎡0⎤⎞ +⎜⎢ ⎥, ⎢ ⎥⎟ +⎝⎣0 0⎦ ⎣0⎦⎠\ +""" + assert upretty(M) == ucode_str + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + ascii_str1 = "A_00" + ucode_str1 = "A₀₀" + assert pretty(A[0, 0]) == ascii_str1 + assert upretty(A[0, 0]) == ucode_str1 + + ascii_str1 = "3*A_00" + ucode_str1 = "3⋅A₀₀" + assert pretty(3*A[0, 0]) == ascii_str1 + assert upretty(3*A[0, 0]) == ucode_str1 + + ascii_str1 = "(-B + A)[0, 0]" + ucode_str1 = "(-B + A)[0, 0]" + F = C[0, 0].subs(C, A - B) + assert pretty(F) == ascii_str1 + assert upretty(F) == ucode_str1 + + +def test_issue_12675(): + x, y, t, j = symbols('x y t j') + e = CoordSys3D('e') + + ucode_str = \ +"""\ +⎛ t⎞ \n\ +⎜⎛x⎞ ⎟ j_e\n\ +⎜⎜─⎟ ⎟ \n\ +⎝⎝y⎠ ⎠ \ +""" + assert upretty((x/y)**t*e.j) == ucode_str + ucode_str = \ +"""\ +⎛1⎞ \n\ +⎜─⎟ j_e\n\ +⎝y⎠ \ +""" + assert upretty((1/y)*e.j) == ucode_str + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + assert pretty(-A*B*C) == "-A*B*C" + assert pretty(A - B) == "-B + A" + assert pretty(A*B*C - A*B - B*C) == "-A*B -B*C + A*B*C" + + # issue #14814 + x = MatrixSymbol('x', n, n) + y = MatrixSymbol('y*', n, n) + assert pretty(x + y) == "x + y*" + ascii_str = \ +"""\ + 2 \n\ +-2*y* -a*x\ +""" + assert pretty(-a*x + -2*y*y) == ascii_str + + +def test_degree_printing(): + expr1 = 90*degree + assert pretty(expr1) == '90°' + expr2 = x*degree + assert pretty(expr2) == 'x°' + expr3 = cos(x*degree + 90*degree) + assert pretty(expr3) == 'cos(x° + 90°)' + + +def test_vector_expr_pretty_printing(): + A = CoordSys3D('A') + + assert upretty(Cross(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)×((x_A) i_A + (3⋅y_A) j_A)" + assert upretty(x*Cross(A.i, A.j)) == 'x⋅(i_A)×(j_A)' + + assert upretty(Curl(A.x*A.i + 3*A.y*A.j)) == "∇×((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Divergence(A.x*A.i + 3*A.y*A.j)) == "∇⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Dot(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Gradient(A.x+3*A.y)) == "∇(x_A + 3⋅y_A)" + assert upretty(Laplacian(A.x+3*A.y)) == "∆(x_A + 3⋅y_A)" + # TODO: add support for ASCII pretty. + + +def test_pretty_print_tensor_expr(): + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + + expr = -i + ascii_str = \ +"""\ +-i\ +""" + ucode_str = \ +"""\ +-i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + ascii_str = \ +"""\ + i\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i0) + ascii_str = \ +"""\ + i_0\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i₀\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(-i) + ascii_str = \ +"""\ + \n\ +A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -3*A(-i) + ascii_str = \ +"""\ + \n\ +-3*A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +-3⋅A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j) + ascii_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + ucode_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -i) + ascii_str = \ +"""\ + L_0 \n\ +H \n\ + L_0\ +""" + ucode_str = \ +"""\ + L₀ \n\ +H \n\ + L₀\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j)*A(j)*B(k) + ascii_str = \ +"""\ + i L_0 k\n\ +H *A *B \n\ + L_0 \ +""" + ucode_str = \ +"""\ + i L₀ k\n\ +H ⋅A ⋅B \n\ + L₀ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1+x)*A(i) + ascii_str = \ +"""\ + i\n\ +(x + 1)*A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +(x + 1)⋅A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + 3*B(i) + ascii_str = \ +"""\ + i i\n\ +3*B + A \n\ + \ +""" + ucode_str = \ +"""\ + i i\n\ +3⋅B + A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_print_tensor_partial_deriv(): + from sympy.tensor.toperators import PartialDerivative + + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + + A, B, C, D = tensor_heads("A B C D", [L]) + + H = TensorHead("H", [L, L]) + + expr = PartialDerivative(A(i), A(j)) + ascii_str = \ +"""\ + d / i\\\n\ +---|A |\n\ + j\\ /\n\ +dA \n\ + \ +""" + ucode_str = \ +"""\ + ∂ ⎛ i⎞\n\ +───⎜A ⎟\n\ + j⎝ ⎠\n\ +∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k \\\n\ +A *---|H |\n\ + j\\ L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k ⎞\n\ +A ⋅───⎜H ⎟\n\ + j⎝ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(B(k)*C(-i) + 3*H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k k \\\n\ +A *---|3*H + B *C |\n\ + j\\ L_0 L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k k ⎞\n\ +A ⋅───⎜3⋅H + B ⋅C ⎟\n\ + j⎝ L₀ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(j), D(j)) + ascii_str = \ +"""\ +/ i i\\ d / L_0\\\n\ +|A + B |*-----|C |\n\ +\\ / L_0\\ /\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ i i⎞ ∂ ⎛ L₀⎞\n\ +⎜A + B ⎟⋅────⎜C ⎟\n\ +⎝ ⎠ L₀⎝ ⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(-i), D(j)) + ascii_str = \ +"""\ +/ L_0 L_0\\ d / \\\n\ +|A + B |*---|C |\n\ +\\ / j\\ L_0/\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ L₀ L₀⎞ ∂ ⎛ ⎞\n\ +⎜A + B ⎟⋅───⎜C ⎟\n\ +⎝ ⎠ j⎝ L₀⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜A + B ⎟\n\ + ⎝ i i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜3⋅A ⎟\n\ + ⎝ i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i:1}) + ascii_str = \ +"""\ + i=1,j\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i: 1, j: 1}) + ascii_str = \ +"""\ + i=1,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {j: 1}) + ascii_str = \ +"""\ + i,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + + expr = TensorElement(H(-i, j), {-i: 1}) + ascii_str = \ +"""\ + j\n\ +H \n\ + i=1 \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_15560(): + a = MatrixSymbol('a', 1, 1) + e = pretty(a*(KroneckerProduct(a, a))) + result = 'a*(a x a)' + assert e == result + + +def test_print_polylog(): + # Part of issue 6013 + uresult = 'Li₂(3)' + aresult = 'polylog(2, 3)' + assert pretty(polylog(2, 3)) == aresult + assert upretty(polylog(2, 3)) == uresult + + +# Issue #25312 +def test_print_expint_polylog_symbolic_order(): + s, z = symbols("s, z") + uresult = 'Liₛ(z)' + aresult = 'polylog(s, z)' + assert pretty(polylog(s, z)) == aresult + assert upretty(polylog(s, z)) == uresult + # TODO: TBD polylog(s - 1, z) + uresult = 'Eₛ(z)' + aresult = 'expint(s, z)' + assert pretty(expint(s, z)) == aresult + assert upretty(expint(s, z)) == uresult + + + +def test_print_polylog_long_order_issue_25309(): + s, z = symbols("s, z") + ucode_str = \ +"""\ + ⎛ 2 ⎞\n\ +polylog⎝s , z⎠\ +""" + assert upretty(polylog(s**2, z)) == ucode_str + + +def test_print_lerchphi(): + # Part of issue 6013 + a = Symbol('a') + pretty(lerchphi(a, 1, 2)) + uresult = 'Φ(a, 1, 2)' + aresult = 'lerchphi(a, 1, 2)' + assert pretty(lerchphi(a, 1, 2)) == aresult + assert upretty(lerchphi(a, 1, 2)) == uresult + + +def test_issue_15583(): + + N = mechanics.ReferenceFrame('N') + result = '(n_x, n_y, n_z)' + e = pretty((N.x, N.y, N.z)) + assert e == result + + +def test_matrixSymbolBold(): + # Issue 15871 + def boldpretty(expr): + return xpretty(expr, use_unicode=True, wrap_line=False, mat_symbol_style="bold") + + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert boldpretty(trace(A)) == 'tr(𝐀)' + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert boldpretty(-A) == '-𝐀' + assert boldpretty(A - A*B - B) == '-𝐁 -𝐀⋅𝐁 + 𝐀' + assert boldpretty(-A*B - A*B*C - B) == '-𝐁 -𝐀⋅𝐁 -𝐀⋅𝐁⋅𝐂' + + A = MatrixSymbol("Addot", 3, 3) + assert boldpretty(A) == '𝐀̈' + omega = MatrixSymbol("omega", 3, 3) + assert boldpretty(omega) == 'ω' + omega = MatrixSymbol("omeganorm", 3, 3) + assert boldpretty(omega) == '‖ω‖' + + a = Symbol('alpha') + b = Symbol('b') + c = MatrixSymbol("c", 3, 1) + d = MatrixSymbol("d", 3, 1) + + assert boldpretty(a*B*c+b*d) == 'b⋅𝐝 + α⋅𝐁⋅𝐜' + + d = MatrixSymbol("delta", 3, 1) + B = MatrixSymbol("Beta", 3, 3) + + assert boldpretty(a*B*c+b*d) == 'b⋅δ + α⋅Β⋅𝐜' + + A = MatrixSymbol("A_2", 3, 3) + assert boldpretty(A) == '𝐀₂' + + +def test_center_accent(): + assert center_accent('a', '\N{COMBINING TILDE}') == 'ã' + assert center_accent('aa', '\N{COMBINING TILDE}') == 'aã' + assert center_accent('aaa', '\N{COMBINING TILDE}') == 'aãa' + assert center_accent('aaaa', '\N{COMBINING TILDE}') == 'aaãa' + assert center_accent('aaaaa', '\N{COMBINING TILDE}') == 'aaãaa' + assert center_accent('abcdefg', '\N{COMBINING FOUR DOTS ABOVE}') == 'abcd⃜efg' + + +def test_imaginary_unit(): + from sympy.printing.pretty import pretty # b/c it was redefined above + assert pretty(1 + I, use_unicode=False) == '1 + I' + assert pretty(1 + I, use_unicode=True) == '1 + ⅈ' + assert pretty(1 + I, use_unicode=False, imaginary_unit='j') == '1 + I' + assert pretty(1 + I, use_unicode=True, imaginary_unit='j') == '1 + ⅉ' + + raises(TypeError, lambda: pretty(I, imaginary_unit=I)) + raises(ValueError, lambda: pretty(I, imaginary_unit="kkk")) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert pretty(Identity(4)) == 'I' + assert upretty(Identity(4)) == '𝕀' + assert pretty(ZeroMatrix(2, 2)) == '0' + assert upretty(ZeroMatrix(2, 2)) == '𝟘' + assert pretty(OneMatrix(2, 2)) == '1' + assert upretty(OneMatrix(2, 2)) == '𝟙' + + +def test_pretty_misc_functions(): + assert pretty(LambertW(x)) == 'W(x)' + assert upretty(LambertW(x)) == 'W(x)' + assert pretty(LambertW(x, y)) == 'W(x, y)' + assert upretty(LambertW(x, y)) == 'W(x, y)' + assert pretty(airyai(x)) == 'Ai(x)' + assert upretty(airyai(x)) == 'Ai(x)' + assert pretty(airybi(x)) == 'Bi(x)' + assert upretty(airybi(x)) == 'Bi(x)' + assert pretty(airyaiprime(x)) == "Ai'(x)" + assert upretty(airyaiprime(x)) == "Ai'(x)" + assert pretty(airybiprime(x)) == "Bi'(x)" + assert upretty(airybiprime(x)) == "Bi'(x)" + assert pretty(fresnelc(x)) == 'C(x)' + assert upretty(fresnelc(x)) == 'C(x)' + assert pretty(fresnels(x)) == 'S(x)' + assert upretty(fresnels(x)) == 'S(x)' + assert pretty(Heaviside(x)) == 'Heaviside(x)' + assert upretty(Heaviside(x)) == 'θ(x)' + assert pretty(Heaviside(x, y)) == 'Heaviside(x, y)' + assert upretty(Heaviside(x, y)) == 'θ(x, y)' + assert pretty(dirichlet_eta(x)) == 'dirichlet_eta(x)' + assert upretty(dirichlet_eta(x)) == 'η(x)' + + +def test_hadamard_power(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + B = MatrixSymbol('B', m, n) + + # Testing printer: + expr = hadamard_power(A, n) + ascii_str = \ +"""\ + .n\n\ +A \ +""" + ucode_str = \ +"""\ + ∘n\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +A \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A*B.T, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +/ T\\ \n\ +\\A*B / \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +⎛ T⎞ \n\ +⎝A⋅B ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_17258(): + n = Symbol('n', integer=True) + assert pretty(Sum(n, (n, -oo, 1))) == \ + ' 1 \n'\ + ' __ \n'\ + ' \\ ` \n'\ + ' ) n\n'\ + ' /_, \n'\ + 'n = -oo ' + + assert upretty(Sum(n, (n, -oo, 1))) == \ +"""\ + 1 \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ n\n\ + ╱ \n\ + ‾‾‾ \n\ +n = -∞ \ +""" + + +def test_is_combining(): + line = "v̇_m" + assert [is_combining(sym) for sym in line] == \ + [False, True, False, False] + + +def test_issue_17616(): + assert pretty(pi**(1/exp(1))) == \ + ' / -1\\\n'\ + ' \\e /\n'\ + 'pi ' + + assert upretty(pi**(1/exp(1))) == \ + ' ⎛ -1⎞\n'\ + ' ⎝ℯ ⎠\n'\ + 'π ' + + assert pretty(pi**(1/pi)) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'pi ' + + assert upretty(pi**(1/pi)) == \ + ' 1\n'\ + ' ─\n'\ + ' π\n'\ + 'π ' + + assert pretty(pi**(1/EulerGamma)) == \ + ' 1 \n'\ + ' ----------\n'\ + ' EulerGamma\n'\ + 'pi ' + + assert upretty(pi**(1/EulerGamma)) == \ + ' 1\n'\ + ' ─\n'\ + ' γ\n'\ + 'π ' + + z = Symbol("x_17") + assert upretty(7**(1/z)) == \ + 'x₁₇___\n'\ + ' ╲╱ 7 ' + + assert pretty(7**(1/z)) == \ + 'x_17___\n'\ + ' \\/ 7 ' + + +def test_issue_17857(): + assert pretty(Range(-oo, oo)) == '{..., -1, 0, 1, ...}' + assert pretty(Range(oo, -oo, -1)) == '{..., 1, 0, -1, ...}' + + +def test_issue_18272(): + x = Symbol('x') + n = Symbol('n') + + assert upretty(ConditionSet(x, Eq(-x + exp(x), 0), S.Complexes)) == \ + '⎧ │ ⎛ x ⎞⎫\n'\ + '⎨x │ x ∊ ℂ ∧ ⎝-x + ℯ = 0⎠⎬\n'\ + '⎩ │ ⎭' + assert upretty(ConditionSet(x, Contains(n/2, Interval(0, oo)), FiniteSet(-n/2, n/2))) == \ + '⎧ │ ⎧-n n⎫ ⎛n ⎞⎫\n'\ + '⎨x │ x ∊ ⎨───, ─⎬ ∧ ⎜─ ∈ [0, ∞)⎟⎬\n'\ + '⎩ │ ⎩ 2 2⎭ ⎝2 ⎠⎭' + assert upretty(ConditionSet(x, Eq(Piecewise((1, x >= 3), (x/2 - 1/2, x >= 2), (1/2, x >= 1), + (x/2, True)) - 1/2, 0), Interval(0, 3))) == \ + '⎧ │ ⎛⎛⎧ 1 for x ≥ 3⎞ ⎞⎫\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪─ - 0.5 for x ≥ 2⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪2 ⎟ ⎟⎪\n'\ + '⎨x │ x ∊ [0, 3] ∧ ⎜⎜⎨ ⎟ - 0.5 = 0⎟⎬\n'\ + '⎪ │ ⎜⎜⎪ 0.5 for x ≥ 1⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ─ otherwise⎟ ⎟⎪\n'\ + '⎩ │ ⎝⎝⎩ 2 ⎠ ⎠⎭' + + +def test_Str(): + from sympy.core.symbol import Str + assert pretty(Str('x')) == 'x' + + +def test_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert pretty(Expectation(X)) == r'E[X]' + assert pretty(Variance(X)) == r'Var(X)' + assert pretty(Probability(X > 0)) == r'P(X > 0)' + Y = Normal("Y", mu, sigma) + assert pretty(Covariance(X, Y)) == 'Cov(X, Y)' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert upretty(piecewise_fold(fo)) == \ + '⎧ 2⋅sin(3⋅x) \n'\ + '⎪2⋅sin(x) - sin(2⋅x) + ────────── + … for n > -∞ ∧ n < ∞ ∧ n ≠ 0\n'\ + '⎨ 3 \n'\ + '⎪ \n'\ + '⎩ 0 otherwise ' + assert pretty(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert pretty(m) == 'M' + p = Patch('P', m) + assert pretty(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert pretty(rect) == "rect" + b = BaseScalarField(rect, 0) + assert pretty(b) == "x" + + +def test_deprecated_prettyForm(): + with warns_deprecated_sympy(): + from sympy.printing.pretty.pretty_symbology import xstr + assert xstr(1) == '1' + + with warns_deprecated_sympy(): + from sympy.printing.pretty.stringpict import prettyForm + p = prettyForm('s', unicode='s') + + with warns_deprecated_sympy(): + assert p.unicode == p.s == 's' + + +def test_center(): + assert center('1', 2) == '1 ' + assert center('1', 3) == ' 1 ' + assert center('1', 3, '-') == '-1-' + assert center('1', 5, '-') == '--1--' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/python.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/python.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6862574d99db90f289de65144c7122ed2d731a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/python.py @@ -0,0 +1,92 @@ +import keyword as kw +import sympy +from .repr import ReprPrinter +from .str import StrPrinter + +# A list of classes that should be printed using StrPrinter +STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow") + + +class PythonPrinter(ReprPrinter, StrPrinter): + """A printer which converts an expression into its Python interpretation.""" + + def __init__(self, settings=None): + super().__init__(settings) + self.symbols = [] + self.functions = [] + + # Create print methods for classes that should use StrPrinter instead + # of ReprPrinter. + for name in STRPRINT: + f_name = "_print_%s" % name + f = getattr(StrPrinter, f_name) + setattr(PythonPrinter, f_name, f) + + def _print_Function(self, expr): + func = expr.func.__name__ + if not hasattr(sympy, func) and func not in self.functions: + self.functions.append(func) + return StrPrinter._print_Function(self, expr) + + # procedure (!) for defining symbols which have be defined in print_python() + def _print_Symbol(self, expr): + symbol = self._str(expr) + if symbol not in self.symbols: + self.symbols.append(symbol) + return StrPrinter._print_Symbol(self, expr) + + def _print_module(self, expr): + raise ValueError('Modules in the expression are unacceptable') + + +def python(expr, **settings): + """Return Python interpretation of passed expression + (can be passed to the exec() function without any modifications)""" + + printer = PythonPrinter(settings) + exprp = printer.doprint(expr) + + result = '' + # Returning found symbols and functions + renamings = {} + for symbolname in printer.symbols: + # Remove curly braces from subscripted variables + if '{' in symbolname: + newsymbolname = symbolname.replace('{', '').replace('}', '') + renamings[sympy.Symbol(symbolname)] = newsymbolname + else: + newsymbolname = symbolname + + # Escape symbol names that are reserved Python keywords + if kw.iskeyword(newsymbolname): + while True: + newsymbolname += "_" + if (newsymbolname not in printer.symbols and + newsymbolname not in printer.functions): + renamings[sympy.Symbol( + symbolname)] = sympy.Symbol(newsymbolname) + break + result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n' + + for functionname in printer.functions: + newfunctionname = functionname + # Escape function names that are reserved Python keywords + if kw.iskeyword(newfunctionname): + while True: + newfunctionname += "_" + if (newfunctionname not in printer.symbols and + newfunctionname not in printer.functions): + renamings[sympy.Function( + functionname)] = sympy.Function(newfunctionname) + break + result += newfunctionname + ' = Function(\'' + functionname + '\')\n' + + if renamings: + exprp = expr.subs(renamings) + result += 'e = ' + printer._str(exprp) + return result + + +def print_python(expr, **settings): + """Print output of python() function""" + print(python(expr, **settings)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/repr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/repr.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4b756abbab77c3eb0fd77ee1f0bd97382c36fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/repr.py @@ -0,0 +1,339 @@ +""" +A Printer for generating executable code. + +The most important function here is srepr that returns a string so that the +relation eval(srepr(expr))=expr holds in an appropriate environment. +""" + +from __future__ import annotations +from typing import Any + +from sympy.core.function import AppliedUndef +from sympy.core.mul import Mul +from mpmath.libmp import repr_dps, to_str as mlib_to_str + +from .printer import Printer, print_function + + +class ReprPrinter(Printer): + printmethod = "_sympyrepr" + + _default_settings: dict[str, Any] = { + "order": None, + "perm_cyclic" : True, + } + + def reprify(self, args, sep): + """ + Prints each item in `args` and joins them with `sep`. + """ + return sep.join([self.doprint(item) for item in args]) + + def emptyPrinter(self, expr): + """ + The fallback printer. + """ + if isinstance(expr, str): + return expr + elif hasattr(expr, "__srepr__"): + return expr.__srepr__() + elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"): + l = [] + for o in expr.args: + l.append(self._print(o)) + return expr.__class__.__name__ + '(%s)' % ', '.join(l) + elif hasattr(expr, "__module__") and hasattr(expr, "__name__"): + return "<'%s.%s'>" % (expr.__module__, expr.__name__) + else: + return str(expr) + + def _print_Add(self, expr, order=None): + args = self._as_ordered_terms(expr, order=order) + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Cycle(self, expr): + return expr.__repr__() + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + if not expr.size: + return 'Permutation()' + # before taking Cycle notation, see if the last element is + # a singleton and move it to the head of the string + s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):] + last = s.rfind('(') + if not last == 0 and ',' not in s[last:]: + s = s[last:] + s[:last] + return 'Permutation%s' %s + else: + s = expr.support() + if not s: + if expr.size < 5: + return 'Permutation(%s)' % str(expr.array_form) + return 'Permutation([], size=%s)' % expr.size + trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size + use = full = str(expr.array_form) + if len(trim) < len(full): + use = trim + return 'Permutation(%s)' % use + + def _print_Function(self, expr): + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.args]) + return r + + def _print_Heaviside(self, expr): + # Same as _print_Function but uses pargs to suppress default value for + # 2nd arg. + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs]) + return r + + def _print_FunctionClass(self, expr): + if issubclass(expr, AppliedUndef): + return 'Function(%r)' % (expr.__name__) + else: + return expr.__name__ + + def _print_Half(self, expr): + return 'Rational(1, 2)' + + def _print_RationalConstant(self, expr): + return str(expr) + + def _print_AtomicExpr(self, expr): + return str(expr) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Integer(self, expr): + return 'Integer(%i)' % expr.p + + def _print_Complexes(self, expr): + return 'Complexes' + + def _print_Integers(self, expr): + return 'Integers' + + def _print_Naturals(self, expr): + return 'Naturals' + + def _print_Naturals0(self, expr): + return 'Naturals0' + + def _print_Rationals(self, expr): + return 'Rationals' + + def _print_Reals(self, expr): + return 'Reals' + + def _print_EmptySet(self, expr): + return 'EmptySet' + + def _print_UniversalSet(self, expr): + return 'UniversalSet' + + def _print_EmptySequence(self, expr): + return 'EmptySequence' + + def _print_list(self, expr): + return "[%s]" % self.reprify(expr, ", ") + + def _print_dict(self, expr): + sep = ", " + dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()] + return "{%s}" % sep.join(dict_kvs) + + def _print_set(self, expr): + if not expr: + return "set()" + return "{%s}" % self.reprify(expr, ", ") + + def _print_MatrixBase(self, expr): + # special case for some empty matrices + if (expr.rows == 0) ^ (expr.cols == 0): + return '%s(%s, %s, %s)' % (expr.__class__.__name__, + self._print(expr.rows), + self._print(expr.cols), + self._print([])) + l = [] + for i in range(expr.rows): + l.append([]) + for j in range(expr.cols): + l[-1].append(expr[i, j]) + return '%s(%s)' % (expr.__class__.__name__, self._print(l)) + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + def _print_NaN(self, expr): + return "nan" + + def _print_Mul(self, expr, order=None): + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Rational(self, expr): + return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q)) + + def _print_PythonRational(self, expr): + return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q) + + def _print_Fraction(self, expr): + return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator)) + + def _print_Float(self, expr): + r = mlib_to_str(expr._mpf_, repr_dps(expr._prec)) + return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec) + + def _print_Sum2(self, expr): + return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i), + self._print(expr.a), self._print(expr.b)) + + def _print_Str(self, s): + return "%s(%s)" % (s.__class__.__name__, self._print(s.name)) + + def _print_Symbol(self, expr): + d = expr._assumptions_orig + # print the dummy_index like it was an assumption + if expr.is_Dummy: + d = d.copy() + d['dummy_index'] = expr.dummy_index + + if d == {}: + return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name)) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.name), ', '.join(attr)) + + def _print_CoordinateSymbol(self, expr): + d = expr._assumptions.generator + + if d == {}: + return "%s(%s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index) + ) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index), + ', '.join(attr) + ) + + def _print_Predicate(self, expr): + return "Q.%s" % expr.name + + def _print_AppliedPredicate(self, expr): + # will be changed to just expr.args when args overriding is removed + args = expr._args + return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", ")) + + def _print_str(self, expr): + return repr(expr) + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.reprify(expr, ", ") + + def _print_WildFunction(self, expr): + return "%s('%s')" % (expr.__class__.__name__, expr.name) + + def _print_AlgebraicNumber(self, expr): + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.root), self._print(expr.coeffs())) + + def _print_PolyRing(self, ring): + return "%s(%s, %s, %s)" % (ring.__class__.__name__, + self._print(ring.symbols), self._print(ring.domain), self._print(ring.order)) + + def _print_FracField(self, field): + return "%s(%s, %s, %s)" % (field.__class__.__name__, + self._print(field.symbols), self._print(field.domain), self._print(field.order)) + + def _print_PolyElement(self, poly): + terms = list(poly.terms()) + terms.sort(key=poly.ring.order, reverse=True) + return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms)) + + def _print_FracElement(self, frac): + numer_terms = list(frac.numer.terms()) + numer_terms.sort(key=frac.field.order, reverse=True) + denom_terms = list(frac.denom.terms()) + denom_terms.sort(key=frac.field.order, reverse=True) + numer = self._print(numer_terms) + denom = self._print(denom_terms) + return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom) + + def _print_FractionField(self, domain): + cls = domain.__class__.__name__ + field = self._print(domain.field) + return "%s(%s)" % (cls, field) + + def _print_PolynomialRingBase(self, ring): + cls = ring.__class__.__name__ + dom = self._print(ring.domain) + gens = ', '.join(map(self._print, ring.gens)) + order = str(ring.order) + if order != ring.default_order: + orderstr = ", order=" + order + else: + orderstr = "" + return "%s(%s, %s%s)" % (cls, dom, gens, orderstr) + + def _print_DMP(self, p): + cls = p.__class__.__name__ + rep = self._print(p.to_list()) + dom = self._print(p.dom) + return "%s(%s, %s)" % (cls, rep, dom) + + def _print_MonogenicFiniteExtension(self, ext): + # The expanded tree shown by srepr(ext.modulus) + # is not practical. + return "FiniteExtension(%s)" % str(ext.modulus) + + def _print_ExtensionElement(self, f): + rep = self._print(f.rep) + ext = self._print(f.ext) + return "ExtElem(%s, %s)" % (rep, ext) + +@print_function(ReprPrinter) +def srepr(expr, **settings): + """return expr in repr form""" + return ReprPrinter(settings).doprint(expr) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d38061a31b4f391927f2a342377754f26ffe03bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0801e39cf2c517c6fd1496ed09ce5974db58ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb62013f9de3e79a402c5a6dfdf543ff876742d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c61793b1c7b9c300dafc8f7e84713584f9cde1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dabfd1b7d7948f6334dba380d7a63c0373fc99b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051a2bc9116aae3efe81a5dca54c32b52a0a484a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77767933125f5dc08fbe65b8f2304bc008920e47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1d57851d91b0adc1b1253baf60af74950b8101a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69af7986bd0a5565ac1b488865b5d33b85092342 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f47bf0e5a4596f4fd13e844c60415eec84f8cce4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53367892c2fcef82121a4590d97b94f967b0ec00 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f688c0792e7d9a4d7eb15acfe3f5d74bafade5e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9725a8cdda6bb88368f9a8904ac1d85fa376c3a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad908d89e3a8c69df35deddae25a198dbe0b7758 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a4a4723f07b0aaa88c0abc251e28dcf4f1f2550 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0660ddb7ad2bdf47feeb5ac5bf28345e87011a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0bc59bf7643836519cecad09b77e043b925e57 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73f0d5db3d2880e916a5f41bb66e6c8fbe5f56e0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..407b2dcf05a8a72df8e96f727a31a5d574e9c44e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baaff821504439fed735bfe72b4f69acd977b500 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3238febbd84f0be595332d911731686f75e5a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3966c4f8b567e30cd67adc7daaabb395ba58b51d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa478c1bc91a8384b86a99544dc5e3e8c0e970c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c77d92e411ff598721a9de66c22b7236e07b8e9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cab3d83fd9b4c00d0cf86929da45dd5fac694985 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e64175e9742f0fc68cd9e34d516711206b2cbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8643f65d7e041fec5e25e0c446508c64a31bc95d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e078fa117d9880dc10da197680e258f3faa9d6ed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7af0dee345e6eb611af113598b0b334ece4803d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1784ced26a4fdba3e1e9e6e7ff1dc30b4e5a6813 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc98d7191811efea157e26f2c7ab2d8fbf14c790 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..750812b820b6c8e423d71edb3a3f4780dc95d6f0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f623d1ebae0f163976b5be44d90819216e93b1d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d7ed59e84294866e1163c0480bd02eb6e3f6c5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_aesaracode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_aesaracode.py new file mode 100644 index 0000000000000000000000000000000000000000..13308af65b382e77de33302bcd75344d2b00adbf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_aesaracode.py @@ -0,0 +1,633 @@ +""" +Important note on tests in this module - the Aesara printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the aesara_code_ and +aesara_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +from sympy.utilities.exceptions import ignore_warnings + + +aesaralogger = logging.getLogger('aesara.configdefaults') +aesaralogger.setLevel(logging.CRITICAL) +aesara = import_module('aesara') +aesaralogger.setLevel(logging.WARNING) + + +if aesara: + import numpy as np + aet = aesara.tensor + from aesara.scalar.basic import ScalarType + from aesara.graph.basic import Variable + from aesara.tensor.var import TensorVariable + from aesara.tensor.elemwise import Elemwise, DimShuffle + from aesara.tensor.math import Dot + + from sympy.printing.aesaracode import true_divide + + xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.aesaracode import (aesara_code, dim_handling, + aesara_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def aesara_code_(expr, **kwargs): + """ Wrapper for aesara_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_code(expr, **kwargs) + +def aesara_function_(inputs, outputs, **kwargs): + """ Wrapper for aesara_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Aesara Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + outs = list(map(aesara_code_, exprs)) + ins = list(aesara.graph.basic.graph_inputs(outs)) + ins, outs = aesara.graph.basic.clone(ins, outs) + return aesara.graph.fg.FunctionGraph(ins, outs) + + +def aesara_simplify(fgraph): + """ Simplify a Aesara Computation. + + Parameters + ========== + fgraph : aesara.graph.fg.FunctionGraph + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + mode = aesara.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.rewrite(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Aesara objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = aesara.printing.debugprint(a, file='str') + bstr = aesara.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'aesara.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Aesara + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, aesara_code_(x)) + assert theq(yt, aesara_code_(y)) + assert theq(zt, aesara_code_(z)) + assert theq(Xt, aesara_code_(X)) + assert theq(Yt, aesara_code_(Y)) + assert theq(Zt, aesara_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a aesara variable. """ + xx = aesara_code_(x) + assert isinstance(xx, Variable) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = aesara_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a aesara variable. """ + XX = aesara_code_(X) + assert isinstance(XX, TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + aesara_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = aesara_code_(f_t) + assert isinstance(ftt, TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = aesara_code_(expr) + assert comp.owner.op == aesara.tensor.add + +def test_trig(): + assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) + assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = aesara_code_(expr) + expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = aesara_code_(expr) + assert isinstance(expr_t.owner.op, Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(aesara_code_(expr).owner.op, Elemwise) + + +def test_Rationals(): + assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3)) + assert theq(aesara_code_(S.Half), true_divide(1, 2)) + +def test_Integers(): + assert aesara_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert aesara_code_(sy.factorial(n)) + +def test_Derivative(): + with ignore_warnings(UserWarning): + simp = lambda expr: aesara_simplify(fgraph_of(expr)) + assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(aesara.grad(aet.sin(xt), xt))) + + +def test_aesara_function_simple(): + """ Test aesara_function() with single output. """ + f = aesara_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_aesara_function_multi(): + """ Test aesara_function() with multiple outputs. """ + f = aesara_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_aesara_function_numpy(): + """ Test aesara_function() vs Numpy implementation. """ + f = aesara_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_aesara_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = aesara_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_aesara_function_kwargs(): + """ + Test passing additional kwargs from aesara_function() to aesara.function(). + """ + import numpy as np + f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_aesara_function_scalar(): + """ Test the "scalar" argument to aesara_function(). """ + from aesara.compile.function.types import Function + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the aesara_function attribute is set whether wrapped or not + assert isinstance(f.aesara_function, Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.aesara_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_aesara_function_bad_kwarg(): + """ + Passing an unknown keyword argument to aesara_function() should raise an + exception. + """ + raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = aesara_code_(Y, cache=cache) + + s = ScalarType('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) + # == doesn't work in Aesara like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].data == i for i in range(1, 7)) + + k = sy.Symbol('k') + aesara_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = aesara_code_(Block) + solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), + aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + from aesara.tensor.basic import Join + + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = aesara_code_(X) + assert isinstance(tX, TensorVariable) + assert isinstance(tX.owner.op, Join) + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = aesara_code_(s1, cache=cache) + + # Test hit with same instance + assert aesara_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert aesara_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert aesara_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.aesaracode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = aesara_code(s) + assert aesara_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = aesara_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert aesara_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = aesara_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + aesara_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = aesara_code_(expr) + + # Iterate through variables in the Aesara computational graph that the + # printed expression depends on + seen = set() + for v in aesara.graph.basic.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = aesara_code_(expr) + assert result.owner.op == aet.switch + + expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = aesara_code_(expr) + expected = aet.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = aesara_code_(expr) + expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ + aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) + # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement + assert theq(aesara_code_(x > y), xt > yt) + assert theq(aesara_code_(x < y), xt < yt) + assert theq(aesara_code_(x >= y), xt >= yt) + assert theq(aesara_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + dtypes = {x:'complex128', y:'complex128'} + with warns_deprecated_sympy(): + xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes) + from sympy.functions.elementary.complexes import conjugate + from aesara.tensor import as_tensor_variable as atv + from aesara.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj())) + assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = aesara_function([],[1+1j]) + assert(tf()==1+1j) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_c.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_c.py new file mode 100644 index 0000000000000000000000000000000000000000..626e7b6f244ea3227b886cd897d327f5d7bf66ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_c.py @@ -0,0 +1,888 @@ +from sympy.core import ( + S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan, + Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr +) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.functions import ( + Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf, + erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh, + sqrt, tan, tanh, fibonacci, lucas +) +from sympy.sets import Range +from sympy.logic import ITE, Implies, Equivalent +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises, XFAIL +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros +from sympy.codegen.ast import ( + AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const, + While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return, + real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString +) +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt, isnan, isinf +from sympy.codegen.cnodes import restrict +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix + +from sympy.printing.codeprinter import ccode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _ccode(self, printer): + return "fabs(%s)" % printer._print(self.args[0]) + + assert ccode(fabs(x)) == "fabs(x)" + + +def test_ccode_sqrt(): + assert ccode(sqrt(x)) == "sqrt(x)" + assert ccode(x**0.5) == "sqrt(x)" + assert ccode(sqrt(x)) == "sqrt(x)" + + +def test_ccode_Pow(): + assert ccode(x**3) == "pow(x, 3)" + assert ccode(x**(y**3)) == "pow(x, pow(y, 3))" + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)" + assert ccode(x**-1.0) == '1.0/x' + assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)' + assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)' + assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)' + _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp), + (lambda base, exp: base != 2, 'pow')] + # Related to gh-11353 + assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)' + assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)' + # For issue 14160 + assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_ccode_Max(): + # Test for gh-11926 + assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_ccode_Min_performance(): + #Shouldn't take more than a few seconds + big_min = Min(*symbols('a[0:50]')) + for curr_standard in ('c89', 'c99', 'c11'): + output = ccode(big_min, standard=curr_standard) + assert output.count('(') == output.count(')') + + +def test_ccode_constants_mathh(): + assert ccode(exp(1)) == "M_E" + assert ccode(pi) == "M_PI" + assert ccode(oo, standard='c89') == "HUGE_VAL" + assert ccode(-oo, standard='c89') == "-HUGE_VAL" + assert ccode(oo) == "INFINITY" + assert ccode(-oo, standard='c99') == "-INFINITY" + assert ccode(pi, type_aliases={real: float80}) == "M_PIl" + + +def test_ccode_constants_other(): + assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert ccode( + 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_ccode_Rational(): + assert ccode(Rational(3, 7)) == "3.0/7.0" + assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(Rational(18, 9)) == "2" + assert ccode(Rational(3, -7)) == "-3.0/7.0" + assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L" + assert ccode(Rational(-3, -7)) == "3.0/7.0" + assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L" + assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x" + assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x" + + +def test_ccode_Integer(): + assert ccode(Integer(67)) == "67" + assert ccode(Integer(-1)) == "-1" + + +def test_ccode_functions(): + assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_ccode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert ccode( + g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert ccode(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i y" + assert ccode(Ge(x, y)) == "x >= y" + + +def test_ccode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + "))") + assert ccode(expr, assign_to="c") == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + assert ccode(expr, assign_to='c') == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else if (x < 2) {\n" + " c = x + 1;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: ccode(expr)) + + +def test_ccode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + assert ccode(expr) == ( + "(((x != 0) ? (\n" + " sin(x)/x\n" + ")\n" + ": (\n" + " 1\n" + ")))") + + +def test_ccode_Piecewise_deep(): + p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == ( + "2*((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + assert ccode(expr) == ( + "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1") + assert ccode(expr, assign_to='c') == ( + "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1;") + + +def test_ccode_ITE(): + expr = ITE(x < 1, y, z) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " y\n" + ")\n" + ": (\n" + " z\n" + "))") + + +def test_ccode_settings(): + raises(TypeError, lambda: ccode(sin(x), method="garbage")) + + +def test_ccode_Indexed(): + s, n, m, o = symbols('s n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + A = IndexedBase('A')[i, j] + B = IndexedBase('B')[i, j, k] + + p = C99CodePrinter() + + assert p._print_Indexed(x) == 'x[j]' + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + A = IndexedBase('A', shape=(5,3))[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (3*i + j) + + A = IndexedBase('A', shape=(5,3), strides='F')[i, j] + assert ccode(A) == 'A[%s]' % (i + 5*j) + + A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j] + assert ccode(A) == 'A[o + s*j + i]' + + Abase = IndexedBase('A', strides=(s, m, n), offset=o) + assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]' + assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]' + + +def test_Element(): + assert ccode(Element('x', 'ij')) == 'x[i][j]' + assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]' + assert ccode(Element('x', (3,))) == 'x[3]' + assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]' + + +def test_ccode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = ccode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_ccode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert ccode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert ccode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert ccode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_sparse_matrix(): + # gh-15791 + with raises(PrintMethodNotImplementedError): + ccode(SparseMatrix([[1, 2, 3]])) + + assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]])) + + + +def test_ccode_reserved_words(): + x, y = symbols('x, if') + with raises(ValueError): + ccode(y**2, error_on_reserved=True, standard='C99') + assert ccode(y**2) == 'pow(if_, 2)' + assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x' + assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)' + + +def test_ccode_sign(): + expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))' + expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))' + expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))' + assert ccode(expr1) == ref1 + assert ccode(expr1, 'z') == 'z = %s;' % ref1 + assert ccode(expr2) == ref2 + assert ccode(expr3) == ref3 + +def test_ccode_Assignment(): + assert ccode(Assignment(x, y + z)) == 'x = y + z;' + assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_ccode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n" + " y *= x;\n" + "}") + +def test_ccode_Max_Min(): + assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)' + assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)' + assert ccode(Min(x, 0, sqrt(x)), standard='c89') == ( + '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))' + ) + +def test_ccode_standard(): + assert ccode(expm1(x), standard='c99') == 'expm1(x)' + assert ccode(nan, standard='c99') == 'NAN' + assert ccode(float('nan'), standard='c99') == 'NAN' + + +def test_C89CodePrinter(): + c89printer = C89CodePrinter() + assert c89printer.language == 'C' + assert c89printer.standard == 'C89' + assert 'void' in c89printer.reserved_words + assert 'template' not in c89printer.reserved_words + assert c89printer.doprint(log10(x)) == 'log10(x)' + + +def test_C99CodePrinter(): + assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' + assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' + assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' + assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' + assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' + assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' + assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. + assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' + assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' + assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' + assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' + c99printer = C99CodePrinter() + assert c99printer.language == 'C' + assert c99printer.standard == 'C99' + assert 'restrict' in c99printer.reserved_words + assert 'using' not in c99printer.reserved_words + + +@XFAIL +def test_C99CodePrinter__precision_f80(): + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)' + + +def test_C99CodePrinter__precision(): + n = symbols('n', integer=True) + p = symbols('p', integer=True, positive=True) + f32_printer = C99CodePrinter({"type_aliases": {real: float32}}) + f64_printer = C99CodePrinter({"type_aliases": {real: float64}}) + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' + assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' + assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' + + for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): + def check(expr, ref): + assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) + check(Abs(n), 'abs(n)') + check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') + check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') + check(exp(x*8.0), 'exp{s}(8.0{S}*x)') + check(exp2(x), 'exp2{s}(x)') + check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') + check(Mod(p, 2), 'p % 2') + check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)') + check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') + check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') + check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') + check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') + check(log2(x*8.0), 'log2{s}(8.0{S}*x)') + check(log1p(x), 'log1p{s}(x)') + check(2**x, 'pow{s}(2, x)') + check(2.0**x, 'pow{s}(2.0{S}, x)') + check(x**3, 'pow{s}(x, 3)') + check(x**4.0, 'pow{s}(x, 4.0{S})') + check(sqrt(3+x), 'sqrt{s}(x + 3)') + check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') + check(hypot(x, y), 'hypot{s}(x, y)') + check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') + check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') + check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') + check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') + check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') + check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') + check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') + + check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') + check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') + check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') + check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') + check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') + check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') + check(erf(42.*x), 'erf{s}(42.0{S}*x)') + check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') + check(gamma(x), 'tgamma{s}(x)') + check(loggamma(x), 'lgamma{s}(x)') + + check(ceiling(x + 2.), "ceil{s}(x) + 2") + check(floor(x + 2.), "floor{s}(x) + 2") + check(fma(x, y, -z), 'fma{s}(x, y, -z)') + check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') + check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)') + + +def test_get_math_macros(): + macros = get_math_macros() + assert macros[exp(1)] == 'M_E' + assert macros[1/Sqrt(2)] == 'M_SQRT1_2' + + +def test_ccode_Declaration(): + i = symbols('i', integer=True) + var1 = Variable(i, type=Type.from_expr(i)) + dcl1 = Declaration(var1) + assert ccode(dcl1) == 'int i' + + var2 = Variable(x, type=float32, attrs={value_const}) + dcl2a = Declaration(var2) + assert ccode(dcl2a) == 'const float x' + dcl2b = var2.as_Declaration(value=pi) + assert ccode(dcl2b) == 'const float x = M_PI' + + var3 = Variable(y, type=Type('bool')) + dcl3 = Declaration(var3) + printer = C89CodePrinter() + assert 'stdbool.h' not in printer.headers + assert printer.doprint(dcl3) == 'bool y' + assert 'stdbool.h' in printer.headers + + u = symbols('u', real=True) + ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict}) + dcl4 = Declaration(ptr4) + assert ccode(dcl4) == 'double * const restrict u' + + var5 = Variable(x, Type('__float128'), attrs={value_const}) + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const __float128 x' + var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs) + dcl5b = Declaration(var5b) + assert ccode(dcl5b) == 'const __float128 x = M_PI' + + +def test_C99CodePrinter_custom_type(): + # We will look at __float128 (new in glibc 2.26) + f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp) + p128 = C99CodePrinter({ + "type_aliases": {real: f128}, + "type_literal_suffixes": {f128: 'Q'}, + "type_func_suffixes": {f128: 'f128'}, + "type_math_macro_suffixes": { + real: 'f128', + f128: 'f128' + }, + "type_macros": { + f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',) + } + }) + assert p128.doprint(x) == 'x' + assert not p128.headers + assert not p128.libraries + assert not p128.macros + assert p128.doprint(2.0) == '2.0Q' + assert not p128.headers + assert not p128.libraries + assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'} + + assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q' + assert p128.doprint(sin(x)) == 'sinf128(x)' + assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)' + assert p128.doprint(x**-1.0) == '1.0Q/x' + + var5 = Variable(x, f128, attrs={value_const}) + + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const _Float128 x' + var5b = Variable(x, f128, pi, attrs={value_const}) + dcl5b = Declaration(var5b) + assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128' + var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const}) + dcl5c = Declaration(var5b) + assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(ccode(A[0, 0]) == "A[0]") + assert(ccode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(ccode(F) == "(A - B)[0]") + +def test_ccode_math_macros(): + assert ccode(z + exp(1)) == 'z + M_E' + assert ccode(z + log2(exp(1))) == 'z + M_LOG2E' + assert ccode(z + 1/log(2)) == 'z + M_LOG2E' + assert ccode(z + log(2)) == 'z + M_LN2' + assert ccode(z + log(10)) == 'z + M_LN10' + assert ccode(z + pi) == 'z + M_PI' + assert ccode(z + pi/2) == 'z + M_PI_2' + assert ccode(z + pi/4) == 'z + M_PI_4' + assert ccode(z + 1/pi) == 'z + M_1_PI' + assert ccode(z + 2/pi) == 'z + M_2_PI' + assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + Sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2' + assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2' + + +def test_ccode_Type(): + assert ccode(Type('float')) == 'float' + assert ccode(intc) == 'int' + + +def test_ccode_codegen_ast(): + # Note that C only allows comments of the form /* ... */, double forward + # slash is not standard C, and some C compilers will grind to a halt upon + # encountering them. + assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not // + assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == ( + 'while (fabs(x) > 1) {\n' + ' x -= 1;\n' + '}' + ) + assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == ( + '{\n' + ' x += 1;\n' + '}' + ) + inp_x = Declaration(Variable(x, type=real)) + assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)' + assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == ( + 'double pwer(double x){\n' + ' x = pow(x, 2);\n' + '}' + ) + + # Elements of CodeBlock are formatted as statements: + block = CodeBlock( + x, + Print([x, y], "%d %d"), + Print([QuotedString('hello'), y], "%s %d", file=stderr), + FunctionCall('pwer', [x]), + Return(x), + ) + assert ccode(block) == '\n'.join([ + 'x;', + 'printf("%d %d", x, y);', + 'fprintf(stderr, "%s %d", "hello", y);', + 'pwer(x);', + 'return x;', + ]) + +def test_ccode_UnevaluatedExpr(): + assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y" + assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955 + w = symbols('w') + assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)" + + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert ccode(expr) == "exp(p + (q + r))" + + +def test_ccode_array_like_containers(): + assert ccode([2,3,4]) == "{2, 3, 4}" + assert ccode((2,3,4)) == "{2, 3, 4}" + +def test_ccode__isinf_isnan(): + assert ccode(isinf(x)) == 'isinf(x)' + assert ccode(isnan(x)) == 'isnan(x)' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_codeprinter.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_codeprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..4b077037eb84e218fcfd4a05fc03e40b211e45b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_codeprinter.py @@ -0,0 +1,77 @@ +from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError +from sympy.core import symbols +from sympy.core.symbol import Dummy +from sympy.testing.pytest import raises +from sympy import cos +from sympy.utilities.lambdify import lambdify +from math import cos as math_cos +from sympy.printing.lambdarepr import LambdaPrinter + + +def setup_test_printer(**kwargs): + p = CodePrinter(settings=kwargs) + p._not_supported = set() + p._number_symbols = set() + return p + + +def test_print_Dummy(): + d = Dummy('d') + p = setup_test_printer() + assert p._print_Dummy(d) == "d_%i" % d.dummy_index + +def test_print_Symbol(): + + x, y = symbols('x, if') + + p = setup_test_printer() + assert p._print(x) == 'x' + assert p._print(y) == 'if' + + p.reserved_words.update(['if']) + assert p._print(y) == 'if_' + + p = setup_test_printer(error_on_reserved=True) + p.reserved_words.update(['if']) + with raises(ValueError): + p._print(y) + + p = setup_test_printer(reserved_word_suffix='_He_Man') + p.reserved_words.update(['if']) + assert p._print(y) == 'if_He_Man' + + +def test_lambdify_LaTeX_symbols_issue_23374(): + # Create symbols with Latex style names + x1, x2 = symbols("x_{1} x_2") + + # Lambdify the function + f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2)) + + # Test that the function works correctly (numerically) + assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2) + + # Explicitly generate a custom printer to verify the naming convention + p = LambdaPrinter() + expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2)) + assert 'x_1' in expr_str + assert 'x_2' in expr_str + + +def test_issue_15791(): + class CrashingCodePrinter(CodePrinter): + def emptyPrinter(self, obj): + raise NotImplementedError + + from sympy.matrices import ( + MutableSparseMatrix, + ImmutableSparseMatrix, + ) + + c = CrashingCodePrinter() + + # these should not silently succeed + with raises(PrintMethodNotImplementedError): + c.doprint(ImmutableSparseMatrix(2, 2, {})) + with raises(PrintMethodNotImplementedError): + c.doprint(MutableSparseMatrix(2, 2, {})) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_conventions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f1fa8532f96130828b89d1ba5ba11fd5bed7a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_conventions.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.functions.special.bessel import besselj +from sympy.functions.special.polynomials import legendre +from sympy.functions.combinatorial.numbers import bell +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.testing.pytest import XFAIL + +def test_super_sub(): + assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"]) + assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"]) + assert split_super_sub("beta_13") == ("beta", [], ["13"]) + assert split_super_sub("x_a_b") == ("x", [], ["a", "b"]) + assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"]) + assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"]) + assert split_super_sub("x_a_1") == ("x", [], ["a", "1"]) + assert split_super_sub("x_1_a") == ("x", [], ["1", "a"]) + assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_11^a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_11__a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"]) + assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("alpha_11") == ("alpha", [], ["11"]) + assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"]) + assert split_super_sub("w1") == ("w", [], ["1"]) + assert split_super_sub("w𝟙") == ("w", [], ["𝟙"]) + assert split_super_sub("w11") == ("w", [], ["11"]) + assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"]) + assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"]) + assert split_super_sub("w1^a") == ("w", ["a"], ["1"]) + assert split_super_sub("ω1") == ("ω", [], ["1"]) + assert split_super_sub("ω11") == ("ω", [], ["11"]) + assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"]) + assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"]) + assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"]) + assert split_super_sub("") == ("", [], []) + + +def test_requires_partial(): + x, y, z, t, nu = symbols('x y z t nu') + n = symbols('n', integer=True) + + f = x * y + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, y)) is True + + ## integrating out one of the variables + assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + ## bessel function with smooth parameter + f = besselj(nu, x) + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, nu)) is True + + ## bessel function with integer parameter + f = besselj(n, x) + assert requires_partial(Derivative(f, x)) is False + # this is not really valid (differentiating with respect to an integer) + # but there's no reason to use the partial derivative symbol there. make + # sure we don't throw an exception here, though + assert requires_partial(Derivative(f, n)) is False + + ## bell polynomial + f = bell(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + ## legendre polynomial + f = legendre(0, x) + assert requires_partial(Derivative(f, x)) is False + + f = legendre(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + f = x ** n + assert requires_partial(Derivative(f, x)) is False + + assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + # parametric equation + f = (exp(t), cos(t)) + g = sum(f) + assert requires_partial(Derivative(g, t)) is False + + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f(x), x)) is False + assert requires_partial(Derivative(f(x), y)) is False + assert requires_partial(Derivative(f(x, y), x)) is True + assert requires_partial(Derivative(f(x, y), y)) is True + assert requires_partial(Derivative(f(x, y), z)) is True + assert requires_partial(Derivative(f(x, y), x, y)) is True + +@XFAIL +def test_requires_partial_unspecified_variables(): + x, y = symbols('x y') + # function of unspecified variables + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f, x)) is False + assert requires_partial(Derivative(f, x, y)) is True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cupy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cupy.py new file mode 100644 index 0000000000000000000000000000000000000000..cf111ec1623390a3dbbf489235d2ed387624a36c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cupy.py @@ -0,0 +1,56 @@ +from sympy.concrete.summations import Sum +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.utilities.lambdify import lambdify +from sympy.abc import x, i, a, b +from sympy.codegen.numpy_nodes import logaddexp +from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +cp = import_module('cupy') + +def test_cupy_print(): + prntr = CuPyPrinter() + assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)' + assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)' + assert prntr.doprint(log(x)) == 'cupy.log(x)' + assert prntr.doprint("acos(x)") == 'cupy.arccos(x)' + assert prntr.doprint("exp(x)") == 'cupy.exp(x)' + assert prntr.doprint("Abs(x)") == 'abs(x)' + +def test_not_cupy_print(): + prntr = CuPyPrinter() + with raises(NotImplementedError): + prntr.doprint("abcd(x)") + +def test_cupy_sum(): + if not cp: + skip("CuPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'cupy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + +def test_cupy_known_funcs_consts(): + assert _cupy_known_constants['NaN'] == 'cupy.nan' + assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma' + + assert _cupy_known_functions['acos'] == 'cupy.arccos' + assert _cupy_known_functions['log'] == 'cupy.log' + +def test_cupy_print_methods(): + prntr = CuPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cxx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..d84ec75cbf0eeb60a1176b9cb3b401a3384454e7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_cxx.py @@ -0,0 +1,86 @@ +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import symbols +from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac +from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode +from sympy.codegen.cfunctions import log1p + + +x, y, u, v = symbols('x y u v') + + +def test_CXX98CodePrinter(): + assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)') + assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))' + cxx98printer = CXX98CodePrinter() + assert cxx98printer.language == 'C++' + assert cxx98printer.standard == 'C++98' + assert 'template' in cxx98printer.reserved_words + assert 'alignas' not in cxx98printer.reserved_words + + +def test_CXX11CodePrinter(): + assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' + + cxx11printer = CXX11CodePrinter() + assert cxx11printer.language == 'C++' + assert cxx11printer.standard == 'C++11' + assert 'operator' in cxx11printer.reserved_words + assert 'noexcept' in cxx11printer.reserved_words + assert 'concept' not in cxx11printer.reserved_words + + +def test_subclass_print_method(): + class MyPrinter(CXX11CodePrinter): + def _print_log1p(self, expr): + return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args)) + + assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_subclass_print_method__ns(): + class MyPrinter(CXX11CodePrinter): + _ns = 'my_library::' + + p = CXX11CodePrinter() + myp = MyPrinter() + + assert p.doprint(log1p(x)) == 'std::log1p(x)' + assert myp.doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_CXX17CodePrinter(): + assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)' + assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)' + assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)' + + # Automatic rewrite + assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))' + assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))' + + +def test_cxxcode(): + assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)']) + +def test_cxxcode_nested_minmax(): + assert cxxcode(Max(Min(x, y), Min(u, v))) \ + == 'std::max(std::min(u, v), std::min(x, y))' + assert cxxcode(Min(Max(x, y), Max(u, v))) \ + == 'std::min(std::max(u, v), std::max(x, y))' + +def test_subclass_Integer_Float(): + class MyPrinter(CXX17CodePrinter): + def _print_Integer(self, arg): + return 'bigInt("%s")' % super()._print_Integer(arg) + + def _print_Float(self, arg): + rat = Rational(arg) + return 'bigFloat(%s, %s)' % ( + self._print(Integer(rat.p)), + self._print(Integer(rat.q)) + ) + + p = MyPrinter() + for i in range(13): + assert p.doprint(i) == 'bigInt("%d")' % i + assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))' + assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_dot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_dot.py new file mode 100644 index 0000000000000000000000000000000000000000..6213e237fb7aac6460a956b4c9fc1f7c8710fec6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_dot.py @@ -0,0 +1,134 @@ +from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, + dotedges, dotprint) +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.numbers import (Float, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.printing.repr import srepr +from sympy.abc import x + + +def test_purestr(): + assert purestr(Symbol('x')) == "Symbol('x')" + assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" + assert purestr(Float(2)) == "Float('2.0', precision=53)" + + assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) + assert purestr(Basic(S(1), S(2)), with_args=True) == \ + ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) + assert purestr(Float(2), with_args=True) == \ + ("Float('2.0', precision=53)", ()) + + +def test_styleof(): + styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'})] + assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} + + assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} + + +def test_attrprint(): + assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ + '"color"="blue", "shape"="ellipse"' + +def test_dotnode(): + + assert dotnode(x, repeat=False) == \ + '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' + assert dotnode(x+2, repeat=False) == \ + '"Add(Integer(2), Symbol(\'x\'))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];', \ + dotnode(x+2,repeat=0) + + assert dotnode(x + x**2, repeat=False) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + assert dotnode(x + x**2, repeat=True) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + +def test_dotedges(): + assert sorted(dotedges(x+2, repeat=False)) == [ + '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', + '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' + ] + assert sorted(dotedges(x + 2, repeat=True)) == [ + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' + ] + +def test_dotprint(): + text = dotprint(x+2, repeat=False) + assert all(e in text for e in dotedges(x+2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x+2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=False) + assert all(e in text for e in dotedges(x+x**2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x**2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=True) + assert all(e in text for e in dotedges(x+x**2, repeat=True)) + assert all( + n in text for n in [dotnode(expr, pos=()) + for expr in [x + x**2]]) + + text = dotprint(x**x, repeat=True) + assert all(e in text for e in dotedges(x**x, repeat=True)) + assert all( + n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) + assert 'digraph' in text + +def test_dotprint_depth(): + text = dotprint(3*x+2, depth=1) + assert dotnode(3*x+2) in text + assert dotnode(x) not in text + text = dotprint(3*x+2) + assert "depth" not in text + +def test_Matrix_and_non_basics(): + from sympy.matrices.expressions.matexpr import MatrixSymbol + n = Symbol('n') + assert dotprint(MatrixSymbol('X', n, n)) == \ +"""digraph{ + +# Graph style +"ordering"="out" +"rankdir"="TD" + +######### +# Nodes # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; +"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; +"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; +"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; + +######### +# Edges # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; +}""" + + +def test_labelfunc(): + text = dotprint(x + 2, labelfunc=srepr) + assert "Symbol('x')" in text + assert "Integer(2)" in text + + +def test_commutative(): + x, y = symbols('x y', commutative=False) + assert dotprint(x + y) == dotprint(y + x) + assert dotprint(x*y) != dotprint(y*x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_glsl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..86ec1dfe4a37d141e8435c369cb692d3a9a3b7bc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_glsl.py @@ -0,0 +1,998 @@ +from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma, + Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.printing.glsl import GLSLPrinter +from sympy.printing.str import StrPrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol +from sympy.core import Tuple +from sympy.printing.glsl import glsl_code +import textwrap + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert glsl_code(Abs(x)) == "abs(x)" + +def test_print_without_operators(): + assert glsl_code(x*y,use_operators = False) == 'mul(x, y)' + assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))' + assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))' + assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))' + +def test_glsl_code_sqrt(): + assert glsl_code(sqrt(x)) == "sqrt(x)" + assert glsl_code(x**0.5) == "sqrt(x)" + assert glsl_code(sqrt(x)) == "sqrt(x)" + + +def test_glsl_code_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(x**3) == "pow(x, 3.0)" + assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))" + assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)" + assert glsl_code(x**-1.0) == '1.0/x' + + +def test_glsl_code_Relational(): + assert glsl_code(Eq(x, y)) == "x == y" + assert glsl_code(Ne(x, y)) == "x != y" + assert glsl_code(Le(x, y)) == "x <= y" + assert glsl_code(Lt(x, y)) == "x < y" + assert glsl_code(Gt(x, y)) == "x > y" + assert glsl_code(Ge(x, y)) == "x >= y" + + +def test_glsl_code_constants_mathh(): + assert glsl_code(exp(1)) == "float E = 2.71828183;\nE" + assert glsl_code(pi) == "float pi = 3.14159265;\npi" + # assert glsl_code(oo) == "Number.POSITIVE_INFINITY" + # assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_glsl_code_constants_other(): + assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio" + assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan" + assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma" + + +def test_glsl_code_Rational(): + assert glsl_code(Rational(3, 7)) == "3.0/7.0" + assert glsl_code(Rational(18, 9)) == "2" + assert glsl_code(Rational(3, -7)) == "-3.0/7.0" + assert glsl_code(Rational(-3, -7)) == "3.0/7.0" + + +def test_glsl_code_Integer(): + assert glsl_code(Integer(67)) == "67" + assert glsl_code(Integer(-1)) == "-1" + + +def test_glsl_code_functions(): + assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_glsl_code_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan" + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert glsl_code(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: glsl_code(expr)) + + +def test_glsl_code_Piecewise_deep(): + p = glsl_code(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + pow(x, 2.0) +))\ +""" + assert p == s + + +def test_glsl_code_settings(): + raises(TypeError, lambda: glsl_code(sin(x), method="garbage")) + + +def test_glsl_code_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = GLSLPrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + +def test_glsl_code_list_tuple_Tuple(): + assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)' + assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)' + assert glsl_code([1,2,3]) == glsl_code((1,2,3)) + assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3)) + + m = MatrixSymbol('A',3,4) + assert glsl_code([m[0],m[1]]) + +def test_glsl_code_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert glsl_code(mat, assign_to=A) == ( +'''A[0][0] = x*y; +if (y > 0) { + A[1][0] = x + 2; +} +else { + A[1][0] = y; +} +A[2][0] = sin(z);''' ) + assert glsl_code(Matrix([A[0],A[1]])) + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert glsl_code(expr) == ( +'''((x > 0) ? ( + 2*A[2][0] +) +: ( + A[2][0] +)) + sin(A[1][0]) + A[0][0]''' ) + + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert glsl_code(m,M) == ( +'''M[0][0] = sin(q[1]); +M[0][1] = 0; +M[0][2] = cos(q[2]); +M[1][0] = q[1] + q[2]; +M[1][1] = q[3]; +M[1][2] = 5; +M[2][0] = 2*q[4]/q[1]; +M[2][1] = sqrt(q[0]) + 4; +M[2][2] = 0;''' + ) + +def test_Matrices_1x7(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Matrices_1x7_array_type_int(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Tuple_array_type_custom(): + gl = glsl_code + A = symbols('a b c') + assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)' + +def test_Matrices_1x7_spread_assign_to_symbols(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g') + assert gl(A, assign_to=assign_to) == textwrap.dedent('''\ + x.a = 1; + x.b = 2; + x.c = 3; + x.d = 4; + x.e = 5; + x.f = 6; + x.g = 7;''' + ) + +def test_spread_assign_to_nested_symbols(): + gl = glsl_code + expr = ((1,2,3), (1,2,3)) + assign_to = (symbols('a b c'), symbols('x y z')) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_spread_assign_to_deeply_nested_symbols(): + gl = glsl_code + a, b, c, x, y, z = symbols('a b c x y z') + expr = (((1,2),3), ((1,2),3)) + assign_to = (((a, b), c), ((x, y), z)) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_matrix_of_tuples_spread_assign_to_symbols(): + gl = glsl_code + with warns_deprecated_sympy(): + expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]]) + assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h')) + assert gl(expr, assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + d = 4; + e = 5; + f = 6; + g = 7; + h = 8;''' + ) + +def test_cannot_assign_to_cause_mismatched_length(): + expr = (1, 2) + assign_to = symbols('x y z') + raises(ValueError, lambda: glsl_code(expr, assign_to)) + +def test_matrix_4x4_assign(): + gl = glsl_code + expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4) + assign_to = MatrixSymbol('X',4,4) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0]; + X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1]; + X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2]; + X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3]; + X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0]; + X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1]; + X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2]; + X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3]; + X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0]; + X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1]; + X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2]; + X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3]; + X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0]; + X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1]; + X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2]; + X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];''' + ) + +def test_1xN_vecs(): + gl = glsl_code + for i in range(1,10): + A = Matrix(range(i)) + assert gl(A.transpose()) == gl(A) + assert gl(A,mat_transpose=True) == gl(A) + if i > 1: + if i <= 4: + assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i))) + else: + assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i))) + +def test_MxN_mats(): + generatedAssertions='def test_misc_mats():\n' + for i in range(1,6): + for j in range(1,6): + A = Matrix([[x + y*j for x in range(j)] for y in range(i)]) + gl = glsl_code(A) + glTransposed = glsl_code(A,mat_transpose=True) + generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n' + generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n' + generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat) == gl\n' + generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n' + if i == 1 and j == 1: + assert gl == '0' + elif i <= 4 and j <= 4 and i>1 and j>1: + assert gl.startswith('mat%s' % j) + assert glTransposed.startswith('mat%s' % i) + elif i == 1 and j <= 4: + assert gl.startswith('vec') + elif j == 1 and i <= 4: + assert gl.startswith('vec') + elif i == 1: + assert gl.startswith('float[%s]('% j*i) + assert glTransposed.startswith('float[%s]('% j*i) + elif j == 1: + assert gl.startswith('float[%s]('% i*j) + assert glTransposed.startswith('float[%s]('% i*j) + else: + assert gl.startswith('float[%s](' % (i*j)) + assert glTransposed.startswith('float[%s](' % (i*j)) + glNested = glsl_code(A,mat_nested=True) + glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True) + assert glNested.startswith('float[%s][%s]' % (i,j)) + assert glNestedTransposed.startswith('float[%s][%s]' % (j,i)) + generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n' + generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n' + generateAssertions = False # set this to true to write bake these generated tests to a file + if generateAssertions: + gen = open('test_glsl_generated_matrices.py','w') + gen.write(generatedAssertions) + gen.close() + + +# these assertions were generated from the previous function +# glsl has complicated rules and this makes it easier to look over all the cases +def test_misc_mats(): + + mat = Matrix([[0]]) + + gl = '''0''' + glTransposed = '''0''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3, 4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0], +[1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3]]) + + gl = '''mat2(0, 1, 2, 3)''' + glTransposed = '''mat2(0, 2, 1, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5]]) + + gl = '''mat3x2(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7]]) + + gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3, 4], +[5, 6, 7, 8, 9]]) + + gl = '''float[10]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9 +) /* a 2x5 matrix */''' + glTransposed = '''float[10]( + 0, 5, + 1, 6, + 2, 7, + 3, 8, + 4, 9 +) /* a 5x2 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[2][5]( + float[](0, 1, 2, 3, 4), + float[](5, 6, 7, 8, 9) +)''' + glNestedTransposed = '''float[5][2]( + float[](0, 5), + float[](1, 6), + float[](2, 7), + float[](3, 8), + float[](4, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5]]) + + gl = '''mat2x3(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8]]) + + gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)''' + glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + + gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 +) /* a 3x5 matrix */''' + glTransposed = '''float[15]( + 0, 5, 10, + 1, 6, 11, + 2, 7, 12, + 3, 8, 13, + 4, 9, 14 +) /* a 5x3 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[3][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14) +)''' + glNestedTransposed = '''float[5][3]( + float[](0, 5, 10), + float[](1, 6, 11), + float[](2, 7, 12), + float[](3, 8, 13), + float[](4, 9, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7]]) + + gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8], +[9, 10, 11]]) + + gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15]]) + + gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)''' + glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19 +) /* a 4x5 matrix */''' + glTransposed = '''float[20]( + 0, 5, 10, 15, + 1, 6, 11, 16, + 2, 7, 12, 17, + 3, 8, 13, 18, + 4, 9, 14, 19 +) /* a 5x4 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[4][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19) +)''' + glNestedTransposed = '''float[5][4]( + float[](0, 5, 10, 15), + float[](1, 6, 11, 16), + float[](2, 7, 12, 17), + float[](3, 8, 13, 18), + float[](4, 9, 14, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3], +[4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7], +[8, 9]]) + + gl = '''float[10]( + 0, 1, + 2, 3, + 4, 5, + 6, 7, + 8, 9 +) /* a 5x2 matrix */''' + glTransposed = '''float[10]( + 0, 2, 4, 6, 8, + 1, 3, 5, 7, 9 +) /* a 2x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][2]( + float[](0, 1), + float[](2, 3), + float[](4, 5), + float[](6, 7), + float[](8, 9) +)''' + glNestedTransposed = '''float[2][5]( + float[](0, 2, 4, 6, 8), + float[](1, 3, 5, 7, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2], +[ 3, 4, 5], +[ 6, 7, 8], +[ 9, 10, 11], +[12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + 12, 13, 14 +) /* a 5x3 matrix */''' + glTransposed = '''float[15]( + 0, 3, 6, 9, 12, + 1, 4, 7, 10, 13, + 2, 5, 8, 11, 14 +) /* a 3x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][3]( + float[]( 0, 1, 2), + float[]( 3, 4, 5), + float[]( 6, 7, 8), + float[]( 9, 10, 11), + float[](12, 13, 14) +)''' + glNestedTransposed = '''float[3][5]( + float[](0, 3, 6, 9, 12), + float[](1, 4, 7, 10, 13), + float[](2, 5, 8, 11, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15], +[16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19 +) /* a 5x4 matrix */''' + glTransposed = '''float[20]( + 0, 4, 8, 12, 16, + 1, 5, 9, 13, 17, + 2, 6, 10, 14, 18, + 3, 7, 11, 15, 19 +) /* a 4x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][4]( + float[]( 0, 1, 2, 3), + float[]( 4, 5, 6, 7), + float[]( 8, 9, 10, 11), + float[](12, 13, 14, 15), + float[](16, 17, 18, 19) +)''' + glNestedTransposed = '''float[4][5]( + float[](0, 4, 8, 12, 16), + float[](1, 5, 9, 13, 17), + float[](2, 6, 10, 14, 18), + float[](3, 7, 11, 15, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19], +[20, 21, 22, 23, 24]]) + + gl = '''float[25]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24 +) /* a 5x5 matrix */''' + glTransposed = '''float[25]( + 0, 5, 10, 15, 20, + 1, 6, 11, 16, 21, + 2, 7, 12, 17, 22, + 3, 8, 13, 18, 23, + 4, 9, 14, 19, 24 +) /* a 5x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19), + float[](20, 21, 22, 23, 24) +)''' + glNestedTransposed = '''float[5][5]( + float[](0, 5, 10, 15, 20), + float[](1, 6, 11, 16, 21), + float[](2, 7, 12, 17, 22), + float[](3, 8, 13, 18, 23), + float[](4, 9, 14, 19, 24) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_gtk.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_gtk.py new file mode 100644 index 0000000000000000000000000000000000000000..5a595ab04d3a29d23e06ec12207bf917392aebce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_gtk.py @@ -0,0 +1,18 @@ +from sympy.functions.elementary.trigonometric import sin +from sympy.printing.gtk import print_gtk +from sympy.testing.pytest import XFAIL, raises + +# this test fails if python-lxml isn't installed. We don't want to depend on +# anything with SymPy + + +@XFAIL +def test_1(): + from sympy.abc import x + print_gtk(x**2, start_viewer=False) + print_gtk(x**2 + sin(x)/4, start_viewer=False) + + +def test_settings(): + from sympy.abc import x + raises(TypeError, lambda: print_gtk(x, method="garbage")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..365d87c5b91fdd49a8e46cfde9c2b5792c23a03c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jax.py @@ -0,0 +1,370 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Function, Pow, Symbol +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +# Unlike NumPy which will aggressively promote operands to double precision, +# jax always uses single precision. Double precision in jax can be +# configured before the call to `import jax`, however this must be explicitly +# configured and is not fully supported. Thus, the tests here have been modified +# from the tests in test_numpy.py, only in the fact that they assert lambdify +# function accuracy to only single precision accuracy. +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + +jax = import_module('jax') + +if jax: + deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype) + JAX_DEFAULT_EPSILON = deafult_float_info.eps + + +def test_jax_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = JaxPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)' + assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}} + + +def test_jax_logaddexp(): + lae = logaddexp(a, b) + assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)' + + +def test_jax_sum(): + if not jax: + skip("JAX not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_jax_multiple_sums(): + if not jax: + skip("JAX not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'jax') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_jax_codegen_einsum(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'jax') + + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all() + + +def test_jax_codegen_extra(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + mc = jax.numpy.array([[2, 0], [1, 2]]) + md = jax.numpy.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'jax') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'jax') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'jax') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_jax_relational(): + if not jax: + skip("JAX not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, True]) + + # Multi-condition expressions + e = (x >= 1) & (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = (x >= 1) | (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, True]) + +def test_jax_mod(): + if not jax: + skip("JAX not installed") + + e = Mod(a, b) + f = lambdify((a, b), e, 'jax') + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = 2 + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = jax.numpy.array([2, 2, 2, 2]) + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([2, 3, 4, 5]) + b_ = jax.numpy.array([2, 3, 4, 5]) + assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_jax_pow(): + if not jax: + skip('JAX not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'jax') + assert f() == 0.5 + + +def test_jax_expm1(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), expm1(a), 'jax') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON + + +def test_jax_log1p(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), log1p(a), 'jax') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON + +def test_jax_hypot(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON + +def test_jax_log10(): + if not jax: + skip("JAX not installed") + + assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_exp2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON + + +def test_jax_log2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON + + +def test_jax_Sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_matsolve(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr, 'jax') + f_matsolve = lambdify((M, x), matsolve_expr, 'jax') + + m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert jax.numpy.linalg.matrix_rank(m0) == 3 + + x0 = jax.numpy.array([3, 4, 5]) + + assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not jax: + skip("JAX not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = JaxPrinter() + assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2), 'jax') + ma = jax.numpy.array([[1, 2], [3, 4]]) + mr = jax.numpy.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax')) + + +def test_jax_array(): + assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])' + assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array([1, 2])' + + +def test_jax_known_funcs_consts(): + assert _jax_known_constants['NaN'] == 'jax.numpy.nan' + assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma' + + assert _jax_known_functions['acos'] == 'jax.numpy.arccos' + assert _jax_known_functions['log'] == 'jax.numpy.log' + + +def test_jax_print_methods(): + prntr = JaxPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + + +def test_jax_printmethod(): + printer = JaxPrinter() + assert hasattr(printer, 'printmethod') + assert printer.printmethod == '_jaxcode' + + +def test_jax_custom_print_method(): + + class expm1(Function): + + def _jaxcode(self, printer): + x, = self.args + function = f'expm1({printer._print(x)})' + return printer._module_format(printer._module + '.' + function) + + printer = JaxPrinter() + assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jscode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jscode.py new file mode 100644 index 0000000000000000000000000000000000000000..9199a8e0d62e87f2e964cb1712726a21c894fd20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_jscode.py @@ -0,0 +1,396 @@ +from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio, + EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le, + Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sinh, cosh, tanh, asin, acos, acosh, Max, Min) +from sympy.testing.pytest import raises +from sympy.printing.jscode import JavascriptCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.jscode import jscode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert jscode(Abs(x)) == "Math.abs(x)" + + +def test_jscode_sqrt(): + assert jscode(sqrt(x)) == "Math.sqrt(x)" + assert jscode(x**0.5) == "Math.sqrt(x)" + assert jscode(x**(S.One/3)) == "Math.cbrt(x)" + + +def test_jscode_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(x**3) == "Math.pow(x, 3)" + assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))" + assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)" + assert jscode(x**-1.0) == '1/x' + + +def test_jscode_constants_mathh(): + assert jscode(exp(1)) == "Math.E" + assert jscode(pi) == "Math.PI" + assert jscode(oo) == "Number.POSITIVE_INFINITY" + assert jscode(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_jscode_constants_other(): + assert jscode( + 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert jscode( + 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_jscode_Rational(): + assert jscode(Rational(3, 7)) == "3/7" + assert jscode(Rational(18, 9)) == "2" + assert jscode(Rational(3, -7)) == "-3/7" + assert jscode(Rational(-3, -7)) == "3/7" + + +def test_Relational(): + assert jscode(Eq(x, y)) == "x == y" + assert jscode(Ne(x, y)) == "x != y" + assert jscode(Le(x, y)) == "x <= y" + assert jscode(Lt(x, y)) == "x < y" + assert jscode(Gt(x, y)) == "x > y" + assert jscode(Ge(x, y)) == "x >= y" + + +def test_Mod(): + assert jscode(Mod(x, y)) == '((x % y) + y) % y' + assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)' + p1, p2 = symbols('p1 p2', positive=True) + assert jscode(Mod(p1, p2)) == 'p1 % p2' + assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)' + assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)' + assert jscode(-Mod(p1, p2)) == '-(p1 % p2)' + assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)' + + +def test_jscode_Integer(): + assert jscode(Integer(67)) == "67" + assert jscode(Integer(-1)) == "-1" + + +def test_jscode_functions(): + assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))" + assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)" + assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)" + assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)" + assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)" + + +def test_jscode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert jscode(g(A[i]), assign_to=A[i]) == ( + "for (var i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: jscode(expr)) + + +def test_jscode_Piecewise_deep(): + p = jscode(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + Math.pow(x, 2) +))\ +""" + assert p == s + + +def test_jscode_settings(): + raises(TypeError, lambda: jscode(sin(x), method="garbage")) + + +def test_jscode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = JavascriptCodePrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + + +def test_jscode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (var i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert jscode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = Math.sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert jscode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + Math.sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert jscode(m, M) == ( + "M[0] = Math.sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = Math.cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = Math.sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(jscode(A[0, 0]) == "A[0]") + assert(jscode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(jscode(F) == "(A - B)[0]") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_julia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..b19c7b4fd4f21d8402ca2f577605322b3ec10f5b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_julia.py @@ -0,0 +1,390 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.testing.pytest import XFAIL + +from sympy.printing.julia import julia_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert julia_code(Integer(67)) == "67" + assert julia_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert julia_code(Rational(3, 7)) == "3 // 7" + assert julia_code(Rational(18, 9)) == "2" + assert julia_code(Rational(3, -7)) == "-3 // 7" + assert julia_code(Rational(-3, -7)) == "3 // 7" + assert julia_code(x + Rational(3, 7)) == "x + 3 // 7" + assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x" + + +def test_Relational(): + assert julia_code(Eq(x, y)) == "x == y" + assert julia_code(Ne(x, y)) == "x != y" + assert julia_code(Le(x, y)) == "x <= y" + assert julia_code(Lt(x, y)) == "x < y" + assert julia_code(Gt(x, y)) == "x > y" + assert julia_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)" + assert julia_code(abs(x)) == "abs(x)" + assert julia_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert julia_code(x**3) == "x .^ 3" + assert julia_code(x**(y**3)) == "x .^ (y .^ 3)" + assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)" + # For issue 14160 + assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2 * x ./ (y .* y)' + + +def test_basic_ops(): + assert julia_code(x*y) == "x .* y" + assert julia_code(x + y) == "x + y" + assert julia_code(x - y) == "x - y" + assert julia_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert julia_code(1/x) == '1 ./ x' + assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x' + assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)' + assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)' + assert julia_code(sqrt(x)) == 'sqrt(x)' + assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)' + assert julia_code(1/pi) == '1 / pi' + assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi' + assert julia_code(pi**-0.5) == '1 / sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert julia_code(3*x) == "3 * x" + assert julia_code(pi*x) == "pi * x" + assert julia_code(3/x) == "3 ./ x" + assert julia_code(pi/x) == "pi ./ x" + assert julia_code(x/3) == "x / 3" + assert julia_code(x/pi) == "x / pi" + assert julia_code(x*y) == "x .* y" + assert julia_code(3*x*y) == "3 * x .* y" + assert julia_code(3*pi*x*y) == "3 * pi * x .* y" + assert julia_code(x/y) == "x ./ y" + assert julia_code(3*x/y) == "3 * x ./ y" + assert julia_code(x*y/z) == "x .* y ./ z" + assert julia_code(x/y*z) == "x .* z ./ y" + assert julia_code(1/x/y) == "1 ./ (x .* y)" + assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)" + assert julia_code(3*pi/x) == "3 * pi ./ x" + assert julia_code(S(3)/5) == "3 // 5" + assert julia_code(S(3)/5*x) == "(3 // 5) * x" + assert julia_code(x/y/z) == "x ./ (y .* z)" + assert julia_code((x+y)/z) == "(x + y) ./ z" + assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)" + assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma" + assert julia_code(x/3/pi) == "x / (3 * pi)" + assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi" + + +def test_mix_number_pow_symbols(): + assert julia_code(pi**3) == 'pi ^ 3' + assert julia_code(x**2) == 'x .^ 2' + assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)' + assert julia_code(x**y) == 'x .^ y' + assert julia_code(x**(y**z)) == 'x .^ (y .^ z)' + assert julia_code((x**y)**z) == '(x .^ y) .^ z' + + +def test_imag(): + I = S('I') + assert julia_code(I) == "im" + assert julia_code(5*I) == "5im" + assert julia_code((S(3)/2)*I) == "(3 // 2) * im" + assert julia_code(3+4*I) == "3 + 4im" + + +def test_constants(): + assert julia_code(pi) == "pi" + assert julia_code(oo) == "Inf" + assert julia_code(-oo) == "-Inf" + assert julia_code(S.NegativeInfinity) == "-Inf" + assert julia_code(S.NaN) == "NaN" + assert julia_code(S.Exp1) == "e" + assert julia_code(exp(1)) == "e" + + +def test_constants_other(): + assert julia_code(2*GoldenRatio) == "2 * golden" + assert julia_code(2*Catalan) == "2 * catalan" + assert julia_code(2*EulerGamma) == "2 * eulergamma" + + +def test_boolean(): + assert julia_code(x & y) == "x && y" + assert julia_code(x | y) == "x || y" + assert julia_code(~x) == "!x" + assert julia_code(x & y & z) == "x && y && z" + assert julia_code(x | y | z) == "x || y || z" + assert julia_code((x & y) | z) == "z || x && y" + assert julia_code((x | y) & z) == "z && (x || y)" + +def test_sinc(): + assert julia_code(sinc(x)) == 'sinc(x / pi)' + assert julia_code(sinc(x + 3)) == 'sinc((x + 3) / pi)' + assert julia_code(sinc(pi * (x + 3))) == 'sinc(x + 3)' + +def test_Matrices(): + assert julia_code(Matrix(1, 1, [10])) == "[10]" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = ("[1 sin(x / 2) abs(x);\n" + "0 1 pi;\n" + "0 e ceil(x)]") + assert julia_code(A) == expected + # row and columns + assert julia_code(A[:,0]) == "[1, 0, 0]" + assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]" + # empty matrices + assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)' + assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]" + assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert julia_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert julia_code(A*B) == "A * B" + assert julia_code(B*A) == "B * A" + assert julia_code(2*A*B) == "2 * A * B" + assert julia_code(B*2*A) == "2 * B * A" + assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)" + assert julia_code(A**(x**2)) == "A ^ (x .^ 2)" + assert julia_code(A**3) == "A ^ 3" + assert julia_code(A**S.Half) == "A ^ (1 // 2)" + + +def test_special_matrices(): + assert julia_code(6*Identity(3)) == "6 * eye(3)" + + +def test_containers(): + assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" + assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" + assert julia_code([1]) == "Any[1]" + assert julia_code((1,)) == "(1,)" + assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" + assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))" + # scalar, matrix, empty matrix and empty list + assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + + +def test_julia_noninline(): + source = julia_code((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "const Catalan = %s\n" + "me = (x + y) / Catalan" + ) % Catalan.evalf(17) + assert source == expected + + +def test_julia_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))" + assert julia_code(expr, assign_to="r") == ( + "r = ((x < 1) ? (x) : (x .^ 2))") + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x\n" + "else\n" + " r = x .^ 2\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1) ? (x .^ 2) :\n" + "(x < 2) ? (x .^ 3) :\n" + "(x < 3) ? (x .^ 4) : (x .^ 5))") + assert julia_code(expr) == expected + assert julia_code(expr, assign_to="r") == "r = " + expected + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x .^ 2\n" + "elseif (x < 2)\n" + " r = x .^ 3\n" + "elseif (x < 3)\n" + " r = x .^ 4\n" + "else\n" + " r = x .^ 5\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: julia_code(expr)) + + +def test_julia_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))" + assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x" + assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)" + assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3" + + +def test_julia_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert julia_code(A, assign_to='a') == "a = [1 2 3]" + A = Matrix([[1, 2], [3, 4]]) + assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]" + + +def test_julia_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert julia_code(A, assign_to=B) == "B = [1 2 3]" + raises(ValueError, lambda: julia_code(A, assign_to=x)) + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert julia_code(A, assign_to=B) == "B = [3]" + # FIXME? + #assert julia_code(A, assign_to=x) == "x = [3]" + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" + A = MatrixSymbol('AA', 1, 3) + assert julia_code(A) == "AA" + assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" + assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + + +def test_julia_boolean(): + assert julia_code(True) == "true" + assert julia_code(S.true) == "true" + assert julia_code(False) == "false" + assert julia_code(S.false) == "false" + + +def test_julia_not_supported(): + with raises(NotImplementedError): + julia_code(S.ComplexInfinity) + + f = Function('f') + assert julia_code(f(x).diff(x), strict=False) == ( + "# Not supported in Julia:\n" + "# Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert julia_code(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_haramard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert julia_code(C) == "A .* B" + assert julia_code(C*v) == "(A .* B) * v" + assert julia_code(h*C*v) == "h * (A .* B) * v" + assert julia_code(C*A) == "(A .* B) * A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert julia_code(C*x*y) == "(x .* y) * (A .* B)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert julia_code(M) == ( + "sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)" + ) + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert julia_code(f(n, x)) == f.__name__ + '(n, x)' + for f in [airyai, airyaiprime, airybi, airybiprime]: + assert julia_code(f(x)) == f.__name__ + '(x)' + assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)' + assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)' + assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2' + assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(julia_code(A[0, 0]) == "A[1,1]") + assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]") + + F = C[0, 0].subs(C, A - B) + assert(julia_code(F) == "(A - B)[1,1]") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_lambdarepr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..94e09ada7a9ce7d01667edd8fc6ec35ebfbb9639 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_lambdarepr.py @@ -0,0 +1,246 @@ +from sympy.concrete.summations import Sum +from sympy.core.expr import Expr +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises + +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter + + +x, y, z = symbols("x,y,z") +i, a, b = symbols("i,a,b") +j, c, d = symbols("j,c,d") + + +def test_basic(): + assert lambdarepr(x*y) == "x*y" + assert lambdarepr(x + y) in ["y + x", "x + y"] + assert lambdarepr(x**y) == "x**y" + + +def test_matrix(): + # Test printing a Matrix that has an element that is printed differently + # with the LambdaPrinter than with the StrPrinter. + e = x % 2 + assert lambdarepr(e) != str(e) + assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])' + + +def test_piecewise(): + # In each case, test eval() the lambdarepr() to make sure there are a + # correct number of parentheses. It will give a SyntaxError if there aren't. + + h = "lambda x: " + + p = Piecewise((x, x < 0)) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 0) else None)" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (0, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else None)" + + p = Piecewise( + (x, x < 1), + (x**2, Interval(3, 4, True, False).contains(x)), + (0, True), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), + (0, True), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else None)" + + p = Piecewise( + (1, x >= 1), + (2, x >= 2), + (3, x >= 3), + (4, x >= 4), + (5, x >= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\ + " else (4) if (x >= 4) else (5) if (x >= 5) else (6))" + + p = Piecewise( + (1, x <= 1), + (2, x <= 2), + (3, x <= 3), + (4, x <= 4), + (5, x <= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\ + " else (4) if (x <= 4) else (5) if (x <= 5) else (6))" + + p = Piecewise( + (1, x > 1), + (2, x > 2), + (3, x > 3), + (4, x > 4), + (5, x > 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\ + " else (4) if (x > 4) else (5) if (x > 5) else (6))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (3, x < 3), + (4, x < 4), + (5, x < 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\ + " else (4) if (x < 4) else (5) if (x < 5) else (6))" + + p = Piecewise( + (Piecewise( + (1, x > 0), + (2, True) + ), y > 0), + (3, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))" + + +def test_sum__1(): + # In each case, test eval() the lambdarepr() to make sure that + # it evaluates to the same results as the symbolic expression + s = Sum(x ** i, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(x**i for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + +def test_sum__2(): + s = Sum(i * x, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(i*x for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + + +def test_multiple_sums(): + s = Sum(i * x + j, (i, a, b), (j, c, d)) + + l = lambdarepr(s) + assert l == "(builtins.sum(i*x + j for j in range(c, d+1) for i in range(a, b+1)))" + + args = x, a, b, c, d + f = lambdify(args, s) + vals = 2, 3, 4, 5, 6 + f_ref = s.subs(zip(args, vals)).doit() + f_res = f(*vals) + assert f_res == f_ref + + +def test_sqrt(): + prntr = LambdaPrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_settings(): + raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) + + +def test_numexpr(): + # test ITE rewrite as Piecewise + from sympy.logic.boolalg import ITE + expr = ITE(x > 0, True, False, evaluate=False) + assert NumExprPrinter().doprint(expr) == \ + "numexpr.evaluate('where((x > 0), True, False)', truediv=True)" + + from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment + func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)]) + expected = "def foo(x):\n"\ + " y = numexpr.evaluate('x', truediv=True)\n"\ + " return numexpr.evaluate('y**2', truediv=True)" + assert NumExprPrinter().doprint(func_def) == expected + + +class CustomPrintedObject(Expr): + def _lambdacode(self, printer): + return 'lambda' + + def _tensorflowcode(self, printer): + return 'tensorflow' + + def _numpycode(self, printer): + return 'numpy' + + def _numexprcode(self, printer): + return 'numexpr' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + # In each case, printmethod is called to test + # its working + + obj = CustomPrintedObject() + assert LambdaPrinter().doprint(obj) == 'lambda' + assert TensorflowPrinter().doprint(obj) == 'tensorflow' + assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)" + + assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \ + "numexpr.evaluate('where((x >= 0), y, z)', truediv=True)" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_latex.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..063611d09a923881cd94bd693f3f3f721535fd0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_latex.py @@ -0,0 +1,3164 @@ +from sympy import MatAdd, MatMul, Array +from sympy.algebras.quaternion import Quaternion +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.permutations import Cycle, Permutation, AppliedPermutation +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple, Dict +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import (Derivative, Function, Lambda, Subs, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import (AlgebraicNumber, Float, I, Integer, Rational, oo, pi) +from sympy.core.parameters import evaluate +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial, factorial2, subfactorial) +from sympy.functions.combinatorial.numbers import (bernoulli, bell, catalan, euler, genocchi, + lucas, fibonacci, tribonacci, divisor_sigma, udivisor_sigma, + mobius, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, polar_lift, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, coth) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (Max, Min, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acsc, asin, cos, cot, sin, tan) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.elliptic_integrals import (elliptic_e, elliptic_f, elliptic_k, elliptic_pi) +from sympy.functions.special.error_functions import (Chi, Ci, Ei, Shi, Si, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.functions.special.mathieu_functions import (mathieuc, mathieucprime, mathieus, mathieusprime) +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevu, gegenbauer, hermite, jacobi, laguerre, legendre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.spherical_harmonics import (Ynm, Znm) +from sympy.functions.special.tensor_functions import (KroneckerDelta, LeviCivita) +from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, stieltjes, zeta) +from sympy.integrals.integrals import Integral +from sympy.integrals.transforms import (CosineTransform, FourierTransform, InverseCosineTransform, InverseFourierTransform, InverseLaplaceTransform, InverseMellinTransform, InverseSineTransform, LaplaceTransform, MellinTransform, SineTransform) +from sympy.logic import Implies +from sympy.logic.boolalg import (And, Or, Xor, Equivalent, false, Not, true) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.kronecker import KroneckerProduct +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.permutation import PermutationMatrix +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.physics.control.lti import TransferFunction, Series, Parallel, Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.quantum import Commutator, Operator +from sympy.physics.quantum.trace import Tr +from sympy.physics.units import meter, gibibyte, gram, microgram, second, milli, micro +from sympy.polys.domains.integerring import ZZ +from sympy.polys.fields import field +from sympy.polys.polytools import Poly +from sympy.polys.rings import ring +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import Order +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ComplexRegion, ImageSet, Range) +from sympy.sets.ordinals import Ordinal, OrdinalOmega, OmegaPower +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import (FiniteSet, Interval, Union, Intersection, Complement, SymmetricDifference, ProductSet) +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableSparseNDimArray, + MutableDenseNDimArray, + tensorproduct) +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement +from sympy.tensor.indexed import (Idx, Indexed, IndexedBase) +from sympy.tensor.toperators import PartialDerivative +from sympy.vector import CoordSys3D, Cross, Curl, Dot, Divergence, Gradient, Laplacian + + +from sympy.testing.pytest import (XFAIL, raises, _both_exp_pow, + warns_deprecated_sympy) +from sympy.printing.latex import (latex, translate, greek_letters_set, + tex_greek_dictionary, multiline_latex, + latex_escape, LatexPrinter) + +import sympy as sym + +from sympy.abc import mu, tau + + +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + + +x, y, z, t, w, a, b, c, s, p = symbols('x y z t w a b c s p') +k, m, n = symbols('k m n', integer=True) + + +def test_printmethod(): + class R(Abs): + def _latex(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert latex(R(x)) == r"foo(x)" + + class R(Abs): + def _latex(self, printer): + return "foo" + assert latex(R(x)) == r"foo" + + +def test_latex_basic(): + assert latex(1 + x) == r"x + 1" + assert latex(x**2) == r"x^{2}" + assert latex(x**(1 + x)) == r"x^{x + 1}" + assert latex(x**3 + x + 1 + x**2) == r"x^{3} + x^{2} + x + 1" + + assert latex(2*x*y) == r"2 x y" + assert latex(2*x*y, mul_symbol='dot') == r"2 \cdot x \cdot y" + assert latex(3*x**2*y, mul_symbol='\\,') == r"3\,x^{2}\,y" + assert latex(1.5*3**x, mul_symbol='\\,') == r"1.5 \cdot 3^{x}" + + assert latex(x**S.Half**5) == r"\sqrt[32]{x}" + assert latex(Mul(S.Half, x**2, -5, evaluate=False)) == r"\frac{1}{2} x^{2} \left(-5\right)" + assert latex(Mul(S.Half, x**2, 5, evaluate=False)) == r"\frac{1}{2} x^{2} \cdot 5" + assert latex(Mul(-5, -5, evaluate=False)) == r"\left(-5\right) \left(-5\right)" + assert latex(Mul(5, -5, evaluate=False)) == r"5 \left(-5\right)" + assert latex(Mul(S.Half, -5, S.Half, evaluate=False)) == r"\frac{1}{2} \left(-5\right) \frac{1}{2}" + assert latex(Mul(5, I, 5, evaluate=False)) == r"5 i 5" + assert latex(Mul(5, I, -5, evaluate=False)) == r"5 i \left(-5\right)" + assert latex(Mul(Pow(x, 2), S.Half*x + 1)) == r"x^{2} \left(\frac{x}{2} + 1\right)" + assert latex(Mul(Pow(x, 3), Rational(2, 3)*x + 1)) == r"x^{3} \left(\frac{2 x}{3} + 1\right)" + assert latex(Mul(Pow(x, 11), 2*x + 1)) == r"x^{11} \left(2 x + 1\right)" + + assert latex(Mul(0, 1, evaluate=False)) == r'0 \cdot 1' + assert latex(Mul(1, 0, evaluate=False)) == r'1 \cdot 0' + assert latex(Mul(1, 1, evaluate=False)) == r'1 \cdot 1' + assert latex(Mul(-1, 1, evaluate=False)) == r'\left(-1\right) 1' + assert latex(Mul(1, 1, 1, evaluate=False)) == r'1 \cdot 1 \cdot 1' + assert latex(Mul(1, 2, evaluate=False)) == r'1 \cdot 2' + assert latex(Mul(1, S.Half, evaluate=False)) == r'1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, S.Half, evaluate=False)) == \ + r'1 \cdot 1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, 2, 3, x, evaluate=False)) == \ + r'1 \cdot 1 \cdot 2 \cdot 3 x' + assert latex(Mul(1, -1, evaluate=False)) == r'1 \left(-1\right)' + assert latex(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \cdot 1 \cdot 0 y x' + assert latex(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \left(z + 1\right) 0 y x' + assert latex(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == \ + r'\frac{2}{3} \cdot \frac{5}{7}' + + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/x, fold_short_frac=True) == r"1 / x" + assert latex(-S(3)/2) == r"- \frac{3}{2}" + assert latex(-S(3)/2, fold_short_frac=True) == r"- 3 / 2" + assert latex(1/x**2) == r"\frac{1}{x^{2}}" + assert latex(1/(x + y)/2) == r"\frac{1}{2 \left(x + y\right)}" + assert latex(x/2) == r"\frac{x}{2}" + assert latex(x/2, fold_short_frac=True) == r"x / 2" + assert latex((x + y)/(2*x)) == r"\frac{x + y}{2 x}" + assert latex((x + y)/(2*x), fold_short_frac=True) == \ + r"\left(x + y\right) / 2 x" + assert latex((x + y)/(2*x), long_frac_ratio=0) == \ + r"\frac{1}{2 x} \left(x + y\right)" + assert latex((x + y)/x) == r"\frac{x + y}{x}" + assert latex((x + y)/x, long_frac_ratio=3) == r"\frac{x + y}{x}" + assert latex((2*sqrt(2)*x)/3) == r"\frac{2 \sqrt{2} x}{3}" + assert latex((2*sqrt(2)*x)/3, long_frac_ratio=2) == \ + r"\frac{2 x}{3} \sqrt{2}" + assert latex(binomial(x, y)) == r"{\binom{x}{y}}" + + x_star = Symbol('x^*') + f = Function('f') + assert latex(x_star**2) == r"\left(x^{*}\right)^{2}" + assert latex(x_star**2, parenthesize_super=False) == r"{x^{*}}^{2}" + assert latex(Derivative(f(x_star), x_star,2)) == r"\frac{d^{2}}{d \left(x^{*}\right)^{2}} f{\left(x^{*} \right)}" + assert latex(Derivative(f(x_star), x_star,2), parenthesize_super=False) == r"\frac{d^{2}}{d {x^{*}}^{2}} f{\left(x^{*} \right)}" + + assert latex(2*Integral(x, x)/3) == r"\frac{2 \int x\, dx}{3}" + assert latex(2*Integral(x, x)/3, fold_short_frac=True) == \ + r"\left(2 \int x\, dx\right) / 3" + + assert latex(sqrt(x)) == r"\sqrt{x}" + assert latex(x**Rational(1, 3)) == r"\sqrt[3]{x}" + assert latex(x**Rational(1, 3), root_notation=False) == r"x^{\frac{1}{3}}" + assert latex(sqrt(x)**3) == r"x^{\frac{3}{2}}" + assert latex(sqrt(x), itex=True) == r"\sqrt{x}" + assert latex(x**Rational(1, 3), itex=True) == r"\root{3}{x}" + assert latex(sqrt(x)**3, itex=True) == r"x^{\frac{3}{2}}" + assert latex(x**Rational(3, 4)) == r"x^{\frac{3}{4}}" + assert latex(x**Rational(3, 4), fold_frac_powers=True) == r"x^{3/4}" + assert latex((x + 1)**Rational(3, 4)) == \ + r"\left(x + 1\right)^{\frac{3}{4}}" + assert latex((x + 1)**Rational(3, 4), fold_frac_powers=True) == \ + r"\left(x + 1\right)^{3/4}" + assert latex(AlgebraicNumber(sqrt(2))) == r"\sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), [3, -7])) == r"-7 + 3 \sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), alias='alpha')) == r"\alpha" + assert latex(AlgebraicNumber(sqrt(2), [3, -7], alias='alpha')) == \ + r"3 \alpha - 7" + assert latex(AlgebraicNumber(2**(S(1)/3), [1, 3, -7], alias='beta')) == \ + r"\beta^{2} + 3 \beta - 7" + + k = ZZ.cyclotomic_field(5) + assert latex(k.ext.field_element([1, 2, 3, 4])) == \ + r"\zeta^{3} + 2 \zeta^{2} + 3 \zeta + 4" + assert latex(k.ext.field_element([1, 2, 3, 4]), order='old') == \ + r"4 + 3 \zeta + 2 \zeta^{2} + \zeta^{3}" + assert latex(k.primes_above(19)[0]) == \ + r"\left(19, \zeta^{2} + 5 \zeta + 1\right)" + assert latex(k.primes_above(19)[0], order='old') == \ + r"\left(19, 1 + 5 \zeta + \zeta^{2}\right)" + assert latex(k.primes_above(7)[0]) == r"\left(7\right)" + + assert latex(1.5e20*x) == r"1.5 \cdot 10^{20} x" + assert latex(1.5e20*x, mul_symbol='dot') == r"1.5 \cdot 10^{20} \cdot x" + assert latex(1.5e20*x, mul_symbol='times') == \ + r"1.5 \times 10^{20} \times x" + + assert latex(1/sin(x)) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**-1) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**Rational(3, 2)) == \ + r"\sin^{\frac{3}{2}}{\left(x \right)}" + assert latex(sin(x)**Rational(3, 2), fold_frac_powers=True) == \ + r"\sin^{3/2}{\left(x \right)}" + + assert latex(~x) == r"\neg x" + assert latex(x & y) == r"x \wedge y" + assert latex(x & y & z) == r"x \wedge y \wedge z" + assert latex(x | y) == r"x \vee y" + assert latex(x | y | z) == r"x \vee y \vee z" + assert latex((x & y) | z) == r"z \vee \left(x \wedge y\right)" + assert latex(Implies(x, y)) == r"x \Rightarrow y" + assert latex(~(x >> ~y)) == r"x \not\Rightarrow \neg y" + assert latex(Implies(Or(x,y), z)) == r"\left(x \vee y\right) \Rightarrow z" + assert latex(Implies(z, Or(x,y))) == r"z \Rightarrow \left(x \vee y\right)" + assert latex(~(x & y)) == r"\neg \left(x \wedge y\right)" + + assert latex(~x, symbol_names={x: "x_i"}) == r"\neg x_i" + assert latex(x & y, symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \wedge y_i" + assert latex(x & y & z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \wedge y_i \wedge z_i" + assert latex(x | y, symbol_names={x: "x_i", y: "y_i"}) == r"x_i \vee y_i" + assert latex(x | y | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \vee y_i \vee z_i" + assert latex((x & y) | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"z_i \vee \left(x_i \wedge y_i\right)" + assert latex(Implies(x, y), symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \Rightarrow y_i" + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r"\frac{1}{\frac{1}{3}}" + assert latex(Pow(Rational(1, 3), -2, evaluate=False)) == r"\frac{1}{(\frac{1}{3})^{2}}" + assert latex(Pow(Integer(1)/100, -1, evaluate=False)) == r"\frac{1}{\frac{1}{100}}" + + p = Symbol('p', positive=True) + assert latex(exp(-p)*log(p)) == r"e^{- p} \log{\left(p \right)}" + + assert latex(Pow(Rational(2, 3), -1, evaluate=False)) == r'\frac{1}{\frac{2}{3}}' + assert latex(Pow(Rational(4, 3), -1, evaluate=False)) == r'\frac{1}{\frac{4}{3}}' + assert latex(Pow(Rational(-3, 4), -1, evaluate=False)) == r'\frac{1}{- \frac{3}{4}}' + assert latex(Pow(Rational(-4, 4), -1, evaluate=False)) == r'\frac{1}{-1}' + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r'\frac{1}{\frac{1}{3}}' + assert latex(Pow(Rational(-1, 3), -1, evaluate=False)) == r'\frac{1}{- \frac{1}{3}}' + + +def test_latex_builtins(): + assert latex(True) == r"\text{True}" + assert latex(False) == r"\text{False}" + assert latex(None) == r"\text{None}" + assert latex(true) == r"\text{True}" + assert latex(false) == r'\text{False}' + + +def test_latex_SingularityFunction(): + assert latex(SingularityFunction(x, 4, 5)) == \ + r"{\left\langle x - 4 \right\rangle}^{5}" + assert latex(SingularityFunction(x, -3, 4)) == \ + r"{\left\langle x + 3 \right\rangle}^{4}" + assert latex(SingularityFunction(x, 0, 4)) == \ + r"{\left\langle x \right\rangle}^{4}" + assert latex(SingularityFunction(x, a, n)) == \ + r"{\left\langle - a + x \right\rangle}^{n}" + assert latex(SingularityFunction(x, 4, -2)) == \ + r"{\left\langle x - 4 \right\rangle}^{-2}" + assert latex(SingularityFunction(x, 4, -1)) == \ + r"{\left\langle x - 4 \right\rangle}^{-1}" + + assert latex(SingularityFunction(x, 4, 5)**3) == \ + r"{\left({\langle x - 4 \rangle}^{5}\right)}^{3}" + assert latex(SingularityFunction(x, -3, 4)**3) == \ + r"{\left({\langle x + 3 \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, 0, 4)**3) == \ + r"{\left({\langle x \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, a, n)**3) == \ + r"{\left({\langle - a + x \rangle}^{n}\right)}^{3}" + assert latex(SingularityFunction(x, 4, -2)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-2}\right)}^{3}" + assert latex((SingularityFunction(x, 4, -1)**3)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-1}\right)}^{9}" + + +def test_latex_cycle(): + assert latex(Cycle(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Cycle(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Cycle()) == r"\left( \right)" + + +def test_latex_permutation(): + assert latex(Permutation(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Permutation(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Permutation()) == r"\left( \right)" + assert latex(Permutation(2, 4)*Permutation(5)) == \ + r"\left( 2\; 4\right)\left( 5\right)" + assert latex(Permutation(5)) == r"\left( 5\right)" + + assert latex(Permutation(0, 1), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}" + assert latex(Permutation(0, 1)(2, 3), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + assert latex(Permutation(), perm_cyclic=False) == \ + r"\left( \right)" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert latex(Permutation(0, 1)(2, 3)) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + Permutation.print_cyclic = old_print_cyclic + +def test_latex_Float(): + assert latex(Float(1.0e100)) == r"1.0 \cdot 10^{100}" + assert latex(Float(1.0e-100)) == r"1.0 \cdot 10^{-100}" + assert latex(Float(1.0e-100), mul_symbol="times") == \ + r"1.0 \times 10^{-100}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=2) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=4) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=5) == \ + r"10000.0" + assert latex(Float('0.099999'), full_prec=True, min=-2, max=5) == \ + r"9.99990000000000 \cdot 10^{-2}" + + +def test_latex_vector_expressions(): + A = CoordSys3D('A') + + assert latex(Cross(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Cross(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}" + assert latex(x*Cross(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}\right)" + assert latex(Cross(x*A.i, A.j)) == \ + r'- \mathbf{\hat{j}_{A}} \times \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)' + + assert latex(Curl(3*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*A.x*A.j+A.i)) == \ + r"\nabla\times \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*x*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}} x\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Curl(3*A.x*A.j)) == \ + r"x \left(\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Divergence(3*A.x*A.j+A.i)) == \ + r"\nabla\cdot \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Divergence(3*A.x*A.j)) == \ + r"\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Divergence(3*A.x*A.j)) == \ + r"x \left(\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Dot(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Dot(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}" + assert latex(Dot(x*A.i, A.j)) == \ + r"\mathbf{\hat{j}_{A}} \cdot \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)" + assert latex(x*Dot(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}\right)" + + assert latex(Gradient(A.x)) == r"\nabla \mathbf{{x}_{A}}" + assert latex(Gradient(A.x + 3*A.y)) == \ + r"\nabla \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Gradient(A.x)) == r"x \left(\nabla \mathbf{{x}_{A}}\right)" + assert latex(Gradient(x*A.x)) == r"\nabla \left(\mathbf{{x}_{A}} x\right)" + + assert latex(Laplacian(A.x)) == r"\Delta \mathbf{{x}_{A}}" + assert latex(Laplacian(A.x + 3*A.y)) == \ + r"\Delta \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Laplacian(A.x)) == r"x \left(\Delta \mathbf{{x}_{A}}\right)" + assert latex(Laplacian(x*A.x)) == r"\Delta \left(\mathbf{{x}_{A}} x\right)" + +def test_latex_symbols(): + Gamma, lmbda, rho = symbols('Gamma, lambda, rho') + tau, Tau, TAU, taU = symbols('tau, Tau, TAU, taU') + assert latex(tau) == r"\tau" + assert latex(Tau) == r"\mathrm{T}" + assert latex(TAU) == r"\tau" + assert latex(taU) == r"\tau" + # Check that all capitalized greek letters are handled explicitly + capitalized_letters = {l.capitalize() for l in greek_letters_set} + assert len(capitalized_letters - set(tex_greek_dictionary.keys())) == 0 + assert latex(Gamma + lmbda) == r"\Gamma + \lambda" + assert latex(Gamma * lmbda) == r"\Gamma \lambda" + assert latex(Symbol('q1')) == r"q_{1}" + assert latex(Symbol('q21')) == r"q_{21}" + assert latex(Symbol('epsilon0')) == r"\epsilon_{0}" + assert latex(Symbol('omega1')) == r"\omega_{1}" + assert latex(Symbol('91')) == r"91" + assert latex(Symbol('alpha_new')) == r"\alpha_{new}" + assert latex(Symbol('C^orig')) == r"C^{orig}" + assert latex(Symbol('x^alpha')) == r"x^{\alpha}" + assert latex(Symbol('beta^alpha')) == r"\beta^{\alpha}" + assert latex(Symbol('e^Alpha')) == r"e^{\mathrm{A}}" + assert latex(Symbol('omega_alpha^beta')) == r"\omega^{\beta}_{\alpha}" + assert latex(Symbol('omega') ** Symbol('beta')) == r"\omega^{\beta}" + + +@XFAIL +def test_latex_symbols_failing(): + rho, mass, volume = symbols('rho, mass, volume') + assert latex( + volume * rho == mass) == r"\rho \mathrm{volume} = \mathrm{mass}" + assert latex(volume / mass * rho == 1) == \ + r"\rho \mathrm{volume} {\mathrm{mass}}^{(-1)} = 1" + assert latex(mass**3 * volume**3) == \ + r"{\mathrm{mass}}^{3} \cdot {\mathrm{volume}}^{3}" + + +@_both_exp_pow +def test_latex_functions(): + assert latex(exp(x)) == r"e^{x}" + assert latex(exp(1) + exp(2)) == r"e + e^{2}" + + f = Function('f') + assert latex(f(x)) == r'f{\left(x \right)}' + assert latex(f) == r'f' + + g = Function('g') + assert latex(g(x, y)) == r'g{\left(x,y \right)}' + assert latex(g) == r'g' + + h = Function('h') + assert latex(h(x, y, z)) == r'h{\left(x,y,z \right)}' + assert latex(h) == r'h' + + Li = Function('Li') + assert latex(Li) == r'\operatorname{Li}' + assert latex(Li(x)) == r'\operatorname{Li}{\left(x \right)}' + + mybeta = Function('beta') + # not to be confused with the beta function + assert latex(mybeta(x, y, z)) == r"\beta{\left(x,y,z \right)}" + assert latex(beta(x, y)) == r'\operatorname{B}\left(x, y\right)' + assert latex(beta(x, evaluate=False)) == r'\operatorname{B}\left(x, x\right)' + assert latex(beta(x, y)**2) == r'\operatorname{B}^{2}\left(x, y\right)' + assert latex(mybeta(x)) == r"\beta{\left(x \right)}" + assert latex(mybeta) == r"\beta" + + g = Function('gamma') + # not to be confused with the gamma function + assert latex(g(x, y, z)) == r"\gamma{\left(x,y,z \right)}" + assert latex(g(x)) == r"\gamma{\left(x \right)}" + assert latex(g) == r"\gamma" + + a_1 = Function('a_1') + assert latex(a_1) == r"a_{1}" + assert latex(a_1(x)) == r"a_{1}{\left(x \right)}" + assert latex(Function('a_1')) == r"a_{1}" + + # Issue #16925 + # multi letter function names + # > simple + assert latex(Function('ab')) == r"\operatorname{ab}" + assert latex(Function('ab1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab_12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_c')) == r"\operatorname{ab}_{c}" + assert latex(Function('ab_cd')) == r"\operatorname{ab}_{cd}" + # > with argument + assert latex(Function('ab')(Symbol('x'))) == r"\operatorname{ab}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))) == r"\operatorname{ab}_{12}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab_c')(Symbol('x'))) == r"\operatorname{ab}_{c}{\left(x \right)}" + assert latex(Function('ab_cd')(Symbol('x'))) == r"\operatorname{ab}_{cd}{\left(x \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('ab')()**2) == r"\operatorname{ab}^{2}{\left( \right)}" + assert latex(Function('ab1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab_1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab_12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab')(Symbol('x'))**2) == r"\operatorname{ab}^{2}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))**2) == r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab_12')(Symbol('x'))**2) == \ + r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + + # single letter function names + # > simple + assert latex(Function('a')) == r"a" + assert latex(Function('a1')) == r"a_{1}" + assert latex(Function('a12')) == r"a_{12}" + assert latex(Function('a_1')) == r"a_{1}" + assert latex(Function('a_12')) == r"a_{12}" + + # > with argument + assert latex(Function('a')()) == r"a{\left( \right)}" + assert latex(Function('a1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a12')()) == r"a_{12}{\left( \right)}" + assert latex(Function('a_1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a_12')()) == r"a_{12}{\left( \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('a')()**2) == r"a^{2}{\left( \right)}" + assert latex(Function('a1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a_1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a_12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**2) == r"a^{2}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + + assert latex(Function('a')()**32) == r"a^{32}{\left( \right)}" + assert latex(Function('a1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a_1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a_12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**32) == r"a^{32}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + + assert latex(Function('a')()**a) == r"a^{a}{\left( \right)}" + assert latex(Function('a1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a_1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a_12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**a) == r"a^{a}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + + ab = Symbol('ab') + assert latex(Function('a')()**ab) == r"a^{ab}{\left( \right)}" + assert latex(Function('a1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a_1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a_12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**ab) == r"a^{ab}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + + assert latex(Function('a^12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a^12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a__12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a__12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a_1__1_2')(x)) == R"a^{1}_{1 2}{\left(x \right)}" + + # issue 5868 + omega1 = Function('omega1') + assert latex(omega1) == r"\omega_{1}" + assert latex(omega1(x)) == r"\omega_{1}{\left(x \right)}" + + assert latex(sin(x)) == r"\sin{\left(x \right)}" + assert latex(sin(x), fold_func_brackets=True) == r"\sin {x}" + assert latex(sin(2*x**2), fold_func_brackets=True) == \ + r"\sin {2 x^{2}}" + assert latex(sin(x**2), fold_func_brackets=True) == \ + r"\sin {x^{2}}" + + assert latex(asin(x)**2) == r"\operatorname{asin}^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="full") == \ + r"\arcsin^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="power") == \ + r"\sin^{-1}{\left(x \right)}^{2}" + assert latex(asin(x**2), inv_trig_style="power", + fold_func_brackets=True) == \ + r"\sin^{-1} {x^{2}}" + assert latex(acsc(x), inv_trig_style="full") == \ + r"\operatorname{arccsc}{\left(x \right)}" + assert latex(asinh(x), inv_trig_style="full") == \ + r"\operatorname{arsinh}{\left(x \right)}" + + assert latex(factorial(k)) == r"k!" + assert latex(factorial(-k)) == r"\left(- k\right)!" + assert latex(factorial(k)**2) == r"k!^{2}" + + assert latex(subfactorial(k)) == r"!k" + assert latex(subfactorial(-k)) == r"!\left(- k\right)" + assert latex(subfactorial(k)**2) == r"\left(!k\right)^{2}" + + assert latex(factorial2(k)) == r"k!!" + assert latex(factorial2(-k)) == r"\left(- k\right)!!" + assert latex(factorial2(k)**2) == r"k!!^{2}" + + assert latex(binomial(2, k)) == r"{\binom{2}{k}}" + assert latex(binomial(2, k)**2) == r"{\binom{2}{k}}^{2}" + + assert latex(FallingFactorial(3, k)) == r"{\left(3\right)}_{k}" + assert latex(RisingFactorial(3, k)) == r"{3}^{\left(k\right)}" + + assert latex(floor(x)) == r"\left\lfloor{x}\right\rfloor" + assert latex(ceiling(x)) == r"\left\lceil{x}\right\rceil" + assert latex(frac(x)) == r"\operatorname{frac}{\left(x\right)}" + assert latex(floor(x)**2) == r"\left\lfloor{x}\right\rfloor^{2}" + assert latex(ceiling(x)**2) == r"\left\lceil{x}\right\rceil^{2}" + assert latex(frac(x)**2) == r"\operatorname{frac}{\left(x\right)}^{2}" + + assert latex(Min(x, 2, x**3)) == r"\min\left(2, x, x^{3}\right)" + assert latex(Min(x, y)**2) == r"\min\left(x, y\right)^{2}" + assert latex(Max(x, 2, x**3)) == r"\max\left(2, x, x^{3}\right)" + assert latex(Max(x, y)**2) == r"\max\left(x, y\right)^{2}" + assert latex(Abs(x)) == r"\left|{x}\right|" + assert latex(Abs(x)**2) == r"\left|{x}\right|^{2}" + assert latex(re(x)) == r"\operatorname{re}{\left(x\right)}" + assert latex(re(x + y)) == \ + r"\operatorname{re}{\left(x\right)} + \operatorname{re}{\left(y\right)}" + assert latex(im(x)) == r"\operatorname{im}{\left(x\right)}" + assert latex(conjugate(x)) == r"\overline{x}" + assert latex(conjugate(x)**2) == r"\overline{x}^{2}" + assert latex(conjugate(x**2)) == r"\overline{x}^{2}" + assert latex(gamma(x)) == r"\Gamma\left(x\right)" + w = Wild('w') + assert latex(gamma(w)) == r"\Gamma\left(w\right)" + assert latex(Order(x)) == r"O\left(x\right)" + assert latex(Order(x, x)) == r"O\left(x\right)" + assert latex(Order(x, (x, 0))) == r"O\left(x\right)" + assert latex(Order(x, (x, oo))) == r"O\left(x; x\rightarrow \infty\right)" + assert latex(Order(x - y, (x, y))) == \ + r"O\left(x - y; x\rightarrow y\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, (x, oo), (y, oo))) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( \infty, \ \infty\right)\right)" + assert latex(lowergamma(x, y)) == r'\gamma\left(x, y\right)' + assert latex(lowergamma(x, y)**2) == r'\gamma^{2}\left(x, y\right)' + assert latex(uppergamma(x, y)) == r'\Gamma\left(x, y\right)' + assert latex(uppergamma(x, y)**2) == r'\Gamma^{2}\left(x, y\right)' + + assert latex(cot(x)) == r'\cot{\left(x \right)}' + assert latex(coth(x)) == r'\coth{\left(x \right)}' + assert latex(re(x)) == r'\operatorname{re}{\left(x\right)}' + assert latex(im(x)) == r'\operatorname{im}{\left(x\right)}' + assert latex(root(x, y)) == r'x^{\frac{1}{y}}' + assert latex(arg(x)) == r'\arg{\left(x \right)}' + + assert latex(zeta(x)) == r"\zeta\left(x\right)" + assert latex(zeta(x)**2) == r"\zeta^{2}\left(x\right)" + assert latex(zeta(x, y)) == r"\zeta\left(x, y\right)" + assert latex(zeta(x, y)**2) == r"\zeta^{2}\left(x, y\right)" + assert latex(dirichlet_eta(x)) == r"\eta\left(x\right)" + assert latex(dirichlet_eta(x)**2) == r"\eta^{2}\left(x\right)" + assert latex(polylog(x, y)) == r"\operatorname{Li}_{x}\left(y\right)" + assert latex( + polylog(x, y)**2) == r"\operatorname{Li}_{x}^{2}\left(y\right)" + assert latex(lerchphi(x, y, n)) == r"\Phi\left(x, y, n\right)" + assert latex(lerchphi(x, y, n)**2) == r"\Phi^{2}\left(x, y, n\right)" + assert latex(stieltjes(x)) == r"\gamma_{x}" + assert latex(stieltjes(x)**2) == r"\gamma_{x}^{2}" + assert latex(stieltjes(x, y)) == r"\gamma_{x}\left(y\right)" + assert latex(stieltjes(x, y)**2) == r"\gamma_{x}\left(y\right)^{2}" + + assert latex(elliptic_k(z)) == r"K\left(z\right)" + assert latex(elliptic_k(z)**2) == r"K^{2}\left(z\right)" + assert latex(elliptic_f(x, y)) == r"F\left(x\middle| y\right)" + assert latex(elliptic_f(x, y)**2) == r"F^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)) == r"E\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)**2) == r"E^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(z)) == r"E\left(z\right)" + assert latex(elliptic_e(z)**2) == r"E^{2}\left(z\right)" + assert latex(elliptic_pi(x, y, z)) == r"\Pi\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y, z)**2) == \ + r"\Pi^{2}\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y)) == r"\Pi\left(x\middle| y\right)" + assert latex(elliptic_pi(x, y)**2) == r"\Pi^{2}\left(x\middle| y\right)" + + assert latex(Ei(x)) == r'\operatorname{Ei}{\left(x \right)}' + assert latex(Ei(x)**2) == r'\operatorname{Ei}^{2}{\left(x \right)}' + assert latex(expint(x, y)) == r'\operatorname{E}_{x}\left(y\right)' + assert latex(expint(x, y)**2) == r'\operatorname{E}_{x}^{2}\left(y\right)' + assert latex(Shi(x)**2) == r'\operatorname{Shi}^{2}{\left(x \right)}' + assert latex(Si(x)**2) == r'\operatorname{Si}^{2}{\left(x \right)}' + assert latex(Ci(x)**2) == r'\operatorname{Ci}^{2}{\left(x \right)}' + assert latex(Chi(x)**2) == r'\operatorname{Chi}^{2}\left(x\right)' + assert latex(Chi(x)) == r'\operatorname{Chi}\left(x\right)' + assert latex(jacobi(n, a, b, x)) == \ + r'P_{n}^{\left(a,b\right)}\left(x\right)' + assert latex(jacobi(n, a, b, x)**2) == \ + r'\left(P_{n}^{\left(a,b\right)}\left(x\right)\right)^{2}' + assert latex(gegenbauer(n, a, x)) == \ + r'C_{n}^{\left(a\right)}\left(x\right)' + assert latex(gegenbauer(n, a, x)**2) == \ + r'\left(C_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(chebyshevt(n, x)) == r'T_{n}\left(x\right)' + assert latex(chebyshevt(n, x)**2) == \ + r'\left(T_{n}\left(x\right)\right)^{2}' + assert latex(chebyshevu(n, x)) == r'U_{n}\left(x\right)' + assert latex(chebyshevu(n, x)**2) == \ + r'\left(U_{n}\left(x\right)\right)^{2}' + assert latex(legendre(n, x)) == r'P_{n}\left(x\right)' + assert latex(legendre(n, x)**2) == r'\left(P_{n}\left(x\right)\right)^{2}' + assert latex(assoc_legendre(n, a, x)) == \ + r'P_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_legendre(n, a, x)**2) == \ + r'\left(P_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(laguerre(n, x)) == r'L_{n}\left(x\right)' + assert latex(laguerre(n, x)**2) == r'\left(L_{n}\left(x\right)\right)^{2}' + assert latex(assoc_laguerre(n, a, x)) == \ + r'L_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_laguerre(n, a, x)**2) == \ + r'\left(L_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(hermite(n, x)) == r'H_{n}\left(x\right)' + assert latex(hermite(n, x)**2) == r'\left(H_{n}\left(x\right)\right)^{2}' + + theta = Symbol("theta", real=True) + phi = Symbol("phi", real=True) + assert latex(Ynm(n, m, theta, phi)) == r'Y_{n}^{m}\left(\theta,\phi\right)' + assert latex(Ynm(n, m, theta, phi)**3) == \ + r'\left(Y_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + assert latex(Znm(n, m, theta, phi)) == r'Z_{n}^{m}\left(\theta,\phi\right)' + assert latex(Znm(n, m, theta, phi)**3) == \ + r'\left(Z_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + + # Test latex printing of function names with "_" + assert latex(polar_lift(0)) == \ + r"\operatorname{polar\_lift}{\left(0 \right)}" + assert latex(polar_lift(0)**3) == \ + r"\operatorname{polar\_lift}^{3}{\left(0 \right)}" + + assert latex(totient(n)) == r'\phi\left(n\right)' + assert latex(totient(n) ** 2) == r'\left(\phi\left(n\right)\right)^{2}' + + assert latex(reduced_totient(n)) == r'\lambda\left(n\right)' + assert latex(reduced_totient(n) ** 2) == \ + r'\left(\lambda\left(n\right)\right)^{2}' + + assert latex(divisor_sigma(x)) == r"\sigma\left(x\right)" + assert latex(divisor_sigma(x)**2) == r"\sigma^{2}\left(x\right)" + assert latex(divisor_sigma(x, y)) == r"\sigma_y\left(x\right)" + assert latex(divisor_sigma(x, y)**2) == r"\sigma^{2}_y\left(x\right)" + + assert latex(udivisor_sigma(x)) == r"\sigma^*\left(x\right)" + assert latex(udivisor_sigma(x)**2) == r"\sigma^*^{2}\left(x\right)" + assert latex(udivisor_sigma(x, y)) == r"\sigma^*_y\left(x\right)" + assert latex(udivisor_sigma(x, y)**2) == r"\sigma^*^{2}_y\left(x\right)" + + assert latex(primenu(n)) == r'\nu\left(n\right)' + assert latex(primenu(n) ** 2) == r'\left(\nu\left(n\right)\right)^{2}' + + assert latex(primeomega(n)) == r'\Omega\left(n\right)' + assert latex(primeomega(n) ** 2) == \ + r'\left(\Omega\left(n\right)\right)^{2}' + + assert latex(LambertW(n)) == r'W\left(n\right)' + assert latex(LambertW(n, -1)) == r'W_{-1}\left(n\right)' + assert latex(LambertW(n, k)) == r'W_{k}\left(n\right)' + assert latex(LambertW(n) * LambertW(n)) == r"W^{2}\left(n\right)" + assert latex(Pow(LambertW(n), 2)) == r"W^{2}\left(n\right)" + assert latex(LambertW(n)**k) == r"W^{k}\left(n\right)" + assert latex(LambertW(n, k)**p) == r"W^{p}_{k}\left(n\right)" + + assert latex(Mod(x, 7)) == r'x \bmod 7' + assert latex(Mod(x + 1, 7)) == r'\left(x + 1\right) \bmod 7' + assert latex(Mod(7, x + 1)) == r'7 \bmod \left(x + 1\right)' + assert latex(Mod(2 * x, 7)) == r'2 x \bmod 7' + assert latex(Mod(7, 2 * x)) == r'7 \bmod 2 x' + assert latex(Mod(x, 7) + 1) == r'\left(x \bmod 7\right) + 1' + assert latex(2 * Mod(x, 7)) == r'2 \left(x \bmod 7\right)' + assert latex(Mod(7, 2 * x)**n) == r'\left(7 \bmod 2 x\right)^{n}' + + # some unknown function name should get rendered with \operatorname + fjlkd = Function('fjlkd') + assert latex(fjlkd(x)) == r'\operatorname{fjlkd}{\left(x \right)}' + # even when it is referred to without an argument + assert latex(fjlkd) == r'\operatorname{fjlkd}' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert latex(mygamma) == r"\operatorname{mygamma}" + assert latex(mygamma(x)) == r"\operatorname{mygamma}{\left(x \right)}" + + +def test_hyper_printing(): + from sympy.abc import x, z + + assert latex(meijerg(Tuple(pi, pi, x), Tuple(1), + (0, 1), Tuple(1, 2, 3/pi), z)) == \ + r'{G_{4, 5}^{2, 3}\left(\begin{matrix} \pi, \pi, x & 1 \\0, 1 & 1, 2, '\ + r'\frac{3}{\pi} \end{matrix} \middle| {z} \right)}' + assert latex(meijerg(Tuple(), Tuple(1), (0,), Tuple(), z)) == \ + r'{G_{1, 1}^{1, 0}\left(\begin{matrix} & 1 \\0 & \end{matrix} \middle| {z} \right)}' + assert latex(hyper((x, 2), (3,), z)) == \ + r'{{}_{2}F_{1}\left(\begin{matrix} 2, x ' \ + r'\\ 3 \end{matrix}\middle| {z} \right)}' + assert latex(hyper(Tuple(), Tuple(1), z)) == \ + r'{{}_{0}F_{1}\left(\begin{matrix} ' \ + r'\\ 1 \end{matrix}\middle| {z} \right)}' + + +def test_latex_bessel(): + from sympy.functions.special.bessel import (besselj, bessely, besseli, + besselk, hankel1, hankel2, + jn, yn, hn1, hn2) + from sympy.abc import z + assert latex(besselj(n, z**2)**k) == r'J^{k}_{n}\left(z^{2}\right)' + assert latex(bessely(n, z)) == r'Y_{n}\left(z\right)' + assert latex(besseli(n, z)) == r'I_{n}\left(z\right)' + assert latex(besselk(n, z)) == r'K_{n}\left(z\right)' + assert latex(hankel1(n, z**2)**2) == \ + r'\left(H^{(1)}_{n}\left(z^{2}\right)\right)^{2}' + assert latex(hankel2(n, z)) == r'H^{(2)}_{n}\left(z\right)' + assert latex(jn(n, z)) == r'j_{n}\left(z\right)' + assert latex(yn(n, z)) == r'y_{n}\left(z\right)' + assert latex(hn1(n, z)) == r'h^{(1)}_{n}\left(z\right)' + assert latex(hn2(n, z)) == r'h^{(2)}_{n}\left(z\right)' + + +def test_latex_fresnel(): + from sympy.functions.special.error_functions import (fresnels, fresnelc) + from sympy.abc import z + assert latex(fresnels(z)) == r'S\left(z\right)' + assert latex(fresnelc(z)) == r'C\left(z\right)' + assert latex(fresnels(z)**2) == r'S^{2}\left(z\right)' + assert latex(fresnelc(z)**2) == r'C^{2}\left(z\right)' + + +def test_latex_brackets(): + assert latex((-1)**x) == r"\left(-1\right)^{x}" + + +def test_latex_indexed(): + Psi_symbol = Symbol('Psi_0', complex=True, real=False) + Psi_indexed = IndexedBase(Symbol('Psi', complex=True, real=False)) + symbol_latex = latex(Psi_symbol * conjugate(Psi_symbol)) + indexed_latex = latex(Psi_indexed[0] * conjugate(Psi_indexed[0])) + # \\overline{{\\Psi}_{0}} {\\Psi}_{0} vs. \\Psi_{0} \\overline{\\Psi_{0}} + assert symbol_latex == r'\Psi_{0} \overline{\Psi_{0}}' + assert indexed_latex == r'\overline{{\Psi}_{0}} {\Psi}_{0}' + + # Symbol('gamma') gives r'\gamma' + interval = '\\mathrel{..}\\nobreak ' + assert latex(Indexed('x1', Symbol('i'))) == r'{x_{1}}_{i}' + assert latex(Indexed('x2', Idx('i'))) == r'{x_{2}}_{i}' + assert latex(Indexed('x3', Idx('i', Symbol('N')))) == r'{x_{3}}_{{i}_{0'+interval+'N - 1}}' + assert latex(Indexed('x3', Idx('i', Symbol('N')+1))) == r'{x_{3}}_{{i}_{0'+interval+'N}}' + assert latex(Indexed('x4', Idx('i', (Symbol('a'),Symbol('b'))))) == r'{x_{4}}_{{i}_{a'+interval+'b}}' + assert latex(IndexedBase('gamma')) == r'\gamma' + assert latex(IndexedBase('a b')) == r'a b' + assert latex(IndexedBase('a_b')) == r'a_{b}' + + +def test_latex_derivatives(): + # regular "d" for ordinary derivatives + assert latex(diff(x**3, x, evaluate=False)) == \ + r"\frac{d}{d x} x^{3}" + assert latex(diff(sin(x) + x**2, x, evaluate=False)) == \ + r"\frac{d}{d x} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False))\ + == \ + r"\frac{d^{2}}{d x^{2}} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False), evaluate=False)) == \ + r"\frac{d^{3}}{d x^{3}} \left(x^{2} + \sin{\left(x \right)}\right)" + + # \partial for partial derivatives + assert latex(diff(sin(x * y), x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \sin{\left(x y \right)}" + assert latex(diff(sin(x * y) + x**2, x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial x^{2}} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial x^{3}} \left(x^{2} + \sin{\left(x y \right)}\right)" + + # mixed partial derivatives + f = Function("f") + assert latex(diff(diff(f(x, y), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial y\partial x} " + latex(f(x, y)) + + assert latex(diff(diff(diff(f(x, y), x, evaluate=False), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial y\partial x^{2}} " + latex(f(x, y)) + + # for negative nested Derivative + assert latex(diff(-diff(y**2,x,evaluate=False),x,evaluate=False)) == r'\frac{d}{d x} \left(- \frac{d}{d x} y^{2}\right)' + assert latex(diff(diff(-diff(diff(y,x,evaluate=False),x,evaluate=False),x,evaluate=False),x,evaluate=False)) == \ + r'\frac{d^{2}}{d x^{2}} \left(- \frac{d^{2}}{d x^{2}} y\right)' + + # use ordinary d when one of the variables has been integrated out + assert latex(diff(Integral(exp(-x*y), (x, 0, oo)), y, evaluate=False)) == \ + r"\frac{d}{d y} \int\limits_{0}^{\infty} e^{- x y}\, dx" + + # Derivative wrapped in power: + assert latex(diff(x, x, evaluate=False)**2) == \ + r"\left(\frac{d}{d x} x\right)^{2}" + + assert latex(diff(f(x), x)**2) == \ + r"\left(\frac{d}{d x} f{\left(x \right)}\right)^{2}" + + assert latex(diff(f(x), (x, n))) == \ + r"\frac{d^{n}}{d x^{n}} f{\left(x \right)}" + + x1 = Symbol('x1') + x2 = Symbol('x2') + assert latex(diff(f(x1, x2), x1)) == r'\frac{\partial}{\partial x_{1}} f{\left(x_{1},x_{2} \right)}' + + n1 = Symbol('n1') + assert latex(diff(f(x), (x, n1))) == r'\frac{d^{n_{1}}}{d x^{n_{1}}} f{\left(x \right)}' + + n2 = Symbol('n2') + assert latex(diff(f(x), (x, Max(n1, n2)))) == \ + r'\frac{d^{\max\left(n_{1}, n_{2}\right)}}{d x^{\max\left(n_{1}, n_{2}\right)}} f{\left(x \right)}' + + # set diff operator + assert latex(diff(f(x), x), diff_operator="rd") == r'\frac{\mathrm{d}}{\mathrm{d} x} f{\left(x \right)}' + + +def test_latex_subs(): + assert latex(Subs(x*y, (x, y), (1, 2))) == r'\left. x y \right|_{\substack{ x=1\\ y=2 }}' + + +def test_latex_integrals(): + assert latex(Integral(log(x), x)) == r"\int \log{\left(x \right)}\, dx" + assert latex(Integral(x**2, (x, 0, 1))) == \ + r"\int\limits_{0}^{1} x^{2}\, dx" + assert latex(Integral(x**2, (x, 10, 20))) == \ + r"\int\limits_{10}^{20} x^{2}\, dx" + assert latex(Integral(y*x**2, (x, 0, 1), y)) == \ + r"\int\int\limits_{0}^{1} x^{2} y\, dx\, dy" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*') == \ + r"\begin{equation*}\int\int\limits_{0}^{1} x^{2} y\, dx\, dy\end{equation*}" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*', itex=True) \ + == r"$$\int\int_{0}^{1} x^{2} y\, dx\, dy$$" + assert latex(Integral(x, (x, 0))) == r"\int\limits^{0} x\, dx" + assert latex(Integral(x*y, x, y)) == r"\iint x y\, dx\, dy" + assert latex(Integral(x*y*z, x, y, z)) == r"\iiint x y z\, dx\, dy\, dz" + assert latex(Integral(x*y*z*t, x, y, z, t)) == \ + r"\iiiint t x y z\, dx\, dy\, dz\, dt" + assert latex(Integral(x, x, x, x, x, x, x)) == \ + r"\int\int\int\int\int\int x\, dx\, dx\, dx\, dx\, dx\, dx" + assert latex(Integral(x, x, y, (z, 0, 1))) == \ + r"\int\limits_{0}^{1}\int\int x\, dx\, dy\, dz" + + # for negative nested Integral + assert latex(Integral(-Integral(y**2,x),x)) == \ + r'\int \left(- \int y^{2}\, dx\right)\, dx' + assert latex(Integral(-Integral(-Integral(y,x),x),x)) == \ + r'\int \left(- \int \left(- \int y\, dx\right)\, dx\right)\, dx' + + # fix issue #10806 + assert latex(Integral(z, z)**2) == r"\left(\int z\, dz\right)^{2}" + assert latex(Integral(x + z, z)) == r"\int \left(x + z\right)\, dz" + assert latex(Integral(x+z/2, z)) == \ + r"\int \left(x + \frac{z}{2}\right)\, dz" + assert latex(Integral(x**y, z)) == r"\int x^{y}\, dz" + + # set diff operator + assert latex(Integral(x, x), diff_operator="rd") == r'\int x\, \mathrm{d}x' + assert latex(Integral(x, (x, 0, 1)), diff_operator="rd") == r'\int\limits_{0}^{1} x\, \mathrm{d}x' + + +def test_latex_sets(): + for s in (frozenset, set): + assert latex(s([x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + s = FiniteSet + assert latex(s(*[x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(*range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(*range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + +def test_latex_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + assert latex(se) == r"SetExpr\left(\left[1, 3\right]\right)" + + +def test_latex_Range(): + assert latex(Range(1, 51)) == r'\left\{1, 2, \ldots, 50\right\}' + assert latex(Range(1, 4)) == r'\left\{1, 2, 3\right\}' + assert latex(Range(0, 3, 1)) == r'\left\{0, 1, 2\right\}' + assert latex(Range(0, 30, 1)) == r'\left\{0, 1, \ldots, 29\right\}' + assert latex(Range(30, 1, -1)) == r'\left\{30, 29, \ldots, 2\right\}' + assert latex(Range(0, oo, 2)) == r'\left\{0, 2, \ldots\right\}' + assert latex(Range(oo, -2, -2)) == r'\left\{\ldots, 2, 0\right\}' + assert latex(Range(-2, -oo, -1)) == r'\left\{-2, -3, \ldots\right\}' + assert latex(Range(-oo, oo)) == r'\left\{\ldots, -1, 0, 1, \ldots\right\}' + assert latex(Range(oo, -oo, -1)) == r'\left\{\ldots, 1, 0, -1, \ldots\right\}' + + a, b, c = symbols('a:c') + assert latex(Range(a, b, c)) == r'\text{Range}\left(a, b, c\right)' + assert latex(Range(a, 10, 1)) == r'\text{Range}\left(a, 10\right)' + assert latex(Range(0, b, 1)) == r'\text{Range}\left(b\right)' + assert latex(Range(0, 10, c)) == r'\text{Range}\left(0, 10, c\right)' + + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert latex(Range(i, i + 3)) == r'\left\{i, i + 1, i + 2\right\}' + assert latex(Range(-oo, n, 2)) == r'\left\{\ldots, n - 4, n - 2\right\}' + assert latex(Range(p, oo)) == r'\left\{p, p + 1, \ldots\right\}' + # The following will work if __iter__ is improved + # assert latex(Range(-3, p + 7)) == r'\left\{-3, -2, \ldots, p + 6\right\}' + # Must have integer assumptions + assert latex(Range(a, a + 3)) == r'\text{Range}\left(a, a + 3\right)' + + +def test_latex_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + latex_str = r'\left[0, 1, 4, 9, \ldots\right]' + assert latex(s1) == latex_str + + latex_str = r'\left[1, 2, 1, 2, \ldots\right]' + assert latex(s2) == latex_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + latex_str = r'\left[0, 1, 4\right]' + assert latex(s3) == latex_str + + latex_str = r'\left[1, 2, 1\right]' + assert latex(s4) == latex_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + latex_str = r'\left[\ldots, 9, 4, 1, 0\right]' + assert latex(s5) == latex_str + + latex_str = r'\left[\ldots, 2, 1, 2, 1\right]' + assert latex(s6) == latex_str + + latex_str = r'\left[1, 3, 5, 11, \ldots\right]' + assert latex(SeqAdd(s1, s2)) == latex_str + + latex_str = r'\left[1, 3, 5\right]' + assert latex(SeqAdd(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 11, 5, 3, 1\right]' + assert latex(SeqAdd(s5, s6)) == latex_str + + latex_str = r'\left[0, 2, 4, 18, \ldots\right]' + assert latex(SeqMul(s1, s2)) == latex_str + + latex_str = r'\left[0, 2, 4\right]' + assert latex(SeqMul(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 18, 4, 2, 0\right]' + assert latex(SeqMul(s5, s6)) == latex_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + latex_str = r'\left\{a^{2}\right\}_{a=0}^{x}' + assert latex(s7) == latex_str + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + latex_str = r'\left[0, b, 4 b\right]' + assert latex(s8) == latex_str + + +def test_latex_FourierSeries(): + latex_str = \ + r'2 \sin{\left(x \right)} - \sin{\left(2 x \right)} + \frac{2 \sin{\left(3 x \right)}}{3} + \ldots' + assert latex(fourier_series(x, (x, -pi, pi))) == latex_str + + +def test_latex_FormalPowerSeries(): + latex_str = r'\sum_{k=1}^{\infty} - \frac{\left(-1\right)^{- k} x^{k}}{k}' + assert latex(fps(log(1 + x))) == latex_str + + +def test_latex_intervals(): + a = Symbol('a', real=True) + assert latex(Interval(0, 0)) == r"\left\{0\right\}" + assert latex(Interval(0, a)) == r"\left[0, a\right]" + assert latex(Interval(0, a, False, False)) == r"\left[0, a\right]" + assert latex(Interval(0, a, True, False)) == r"\left(0, a\right]" + assert latex(Interval(0, a, False, True)) == r"\left[0, a\right)" + assert latex(Interval(0, a, True, True)) == r"\left(0, a\right)" + + +def test_latex_AccumuBounds(): + a = Symbol('a', real=True) + assert latex(AccumBounds(0, 1)) == r"\left\langle 0, 1\right\rangle" + assert latex(AccumBounds(0, a)) == r"\left\langle 0, a\right\rangle" + assert latex(AccumBounds(a + 1, a + 2)) == \ + r"\left\langle a + 1, a + 2\right\rangle" + + +def test_latex_emptyset(): + assert latex(S.EmptySet) == r"\emptyset" + + +def test_latex_universalset(): + assert latex(S.UniversalSet) == r"\mathbb{U}" + + +def test_latex_commutator(): + A = Operator('A') + B = Operator('B') + comm = Commutator(B, A) + assert latex(comm.doit()) == r"- (A B - B A)" + + +def test_latex_union(): + assert latex(Union(Interval(0, 1), Interval(2, 3))) == \ + r"\left[0, 1\right] \cup \left[2, 3\right]" + assert latex(Union(Interval(1, 1), Interval(2, 2), Interval(3, 4))) == \ + r"\left\{1, 2\right\} \cup \left[3, 4\right]" + + +def test_latex_intersection(): + assert latex(Intersection(Interval(0, 1), Interval(x, y))) == \ + r"\left[0, 1\right] \cap \left[x, y\right]" + + +def test_latex_symmetric_difference(): + assert latex(SymmetricDifference(Interval(2, 5), Interval(4, 7), + evaluate=False)) == \ + r'\left[2, 5\right] \triangle \left[4, 7\right]' + + +def test_latex_Complement(): + assert latex(Complement(S.Reals, S.Naturals)) == \ + r"\mathbb{R} \setminus \mathbb{N}" + + +def test_latex_productset(): + line = Interval(0, 1) + bigline = Interval(0, 10) + fset = FiniteSet(1, 2, 3) + assert latex(line**2) == r"%s^{2}" % latex(line) + assert latex(line**10) == r"%s^{10}" % latex(line) + assert latex((line * bigline * fset).flatten()) == r"%s \times %s \times %s" % ( + latex(line), latex(bigline), latex(fset)) + + +def test_latex_powerset(): + fset = FiniteSet(1, 2, 3) + assert latex(PowerSet(fset)) == r'\mathcal{P}\left(\left\{1, 2, 3\right\}\right)' + + +def test_latex_ordinals(): + w = OrdinalOmega() + assert latex(w) == r"\omega" + wp = OmegaPower(2, 3) + assert latex(wp) == r'3 \omega^{2}' + assert latex(Ordinal(wp, OmegaPower(1, 1))) == r'3 \omega^{2} + \omega' + assert latex(Ordinal(OmegaPower(2, 1), OmegaPower(1, 2))) == r'\omega^{2} + 2 \omega' + + +def test_set_operators_parenthesis(): + a, b, c, d = symbols('a:d') + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(A, B, evaluate=False) + D2 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert latex(Intersection(A, U2, evaluate=False)) == \ + r'\left\{a\right\} \cap ' \ + r'\left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Intersection(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Intersection(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Union(A, I2, evaluate=False)) == \ + r'\left\{a\right\} \cup ' \ + r'\left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Union(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Union(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Complement(A, C2, evaluate=False)) == \ + r'\left\{a\right\} \setminus \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Complement(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \setminus ' \ + r'\left(\left\{c\right\} \triangle \left\{d\right\}\right)' + assert latex(Complement(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) '\ + r'\setminus \left(\left\{c\right\} \times '\ + r'\left\{d\right\}\right)' + + assert latex(SymmetricDifference(A, D2, evaluate=False)) == \ + r'\left\{a\right\} \triangle \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(SymmetricDifference(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \triangle ' \ + r'\left(\left\{c\right\} \setminus \left\{d\right\}\right)' + assert latex(SymmetricDifference(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + # XXX This can be incorrect since cartesian product is not associative + assert latex(ProductSet(A, P2).flatten()) == \ + r'\left\{a\right\} \times \left\{c\right\} \times ' \ + r'\left\{d\right\}' + assert latex(ProductSet(U1, U2)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(I1, I2)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(C1, C2)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(ProductSet(D1, D2)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + + +def test_latex_Complexes(): + assert latex(S.Complexes) == r"\mathbb{C}" + + +def test_latex_Naturals(): + assert latex(S.Naturals) == r"\mathbb{N}" + + +def test_latex_Naturals0(): + assert latex(S.Naturals0) == r"\mathbb{N}_0" + + +def test_latex_Integers(): + assert latex(S.Integers) == r"\mathbb{Z}" + + +def test_latex_ImageSet(): + x = Symbol('x') + assert latex(ImageSet(Lambda(x, x**2), S.Naturals)) == \ + r"\left\{x^{2}\; \middle|\; x \in \mathbb{N}\right\}" + + y = Symbol('y') + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; x \in \left\{1, 2, 3\right\}, y \in \left\{3, 4\right\}\right\}" + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; \left( x, \ y\right) \in \left\{1, 2, 3\right\} \times \left\{3, 4\right\}\right\}" + + +def test_latex_ConditionSet(): + x = Symbol('x') + assert latex(ConditionSet(x, Eq(x**2, 1), S.Reals)) == \ + r"\left\{x\; \middle|\; x \in \mathbb{R} \wedge x^{2} = 1 \right\}" + assert latex(ConditionSet(x, Eq(x**2, 1), S.UniversalSet)) == \ + r"\left\{x\; \middle|\; x^{2} = 1 \right\}" + + +def test_latex_ComplexRegion(): + assert latex(ComplexRegion(Interval(3, 5)*Interval(4, 6))) == \ + r"\left\{x + y i\; \middle|\; x, y \in \left[3, 5\right] \times \left[4, 6\right] \right\}" + assert latex(ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True)) == \ + r"\left\{r \left(i \sin{\left(\theta \right)} + \cos{\left(\theta "\ + r"\right)}\right)\; \middle|\; r, \theta \in \left[0, 1\right] \times \left[0, 2 \pi\right) \right\}" + + +def test_latex_Contains(): + x = Symbol('x') + assert latex(Contains(x, S.Naturals)) == r"x \in \mathbb{N}" + + +def test_latex_sum(): + assert latex(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\sum_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Sum(x**2, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} x^{2}" + assert latex(Sum(x**2 + y, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} \left(x^{2} + y\right)" + assert latex(Sum(x**2 + y, (x, -2, 2))**2) == \ + r"\left(\sum_{x=-2}^{2} \left(x^{2} + y\right)\right)^{2}" + + +def test_latex_product(): + assert latex(Product(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\prod_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Product(x**2, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} x^{2}" + assert latex(Product(x**2 + y, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} \left(x^{2} + y\right)" + + assert latex(Product(x, (x, -2, 2))**2) == \ + r"\left(\prod_{x=-2}^{2} x\right)^{2}" + + +def test_latex_limits(): + assert latex(Limit(x, x, oo)) == r"\lim_{x \to \infty} x" + + # issue 8175 + f = Function('f') + assert latex(Limit(f(x), x, 0)) == r"\lim_{x \to 0^+} f{\left(x \right)}" + assert latex(Limit(f(x), x, 0, "-")) == \ + r"\lim_{x \to 0^-} f{\left(x \right)}" + + # issue #10806 + assert latex(Limit(f(x), x, 0)**2) == \ + r"\left(\lim_{x \to 0^+} f{\left(x \right)}\right)^{2}" + # bi-directional limit + assert latex(Limit(f(x), x, 0, dir='+-')) == \ + r"\lim_{x \to 0} f{\left(x \right)}" + + +def test_latex_log(): + assert latex(log(x)) == r"\log{\left(x \right)}" + assert latex(log(x), ln_notation=True) == r"\ln{\left(x \right)}" + assert latex(log(x) + log(y)) == \ + r"\log{\left(x \right)} + \log{\left(y \right)}" + assert latex(log(x) + log(y), ln_notation=True) == \ + r"\ln{\left(x \right)} + \ln{\left(y \right)}" + assert latex(pow(log(x), x)) == r"\log{\left(x \right)}^{x}" + assert latex(pow(log(x), x), ln_notation=True) == \ + r"\ln{\left(x \right)}^{x}" + + +def test_issue_3568(): + beta = Symbol(r'\beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + beta = Symbol(r'beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + +def test_latex(): + assert latex((2*tau)**Rational(7, 2)) == r"8 \sqrt{2} \tau^{\frac{7}{2}}" + assert latex((2*mu)**Rational(7, 2), mode='equation*') == \ + r"\begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*}" + assert latex((2*mu)**Rational(7, 2), mode='equation', itex=True) == \ + r"$$8 \sqrt{2} \mu^{\frac{7}{2}}$$" + assert latex([2/x, y]) == r"\left[ \frac{2}{x}, \ y\right]" + + +def test_latex_dict(): + d = {Rational(1): 1, x**2: 2, x: 3, x**3: 4} + assert latex(d) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + D = Dict(d) + assert latex(D) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + + +def test_latex_list(): + ll = [Symbol('omega1'), Symbol('a'), Symbol('alpha')] + assert latex(ll) == r'\left[ \omega_{1}, \ a, \ \alpha\right]' + + +def test_latex_NumberSymbols(): + assert latex(S.Catalan) == "G" + assert latex(S.EulerGamma) == r"\gamma" + assert latex(S.Exp1) == "e" + assert latex(S.GoldenRatio) == r"\phi" + assert latex(S.Pi) == r"\pi" + assert latex(S.TribonacciConstant) == r"\text{TribonacciConstant}" + + +def test_latex_rational(): + # tests issue 3973 + assert latex(-Rational(1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(-1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(1, -2)) == r"- \frac{1}{2}" + assert latex(-Rational(-1, 2)) == r"\frac{1}{2}" + assert latex(-Rational(1, 2)*x) == r"- \frac{x}{2}" + assert latex(-Rational(1, 2)*x + Rational(-2, 3)*y) == \ + r"- \frac{x}{2} - \frac{2 y}{3}" + + +def test_latex_inverse(): + # tests issue 4129 + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/(x + y)) == r"\frac{1}{x + y}" + + +def test_latex_DiracDelta(): + assert latex(DiracDelta(x)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x)**2) == r"\left(\delta\left(x\right)\right)^{2}" + assert latex(DiracDelta(x, 0)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x, 5)) == \ + r"\delta^{\left( 5 \right)}\left( x \right)" + assert latex(DiracDelta(x, 5)**2) == \ + r"\left(\delta^{\left( 5 \right)}\left( x \right)\right)^{2}" + + +def test_latex_Heaviside(): + assert latex(Heaviside(x)) == r"\theta\left(x\right)" + assert latex(Heaviside(x)**2) == r"\left(\theta\left(x\right)\right)^{2}" + + +def test_latex_KroneckerDelta(): + assert latex(KroneckerDelta(x, y)) == r"\delta_{x y}" + assert latex(KroneckerDelta(x, y + 1)) == r"\delta_{x, y + 1}" + # issue 6578 + assert latex(KroneckerDelta(x + 1, y)) == r"\delta_{y, x + 1}" + assert latex(Pow(KroneckerDelta(x, y), 2, evaluate=False)) == \ + r"\left(\delta_{x y}\right)^{2}" + + +def test_latex_LeviCivita(): + assert latex(LeviCivita(x, y, z)) == r"\varepsilon_{x y z}" + assert latex(LeviCivita(x, y, z)**2) == \ + r"\left(\varepsilon_{x y z}\right)^{2}" + assert latex(LeviCivita(x, y, z + 1)) == r"\varepsilon_{x, y, z + 1}" + assert latex(LeviCivita(x, y + 1, z)) == r"\varepsilon_{x, y + 1, z}" + assert latex(LeviCivita(x + 1, y, z)) == r"\varepsilon_{x + 1, y, z}" + + +def test_mode(): + expr = x + y + assert latex(expr) == r'x + y' + assert latex(expr, mode='plain') == r'x + y' + assert latex(expr, mode='inline') == r'$x + y$' + assert latex( + expr, mode='equation*') == r'\begin{equation*}x + y\end{equation*}' + assert latex( + expr, mode='equation') == r'\begin{equation}x + y\end{equation}' + raises(ValueError, lambda: latex(expr, mode='foo')) + + +def test_latex_mathieu(): + assert latex(mathieuc(x, y, z)) == r"C\left(x, y, z\right)" + assert latex(mathieus(x, y, z)) == r"S\left(x, y, z\right)" + assert latex(mathieuc(x, y, z)**2) == r"C\left(x, y, z\right)^{2}" + assert latex(mathieus(x, y, z)**2) == r"S\left(x, y, z\right)^{2}" + assert latex(mathieucprime(x, y, z)) == r"C^{\prime}\left(x, y, z\right)" + assert latex(mathieusprime(x, y, z)) == r"S^{\prime}\left(x, y, z\right)" + assert latex(mathieucprime(x, y, z)**2) == r"C^{\prime}\left(x, y, z\right)^{2}" + assert latex(mathieusprime(x, y, z)**2) == r"S^{\prime}\left(x, y, z\right)^{2}" + +def test_latex_Piecewise(): + p = Piecewise((x, x < 1), (x**2, True)) + assert latex(p) == r"\begin{cases} x & \text{for}\: x < 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + assert latex(p, itex=True) == \ + r"\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + p = Piecewise((x, x < 0), (0, x >= 0)) + assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \ + r' \text{otherwise} \end{cases}' + A, B = symbols("A B", commutative=False) + p = Piecewise((A**2, Eq(A, B)), (A*B, True)) + s = r"\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}" + assert latex(p) == s + assert latex(A*p) == r"A \left(%s\right)" % s + assert latex(p*A) == r"\left(%s\right) A" % s + assert latex(Piecewise((x, x < 1), (x**2, x < 2))) == \ + r'\begin{cases} x & ' \ + r'\text{for}\: x < 1 \\x^{2} & \text{for}\: x < 2 \end{cases}' + + +def test_latex_Matrix(): + M = Matrix([[1 + x, y], [y, x - 1]]) + assert latex(M) == \ + r'\left[\begin{matrix}x + 1 & y\\y & x - 1\end{matrix}\right]' + assert latex(M, mode='inline') == \ + r'$\left[\begin{smallmatrix}x + 1 & y\\' \ + r'y & x - 1\end{smallmatrix}\right]$' + assert latex(M, mat_str='array') == \ + r'\left[\begin{array}{cc}x + 1 & y\\y & x - 1\end{array}\right]' + assert latex(M, mat_str='bmatrix') == \ + r'\left[\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}\right]' + assert latex(M, mat_delim=None, mat_str='bmatrix') == \ + r'\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}' + + M2 = Matrix(1, 11, range(11)) + assert latex(M2) == \ + r'\left[\begin{array}{ccccccccccc}' \ + r'0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + + +def test_latex_matrix_with_functions(): + t = symbols('t') + theta1 = symbols('theta1', cls=Function) + + M = Matrix([[sin(theta1(t)), cos(theta1(t))], + [cos(theta1(t).diff(t)), sin(theta1(t).diff(t))]]) + + expected = (r'\left[\begin{matrix}\sin{\left(' + r'\theta_{1}{\left(t \right)} \right)} & ' + r'\cos{\left(\theta_{1}{\left(t \right)} \right)' + r'}\\\cos{\left(\frac{d}{d t} \theta_{1}{\left(t ' + r'\right)} \right)} & \sin{\left(\frac{d}{d t} ' + r'\theta_{1}{\left(t \right)} \right' + r')}\end{matrix}\right]') + + assert latex(M) == expected + + +def test_latex_NDimArray(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert latex(M) == r"x" + + M = ArrayType([[1 / x, y], [z, w]]) + M1 = ArrayType([1 / x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + assert latex(M) == \ + r'\left[\begin{matrix}\frac{1}{x} & y\\z & w\end{matrix}\right]' + assert latex(M1) == \ + r"\left[\begin{matrix}\frac{1}{x} & y & z\end{matrix}\right]" + assert latex(M2) == \ + r"\left[\begin{matrix}" \ + r"\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right]" \ + r"\end{matrix}\right]" + assert latex(M3) == \ + r"""\left[\begin{matrix}"""\ + r"""\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right]\\"""\ + r"""\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{w}{x} & w y\\w z & w^{2}\end{matrix}\right]"""\ + r"""\end{matrix}\right]""" + + Mrow = ArrayType([[x, y, 1/z]]) + Mcolumn = ArrayType([[x], [y], [1/z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + assert latex(Mrow) == \ + r"\left[\left[\begin{matrix}x & y & \frac{1}{z}\end{matrix}\right]\right]" + assert latex(Mcolumn) == \ + r"\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]" + assert latex(Mcol2) == \ + r'\left[\begin{matrix}\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]\end{matrix}\right]' + + +def test_latex_mul_symbol(): + assert latex(4*4**x, mul_symbol='times') == r"4 \times 4^{x}" + assert latex(4*4**x, mul_symbol='dot') == r"4 \cdot 4^{x}" + assert latex(4*4**x, mul_symbol='ldot') == r"4 \,.\, 4^{x}" + + assert latex(4*x, mul_symbol='times') == r"4 \times x" + assert latex(4*x, mul_symbol='dot') == r"4 \cdot x" + assert latex(4*x, mul_symbol='ldot') == r"4 \,.\, x" + + +def test_latex_issue_4381(): + y = 4*4**log(2) + assert latex(y) == r'4 \cdot 4^{\log{\left(2 \right)}}' + assert latex(1/y) == r'\frac{1}{4 \cdot 4^{\log{\left(2 \right)}}}' + + +def test_latex_issue_4576(): + assert latex(Symbol("beta_13_2")) == r"\beta_{13 2}" + assert latex(Symbol("beta_132_20")) == r"\beta_{132 20}" + assert latex(Symbol("beta_13")) == r"\beta_{13}" + assert latex(Symbol("x_a_b")) == r"x_{a b}" + assert latex(Symbol("x_1_2_3")) == r"x_{1 2 3}" + assert latex(Symbol("x_a_b1")) == r"x_{a b1}" + assert latex(Symbol("x_a_1")) == r"x_{a 1}" + assert latex(Symbol("x_1_a")) == r"x_{1 a}" + assert latex(Symbol("x_1^aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_1__aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_11^a")) == r"x^{a}_{11}" + assert latex(Symbol("x_11__a")) == r"x^{a}_{11}" + assert latex(Symbol("x_a_a_a_a")) == r"x_{a a a a}" + assert latex(Symbol("x_a_a^a^a")) == r"x^{a a}_{a a}" + assert latex(Symbol("x_a_a__a__a")) == r"x^{a a}_{a a}" + assert latex(Symbol("alpha_11")) == r"\alpha_{11}" + assert latex(Symbol("alpha_11_11")) == r"\alpha_{11 11}" + assert latex(Symbol("alpha_alpha")) == r"\alpha_{\alpha}" + assert latex(Symbol("alpha^aleph")) == r"\alpha^{\aleph}" + assert latex(Symbol("alpha__aleph")) == r"\alpha^{\aleph}" + + +def test_latex_pow_fraction(): + x = Symbol('x') + # Testing exp + assert r'e^{-x}' in latex(exp(-x)/2).replace(' ', '') # Remove Whitespace + + # Testing e^{-x} in case future changes alter behavior of muls or fracs + # In particular current output is \frac{1}{2}e^{- x} but perhaps this will + # change to \frac{e^{-x}}{2} + + # Testing general, non-exp, power + assert r'3^{-x}' in latex(3**-x/2).replace(' ', '') + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert latex(A*B*C**-1) == r"A B C^{-1}" + assert latex(C**-1*A*B) == r"C^{-1} A B" + assert latex(A*C**-1*B) == r"A C^{-1} B" + + +def test_latex_order(): + expr = x**3 + x**2*y + y**4 + 3*x*y**3 + + assert latex(expr, order='lex') == r"x^{3} + x^{2} y + 3 x y^{3} + y^{4}" + assert latex( + expr, order='rev-lex') == r"y^{4} + 3 x y^{3} + x^{2} y + x^{3}" + assert latex(expr, order='none') == r"x^{3} + y^{4} + y x^{2} + 3 x y^{3}" + + +def test_latex_Lambda(): + assert latex(Lambda(x, x + 1)) == r"\left( x \mapsto x + 1 \right)" + assert latex(Lambda((x, y), x + 1)) == r"\left( \left( x, \ y\right) \mapsto x + 1 \right)" + assert latex(Lambda(x, x)) == r"\left( x \mapsto x \right)" + +def test_latex_PolyElement(): + Ruv, u, v = ring("u,v", ZZ) + Rxyz, x, y, z = ring("x,y,z", Ruv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex((u**2 + 3*u*v + 1)*x**2*y + u + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + u + 1" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x + 1" + assert latex((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == \ + r"-\left({u}^{2} - 3 u v + 1\right) {x}^{2} y - \left(u + 1\right) x - 1" + + assert latex(-(v**2 + v + 1)*x + 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x + 3 u v + 1" + assert latex(-(v**2 + v + 1)*x - 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x - 3 u v + 1" + + +def test_latex_FracElement(): + Fuv, u, v = field("u,v", ZZ) + Fxyzt, x, y, z, t = field("x,y,z,t", Fuv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex(x/3) == r"\frac{x}{3}" + assert latex(x/z) == r"\frac{x}{z}" + assert latex(x*y/z) == r"\frac{x y}{z}" + assert latex(x/(z*t)) == r"\frac{x}{z t}" + assert latex(x*y/(z*t)) == r"\frac{x y}{z t}" + + assert latex((x - 1)/y) == r"\frac{x - 1}{y}" + assert latex((x + 1)/y) == r"\frac{x + 1}{y}" + assert latex((-x - 1)/y) == r"\frac{-x - 1}{y}" + assert latex((x + 1)/(y*z)) == r"\frac{x + 1}{y z}" + assert latex(-y/(x + 1)) == r"\frac{-y}{x + 1}" + assert latex(y*z/(x + 1)) == r"\frac{y z}{x + 1}" + + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - 1}" + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - u v t - 1}" + + +def test_latex_Poly(): + assert latex(Poly(x**2 + 2 * x, x)) == \ + r"\operatorname{Poly}{\left( x^{2} + 2 x, x, domain=\mathbb{Z} \right)}" + assert latex(Poly(x/y, x)) == \ + r"\operatorname{Poly}{\left( \frac{1}{y} x, x, domain=\mathbb{Z}\left(y\right) \right)}" + assert latex(Poly(2.0*x + y)) == \ + r"\operatorname{Poly}{\left( 2.0 x + 1.0 y, x, y, domain=\mathbb{R} \right)}" + + +def test_latex_Poly_order(): + assert latex(Poly([a, 1, b, 2, c, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{5} + x^{4} + b x^{3} + 2 x^{2} + c'\ + r' x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly([a, 1, b+c, 2, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{4} + x^{3} + \left(b + c\right) '\ + r'x^{2} + 2 x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly(a*x**3 + x**2*y - x*y - c*y**3 - b*x*y**2 + y - a*x + b, + (x, y))) == \ + r'\operatorname{Poly}{\left( a x^{3} + x^{2}y - b xy^{2} - xy - '\ + r'a x - c y^{3} + y + b, x, y, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + + +def test_latex_ComplexRootOf(): + assert latex(rootof(x**5 + x + 3, 0)) == \ + r"\operatorname{CRootOf} {\left(x^{5} + x + 3, 0\right)}" + + +def test_latex_RootSum(): + assert latex(RootSum(x**5 + x + 3, sin)) == \ + r"\operatorname{RootSum} {\left(x^{5} + x + 3, \left( x \mapsto \sin{\left(x \right)} \right)\right)}" + + +def test_settings(): + raises(TypeError, lambda: latex(x*y, method="garbage")) + + +def test_latex_numbers(): + assert latex(catalan(n)) == r"C_{n}" + assert latex(catalan(n)**2) == r"C_{n}^{2}" + assert latex(bernoulli(n)) == r"B_{n}" + assert latex(bernoulli(n, x)) == r"B_{n}\left(x\right)" + assert latex(bernoulli(n)**2) == r"B_{n}^{2}" + assert latex(bernoulli(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(genocchi(n)) == r"G_{n}" + assert latex(genocchi(n, x)) == r"G_{n}\left(x\right)" + assert latex(genocchi(n)**2) == r"G_{n}^{2}" + assert latex(genocchi(n, x)**2) == r"G_{n}^{2}\left(x\right)" + assert latex(bell(n)) == r"B_{n}" + assert latex(bell(n, x)) == r"B_{n}\left(x\right)" + assert latex(bell(n, m, (x, y))) == r"B_{n, m}\left(x, y\right)" + assert latex(bell(n)**2) == r"B_{n}^{2}" + assert latex(bell(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(bell(n, m, (x, y))**2) == r"B_{n, m}^{2}\left(x, y\right)" + assert latex(fibonacci(n)) == r"F_{n}" + assert latex(fibonacci(n, x)) == r"F_{n}\left(x\right)" + assert latex(fibonacci(n)**2) == r"F_{n}^{2}" + assert latex(fibonacci(n, x)**2) == r"F_{n}^{2}\left(x\right)" + assert latex(lucas(n)) == r"L_{n}" + assert latex(lucas(n)**2) == r"L_{n}^{2}" + assert latex(tribonacci(n)) == r"T_{n}" + assert latex(tribonacci(n, x)) == r"T_{n}\left(x\right)" + assert latex(tribonacci(n)**2) == r"T_{n}^{2}" + assert latex(tribonacci(n, x)**2) == r"T_{n}^{2}\left(x\right)" + assert latex(mobius(n)) == r"\mu\left(n\right)" + assert latex(mobius(n)**2) == r"\mu^{2}\left(n\right)" + + +def test_latex_euler(): + assert latex(euler(n)) == r"E_{n}" + assert latex(euler(n, x)) == r"E_{n}\left(x\right)" + assert latex(euler(n, x)**2) == r"E_{n}^{2}\left(x\right)" + + +def test_lamda(): + assert latex(Symbol('lamda')) == r"\lambda" + assert latex(Symbol('Lamda')) == r"\Lambda" + + +def test_custom_symbol_names(): + x = Symbol('x') + y = Symbol('y') + assert latex(x) == r"x" + assert latex(x, symbol_names={x: "x_i"}) == r"x_i" + assert latex(x + y, symbol_names={x: "x_i"}) == r"x_i + y" + assert latex(x**2, symbol_names={x: "x_i"}) == r"x_i^{2}" + assert latex(x + y, symbol_names={x: "x_i", y: "y_j"}) == r"x_i + y_j" + + +def test_matAdd(): + C = MatrixSymbol('C', 5, 5) + B = MatrixSymbol('B', 5, 5) + + n = symbols("n") + h = MatrixSymbol("h", 1, 1) + + assert latex(C - 2*B) in [r'- 2 B + C', r'C -2 B'] + assert latex(C + 2*B) in [r'2 B + C', r'C + 2 B'] + assert latex(B - 2*C) in [r'B - 2 C', r'- 2 C + B'] + assert latex(B + 2*C) in [r'B + 2 C', r'2 C + B'] + + assert latex(n * h - (-h + h.T) * (h + h.T)) == 'n h - \\left(- h + h^{T}\\right) \\left(h + h^{T}\\right)' + assert latex(MatAdd(MatAdd(h, h), MatAdd(h, h))) == '\\left(h + h\\right) + \\left(h + h\\right)' + assert latex(MatMul(MatMul(h, h), MatMul(h, h))) == '\\left(h h\\right) \\left(h h\\right)' + + +def test_matMul(): + A = MatrixSymbol('A', 5, 5) + B = MatrixSymbol('B', 5, 5) + x = Symbol('x') + assert latex(2*A) == r'2 A' + assert latex(2*x*A) == r'2 x A' + assert latex(-2*A) == r'- 2 A' + assert latex(1.5*A) == r'1.5 A' + assert latex(sqrt(2)*A) == r'\sqrt{2} A' + assert latex(-sqrt(2)*A) == r'- \sqrt{2} A' + assert latex(2*sqrt(2)*x*A) == r'2 \sqrt{2} x A' + assert latex(-2*A*(A + 2*B)) in [r'- 2 A \left(A + 2 B\right)', + r'- 2 A \left(2 B + A\right)'] + + +def test_latex_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert latex(MatrixSlice(X, (None, None, None), (None, None, None))) == r'X\left[:, :\right]' + assert latex(X[x:x + 1, y:y + 1]) == r'X\left[x:x + 1, y:y + 1\right]' + assert latex(X[x:x + 1:2, y:y + 1:2]) == r'X\left[x:x + 1:2, y:y + 1:2\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[x:, :y]) == r'X\left[x:, :y\right]' + assert latex(X[x:y, z:w]) == r'X\left[x:y, z:w\right]' + assert latex(X[x:y:t, w:t:x]) == r'X\left[x:y:t, w:t:x\right]' + assert latex(X[x::y, t::w]) == r'X\left[x::y, t::w\right]' + assert latex(X[:x:y, :t:w]) == r'X\left[:x:y, :t:w\right]' + assert latex(X[::x, ::y]) == r'X\left[::x, ::y\right]' + assert latex(MatrixSlice(X, (0, None, None), (0, None, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (None, n, None), (None, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, None), (0, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, 2), (0, n, 2))) == r'X\left[::2, ::2\right]' + assert latex(X[1:2:3, 4:5:6]) == r'X\left[1:2:3, 4:5:6\right]' + assert latex(X[1:3:5, 4:6:8]) == r'X\left[1:3:5, 4:6:8\right]' + assert latex(X[1:10:2]) == r'X\left[1:10:2, :\right]' + assert latex(Y[:5, 1:9:2]) == r'Y\left[:5, 1:9:2\right]' + assert latex(Y[:5, 1:10:2]) == r'Y\left[:5, 1::2\right]' + assert latex(Y[5, :5:2]) == r'Y\left[5:6, :5:2\right]' + assert latex(X[0:1, 0:1]) == r'X\left[:1, :1\right]' + assert latex(X[0:1:2, 0:1:2]) == r'X\left[:1:2, :1:2\right]' + assert latex((Y + Z)[2:, 2:]) == r'\left(Y + Z\right)\left[2:, 2:\right]' + + +def test_latex_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + from sympy.stats.rv import RandomDomain + + X = Normal('x1', 0, 1) + assert latex(where(X > 0)) == r"\text{Domain: }0 < x_{1} \wedge x_{1} < \infty" + + D = Die('d1', 6) + assert latex(where(D > 4)) == r"\text{Domain: }d_{1} = 5 \vee d_{1} = 6" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert latex( + pspace(Tuple(A, B)).domain) == \ + r"\text{Domain: }0 \leq a \wedge 0 \leq b \wedge a < \infty \wedge b < \infty" + + assert latex(RandomDomain(FiniteSet(x), FiniteSet(1, 2))) == \ + r'\text{Domain: }\left\{x\right\} \in \left\{1, 2\right\}' + +def test_PrettyPoly(): + from sympy.polys.domains import QQ + F = QQ.frac_field(x, y) + R = QQ[x, y] + + assert latex(F.convert(x/(x + y))) == latex(x/(x + y)) + assert latex(R.convert(x + y)) == latex(x + y) + + +def test_integral_transforms(): + x = Symbol("x") + k = Symbol("k") + f = Function("f") + a = Symbol("a") + b = Symbol("b") + + assert latex(MellinTransform(f(x), x, k)) == \ + r"\mathcal{M}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseMellinTransform(f(k), k, x, a, b)) == \ + r"\mathcal{M}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(LaplaceTransform(f(x), x, k)) == \ + r"\mathcal{L}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseLaplaceTransform(f(k), k, x, (a, b))) == \ + r"\mathcal{L}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(FourierTransform(f(x), x, k)) == \ + r"\mathcal{F}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseFourierTransform(f(k), k, x)) == \ + r"\mathcal{F}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(CosineTransform(f(x), x, k)) == \ + r"\mathcal{COS}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseCosineTransform(f(k), k, x)) == \ + r"\mathcal{COS}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(SineTransform(f(x), x, k)) == \ + r"\mathcal{SIN}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseSineTransform(f(k), k, x)) == \ + r"\mathcal{SIN}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + +def test_PolynomialRingBase(): + from sympy.polys.domains import QQ + assert latex(QQ.old_poly_ring(x, y)) == r"\mathbb{Q}\left[x, y\right]" + assert latex(QQ.old_poly_ring(x, y, order="ilex")) == \ + r"S_<^{-1}\mathbb{Q}\left[x, y\right]" + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, + DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert latex(A1) == r"A_{1}" + assert latex(f1) == r"f_{1}:A_{1}\rightarrow A_{2}" + assert latex(id_A1) == r"id:A_{1}\rightarrow A_{1}" + assert latex(f2*f1) == r"f_{2}\circ f_{1}:A_{1}\rightarrow A_{3}" + + assert latex(K1) == r"\mathbf{K_{1}}" + + d = Diagram() + assert latex(d) == r"\emptyset" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}, " \ + r"\ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}," \ + r" \ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" \ + r"\Longrightarrow \left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \left\{unique\right\}\right\}" + + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert latex(grid) == r"\begin{array}{cc}" + "\n" \ + r"A & B \\" + "\n" \ + r" & C " + "\n" \ + r"\end{array}" + "\n" + + +def test_Modules(): + from sympy.polys.domains import QQ + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + assert latex(F) == r"{\mathbb{Q}\left[x, y\right]}^{2}" + assert latex(M) == \ + r"\left\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle" + + I = R.ideal(x**2, y) + assert latex(I) == r"\left\langle {x^{2}},{y} \right\rangle" + + Q = F / M + assert latex(Q) == \ + r"\frac{{\mathbb{Q}\left[x, y\right]}^{2}}{\left\langle {\left[ {x},"\ + r"{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}" + assert latex(Q.submodule([1, x**3/2], [2, y])) == \ + r"\left\langle {{\left[ {1},{\frac{x^{3}}{2}} \right]} + {\left"\ + r"\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} "\ + r"\right\rangle}},{{\left[ {2},{y} \right]} + {\left\langle {\left[ "\ + r"{x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}} \right\rangle" + + h = homomorphism(QQ.old_poly_ring(x).free_module(2), + QQ.old_poly_ring(x).free_module(2), [0, 0]) + + assert latex(h) == \ + r"{\left[\begin{matrix}0 & 0\\0 & 0\end{matrix}\right]} : "\ + r"{{\mathbb{Q}\left[x\right]}^{2}} \to {{\mathbb{Q}\left[x\right]}^{2}}" + + +def test_QuotientRing(): + from sympy.polys.domains import QQ + R = QQ.old_poly_ring(x)/[x**2 + 1] + + assert latex(R) == \ + r"\frac{\mathbb{Q}\left[x\right]}{\left\langle {x^{2} + 1} \right\rangle}" + assert latex(R.one) == r"{1} + {\left\langle {x^{2} + 1} \right\rangle}" + + +def test_Tr(): + #TODO: Handle indices + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert latex(t) == r'\operatorname{tr}\left(A B\right)' + + +def test_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert latex(Determinant(m)) == '\\left|{\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}}\\right|' + assert latex(Determinant(Inverse(m))) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{-1}}\\right|' + X = MatrixSymbol('X', 2, 2) + assert latex(Determinant(X)) == '\\left|{X}\\right|' + assert latex(Determinant(X + m)) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X}\\right|' + assert latex(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left|{\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}}\\right|' + + +def test_Adjoint(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Adjoint(X)) == r'X^{\dagger}' + assert latex(Adjoint(X + Y)) == r'\left(X + Y\right)^{\dagger}' + assert latex(Adjoint(X) + Adjoint(Y)) == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(X*Y)) == r'\left(X Y\right)^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2)) == r'\left(X^{2}\right)^{\dagger}' + assert latex(Adjoint(X)**2) == r'\left(X^{\dagger}\right)^{2}' + assert latex(Adjoint(Inverse(X))) == r'\left(X^{-1}\right)^{\dagger}' + assert latex(Inverse(Adjoint(X))) == r'\left(X^{\dagger}\right)^{-1}' + assert latex(Adjoint(Transpose(X))) == r'\left(X^{T}\right)^{\dagger}' + assert latex(Transpose(Adjoint(X))) == r'\left(X^{\dagger}\right)^{T}' + assert latex(Transpose(Adjoint(X) + Y)) == r'\left(X^{\dagger} + Y\right)^{T}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Adjoint(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{\\dagger}' + assert latex(Adjoint(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{\\dagger}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{\\dagger}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Adjoint(Mx)) == r'\left(M^{x}\right)^{\dagger}' + + # adjoint style + assert latex(Adjoint(X), adjoint_style="star") == r'X^{\ast}' + assert latex(Adjoint(X + Y), adjoint_style="hermitian") == r'\left(X + Y\right)^{\mathsf{H}}' + assert latex(Adjoint(X) + Adjoint(Y), adjoint_style="dagger") == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2), adjoint_style="star") == r'\left(X^{2}\right)^{\ast}' + assert latex(Adjoint(X)**2, adjoint_style="hermitian") == r'\left(X^{\mathsf{H}}\right)^{2}' + +def test_Transpose(): + from sympy.matrices import Transpose, MatPow, HadamardPower + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Transpose(X)) == r'X^{T}' + assert latex(Transpose(X + Y)) == r'\left(X + Y\right)^{T}' + + assert latex(Transpose(HadamardPower(X, 2))) == r'\left(X^{\circ {2}}\right)^{T}' + assert latex(HadamardPower(Transpose(X), 2)) == r'\left(X^{T}\right)^{\circ {2}}' + assert latex(Transpose(MatPow(X, 2))) == r'\left(X^{2}\right)^{T}' + assert latex(MatPow(Transpose(X), 2)) == r'\left(X^{T}\right)^{2}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Transpose(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{T}' + assert latex(Transpose(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{T}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{T}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Transpose(Mx)) == r'\left(M^{x}\right)^{T}' + + +def test_Hadamard(): + from sympy.matrices import HadamardProduct, HadamardPower + from sympy.matrices.expressions import MatAdd, MatMul, MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(HadamardProduct(X, Y*Y)) == r'X \circ Y^{2}' + assert latex(HadamardProduct(X, Y)*Y) == r'\left(X \circ Y\right) Y' + + assert latex(HadamardPower(X, 2)) == r'X^{\circ {2}}' + assert latex(HadamardPower(X, -1)) == r'X^{\circ \left({-1}\right)}' + assert latex(HadamardPower(MatAdd(X, Y), 2)) == \ + r'\left(X + Y\right)^{\circ {2}}' + assert latex(HadamardPower(MatMul(X, Y), 2)) == \ + r'\left(X Y\right)^{\circ {2}}' + + assert latex(HadamardPower(MatPow(X, -1), -1)) == \ + r'\left(X^{-1}\right)^{\circ \left({-1}\right)}' + assert latex(MatPow(HadamardPower(X, -1), -1)) == \ + r'\left(X^{\circ \left({-1}\right)}\right)^{-1}' + + assert latex(HadamardPower(X, n+1)) == \ + r'X^{\circ \left({n + 1}\right)}' + + +def test_MatPow(): + from sympy.matrices.expressions import MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(MatPow(X, 2)) == 'X^{2}' + assert latex(MatPow(X*X, 2)) == '\\left(X^{2}\\right)^{2}' + assert latex(MatPow(X*Y, 2)) == '\\left(X Y\\right)^{2}' + assert latex(MatPow(X + Y, 2)) == '\\left(X + Y\\right)^{2}' + assert latex(MatPow(X + X, 2)) == '\\left(2 X\\right)^{2}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(MatPow(Mx, 2)) == r'\left(M^{x}\right)^{2}' + + +def test_ElementwiseApplyFunction(): + X = MatrixSymbol('X', 2, 2) + expr = (X.T*X).applyfunc(sin) + assert latex(expr) == r"{\left( d \mapsto \sin{\left(d \right)} \right)}_{\circ}\left({X^{T} X}\right)" + expr = X.applyfunc(Lambda(x, 1/x)) + assert latex(expr) == r'{\left( x \mapsto \frac{1}{x} \right)}_{\circ}\left({X}\right)' + + +def test_ZeroMatrix(): + from sympy.matrices.expressions.special import ZeroMatrix + assert latex(ZeroMatrix(1, 1), mat_symbol_style='plain') == r"0" + assert latex(ZeroMatrix(1, 1), mat_symbol_style='bold') == r"\mathbf{0}" + + +def test_OneMatrix(): + from sympy.matrices.expressions.special import OneMatrix + assert latex(OneMatrix(3, 4), mat_symbol_style='plain') == r"1" + assert latex(OneMatrix(3, 4), mat_symbol_style='bold') == r"\mathbf{1}" + + +def test_Identity(): + from sympy.matrices.expressions.special import Identity + assert latex(Identity(1), mat_symbol_style='plain') == r"\mathbb{I}" + assert latex(Identity(1), mat_symbol_style='bold') == r"\mathbf{I}" + + +def test_latex_DFT_IDFT(): + from sympy.matrices.expressions.fourier import DFT, IDFT + assert latex(DFT(13)) == r"\text{DFT}_{13}" + assert latex(IDFT(x)) == r"\text{IDFT}_{x}" + + +def test_boolean_args_order(): + syms = symbols('a:f') + + expr = And(*syms) + assert latex(expr) == r'a \wedge b \wedge c \wedge d \wedge e \wedge f' + + expr = Or(*syms) + assert latex(expr) == r'a \vee b \vee c \vee d \vee e \vee f' + + expr = Equivalent(*syms) + assert latex(expr) == \ + r'a \Leftrightarrow b \Leftrightarrow c \Leftrightarrow d \Leftrightarrow e \Leftrightarrow f' + + expr = Xor(*syms) + assert latex(expr) == \ + r'a \veebar b \veebar c \veebar d \veebar e \veebar f' + + +def test_imaginary(): + i = sqrt(-1) + assert latex(i) == r'i' + + +def test_builtins_without_args(): + assert latex(sin) == r'\sin' + assert latex(cos) == r'\cos' + assert latex(tan) == r'\tan' + assert latex(log) == r'\log' + assert latex(Ei) == r'\operatorname{Ei}' + assert latex(zeta) == r'\zeta' + + +def test_latex_greek_functions(): + # bug because capital greeks that have roman equivalents should not use + # \Alpha, \Beta, \Eta, etc. + s = Function('Alpha') + assert latex(s) == r'\mathrm{A}' + assert latex(s(x)) == r'\mathrm{A}{\left(x \right)}' + s = Function('Beta') + assert latex(s) == r'\mathrm{B}' + s = Function('Eta') + assert latex(s) == r'\mathrm{H}' + assert latex(s(x)) == r'\mathrm{H}{\left(x \right)}' + + # bug because sympy.core.numbers.Pi is special + p = Function('Pi') + # assert latex(p(x)) == r'\Pi{\left(x \right)}' + assert latex(p) == r'\Pi' + + # bug because not all greeks are included + c = Function('chi') + assert latex(c(x)) == r'\chi{\left(x \right)}' + assert latex(c) == r'\chi' + + +def test_translate(): + s = 'Alpha' + assert translate(s) == r'\mathrm{A}' + s = 'Beta' + assert translate(s) == r'\mathrm{B}' + s = 'Eta' + assert translate(s) == r'\mathrm{H}' + s = 'omicron' + assert translate(s) == r'o' + s = 'Pi' + assert translate(s) == r'\Pi' + s = 'pi' + assert translate(s) == r'\pi' + s = 'LamdaHatDOT' + assert translate(s) == r'\dot{\hat{\Lambda}}' + + +def test_other_symbols(): + from sympy.printing.latex import other_symbols + for s in other_symbols: + assert latex(symbols(s)) == r"" "\\" + s + + +def test_modifiers(): + # Test each modifier individually in the simplest case + # (with funny capitalizations) + assert latex(symbols("xMathring")) == r"\mathring{x}" + assert latex(symbols("xCheck")) == r"\check{x}" + assert latex(symbols("xBreve")) == r"\breve{x}" + assert latex(symbols("xAcute")) == r"\acute{x}" + assert latex(symbols("xGrave")) == r"\grave{x}" + assert latex(symbols("xTilde")) == r"\tilde{x}" + assert latex(symbols("xPrime")) == r"{x}'" + assert latex(symbols("xddDDot")) == r"\ddddot{x}" + assert latex(symbols("xDdDot")) == r"\dddot{x}" + assert latex(symbols("xDDot")) == r"\ddot{x}" + assert latex(symbols("xBold")) == r"\boldsymbol{x}" + assert latex(symbols("xnOrM")) == r"\left\|{x}\right\|" + assert latex(symbols("xAVG")) == r"\left\langle{x}\right\rangle" + assert latex(symbols("xHat")) == r"\hat{x}" + assert latex(symbols("xDot")) == r"\dot{x}" + assert latex(symbols("xBar")) == r"\bar{x}" + assert latex(symbols("xVec")) == r"\vec{x}" + assert latex(symbols("xAbs")) == r"\left|{x}\right|" + assert latex(symbols("xMag")) == r"\left|{x}\right|" + assert latex(symbols("xPrM")) == r"{x}'" + assert latex(symbols("xBM")) == r"\boldsymbol{x}" + # Test strings that are *only* the names of modifiers + assert latex(symbols("Mathring")) == r"Mathring" + assert latex(symbols("Check")) == r"Check" + assert latex(symbols("Breve")) == r"Breve" + assert latex(symbols("Acute")) == r"Acute" + assert latex(symbols("Grave")) == r"Grave" + assert latex(symbols("Tilde")) == r"Tilde" + assert latex(symbols("Prime")) == r"Prime" + assert latex(symbols("DDot")) == r"\dot{D}" + assert latex(symbols("Bold")) == r"Bold" + assert latex(symbols("NORm")) == r"NORm" + assert latex(symbols("AVG")) == r"AVG" + assert latex(symbols("Hat")) == r"Hat" + assert latex(symbols("Dot")) == r"Dot" + assert latex(symbols("Bar")) == r"Bar" + assert latex(symbols("Vec")) == r"Vec" + assert latex(symbols("Abs")) == r"Abs" + assert latex(symbols("Mag")) == r"Mag" + assert latex(symbols("PrM")) == r"PrM" + assert latex(symbols("BM")) == r"BM" + assert latex(symbols("hbar")) == r"\hbar" + # Check a few combinations + assert latex(symbols("xvecdot")) == r"\dot{\vec{x}}" + assert latex(symbols("xDotVec")) == r"\vec{\dot{x}}" + assert latex(symbols("xHATNorm")) == r"\left\|{\hat{x}}\right\|" + # Check a couple big, ugly combinations + assert latex(symbols('xMathringBm_yCheckPRM__zbreveAbs')) == \ + r"\boldsymbol{\mathring{x}}^{\left|{\breve{z}}\right|}_{{\check{y}}'}" + assert latex(symbols('alphadothat_nVECDOT__tTildePrime')) == \ + r"\hat{\dot{\alpha}}^{{\tilde{t}}'}_{\dot{\vec{n}}}" + + +def test_greek_symbols(): + assert latex(Symbol('alpha')) == r'\alpha' + assert latex(Symbol('beta')) == r'\beta' + assert latex(Symbol('gamma')) == r'\gamma' + assert latex(Symbol('delta')) == r'\delta' + assert latex(Symbol('epsilon')) == r'\epsilon' + assert latex(Symbol('zeta')) == r'\zeta' + assert latex(Symbol('eta')) == r'\eta' + assert latex(Symbol('theta')) == r'\theta' + assert latex(Symbol('iota')) == r'\iota' + assert latex(Symbol('kappa')) == r'\kappa' + assert latex(Symbol('lambda')) == r'\lambda' + assert latex(Symbol('mu')) == r'\mu' + assert latex(Symbol('nu')) == r'\nu' + assert latex(Symbol('xi')) == r'\xi' + assert latex(Symbol('omicron')) == r'o' + assert latex(Symbol('pi')) == r'\pi' + assert latex(Symbol('rho')) == r'\rho' + assert latex(Symbol('sigma')) == r'\sigma' + assert latex(Symbol('tau')) == r'\tau' + assert latex(Symbol('upsilon')) == r'\upsilon' + assert latex(Symbol('phi')) == r'\phi' + assert latex(Symbol('chi')) == r'\chi' + assert latex(Symbol('psi')) == r'\psi' + assert latex(Symbol('omega')) == r'\omega' + + assert latex(Symbol('Alpha')) == r'\mathrm{A}' + assert latex(Symbol('Beta')) == r'\mathrm{B}' + assert latex(Symbol('Gamma')) == r'\Gamma' + assert latex(Symbol('Delta')) == r'\Delta' + assert latex(Symbol('Epsilon')) == r'\mathrm{E}' + assert latex(Symbol('Zeta')) == r'\mathrm{Z}' + assert latex(Symbol('Eta')) == r'\mathrm{H}' + assert latex(Symbol('Theta')) == r'\Theta' + assert latex(Symbol('Iota')) == r'\mathrm{I}' + assert latex(Symbol('Kappa')) == r'\mathrm{K}' + assert latex(Symbol('Lambda')) == r'\Lambda' + assert latex(Symbol('Mu')) == r'\mathrm{M}' + assert latex(Symbol('Nu')) == r'\mathrm{N}' + assert latex(Symbol('Xi')) == r'\Xi' + assert latex(Symbol('Omicron')) == r'\mathrm{O}' + assert latex(Symbol('Pi')) == r'\Pi' + assert latex(Symbol('Rho')) == r'\mathrm{P}' + assert latex(Symbol('Sigma')) == r'\Sigma' + assert latex(Symbol('Tau')) == r'\mathrm{T}' + assert latex(Symbol('Upsilon')) == r'\Upsilon' + assert latex(Symbol('Phi')) == r'\Phi' + assert latex(Symbol('Chi')) == r'\mathrm{X}' + assert latex(Symbol('Psi')) == r'\Psi' + assert latex(Symbol('Omega')) == r'\Omega' + + assert latex(Symbol('varepsilon')) == r'\varepsilon' + assert latex(Symbol('varkappa')) == r'\varkappa' + assert latex(Symbol('varphi')) == r'\varphi' + assert latex(Symbol('varpi')) == r'\varpi' + assert latex(Symbol('varrho')) == r'\varrho' + assert latex(Symbol('varsigma')) == r'\varsigma' + assert latex(Symbol('vartheta')) == r'\vartheta' + + +def test_fancyset_symbols(): + assert latex(S.Rationals) == r'\mathbb{Q}' + assert latex(S.Naturals) == r'\mathbb{N}' + assert latex(S.Naturals0) == r'\mathbb{N}_0' + assert latex(S.Integers) == r'\mathbb{Z}' + assert latex(S.Reals) == r'\mathbb{R}' + assert latex(S.Complexes) == r'\mathbb{C}' + + +@XFAIL +def test_builtin_without_args_mismatched_names(): + assert latex(CosineTransform) == r'\mathcal{COS}' + + +def test_builtin_no_args(): + assert latex(Chi) == r'\operatorname{Chi}' + assert latex(beta) == r'\operatorname{B}' + assert latex(gamma) == r'\Gamma' + assert latex(KroneckerDelta) == r'\delta' + assert latex(DiracDelta) == r'\delta' + assert latex(lowergamma) == r'\gamma' + + +def test_issue_6853(): + p = Function('Pi') + assert latex(p(x)) == r"\Pi{\left(x \right)}" + + +def test_Mul(): + e = Mul(-2, x + 1, evaluate=False) + assert latex(e) == r'- 2 \left(x + 1\right)' + e = Mul(2, x + 1, evaluate=False) + assert latex(e) == r'2 \left(x + 1\right)' + e = Mul(S.Half, x + 1, evaluate=False) + assert latex(e) == r'\frac{x + 1}{2}' + e = Mul(y, x + 1, evaluate=False) + assert latex(e) == r'y \left(x + 1\right)' + e = Mul(-y, x + 1, evaluate=False) + assert latex(e) == r'- y \left(x + 1\right)' + e = Mul(-2, x + 1) + assert latex(e) == r'- 2 x - 2' + e = Mul(2, x + 1) + assert latex(e) == r'2 x + 2' + + +def test_Pow(): + e = Pow(2, 2, evaluate=False) + assert latex(e) == r'2^{2}' + assert latex(x**(Rational(-1, 3))) == r'\frac{1}{\sqrt[3]{x}}' + x2 = Symbol(r'x^2') + assert latex(x2**2) == r'\left(x^{2}\right)^{2}' + # Issue 11011 + assert latex(S('1.453e4500')**x) == r'{1.453 \cdot 10^{4500}}^{x}' + + +def test_issue_7180(): + assert latex(Equivalent(x, y)) == r"x \Leftrightarrow y" + assert latex(Not(Equivalent(x, y))) == r"x \not\Leftrightarrow y" + + +def test_issue_8409(): + assert latex(S.Half**n) == r"\left(\frac{1}{2}\right)^{n}" + + +def test_issue_8470(): + from sympy.parsing.sympy_parser import parse_expr + e = parse_expr("-B*A", evaluate=False) + assert latex(e) == r"A \left(- B\right)" + + +def test_issue_15439(): + x = MatrixSymbol('x', 2, 2) + y = MatrixSymbol('y', 2, 2) + assert latex((x * y).subs(y, -y)) == r"x \left(- y\right)" + assert latex((x * y).subs(y, -2*y)) == r"x \left(- 2 y\right)" + assert latex((x * y).subs(x, -x)) == r"\left(- x\right) y" + + +def test_issue_2934(): + assert latex(Symbol(r'\frac{a_1}{b_1}')) == r'\frac{a_1}{b_1}' + + +def test_issue_10489(): + latexSymbolWithBrace = r'C_{x_{0}}' + s = Symbol(latexSymbolWithBrace) + assert latex(s) == latexSymbolWithBrace + assert latex(cos(s)) == r'\cos{\left(C_{x_{0}} \right)}' + + +def test_issue_12886(): + m__1, l__1 = symbols('m__1, l__1') + assert latex(m__1**2 + l__1**2) == \ + r'\left(l^{1}\right)^{2} + \left(m^{1}\right)^{2}' + + +def test_issue_13559(): + from sympy.parsing.sympy_parser import parse_expr + expr = parse_expr('5/1', evaluate=False) + assert latex(expr) == r"\frac{5}{1}" + + +def test_issue_13651(): + expr = c + Mul(-1, a + b, evaluate=False) + assert latex(expr) == r"c - \left(a + b\right)" + + +def test_latex_UnevaluatedExpr(): + x = symbols("x") + he = UnevaluatedExpr(1/x) + assert latex(he) == latex(1/x) == r"\frac{1}{x}" + assert latex(he**2) == r"\left(\frac{1}{x}\right)^{2}" + assert latex(he + 1) == r"1 + \frac{1}{x}" + assert latex(x*he) == r"x \frac{1}{x}" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert latex(A[0, 0]) == r"{A}_{0,0}" + assert latex(3 * A[0, 0]) == r"3 {A}_{0,0}" + + F = C[0, 0].subs(C, A - B) + assert latex(F) == r"{\left(A - B\right)}_{0,0}" + + i, j, k = symbols("i j k") + M = MatrixSymbol("M", k, k) + N = MatrixSymbol("N", k, k) + assert latex((M*N)[i, j]) == \ + r'\sum_{i_{1}=0}^{k - 1} {M}_{i,i_{1}} {N}_{i_{1},j}' + + X_a = MatrixSymbol('X_a', 3, 3) + assert latex(X_a[0, 0]) == r"{X_{a}}_{0,0}" + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A) == r"- A" + assert latex(A - A*B - B) == r"A - A B - B" + assert latex(-A*B - A*B*C - B) == r"- A B - A B C - B" + + +def test_DotProduct_printing(): + X = MatrixSymbol('X', 3, 1) + Y = MatrixSymbol('Y', 3, 1) + a = Symbol('a') + assert latex(DotProduct(X, Y)) == r"X \cdot Y" + assert latex(DotProduct(a * X, Y)) == r"a X \cdot Y" + assert latex(a * DotProduct(X, Y)) == r"a \left(X \cdot Y\right)" + + +def test_KroneckerProduct_printing(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 2, 2) + assert latex(KroneckerProduct(A, B)) == r'A \otimes B' + + +def test_Series_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert latex(Series(tf1, tf2)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right)' + assert latex(Series(tf1, tf2, tf3)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right) \left(\frac{t x^{2} - t^{w} x + w}{t - y}\right)' + assert latex(Series(-tf2, tf1)) == \ + r'\left(\frac{- x + y}{x + y}\right) \left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right)' + + M_1 = Matrix([[5/s], [5/(2*s)]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5, 6*s**3]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + # Brackets + assert latex(T_1*(T_2 + T_2)) == \ + r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left(\left[\begin{matrix}\frac{5}{1} &' \ + r' \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau\right)' \ + == latex(MIMOSeries(MIMOParallel(T_2, T_2), T_1)) + # No Brackets + M_3 = Matrix([[5, 6], [6, 5/s]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1*T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}' \ + r'\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & ' \ + r'\frac{5}{s}\end{matrix}\right]_\tau' == latex(MIMOParallel(MIMOSeries(T_2, T_1), T_3)) + + +def test_TransferFunction_printing(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert latex(tf1) == r"\frac{x - 1}{x + 1}" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert latex(tf2) == r"\frac{x + 1}{2 - y}" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert latex(tf3) == r"\frac{y}{y^{2} + 2 y + 3}" + + +def test_Parallel_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + assert latex(Parallel(tf1, tf2)) == \ + r'\frac{x y^{2} - z}{- t^{3} + y^{3}} + \frac{x - y}{x + y}' + assert latex(Parallel(-tf2, tf1)) == \ + r'\frac{- x + y}{x + y} + \frac{x y^{2} - z}{- t^{3} + y^{3}}' + + M_1 = Matrix([[5, 6], [6, 5/s]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5/s, 6], [6, 5/(s - 1)]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + M_3 = Matrix([[6, 5/(s*(s - 1))], [5, 6]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1 + T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s}\end{matrix}\right]' \ + r'_\tau + \left[\begin{matrix}\frac{5}{s} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s - 1}\end{matrix}\right]_\tau + \left[\begin{matrix}' \ + r'\frac{6}{1} & \frac{5}{s \left(s - 1\right)}\\\frac{5}{1} & \frac{6}{1}\end{matrix}\right]_\tau' \ + == latex(MIMOParallel(T_1, T_2, T_3)) == latex(MIMOParallel(T_1, MIMOParallel(T_2, T_3))) == latex(MIMOParallel(MIMOParallel(T_1, T_2), T_3)) + + +def test_TransferFunctionMatrix_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + tf3 = TransferFunction(p, y**2 + 2*y + 3, p) + assert latex(TransferFunctionMatrix([[tf1], [tf2]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x}\\\frac{p - s}{p + s}\end{matrix}\right]_\tau' + assert latex(TransferFunctionMatrix([[tf1, tf2], [tf3, -tf1]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x} & \frac{p - s}{p + s}\\\frac{p}{y^{2} + 2 y + 3} & \frac{\left(-1\right) p}{p + x}\end{matrix}\right]_\tau' + + +def test_Feedback_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + # Negative Feedback (Default) + assert latex(Feedback(tf1, tf2)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, TransferFunction(1, 1, p))) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + # Positive Feedback + assert latex(Feedback(tf1, tf2, 1)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, sign=1)) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + + +def test_MIMOFeedback_printing(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**2 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s**2, s**2 - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf4, tf3], [tf2, tf1]]) + + # Negative Feedback (Default) + assert latex(MIMOFeedback(tfm_1, tfm_2)) == \ + r'\left(I_{\tau} + \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[' \ + r'\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}' \ + r'\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau' + + # Positive Feedback + assert latex(MIMOFeedback(tfm_1*tfm_2, tfm_1, 1)) == \ + r'\left(I_{\tau} - \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left' \ + r'[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1}' \ + r' & \frac{1}{s}\end{matrix}\right]_\tau' + + +def test_Quaternion_latex_printing(): + q = Quaternion(x, y, z, t) + assert latex(q) == r"x + y i + z j + t k" + q = Quaternion(x, y, z, x*t) + assert latex(q) == r"x + y i + z j + t x k" + q = Quaternion(x, y, z, x + t) + assert latex(q) == r"x + y i + z j + \left(t + x\right) k" + + +def test_TensorProduct_printing(): + from sympy.tensor.functions import TensorProduct + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert latex(TensorProduct(A, B)) == r"A \otimes B" + + +def test_WedgeProduct_printing(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert latex(wp) == r"\operatorname{d}x \wedge \operatorname{d}y" + + +def test_issue_9216(): + expr_1 = Pow(1, -1, evaluate=False) + assert latex(expr_1) == r"1^{-1}" + + expr_2 = Pow(1, Pow(1, -1, evaluate=False), evaluate=False) + assert latex(expr_2) == r"1^{1^{-1}}" + + expr_3 = Pow(3, -2, evaluate=False) + assert latex(expr_3) == r"\frac{1}{9}" + + expr_4 = Pow(1, -2, evaluate=False) + assert latex(expr_4) == r"1^{-2}" + + +def test_latex_printer_tensor(): + from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, tensor_heads + L = TensorIndexType("L") + i, j, k, l = tensor_indices("i j k l", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + K = TensorHead("K", [L, L, L, L]) + + assert latex(i) == r"{}^{i}" + assert latex(-i) == r"{}_{i}" + + expr = A(i) + assert latex(expr) == r"A{}^{i}" + + expr = A(i0) + assert latex(expr) == r"A{}^{i_{0}}" + + expr = A(-i) + assert latex(expr) == r"A{}_{i}" + + expr = -3*A(i) + assert latex(expr) == r"-3A{}^{i}" + + expr = K(i, j, -k, -i0) + assert latex(expr) == r"K{}^{ij}{}_{ki_{0}}" + + expr = K(i, -j, -k, i0) + assert latex(expr) == r"K{}^{i}{}_{jk}{}^{i_{0}}" + + expr = K(i, -j, k, -i0) + assert latex(expr) == r"K{}^{i}{}_{j}{}^{k}{}_{i_{0}}" + + expr = H(i, -j) + assert latex(expr) == r"H{}^{i}{}_{j}" + + expr = H(i, j) + assert latex(expr) == r"H{}^{ij}" + + expr = H(-i, -j) + assert latex(expr) == r"H{}_{ij}" + + expr = (1+x)*A(i) + assert latex(expr) == r"\left(x + 1\right)A{}^{i}" + + expr = H(i, -i) + assert latex(expr) == r"H{}^{L_{0}}{}_{L_{0}}" + + expr = H(i, -j)*A(j)*B(k) + assert latex(expr) == r"H{}^{i}{}_{L_{0}}A{}^{L_{0}}B{}^{k}" + + expr = A(i) + 3*B(i) + assert latex(expr) == r"3B{}^{i} + A{}^{i}" + + # Test ``TensorElement``: + from sympy.tensor.tensor import TensorElement + + expr = TensorElement(K(i, j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3,j,k=2,l}' + + expr = TensorElement(K(i, j, k, l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,jkl}' + + expr = TensorElement(K(i, -j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2,l}' + + expr = TensorElement(K(i, -j, k, -l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2}{}_{l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3, -k: 2}) + assert latex(expr) == r'K{}^{i=3,j}{}_{k=2,l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,j}{}_{kl}' + + expr = PartialDerivative(A(i), A(i)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}^{L_{0}}}}{A{}^{L_{0}}}" + + expr = PartialDerivative(A(-i), A(-j)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}_{j}}}{A{}_{i}}" + + expr = PartialDerivative(K(i, j, -k, -l), A(m), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}^{m}} \partial {A{}_{n}}}{K{}^{ij}{}_{kl}}" + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(A{}_{i} + B{}_{i}\right)}" + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(3A{}_{i}\right)}" + + +def test_multiline_latex(): + a, b, c, d, e, f = symbols('a b c d e f') + expr = -a + 2*b -3*c +4*d -5*e + expected = r"\begin{eqnarray}" + "\n"\ + r"f & = &- a \nonumber\\" + "\n"\ + r"& & + 2 b \nonumber\\" + "\n"\ + r"& & - 3 c \nonumber\\" + "\n"\ + r"& & + 4 d \nonumber\\" + "\n"\ + r"& & - 5 e " + "\n"\ + r"\end{eqnarray}" + assert multiline_latex(f, expr, environment="eqnarray") == expected + + expected2 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 2, environment="eqnarray") == expected2 + + expected3 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray") == expected3 + + expected3dots = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \dots\nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray", use_dots=True) == expected3dots + + expected3align = r'\begin{align*}' + '\n'\ + r'f = &- a + 2 b - 3 c \\'+ '\n'\ + r'& + 4 d - 5 e ' + '\n'\ + r'\end{align*}' + + assert multiline_latex(f, expr, 3) == expected3align + assert multiline_latex(f, expr, 3, environment='align*') == expected3align + + expected2ieee = r'\begin{IEEEeqnarray}{rCl}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{IEEEeqnarray}' + + assert multiline_latex(f, expr, 2, environment="IEEEeqnarray") == expected2ieee + + raises(ValueError, lambda: multiline_latex(f, expr, environment="foo")) + +def test_issue_15353(): + a, x = symbols('a x') + # Obtained from nonlinsolve([(sin(a*x)),cos(a*x)],[x,a]) + sol = ConditionSet( + Tuple(x, a), Eq(sin(a*x), 0) & Eq(cos(a*x), 0), S.Complexes**2) + assert latex(sol) == \ + r'\left\{\left( x, \ a\right)\; \middle|\; \left( x, \ a\right) \in ' \ + r'\mathbb{C}^{2} \wedge \sin{\left(a x \right)} = 0 \wedge ' \ + r'\cos{\left(a x \right)} = 0 \right\}' + + +def test_latex_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert latex(Expectation(X)) == r'\operatorname{E}\left[X\right]' + assert latex(Variance(X)) == r'\operatorname{Var}\left(X\right)' + assert latex(Probability(X > 0)) == r'\operatorname{P}\left(X > 0\right)' + Y = Normal("Y", mu, sigma) + assert latex(Covariance(X, Y)) == r'\operatorname{Cov}\left(X, Y\right)' + + +def test_trace(): + # Issue 15303 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A)) == r"\operatorname{tr}\left(A \right)" + assert latex(trace(A**2)) == r"\operatorname{tr}\left(A^{2} \right)" + + +def test_print_basic(): + # Issue 15303 + from sympy.core.basic import Basic + from sympy.core.expr import Expr + + # dummy class for testing printing where the function is not + # implemented in latex.py + class UnimplementedExpr(Expr): + def __new__(cls, e): + return Basic.__new__(cls, e) + + # dummy function for testing + def unimplemented_expr(expr): + return UnimplementedExpr(expr).doit() + + # override class name to use superscript / subscript + def unimplemented_expr_sup_sub(expr): + result = UnimplementedExpr(expr) + result.__class__.__name__ = 'UnimplementedExpr_x^1' + return result + + assert latex(unimplemented_expr(x)) == r'\operatorname{UnimplementedExpr}\left(x\right)' + assert latex(unimplemented_expr(x**2)) == \ + r'\operatorname{UnimplementedExpr}\left(x^{2}\right)' + assert latex(unimplemented_expr_sup_sub(x)) == \ + r'\operatorname{UnimplementedExpr^{1}_{x}}\left(x\right)' + + +def test_MatrixSymbol_bold(): + # Issue #15871 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A), mat_symbol_style='bold') == \ + r"\operatorname{tr}\left(\mathbf{A} \right)" + assert latex(trace(A), mat_symbol_style='plain') == \ + r"\operatorname{tr}\left(A \right)" + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A, mat_symbol_style='bold') == r"- \mathbf{A}" + assert latex(A - A*B - B, mat_symbol_style='bold') == \ + r"\mathbf{A} - \mathbf{A} \mathbf{B} - \mathbf{B}" + assert latex(-A*B - A*B*C - B, mat_symbol_style='bold') == \ + r"- \mathbf{A} \mathbf{B} - \mathbf{A} \mathbf{B} \mathbf{C} - \mathbf{B}" + + A_k = MatrixSymbol("A_k", 3, 3) + assert latex(A_k, mat_symbol_style='bold') == r"\mathbf{A}_{k}" + + A = MatrixSymbol(r"\nabla_k", 3, 3) + assert latex(A, mat_symbol_style='bold') == r"\mathbf{\nabla}_{k}" + +def test_AppliedPermutation(): + p = Permutation(0, 1, 2) + x = Symbol('x') + assert latex(AppliedPermutation(p, x)) == \ + r'\sigma_{\left( 0\; 1\; 2\right)}(x)' + + +def test_PermutationMatrix(): + p = Permutation(0, 1, 2) + assert latex(PermutationMatrix(p)) == r'P_{\left( 0\; 1\; 2\right)}' + p = Permutation(0, 3)(1, 2) + assert latex(PermutationMatrix(p)) == \ + r'P_{\left( 0\; 3\right)\left( 1\; 2\right)}' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert latex(piecewise_fold(fo)) == '\\begin{cases} 2 \\sin{\\left(x \\right)}' \ + ' - \\sin{\\left(2 x \\right)} + \\frac{2 \\sin{\\left(3 x \\right)}}{3} +' \ + ' \\ldots & \\text{for}\\: n > -\\infty \\wedge n < \\infty \\wedge ' \ + 'n \\neq 0 \\\\0 & \\text{otherwise} \\end{cases}' + assert latex(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_imaginary_unit(): + assert latex(1 + I) == r'1 + i' + assert latex(1 + I, imaginary_unit='i') == r'1 + i' + assert latex(1 + I, imaginary_unit='j') == r'1 + j' + assert latex(1 + I, imaginary_unit='foo') == r'1 + foo' + assert latex(I, imaginary_unit="ti") == r'\text{i}' + assert latex(I, imaginary_unit="tj") == r'\text{j}' + + +def test_text_re_im(): + assert latex(im(x), gothic_re_im=True) == r'\Im{\left(x\right)}' + assert latex(im(x), gothic_re_im=False) == r'\operatorname{im}{\left(x\right)}' + assert latex(re(x), gothic_re_im=True) == r'\Re{\left(x\right)}' + assert latex(re(x), gothic_re_im=False) == r'\operatorname{re}{\left(x\right)}' + + +def test_latex_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField, Differential + from sympy.diffgeom.rn import R2 + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert latex(m) == r'\text{M}' + p = Patch('P', m) + assert latex(p) == r'\text{P}_{\text{M}}' + rect = CoordSystem('rect', p, [x, y]) + assert latex(rect) == r'\text{rect}^{\text{P}}_{\text{M}}' + b = BaseScalarField(rect, 0) + assert latex(b) == r'\mathbf{x}' + + g = Function('g') + s_field = g(R2.x, R2.y) + assert latex(Differential(s_field)) == \ + r'\operatorname{d}\left(g{\left(\mathbf{x},\mathbf{y} \right)}\right)' + + +def test_unit_printing(): + assert latex(5*meter) == r'5 \text{m}' + assert latex(3*gibibyte) == r'3 \text{gibibyte}' + assert latex(4*microgram/second) == r'\frac{4 \mu\text{g}}{\text{s}}' + assert latex(4*micro*gram/second) == r'\frac{4 \mu \text{g}}{\text{s}}' + assert latex(5*milli*meter) == r'5 \text{m} \text{m}' + assert latex(milli) == r'\text{m}' + + +def test_issue_17092(): + x_star = Symbol('x^*') + assert latex(Derivative(x_star, x_star,2)) == r'\frac{d^{2}}{d \left(x^{*}\right)^{2}} x^{*}' + + +def test_latex_decimal_separator(): + + x, y, z, t = symbols('x y z t') + k, m, n = symbols('k m n', integer=True) + f, g, h = symbols('f g h', cls=Function) + + # comma decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='comma') == r'\left[ 1; \ 2{,}3; \ 4{,}5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='comma') == r'\left\{1; 2{,}3; 4{,}5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'comma') == r'\left( 1; \ 2{,}3; \ 4{,}6\right)') + assert(latex((1,), decimal_separator='comma') == r'\left( 1;\right)') + + # period decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='period') == r'\left[ 1, \ 2.3, \ 4.5\right]' ) + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'period') == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,), decimal_separator='period') == r'\left( 1,\right)') + + # default decimal_separator + assert(latex([1, 2.3, 4.5]) == r'\left[ 1, \ 2.3, \ 4.5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5)) == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6)) == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,)) == r'\left( 1,\right)') + + assert(latex(Mul(3.4,5.3), decimal_separator = 'comma') == r'18{,}02') + assert(latex(3.4*5.3, decimal_separator = 'comma') == r'18{,}02') + x = symbols('x') + y = symbols('y') + z = symbols('z') + assert(latex(x*5.3 + 2**y**3.4 + 4.5 + z, decimal_separator = 'comma') == r'2^{y^{3{,}4}} + 5{,}3 x + z + 4{,}5') + + assert(latex(0.987, decimal_separator='comma') == r'0{,}987') + assert(latex(S(0.987), decimal_separator='comma') == r'0{,}987') + assert(latex(.3, decimal_separator='comma') == r'0{,}3') + assert(latex(S(.3), decimal_separator='comma') == r'0{,}3') + + + assert(latex(5.8*10**(-7), decimal_separator='comma') == r'5{,}8 \cdot 10^{-7}') + assert(latex(S(5.7)*10**(-7), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + assert(latex(S(5.7*10**(-7)), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + + x = symbols('x') + assert(latex(1.2*x+3.4, decimal_separator='comma') == r'1{,}2 x + 3{,}4') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + + # Error Handling tests + raises(ValueError, lambda: latex([1,2.3,4.5], decimal_separator='non_existing_decimal_separator_in_list')) + raises(ValueError, lambda: latex(FiniteSet(1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_set')) + raises(ValueError, lambda: latex((1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_tuple')) + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == r'x' + +def test_latex_escape(): + assert latex_escape(r"~^\&%$#_{}") == "".join([ + r'\textasciitilde', + r'\textasciicircum', + r'\textbackslash', + r'\&', + r'\%', + r'\$', + r'\#', + r'\_', + r'\{', + r'\}', + ]) + +def test_emptyPrinter(): + class MyObject: + def __repr__(self): + return "" + + # unknown objects are monospaced + assert latex(MyObject()) == r"\mathtt{\text{}}" + + # even if they are nested within other objects + assert latex((MyObject(),)) == r"\left( \mathtt{\text{}},\right)" + +def test_global_settings(): + import inspect + + # settings should be visible in the signature of `latex` + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + try: + # but changing the defaults... + LatexPrinter.set_global_settings(imaginary_unit='j') + # ... should change the signature + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'j' + assert latex(I) == r'j' + finally: + # there's no public API to undo this, but we need to make sure we do + # so as not to impact other tests + del LatexPrinter._global_settings['imaginary_unit'] + + # check we really did undo it + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + +def test_pickleable(): + # this tests that the _PrintFunction instance is pickleable + import pickle + assert pickle.loads(pickle.dumps(latex)) is latex + +def test_printing_latex_array_expressions(): + assert latex(ArraySymbol("A", (2, 3, 4))) == "A" + assert latex(ArrayElement("A", (2, 1/(1-x), 0))) == "{{A}_{2, \\frac{1}{1 - x}, 0}}" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert latex(ArrayElement(M*N, [x, 0])) == "{{\\left(M N\\right)}_{x, 0}}" + +def test_Array(): + arr = Array(range(10)) + assert latex(arr) == r'\left[\begin{matrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9\end{matrix}\right]' + + arr = Array(range(11)) + # fill the empty argument with a bunch of 'c' to avoid latex errors + assert latex(arr) == r'\left[\begin{array}{ccccccccccc}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + +def test_latex_with_unevaluated(): + with evaluate(False): + assert latex(a * a) == r"a a" + + +def test_latex_disable_split_super_sub(): + assert latex(Symbol('u^a_b')) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=False) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=True) == 'u\\^a\\_b' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_llvmjit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_llvmjit.py new file mode 100644 index 0000000000000000000000000000000000000000..709476f1d7517dc629210341594a70dc6f41808f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_llvmjit.py @@ -0,0 +1,224 @@ +from sympy.external import import_module +from sympy.testing.pytest import raises +import ctypes + + +if import_module('llvmlite'): + import sympy.printing.llvmjitcode as g +else: + disabled = True + +import sympy +from sympy.abc import a, b, n + + +# copied from numpy.isclose documentation +def isclose(a, b): + rtol = 1e-5 + atol = 1e-8 + return abs(a-b) <= atol + rtol*abs(b) + + +def test_simple_expr(): + e = a + 1.0 + f = g.llvm_callable([a], e) + res = float(e.subs({a: 4.0}).evalf()) + jit_res = f(4.0) + + assert isclose(jit_res, res) + + +def test_two_arg(): + e = 4.0*a + b + 3.0 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 4.0, b: 3.0}).evalf()) + jit_res = f(4.0, 3.0) + + assert isclose(jit_res, res) + + +def test_func(): + e = 4.0*sympy.exp(-a) + f = g.llvm_callable([a], e) + res = float(e.subs({a: 1.5}).evalf()) + jit_res = f(1.5) + + assert isclose(jit_res, res) + + +def test_two_func(): + e = 4.0*sympy.exp(-a) + sympy.exp(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_sqrt(): + e = 4.0*sympy.sqrt(a) + sympy.sqrt(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_pow(): + e = a**1.5 + b**7 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_callback(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_cubature(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='cubature') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + out_array = array_type(0.0) + jit_ret = f(m, array, None, m, out_array) + + assert jit_ret == 0 + + res = float(e.subs(inp).evalf()) + + assert isclose(out_array[0], res) + + +def test_callback_two(): + e = 3*a*b + f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {a: 0.2, b: 1.7} + array = array_type(inp[a], inp[b]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_alt_two(): + d = sympy.IndexedBase('d') + e = 3*d[0]*d[1] + f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {d[0]: 0.2, d[1]: 1.7} + array = array_type(inp[d[0]], inp[d[1]]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_multiple_statements(): + # Match return from CSE + e = [[(b, 4.0*a)], [b + 5]] + f = g.llvm_callable([a], e) + b_val = e[0][0][1].subs({a: 1.5}) + res = float(e[1][0].subs({b: b_val}).evalf()) + jit_res = f(1.5) + assert isclose(jit_res, res) + + f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + array = array_type(1.5) + jit_callback_res = f_callback(m, array) + assert isclose(jit_callback_res, res) + + +def test_cse(): + e = a*a + b*b + sympy.exp(-a*a - b*b) + e2 = sympy.cse(e) + f = g.llvm_callable([a, b], e2) + res = float(e.subs({a: 2.3, b: 0.1}).evalf()) + jit_res = f(2.3, 0.1) + + assert isclose(jit_res, res) + + +def eval_cse(e, sub_dict): + tmp_dict = {} + for tmp_name, tmp_expr in e[0]: + e2 = tmp_expr.subs(sub_dict) + e3 = e2.subs(tmp_dict) + tmp_dict[tmp_name] = e3 + return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]] + + +def test_cse_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2]) + + raises(NotImplementedError, + lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate')) + + f = g.llvm_callable([a, b], e3) + jit_res = f(0.1, 1.5) + assert len(jit_res) == 2 + res = eval_cse(e3, {a: 0.1, b: 1.5}) + assert isclose(res[0], jit_res[0]) + assert isclose(res[1], jit_res[1]) + + +def test_callback_cubature_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2, 4*e2]) + f = g.llvm_callable([a, b], e3, callback_type='cubature') + + # Number of input variables + ndim = 2 + # Number of output expression values + outdim = 3 + + m = ctypes.c_int(ndim) + fdim = ctypes.c_int(outdim) + array_type = ctypes.c_double * ndim + out_array_type = ctypes.c_double * outdim + inp = {a: 0.2, b: 1.5} + array = array_type(inp[a], inp[b]) + out_array = out_array_type() + jit_ret = f(m, array, None, fdim, out_array) + + assert jit_ret == 0 + + res = eval_cse(e3, inp) + + assert isclose(out_array[0], res[0]) + assert isclose(out_array[1], res[1]) + assert isclose(out_array[2], res[2]) + + +def test_symbol_not_found(): + e = a*a + b + raises(LookupError, lambda: g.llvm_callable([a], e)) + + +def test_bad_callback(): + e = a + raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback')) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_maple.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_maple.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4c512ad3203bd64ae56b350e15734b3a6afb0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_maple.py @@ -0,0 +1,381 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import besseli + +from sympy.printing.maple import maple_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert maple_code(Integer(67)) == "67" + assert maple_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert maple_code(Rational(3, 7)) == "3/7" + assert maple_code(Rational(18, 9)) == "2" + assert maple_code(Rational(3, -7)) == "-3/7" + assert maple_code(Rational(-3, -7)) == "3/7" + assert maple_code(x + Rational(3, 7)) == "x + 3/7" + assert maple_code(Rational(3, 7) * x) == '(3/7)*x' + + +def test_Relational(): + assert maple_code(Eq(x, y)) == "x = y" + assert maple_code(Ne(x, y)) == "x <> y" + assert maple_code(Le(x, y)) == "x <= y" + assert maple_code(Lt(x, y)) == "x < y" + assert maple_code(Gt(x, y)) == "x > y" + assert maple_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert maple_code(abs(x)) == "abs(x)" + assert maple_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert maple_code(x ** 3) == "x^3" + assert maple_code(x ** (y ** 3)) == "x^(y^3)" + + assert maple_code((x ** 3) ** y) == "(x^3)^y" + assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)' + + g = implemented_function('g', Lambda(x, 2 * x)) + assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + # For issue 14160 + assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_basic_ops(): + assert maple_code(x * y) == "x*y" + assert maple_code(x + y) == "x + y" + assert maple_code(x - y) == "x - y" + assert maple_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert maple_code(1 / x) == '1/x' + assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x' + assert maple_code(1 / sqrt(x)) == '1/sqrt(x)' + assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)' + assert maple_code(sqrt(x)) == 'sqrt(x)' + assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)' + assert maple_code(1 / pi) == '1/Pi' + assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi' + assert maple_code(pi ** -0.5) == '1/sqrt(Pi)' + + +def test_mix_number_mult_symbols(): + assert maple_code(3 * x) == "3*x" + assert maple_code(pi * x) == "Pi*x" + assert maple_code(3 / x) == "3/x" + assert maple_code(pi / x) == "Pi/x" + assert maple_code(x / 3) == '(1/3)*x' + assert maple_code(x / pi) == "x/Pi" + assert maple_code(x * y) == "x*y" + assert maple_code(3 * x * y) == "3*x*y" + assert maple_code(3 * pi * x * y) == "3*Pi*x*y" + assert maple_code(x / y) == "x/y" + assert maple_code(3 * x / y) == "3*x/y" + assert maple_code(x * y / z) == "x*y/z" + assert maple_code(x / y * z) == "x*z/y" + assert maple_code(1 / x / y) == "1/(x*y)" + assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)" + assert maple_code(3 * pi / x) == "3*Pi/x" + assert maple_code(S(3) / 5) == "3/5" + assert maple_code(S(3) / 5 * x) == '(3/5)*x' + assert maple_code(x / y / z) == "x/(y*z)" + assert maple_code((x + y) / z) == "(x + y)/z" + assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)" + assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma' + assert maple_code(x / 3 / pi) == '(1/3)*x/Pi' + assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi' + + +def test_mix_number_pow_symbols(): + assert maple_code(pi ** 3) == 'Pi^3' + assert maple_code(x ** 2) == 'x^2' + + assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)' + assert maple_code(x ** y) == 'x^y' + + assert maple_code(x ** (y ** z)) == 'x^(y^z)' + assert maple_code((x ** y) ** z) == '(x^y)^z' + + +def test_imag(): + I = S('I') + assert maple_code(I) == "I" + assert maple_code(5 * I) == "5*I" + + assert maple_code((S(3) / 2) * I) == "(3/2)*I" + assert maple_code(3 + 4 * I) == "3 + 4*I" + + +def test_constants(): + assert maple_code(pi) == "Pi" + assert maple_code(oo) == "infinity" + assert maple_code(-oo) == "-infinity" + assert maple_code(S.NegativeInfinity) == "-infinity" + assert maple_code(S.NaN) == "undefined" + assert maple_code(S.Exp1) == "exp(1)" + assert maple_code(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))' + assert maple_code(2 * Catalan) == '2*Catalan' + assert maple_code(2 * EulerGamma) == "2*gamma" + + +def test_boolean(): + assert maple_code(x & y) == "x and y" + assert maple_code(x | y) == "x or y" + assert maple_code(~x) == "not x" + assert maple_code(x & y & z) == "x and y and z" + assert maple_code(x | y | z) == "x or y or z" + assert maple_code((x & y) | z) == "z or x and y" + assert maple_code((x | y) & z) == "z and (x or y)" + + +def test_Matrices(): + assert maple_code(Matrix(1, 1, [10])) == \ + 'Matrix([[10]], storage = rectangular)' + + A = Matrix([[1, sin(x / 2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = \ + 'Matrix(' \ + '[[1, sin((1/2)*x), abs(x)],' \ + ' [0, 1, Pi],' \ + ' [0, exp(1), ceil(x)]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + # row and columns + assert maple_code(A[:, 0]) == \ + 'Matrix([[1], [0], [0]], storage = rectangular)' + assert maple_code(A[0, :]) == \ + 'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)' + assert maple_code(Matrix([[x, x - y, -y]])) == \ + 'Matrix([[x, x - y, -y]], storage = rectangular)' + + # empty matrices + assert maple_code(Matrix(0, 0, [])) == \ + 'Matrix([], storage = rectangular)' + assert maple_code(Matrix(0, 3, [])) == \ + 'Matrix([], storage = rectangular)' + +def test_SparseMatrices(): + assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)' + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]]) + assert maple_code(A) == \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)' + assert maple_code(A.T) == \ + 'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)' + + +def test_Matrices_entries_not_hadamard(): + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]]) + expected = \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert maple_code(A * B) == "A.B" + assert maple_code(B * A) == "B.A" + assert maple_code(2 * A * B) == "2*A.B" + assert maple_code(B * 2 * A) == "2*B.A" + + assert maple_code( + A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)" + + assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)" + assert maple_code(A ** 3) == "MatrixPower(A, 3)" + assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)" + + +def test_special_matrices(): + assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)" + assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)' + + +def test_containers(): + assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]" + + assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]" + assert maple_code([1]) == "[1]" + assert maple_code((1,)) == "[1]" + assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]" + assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]" + # scalar, matrix, empty matrix and empty list + + assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \ + "[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]" + + +def test_maple_noninline(): + source = maple_code((x + y)/Catalan, assign_to='me', inline=False) + expected = "me := (x + y)/Catalan" + + assert source == expected + + +def test_maple_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)" + A = Matrix([[1, 2], [3, 4]]) + assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)" + + +def test_maple_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)" + raises(ValueError, lambda: maple_code(A, assign_to=x)) + raises(ValueError, lambda: maple_code(A, assign_to=C)) + + +def test_maple_matrix_1x1(): + A = Matrix([[3]]) + assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)" + + +def test_maple_matrix_elements(): + A = Matrix([[x, 2, x * y]]) + + assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2" + AA = MatrixSymbol('AA', 1, 3) + assert maple_code(AA) == "AA" + + assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \ + "sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]" + assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]" + + +def test_maple_boolean(): + assert maple_code(True) == "true" + assert maple_code(S.true) == "true" + assert maple_code(False) == "false" + assert maple_code(S.false) == "false" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x * y + assert maple_code(M) == \ + 'Matrix([[0, 0, 0, 30, 0, 0],' \ + ' [0, 0, 20, 22, 0, 0],' \ + ' [0, 0, 10, 0, 0, 0],' \ + ' [x*y, 0, 0, 0, 0, 0],' \ + ' [0, 0, 0, 0, 0, 0]], ' \ + 'storage = sparse)' + +# Not an important point. +def test_maple_not_supported(): + with raises(NotImplementedError): + maple_code(S.ComplexInfinity) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + + assert (maple_code(A[0, 0]) == "A[1, 1]") + assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]") + + F = A-B + + assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert maple_code(C) == "A*B" + + assert maple_code(C * v) == "(A*B).v" + # HadamardProduct is higher than dot product. + + assert maple_code(h * C * v) == "h.(A*B).v" + + assert maple_code(C * A) == "(A*B).A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + + assert maple_code(C * x * y) == "x*y*(A*B)" + + +def test_maple_piecewise(): + expr = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(expr) == "piecewise(x < 1, x, x^2)" + assert maple_code(expr, assign_to="r") == ( + "r := piecewise(x < 1, x, x^2)") + + expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True)) + expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)" + assert maple_code(expr) == expected + assert maple_code(expr, assign_to="r") == "r := " + expected + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: maple_code(expr)) + + +def test_maple_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)" + assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x" + assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)" + assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)" + + +def test_maple_derivatives(): + f = Function('f') + assert maple_code(f(x).diff(x)) == 'diff(f(x), x)' + assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)' + + +def test_automatic_rewrites(): + assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))' + assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))' + + +def test_specfun(): + assert maple_code('asin(x)') == 'arcsin(x)' + assert maple_code(besseli(x, y)) == 'BesselI(x, y)' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_mathematica.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf6b537677442ae59a4f1bbd2b5774d6646f4e2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_mathematica.py @@ -0,0 +1,287 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple, + Derivative, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.integrals import Integral +from sympy.concrete import Sum +from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max, + Min, gamma, polygamma, loggamma, erf, erfi, erfc, + erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li, + Shi, Chi, uppergamma, beta, subfactorial, erf2inv, + factorial, factorial2, catalan, RisingFactorial, + FallingFactorial, harmonic, atan2, sec, acsc, + hermite, laguerre, assoc_laguerre, jacobi, + gegenbauer, chebyshevt, chebyshevu, legendre, + assoc_legendre, Li, LambertW) + +from sympy.printing.mathematica import mathematica_code as mcode + +x, y, z, w = symbols('x,y,z,w') +f = Function('f') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "(3/7)*x" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(f(x, y, z)) == "f[x, y, z]" + assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]" + assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]" + assert mcode(atan2(y, x)) == "ArcTan[x, y]" + assert mcode(conjugate(x)) == "Conjugate[x]" + assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]" + assert mcode(fresnelc(x)) == "FresnelC[x]" + assert mcode(fresnels(x)) == "FresnelS[x]" + assert mcode(gamma(x)) == "Gamma[x]" + assert mcode(uppergamma(x, y)) == "Gamma[x, y]" + assert mcode(polygamma(x, y)) == "PolyGamma[x, y]" + assert mcode(loggamma(x)) == "LogGamma[x]" + assert mcode(erf(x)) == "Erf[x]" + assert mcode(erfc(x)) == "Erfc[x]" + assert mcode(erfi(x)) == "Erfi[x]" + assert mcode(erf2(x, y)) == "Erf[x, y]" + assert mcode(expint(x, y)) == "ExpIntegralE[x, y]" + assert mcode(erfcinv(x)) == "InverseErfc[x]" + assert mcode(erfinv(x)) == "InverseErf[x]" + assert mcode(erf2inv(x, y)) == "InverseErf[x, y]" + assert mcode(Ei(x)) == "ExpIntegralEi[x]" + assert mcode(Ci(x)) == "CosIntegral[x]" + assert mcode(li(x)) == "LogIntegral[x]" + assert mcode(Si(x)) == "SinIntegral[x]" + assert mcode(Shi(x)) == "SinhIntegral[x]" + assert mcode(Chi(x)) == "CoshIntegral[x]" + assert mcode(beta(x, y)) == "Beta[x, y]" + assert mcode(factorial(x)) == "Factorial[x]" + assert mcode(factorial2(x)) == "Factorial2[x]" + assert mcode(subfactorial(x)) == "Subfactorial[x]" + assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]" + assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]" + assert mcode(catalan(x)) == "CatalanNumber[x]" + assert mcode(harmonic(x)) == "HarmonicNumber[x]" + assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]" + assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]" + assert mcode(LambertW(x)) == "ProductLog[x]" + assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]" + assert mcode(LambertW(x, y)) == "ProductLog[y, x]" + + +def test_special_polynomials(): + assert mcode(hermite(x, y)) == "HermiteH[x, y]" + assert mcode(laguerre(x, y)) == "LaguerreL[x, y]" + assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]" + assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]" + assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]" + assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]" + assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]" + assert mcode(legendre(x, y)) == "LegendreP[x, y]" + assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]" + + +def test_Pow(): + assert mcode(x**3) == "x^3" + assert mcode(x**(y**3)) == "x^(y^3)" + assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*f[x])^(-x + y^x)/(x^2 + y)" + assert mcode(x**-1.0) == 'x^(-1.0)' + assert mcode(x**Rational(2, 3)) == 'x^(2/3)' + + +def test_Mul(): + A, B, C, D = symbols('A B C D', commutative=False) + assert mcode(x*y*z) == "x*y*z" + assert mcode(x*y*A) == "x*y*A" + assert mcode(x*y*A*B) == "x*y*A**B" + assert mcode(x*y*A*B*C) == "x*y*A**B**C" + assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A" + + +def test_constants(): + assert mcode(S.Zero) == "0" + assert mcode(S.One) == "1" + assert mcode(S.NegativeOne) == "-1" + assert mcode(S.Half) == "1/2" + assert mcode(S.ImaginaryUnit) == "I" + + assert mcode(oo) == "Infinity" + assert mcode(S.NegativeInfinity) == "-Infinity" + assert mcode(S.ComplexInfinity) == "ComplexInfinity" + assert mcode(S.NaN) == "Indeterminate" + + assert mcode(S.Exp1) == "E" + assert mcode(pi) == "Pi" + assert mcode(S.GoldenRatio) == "GoldenRatio" + assert mcode(S.TribonacciConstant) == \ + "(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(2*S.TribonacciConstant) == \ + "2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(S.EulerGamma) == "EulerGamma" + assert mcode(S.Catalan) == "Catalan" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + + +def test_matrices(): + from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \ + ImmutableDenseMatrix, ImmutableSparseMatrix + A = MutableDenseMatrix( + [[1, -1, 0, 0], + [0, 1, -1, 0], + [0, 0, 1, -1], + [0, 0, 0, 1]] + ) + B = MutableSparseMatrix(A) + C = ImmutableDenseMatrix(A) + D = ImmutableSparseMatrix(A) + + assert mcode(C) == mcode(A) == \ + "{{1, -1, 0, 0}, " \ + "{0, 1, -1, 0}, " \ + "{0, 0, 1, -1}, " \ + "{0, 0, 0, 1}}" + + assert mcode(D) == mcode(B) == \ + "SparseArray[{" \ + "{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \ + "{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \ + "}, {4, 4}]" + + # Trivial cases of matrices + assert mcode(MutableDenseMatrix(0, 0, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]' + assert mcode(MutableDenseMatrix(0, 3, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]' + assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}' + assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]' + +def test_NDArray(): + from sympy.tensor.array import ( + MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray) + + example = MutableDenseNDimArray( + [[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]] + ) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = ImmutableDenseNDimArray(example) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = MutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + example = ImmutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + +def test_Integral(): + assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]" + assert mcode(Integral(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_Derivative(): + assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]" + assert mcode(Derivative(x, x)) == "Hold[D[x, x]]" + assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]" + assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]" + assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]" + + +def test_Sum(): + assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]" + assert mcode(Sum(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_comment(): + from sympy.printing.mathematica import MCodePrinter + assert MCodePrinter()._get_comment("Hello World") == \ + "(* Hello World *)" + + +def test_userfuncs(): + # Dictionary mutation test + some_function = symbols("some_function", cls=Function) + my_user_functions = {"some_function": "SomeFunction"} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + + # List argument test + my_user_functions = \ + {"some_function": [(lambda x: True, "SomeOtherFunction")]} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeOtherFunction[z]' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_numpy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..fee1c6bd95e54790a048220f37b8e5de79017d2f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_numpy.py @@ -0,0 +1,381 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import polygamma +from sympy.functions.special.error_functions import (Si, Ci) +from sympy.matrices import Matrix +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify +from sympy import symbols, Min, Max + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Pow +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \ + _numpy_known_functions, _scipy_known_constants, _scipy_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +np = import_module('numpy') +jax = import_module('jax') + +if np: + deafult_float_info = np.finfo(np.array([]).dtype) + NUMPY_DEFAULT_EPSILON = deafult_float_info.eps + +def test_numpy_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = NumPyPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)' + assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}} + +def test_numpy_logaddexp(): + lae = logaddexp(a, b) + assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)' + + +def test_sum(): + if not np: + skip("NumPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_multiple_sums(): + if not np: + skip("NumPy not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'numpy') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_codegen_einsum(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'numpy') + + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == np.matmul(ma, mb)).all() + + +def test_codegen_extra(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + mc = np.array([[2, 0], [1, 2]]) + md = np.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'numpy') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'numpy') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'numpy') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_relational(): + if not np: + skip("NumPy not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, True]) + + +def test_mod(): + if not np: + skip("NumPy not installed") + + e = Mod(a, b) + f = lambdify((a, b), e) + + a_ = np.array([0, 1, 2, 3]) + b_ = 2 + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([0, 1, 2, 3]) + b_ = np.array([2, 2, 2, 2]) + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([2, 3, 4, 5]) + b_ = np.array([2, 3, 4, 5]) + assert np.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_pow(): + if not np: + skip('NumPy not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'numpy') + assert f() == 0.5 + + +def test_expm1(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), expm1(a), 'numpy') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON + + +def test_log1p(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), log1p(a), 'numpy') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON + +def test_hypot(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON + +def test_log10(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_exp2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON + + +def test_log2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON + + +def test_Sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_matsolve(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr) + f_matsolve = lambdify((M, x), matsolve_expr) + + m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert np.linalg.matrix_rank(m0) == 3 + + x0 = np.array([3, 4, 5]) + + assert np.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not np: + skip("NumPy not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = NumPyPrinter() + assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2)) + ma = np.array([[1, 2], [3, 4]]) + mr = np.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n))) + +def test_jax_tuple_compatibility(): + if not jax: + skip("Jax not installed") + + x, y, z = symbols('x y z') + expr = Max(x, y, z) + Min(x, y, z) + func = lambdify((x, y, z), expr, 'jax') + input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6) + input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2) + assert np.allclose(func(*input_tuple1), func(*input_array1)) + assert np.allclose(func(*input_tuple2), func(*input_array2)) + +def test_numpy_array(): + p = NumPyPrinter() + assert p.doprint(Array([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Array([1, 2])) == 'numpy.array([1, 2])' + assert p.doprint(Array([[[1, 2, 3]]])) == 'numpy.array([[[1, 2, 3]]])' + assert p.doprint(Array([], (0,))) == 'numpy.zeros((0,))' + assert p.doprint(Array([], (0, 0))) == 'numpy.zeros((0, 0))' + assert p.doprint(Array([], (0, 1))) == 'numpy.zeros((0, 1))' + assert p.doprint(Array([], (1, 0))) == 'numpy.zeros((1, 0))' + assert p.doprint(Array([1], ())) == 'numpy.array(1)' + +def test_numpy_matrix(): + p = NumPyPrinter() + assert p.doprint(Matrix([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Matrix([1, 2])) == 'numpy.array([[1], [2]])' + assert p.doprint(Matrix(0, 0, [])) == 'numpy.zeros((0, 0))' + assert p.doprint(Matrix(0, 1, [])) == 'numpy.zeros((0, 1))' + assert p.doprint(Matrix(1, 0, [])) == 'numpy.zeros((1, 0))' + +def test_numpy_known_funcs_consts(): + assert _numpy_known_constants['NaN'] == 'numpy.nan' + assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma' + + assert _numpy_known_functions['acos'] == 'numpy.arccos' + assert _numpy_known_functions['log'] == 'numpy.log' + +def test_scipy_known_funcs_consts(): + assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio' + assert _scipy_known_constants['Pi'] == 'scipy.constants.pi' + + assert _scipy_known_functions['erf'] == 'scipy.special.erf' + assert _scipy_known_functions['factorial'] == 'scipy.special.factorial' + +def test_numpy_print_methods(): + prntr = NumPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + +def test_scipy_print_methods(): + prntr = SciPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + assert hasattr(prntr, '_print_erf') + assert hasattr(prntr, '_print_factorial') + assert hasattr(prntr, '_print_chebyshevt') + k = Symbol('k', integer=True, nonnegative=True) + x = Symbol('x', real=True) + assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)" + assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]" + assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_octave.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..1aba318f873c48ec702f1b4e3a6cc047f75d647d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_octave.py @@ -0,0 +1,515 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, EulerGamma, GoldenRatio, Catalan, + Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu, + chebyshevt, conjugate, DiracDelta, exp, expint, + factorial, floor, harmonic, Heaviside, im, + laguerre, LambertW, log, Max, Min, Piecewise, + polylog, re, RisingFactorial, sign, sinc, sqrt, + zeta, binomial, legendre, dirichlet_eta, + riemann_xi) +from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot, + atan, asec, acsc, sinh, cosh, tanh, coth, csch, + sech, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix, HadamardPower) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, + uppergamma, loggamma, + polygamma) +from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi, + erfcinv, erfinv, fresnelc, + fresnels, li, Shi, Si, Li, + erf2, Ei) +from sympy.printing.octave import octave_code, octave_code as mcode + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "3*x/7" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)" + assert mcode(sign(x)) == "sign(x)" + assert mcode(exp(x)) == "exp(x)" + assert mcode(log(x)) == "log(x)" + assert mcode(factorial(x)) == "factorial(x)" + assert mcode(floor(x)) == "floor(x)" + assert mcode(atan2(y, x)) == "atan2(y, x)" + assert mcode(beta(x, y)) == 'beta(x, y)' + assert mcode(polylog(x, y)) == 'polylog(x, y)' + assert mcode(harmonic(x)) == 'harmonic(x)' + assert mcode(bernoulli(x)) == "bernoulli(x)" + assert mcode(bernoulli(x, y)) == "bernoulli(x, y)" + assert mcode(legendre(x, y)) == "legendre(x, y)" + + +def test_Function_change_name(): + assert mcode(abs(x)) == "abs(x)" + assert mcode(ceiling(x)) == "ceil(x)" + assert mcode(arg(x)) == "angle(x)" + assert mcode(im(x)) == "imag(x)" + assert mcode(re(x)) == "real(x)" + assert mcode(conjugate(x)) == "conj(x)" + assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)" + assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)" + assert mcode(laguerre(x, y)) == "laguerreL(x, y)" + assert mcode(Chi(x)) == "coshint(x)" + assert mcode(Shi(x)) == "sinhint(x)" + assert mcode(Ci(x)) == "cosint(x)" + assert mcode(Si(x)) == "sinint(x)" + assert mcode(li(x)) == "logint(x)" + assert mcode(loggamma(x)) == "gammaln(x)" + assert mcode(polygamma(x, y)) == "psi(x, y)" + assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)" + assert mcode(DiracDelta(x)) == "dirac(x)" + assert mcode(DiracDelta(x, 3)) == "dirac(3, x)" + assert mcode(Heaviside(x)) == "heaviside(x, 1/2)" + assert mcode(Heaviside(x, y)) == "heaviside(x, y)" + assert mcode(binomial(x, y)) == "bincoeff(x, y)" + assert mcode(Mod(x, y)) == "mod(x, y)" + + +def test_minmax(): + assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)" + assert mcode(Max(x, y, z)) == "max(x, max(y, z))" + assert mcode(Min(x, y, z)) == "min(x, min(y, z))" + + +def test_Pow(): + assert mcode(x**3) == "x.^3" + assert mcode(x**(y**3)) == "x.^(y.^3)" + assert mcode(x**Rational(2, 3)) == 'x.^(2/3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x).^(-x + y.^x)./(x.^2 + y)" + # For issue 14160 + assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x./(y.*y)' + + +def test_basic_ops(): + assert mcode(x*y) == "x.*y" + assert mcode(x + y) == "x + y" + assert mcode(x - y) == "x - y" + assert mcode(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert mcode(1/x) == '1./x' + assert mcode(x**-1) == mcode(x**-1.0) == '1./x' + assert mcode(1/sqrt(x)) == '1./sqrt(x)' + assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)' + assert mcode(sqrt(x)) == 'sqrt(x)' + assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)' + assert mcode(1/pi) == '1/pi' + assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi' + assert mcode(pi**-0.5) == '1/sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert mcode(3*x) == "3*x" + assert mcode(pi*x) == "pi*x" + assert mcode(3/x) == "3./x" + assert mcode(pi/x) == "pi./x" + assert mcode(x/3) == "x/3" + assert mcode(x/pi) == "x/pi" + assert mcode(x*y) == "x.*y" + assert mcode(3*x*y) == "3*x.*y" + assert mcode(3*pi*x*y) == "3*pi*x.*y" + assert mcode(x/y) == "x./y" + assert mcode(3*x/y) == "3*x./y" + assert mcode(x*y/z) == "x.*y./z" + assert mcode(x/y*z) == "x.*z./y" + assert mcode(1/x/y) == "1./(x.*y)" + assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)" + assert mcode(3*pi/x) == "3*pi./x" + assert mcode(S(3)/5) == "3/5" + assert mcode(S(3)/5*x) == "3*x/5" + assert mcode(x/y/z) == "x./(y.*z)" + assert mcode((x+y)/z) == "(x + y)./z" + assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)" + assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17) + assert mcode(x/3/pi) == "x/(3*pi)" + assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)" + + +def test_mix_number_pow_symbols(): + assert mcode(pi**3) == 'pi^3' + assert mcode(x**2) == 'x.^2' + assert mcode(x**(pi**3)) == 'x.^(pi^3)' + assert mcode(x**y) == 'x.^y' + assert mcode(x**(y**z)) == 'x.^(y.^z)' + assert mcode((x**y)**z) == '(x.^y).^z' + + +def test_imag(): + I = S('I') + assert mcode(I) == "1i" + assert mcode(5*I) == "5i" + assert mcode((S(3)/2)*I) == "3*1i/2" + assert mcode(3+4*I) == "3 + 4i" + assert mcode(sqrt(3)*I) == "sqrt(3)*1i" + + +def test_constants(): + assert mcode(pi) == "pi" + assert mcode(oo) == "inf" + assert mcode(-oo) == "-inf" + assert mcode(S.NegativeInfinity) == "-inf" + assert mcode(S.NaN) == "NaN" + assert mcode(S.Exp1) == "exp(1)" + assert mcode(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2" + assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17) + assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17) + + +def test_boolean(): + assert mcode(x & y) == "x & y" + assert mcode(x | y) == "x | y" + assert mcode(~x) == "~x" + assert mcode(x & y & z) == "x & y & z" + assert mcode(x | y | z) == "x | y | z" + assert mcode((x & y) | z) == "z | x & y" + assert mcode((x | y) & z) == "z & (x | y)" + + +def test_KroneckerDelta(): + from sympy.functions import KroneckerDelta + assert mcode(KroneckerDelta(x, y)) == "double(x == y)" + assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))" + assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)" + + +def test_Matrices(): + assert mcode(Matrix(1, 1, [10])) == "10" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]" + assert mcode(A) == expected + # row and columns + assert mcode(A[:,0]) == "[1; 0; 0]" + assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]" + # empty matrices + assert mcode(Matrix(0, 0, [])) == '[]' + assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]" + assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert mcode(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert mcode(A*B) == "A*B" + assert mcode(B*A) == "B*A" + assert mcode(2*A*B) == "2*A*B" + assert mcode(B*2*A) == "2*B*A" + assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)" + assert mcode(A**(x**2)) == "A^(x.^2)" + assert mcode(A**3) == "A^3" + assert mcode(A**S.Half) == "A^(1/2)" + + +def test_MatrixSolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + assert mcode(MatrixSolve(A, x)) == "A \\ x" + +def test_special_matrices(): + assert mcode(6*Identity(3)) == "6*eye(3)" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}" + # scalar, matrix, empty matrix and empty list + assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}" + + +def test_octave_noninline(): + source = mcode((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "Catalan = %s;\n" + "me = (x + y)/Catalan;" + ) % Catalan.evalf(17) + assert source == expected + + +def test_octave_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(expr, assign_to="r") == ( + "r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));") + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x;\n" + "else\n" + " r = x.^2;\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n" + "(x < 2).*(x.^3) + (~(x < 2)).*( ...\n" + "(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))") + assert mcode(expr) == expected + assert mcode(expr, assign_to="r") == "r = " + expected + ";" + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x.^2;\n" + "elseif (x < 2)\n" + " r = x.^3;\n" + "elseif (x < 3)\n" + " r = x.^4;\n" + "else\n" + " r = x.^5;\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: mcode(expr)) + + +def test_octave_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x" + assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)" + assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3" + + +def test_octave_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert mcode(A, assign_to='a') == "a = [1 2 3];" + A = Matrix([[1, 2], [3, 4]]) + assert mcode(A, assign_to='A') == "A = [1 2; 3 4];" + + +def test_octave_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert mcode(A, assign_to=B) == "B = [1 2 3];" + raises(ValueError, lambda: mcode(A, assign_to=x)) + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert mcode(A, assign_to=B) == "B = 3;" + # FIXME? + #assert mcode(A, assign_to=x) == "x = 3;" + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2" + A = MatrixSymbol('AA', 1, 3) + assert mcode(A) == "AA" + assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)" + assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)" + + +def test_octave_boolean(): + assert mcode(True) == "true" + assert mcode(S.true) == "true" + assert mcode(False) == "false" + assert mcode(S.false) == "false" + + +def test_octave_not_supported(): + with raises(NotImplementedError): + mcode(S.ComplexInfinity) + f = Function('f') + assert mcode(f(x).diff(x), strict=False) == ( + "% Not supported in Octave:\n" + "% Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_octave_not_supported_not_on_whitelist(): + from sympy.functions.special.polynomials import assoc_laguerre + with raises(NotImplementedError): + mcode(assoc_laguerre(x, y, z)) + + +def test_octave_expint(): + assert mcode(expint(1, x)) == "expint(x)" + with raises(NotImplementedError): + mcode(expint(2, x)) + assert mcode(expint(y, x), strict=False) == ( + "% Not supported in Octave:\n" + "% expint\n" + "expint(y, x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert mcode(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + n = Symbol('n') + assert mcode(C) == "A.*B" + assert mcode(C*v) == "(A.*B)*v" + assert mcode(h*C*v) == "h*(A.*B)*v" + assert mcode(C*A) == "(A.*B)*A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert mcode(C*x*y) == "(x.*y)*(A.*B)" + + # Testing HadamardPower: + assert mcode(HadamardPower(A, n)) == "A.**n" + assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)" + assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert mcode(M) == ( + "sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)" + ) + + +def test_sinc(): + assert mcode(sinc(x)) == 'sinc(x/pi)' + assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)' + assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)' + + +def test_trigfun(): + for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc, + sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth, + asech, acsch): + assert octave_code(f(x) == f.__name__ + '(x)') + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert octave_code(f(n, x)) == f.__name__ + '(n, x)' + for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma): + assert octave_code(f(x)) == f.__name__ + '(x)' + assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)' + assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)' + assert octave_code(airyai(x)) == 'airy(0, x)' + assert octave_code(airyaiprime(x)) == 'airy(1, x)' + assert octave_code(airybi(x)) == 'airy(2, x)' + assert octave_code(airybiprime(x)) == 'airy(3, x)' + assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))' + assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))' + assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))' + assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2' + assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2' + assert octave_code(LambertW(x)) == 'lambertw(x)' + assert octave_code(LambertW(x, n)) == 'lambertw(n, x)' + + # Automatic rewrite + assert octave_code(Ei(x)) == '(logint(exp(x)))' + assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))' + assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert mcode(A[0, 0]) == "A(1, 1)" + assert mcode(3 * A[0, 0]) == "3*A(1, 1)" + + F = C[0, 0].subs(C, A - B) + assert mcode(F) == "(A - B)(1, 1)" + + +def test_zeta_printing_issue_14820(): + assert octave_code(zeta(x)) == 'zeta(x)' + with raises(NotImplementedError): + octave_code(zeta(x, y)) + + +def test_automatic_rewrite(): + assert octave_code(Li(x)) == '(logint(x) - logint(2))' + assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_precedence.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ea07483857e8c2ee7f930aa53d2dacdc58193 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_precedence.py @@ -0,0 +1,128 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import Integer, Rational, Float, oo +from sympy.core.relational import Rel +from sympy.core.symbol import symbols +from sympy.functions import sin +from sympy.integrals.integrals import Integral +from sympy.series.order import Order + +from sympy.printing.precedence import precedence, PRECEDENCE + +x, y = symbols("x,y") + + +def test_Add(): + assert precedence(x + y) == PRECEDENCE["Add"] + assert precedence(x*y + 1) == PRECEDENCE["Add"] + + +def test_Function(): + assert precedence(sin(x)) == PRECEDENCE["Func"] + +def test_Derivative(): + assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"] + +def test_Integral(): + assert precedence(Integral(x, y)) == PRECEDENCE["Atom"] + + +def test_Mul(): + assert precedence(x*y) == PRECEDENCE["Mul"] + assert precedence(-x*y) == PRECEDENCE["Add"] + + +def test_Number(): + assert precedence(Integer(0)) == PRECEDENCE["Atom"] + assert precedence(Integer(1)) == PRECEDENCE["Atom"] + assert precedence(Integer(-1)) == PRECEDENCE["Add"] + assert precedence(Integer(10)) == PRECEDENCE["Atom"] + assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"] + assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"] + assert precedence(Float(5)) == PRECEDENCE["Atom"] + assert precedence(Float(-5)) == PRECEDENCE["Add"] + assert precedence(oo) == PRECEDENCE["Atom"] + assert precedence(-oo) == PRECEDENCE["Add"] + + +def test_Order(): + assert precedence(Order(x)) == PRECEDENCE["Atom"] + + +def test_Pow(): + assert precedence(x**y) == PRECEDENCE["Pow"] + assert precedence(-x**y) == PRECEDENCE["Add"] + assert precedence(x**-y) == PRECEDENCE["Pow"] + + +def test_Product(): + assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Relational(): + assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"] + + +def test_Sum(): + assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Symbol(): + assert precedence(x) == PRECEDENCE["Atom"] + + +def test_And_Or(): + # precedence relations between logical operators, ... + assert precedence(x & y) > precedence(x | y) + assert precedence(~y) > precedence(x & y) + # ... and with other operators (cfr. other programming languages) + assert precedence(x + y) > precedence(x | y) + assert precedence(x + y) > precedence(x & y) + assert precedence(x*y) > precedence(x | y) + assert precedence(x*y) > precedence(x & y) + assert precedence(~y) > precedence(x*y) + assert precedence(~y) > precedence(x - y) + # double checks + assert precedence(x & y) == PRECEDENCE["And"] + assert precedence(x | y) == PRECEDENCE["Or"] + assert precedence(~y) == PRECEDENCE["Not"] + + +def test_custom_function_precedence_comparison(): + """ + Test cases for custom functions with different precedence values, + specifically handling: + 1. Functions with precedence < PRECEDENCE["Mul"] (50) + 2. Functions with precedence = Func (70) + + Key distinction: + 1. Lower precedence functions (45) need parentheses: -2*(x F y) + 2. Higher precedence functions (70) don't: -2*x F y + """ + class LowPrecedenceF(Function): + precedence = PRECEDENCE["Mul"] - 5 + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + class HighPrecedenceF(Function): + precedence = PRECEDENCE["Func"] + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + def test_low_precedence(): + expr1 = 2 * LowPrecedenceF(x, y) + assert str(expr1) == "2*(x F y)" + + expr2 = -2 * LowPrecedenceF(x, y) + assert str(expr2) == "-2*(x F y)" + + def test_high_precedence(): + expr1 = 2 * HighPrecedenceF(x, y) + assert str(expr1) == "2*x F y" + + expr2 = -2 * HighPrecedenceF(x, y) + assert str(expr2) == "-2*x F y" + + test_low_precedence() + test_high_precedence() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_preview.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..91771ceb0466d6b0fee00570426713d02da14872 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_preview.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.preview import preview + +from io import BytesIO + + +def test_preview(): + x = Symbol('x') + obj = BytesIO() + try: + preview(x, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_unicode_symbol(): + # issue 9107 + a = Symbol('α') + obj = BytesIO() + try: + preview(a, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_latex_construct_in_expr(): + # see PR 9801 + x = Symbol('x') + pw = Piecewise((1, Eq(x, 0)), (0, True)) + obj = BytesIO() + try: + preview(pw, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_pycode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..2c38fe81d830149cdce6b55f15e6e07513fdd146 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_pycode.py @@ -0,0 +1,493 @@ +from sympy import Not +from sympy.codegen import Assignment +from sympy.codegen.ast import none +from sympy.codegen.cfunctions import expm1, log1p +from sympy.codegen.scipy_nodes import cosm1 +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow +from sympy.core.function import Derivative +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh +from sympy.functions.elementary.trigonometric import atan2 +from sympy.logic import And, Or +from sympy.matrices import SparseMatrix, MatrixSymbol, Identity +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.pycode import ( + MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter +) +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter +from sympy.testing.pytest import raises, skip +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray +from sympy.external import import_module +from sympy.functions.special.gamma_functions import loggamma + + + +x, y, z = symbols('x y z') +p = IndexedBase("p") + + +def test_PythonCodePrinter(): + prntr = PythonCodePrinter() + + assert not prntr.module_imports + + assert prntr.doprint(x**y) == 'x**y' + assert prntr.doprint(Mod(x, 2)) == 'x % 2' + assert prntr.doprint(-Mod(x, y)) == '-(x % y)' + assert prntr.doprint(Mod(-x, y)) == '(-x) % y' + assert prntr.doprint(And(x, y)) == 'x and y' + assert prntr.doprint(Or(x, y)) == 'x or y' + assert prntr.doprint(1/(x+y)) == '1/(x + y)' + assert prntr.doprint(Not(x)) == 'not x' + assert not prntr.module_imports + + assert prntr.doprint(pi) == 'math.pi' + assert prntr.module_imports == {'math': {'pi'}} + + assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)' + assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)' + assert prntr.module_imports == {'math': {'pi', 'sqrt'}} + + assert prntr.doprint(acos(x)) == 'math.acos(x)' + assert prntr.doprint(cot(x)) == '(1/math.tan(x))' + assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))' + assert prntr.doprint(asec(x)) == '(math.acos(1/x))' + assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))' + + assert prntr.doprint(Assignment(x, 2)) == 'x = 2' + assert prntr.doprint(Piecewise((1, Eq(x, 0)), + (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)' + assert prntr.doprint(Piecewise((2, Le(x, 0)), + (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\ + ' (3) if (x > 0) else None)' + assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))' + assert prntr.doprint(p[0, 1]) == 'p[0, 1]' + assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)' + + assert prntr.doprint((2,3)) == "(2, 3)" + assert prntr.doprint([2,3]) == "[2, 3]" + + assert prntr.doprint(Min(x, y)) == "min(x, y)" + assert prntr.doprint(Max(x, y)) == "max(x, y)" + + +def test_PythonCodePrinter_standard(): + prntr = PythonCodePrinter() + + assert prntr.standard == 'python3' + + raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'})) + + +def test_CmathPrinter(): + p = CmathPrinter() + + assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)' + assert p.doprint(log(x)) == 'cmath.log(x)' + + assert p.doprint(sin(x)) == 'cmath.sin(x)' + assert p.doprint(cos(x)) == 'cmath.cos(x)' + assert p.doprint(tan(x)) == 'cmath.tan(x)' + + assert p.doprint(asin(x)) == 'cmath.asin(x)' + assert p.doprint(acos(x)) == 'cmath.acos(x)' + assert p.doprint(atan(x)) == 'cmath.atan(x)' + + assert p.doprint(sinh(x)) == 'cmath.sinh(x)' + assert p.doprint(cosh(x)) == 'cmath.cosh(x)' + assert p.doprint(tanh(x)) == 'cmath.tanh(x)' + + assert p.doprint(asinh(x)) == 'cmath.asinh(x)' + assert p.doprint(acosh(x)) == 'cmath.acosh(x)' + assert p.doprint(atanh(x)) == 'cmath.atanh(x)' + + +def test_MpmathPrinter(): + p = MpmathPrinter() + assert p.doprint(sign(x)) == 'mpmath.sign(x)' + assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)' + + assert p.doprint(S.Exp1) == 'mpmath.e' + assert p.doprint(S.Pi) == 'mpmath.pi' + assert p.doprint(S.GoldenRatio) == 'mpmath.phi' + assert p.doprint(S.EulerGamma) == 'mpmath.euler' + assert p.doprint(S.NaN) == 'mpmath.nan' + assert p.doprint(S.Infinity) == 'mpmath.inf' + assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf' + assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)' + + +def test_NumPyPrinter(): + from sympy.core.function import Lambda + from sympy.matrices.expressions.adjoint import Adjoint + from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf) + from sympy.matrices.expressions.funcmatrix import FunctionMatrix + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix) + from sympy.abc import a, b + p = NumPyPrinter() + assert p.doprint(sign(x)) == 'numpy.sign(x)' + A = MatrixSymbol("A", 2, 2) + B = MatrixSymbol("B", 2, 2) + C = MatrixSymbol("C", 1, 5) + D = MatrixSymbol("D", 3, 4) + assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)" + assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)" + assert p.doprint(Identity(3)) == "numpy.eye(3)" + + u = MatrixSymbol('x', 2, 1) + v = MatrixSymbol('y', 2, 1) + assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)' + assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y' + + assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))" + assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))" + assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \ + "numpy.fromfunction(lambda a, b: a + b, (4, 5))" + assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)" + assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)" + assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))" + assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))" + assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)" + assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))" + + # Workaround for numpy negative integer power errors + assert p.doprint(x**-1) == 'x**(-1.0)' + assert p.doprint(x**-2) == 'x**(-2.0)' + + expr = Pow(2, -1, evaluate=False) + assert p.doprint(expr) == "2**(-1.0)" + + assert p.doprint(S.Exp1) == 'numpy.e' + assert p.doprint(S.Pi) == 'numpy.pi' + assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma' + assert p.doprint(S.NaN) == 'numpy.nan' + assert p.doprint(S.Infinity) == 'numpy.inf' + assert p.doprint(S.NegativeInfinity) == '-numpy.inf' + + # Function rewriting operator precedence fix + assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2' + + +def test_issue_18770(): + numpy = import_module('numpy') + if not numpy: + skip("numpy not installed.") + + from sympy.functions.elementary.miscellaneous import (Max, Min) + from sympy.utilities.lambdify import lambdify + + expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1) + func = lambdify(x, expr1, "numpy") + assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all() + assert func(4) == 3 + + expr1 = Max(x**2, x**3) + func = lambdify(x,expr1, "numpy") + assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all() + assert func(4) == 64 + + +def test_SciPyPrinter(): + p = SciPyPrinter() + expr = acos(x) + assert 'numpy' not in p.module_imports + assert p.doprint(expr) == 'numpy.arccos(x)' + assert 'numpy' in p.module_imports + assert not any(m.startswith('scipy') for m in p.module_imports) + smat = SparseMatrix(2, 5, {(0, 1): 3}) + assert p.doprint(smat) == \ + 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))' + assert 'scipy.sparse' in p.module_imports + + assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio' + assert p.doprint(S.Pi) == 'scipy.constants.pi' + assert p.doprint(S.Exp1) == 'numpy.e' + + +def test_pycode_reserved_words(): + s1, s2 = symbols('if else') + raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) + py_str = pycode(s1 + s2) + assert py_str in ('else_ + if_', 'if_ + else_') + + +def test_issue_20762(): + # Make sure pycode removes curly braces from subscripted variables + a_b, b, a_11 = symbols('a_{b} b a_{11}') + expr = a_b*b + assert pycode(expr) == 'a_b*b' + expr = a_11*b + assert pycode(expr) == 'a_11*b' + + +def test_sqrt(): + prntr = PythonCodePrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)' + assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)' + + prntr = PythonCodePrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)' + + prntr = MpmathPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == \ + "x**(mpmath.mpf(1)/mpmath.mpf(2))" + + prntr = NumPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SciPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SymPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_frac(): + from sympy.functions.elementary.integers import frac + + expr = frac(x) + prntr = NumPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == 'x % 1' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.frac(x)' + + prntr = SymPyPrinter() + assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)' + + +class CustomPrintedObject(Expr): + def _numpycode(self, printer): + return 'numpy' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + obj = CustomPrintedObject() + assert NumPyPrinter().doprint(obj) == 'numpy' + assert MpmathPrinter().doprint(obj) == 'mpmath' + + +def test_codegen_ast_nodes(): + assert pycode(none) == 'None' + + +def test_issue_14283(): + prntr = PythonCodePrinter() + + assert prntr.doprint(zoo) == "math.nan" + assert prntr.doprint(-oo) == "float('-inf')" + + +def test_NumPyPrinter_print_seq(): + n = NumPyPrinter() + + assert n._print_seq(range(2)) == '(0, 1,)' + + +def test_issue_16535_16536(): + from sympy.functions.special.gamma_functions import (lowergamma, uppergamma) + + a = symbols('a') + expr1 = lowergamma(a, x) + expr2 = uppergamma(a, x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)' + assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter({'strict': False}) + + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr1) + assert "Not supported" in p_pycode.doprint(expr) + + +def test_Integral(): + from sympy.functions.elementary.exponential import exp + from sympy.integrals.integrals import Integral + + single = Integral(exp(-x), (x, 0, oo)) + double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z)) + indefinite = Integral(x**2, x) + evaluateat = Integral(x**2, (x, 1)) + + prntr = SciPyPrinter() + assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]' + assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + prntr = MpmathPrinter() + assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))' + assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + +def test_fresnel_integrals(): + from sympy.functions.special.error_functions import (fresnelc, fresnels) + + expr1 = fresnelc(x) + expr2 = fresnels(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter() + p_mpmath = MpmathPrinter() + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr) + with raises(NotImplementedError): + p_pycode.doprint(expr) + + assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)' + assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)' + + +def test_beta(): + from sympy.functions.special.beta_functions import beta + + expr = beta(x, y) + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'scipy.special.beta(x, y)' + + prntr = NumPyPrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter({'allow_unknown_functions': True}) + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.beta(x, y)' + +def test_airy(): + from sympy.functions.special.bessel import (airyai, airybi) + + expr1 = airyai(x) + expr2 = airybi(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + +def test_airy_prime(): + from sympy.functions.special.bessel import (airyaiprime, airybiprime) + + expr1 = airyaiprime(x) + expr2 = airybiprime(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + +def test_numerical_accuracy_functions(): + prntr = SciPyPrinter() + assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' + assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' + assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)' + +def test_array_printer(): + A = ArraySymbol('A', (4,4,6,6,6)) + I = IndexedBase('I') + i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5)) + + prntr = NumPyPrinter() + assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + prntr = TensorflowPrinter() + assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + +def test_custom_Derivative_methods(): + class MyPrinter(SciPyPrinter): + def _print_Derivative_cosm1(self, args, seq_orders): + arg, = args + order, = seq_orders + return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order) + + def _print_Derivative_atan2(self, args, seq_orders): + arg1, arg2 = args + ord1, ord2 = seq_orders + return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % ( + self._print(arg1), self._print(arg2), ord1, ord2 + ) + + p = MyPrinter() + cosm1_1 = cosm1(x).diff(x, evaluate=False) + assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)' + atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False) + assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)' + + try: + p.doprint(expm1(x).diff(x, evaluate=False)) + except PrintMethodNotImplementedError as e: + assert '_print_Derivative_expm1' in repr(e) + else: + assert False # should have thrown + + try: + p.doprint(Derivative(cosm1(x**2),x)) + except ValueError as e: + assert '_print_Derivative(' in repr(e) + else: + assert False # should have thrown diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_python.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..fb94a662be90934a672d08b3de44a22e2580d8b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_python.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.series.limits import limit + +from sympy.printing.python import python + +from sympy.testing.pytest import raises, XFAIL + +x, y = symbols('x,y') +th = Symbol('theta') +ph = Symbol('phi') + + +def test_python_basic(): + # Simple numbers/symbols + assert python(-Rational(1)/2) == "e = Rational(-1, 2)" + assert python(-Rational(13)/22) == "e = Rational(-13, 22)" + assert python(oo) == "e = oo" + + # Powers + assert python(x**2) == "x = Symbol(\'x\')\ne = x**2" + assert python(1/x) == "x = Symbol('x')\ne = 1/x" + assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2" + assert python( + x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)" + + # Sums of terms + assert python(x**2 + x + 1) in [ + "x = Symbol('x')\ne = 1 + x + x**2", + "x = Symbol('x')\ne = x + x**2 + 1", + "x = Symbol('x')\ne = x**2 + x + 1", ] + assert python(1 - x) in [ + "x = Symbol('x')\ne = 1 - x", + "x = Symbol('x')\ne = -x + 1"] + assert python(1 - 2*x) in [ + "x = Symbol('x')\ne = 1 - 2*x", + "x = Symbol('x')\ne = -2*x + 1"] + assert python(1 - Rational(3, 2)*y/x) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x", + "y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1", + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"] + + # Multiplication + assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y" + assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y" + assert python((x + 2)/y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)", + "x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)", + "x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y", + "x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"] + assert python((1 + x)*y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ] + + # Check for proper placement of negative sign + assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)" + assert python(1 - Rational(3, 2)*(x + 1)) in [ + "x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)" + ] + + +def test_python_keyword_symbol_name_escaping(): + # Check for escaping of keywords + assert python( + 5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_" + assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) == + "lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__") + assert (python(5*Symbol("for") + Function("for_")(8)) == + "for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)") + + +def test_python_keyword_function_name_escaping(): + assert python( + 5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)" + + +def test_python_relational(): + assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)" + assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y" + assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y" + assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y" + assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y" + assert python(Ne(x/(y + 1), y**2)) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)", + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"] + + +def test_python_functions(): + # Simple + assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)" + assert python(sqrt(2)) == 'e = sqrt(2)' + assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)' + assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)' + assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)' + assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)' + assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)" + assert python( + Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))", + "x = Symbol('x')\ne = Abs(x/(x**2 + 1))"] + + # Univariate/Multivariate functions + f = Function('f') + assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)" + assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)" + assert python(f(x/(y + 1), y)) in [ + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)", + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"] + + # Nesting of square roots + assert python(sqrt((sqrt(x + 1)) + 1)) in [ + "x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))", + "x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"] + + # Nesting of powers + assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [ + "x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)", + "x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"] + + # Function powers + assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2" + + +@XFAIL +def test_python_functions_conjugates(): + a, b = map(Symbol, 'ab') + assert python( conjugate(a + b*I) ) == '_ _\na - I*b' + assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne ' + + +def test_python_derivatives(): + # Simple + f_1 = Derivative(log(x), x, evaluate=False) + assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)" + + f_2 = Derivative(log(x), x, evaluate=False) + x + assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)" + + # Multiple symbols + f_3 = Derivative(log(x) + x**2, x, y, evaluate=False) + assert python(f_3) == \ + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)" + + f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2 + assert python(f_4) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)", + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"] + + +def test_python_integrals(): + # Simple + f_1 = Integral(log(x), x) + assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)" + + f_2 = Integral(x**2, x) + assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)" + + # Double nesting of pow + f_3 = Integral(x**(2**x), x) + assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)" + + # Definite integrals + f_4 = Integral(x**2, (x, 1, 2)) + assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))" + + f_5 = Integral(x**2, (x, Rational(1, 2), 10)) + assert python( + f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))" + + # Nested integrals + f_6 = Integral(x**2*y**2, x, y) + assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)" + + +def test_python_matrix(): + p = python(Matrix([[x**2+1, 1], [y, x+y]])) + s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])" + assert p == s + +def test_python_limits(): + assert python(limit(x, x, oo)) == 'e = oo' + assert python(limit(x**2, x, 0)) == 'e = 0' + +def test_issue_20762(): + # Make sure Python removes curly braces from subscripted variables + a_b = Symbol('a_{b}') + b = Symbol('b') + expr = a_b*b + assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b" + + +def test_settings(): + raises(TypeError, lambda: python(x, method="garbage")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rcode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rcode.py new file mode 100644 index 0000000000000000000000000000000000000000..a83235b0654c6bf24c30846dbf68678d29cd3c80 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rcode.py @@ -0,0 +1,476 @@ +from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + gamma, sign, Max, Min, factorial, beta) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.sets import Range +from sympy.logic import ITE +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises +from sympy.printing.rcode import RCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.rcode import rcode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _rcode(self, printer): + return "abs(%s)" % printer._print(self.args[0]) + + assert rcode(fabs(x)) == "abs(x)" + + +def test_rcode_sqrt(): + assert rcode(sqrt(x)) == "sqrt(x)" + assert rcode(x**0.5) == "sqrt(x)" + assert rcode(sqrt(x)) == "sqrt(x)" + + +def test_rcode_Pow(): + assert rcode(x**3) == "x^3" + assert rcode(x**(y**3)) == "x^(y^3)" + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + assert rcode(x**-1.0) == '1.0/x' + assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)' + + +def test_rcode_Max(): + # Test for gh-11926 + assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_rcode_constants_mathh(): + assert rcode(exp(1)) == "exp(1)" + assert rcode(pi) == "pi" + assert rcode(oo) == "Inf" + assert rcode(-oo) == "-Inf" + + +def test_rcode_constants_other(): + assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio" + assert rcode( + 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan" + assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma" + + +def test_rcode_Rational(): + assert rcode(Rational(3, 7)) == "3.0/7.0" + assert rcode(Rational(18, 9)) == "2" + assert rcode(Rational(3, -7)) == "-3.0/7.0" + assert rcode(Rational(-3, -7)) == "3.0/7.0" + assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x" + + +def test_rcode_Integer(): + assert rcode(Integer(67)) == "67" + assert rcode(Integer(-1)) == "-1" + + +def test_rcode_functions(): + assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)" + assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))" + + +def test_rcode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rcode( + g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n() + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + res=rcode(g(A[i]), assign_to=A[i]) + ref=( + "for (i in 1:n){\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}" + ) + assert res == ref + + +def test_rcode_exceptions(): + assert rcode(ceiling(x)) == "ceiling(x)" + assert rcode(Abs(x)) == "abs(x)" + assert rcode(gamma(x)) == "gamma(x)" + + +def test_rcode_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "myceil", + "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")], + } + assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)" + assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_rcode_boolean(): + assert rcode(True) == "True" + assert rcode(S.true) == "True" + assert rcode(False) == "False" + assert rcode(S.false) == "False" + assert rcode(x & y) == "x & y" + assert rcode(x | y) == "x | y" + assert rcode(~x) == "!x" + assert rcode(x & y & z) == "x & y & z" + assert rcode(x | y | z) == "x | y | z" + assert rcode((x & y) | z) == "z | x & y" + assert rcode((x | y) & z) == "z & (x | y)" + +def test_rcode_Relational(): + assert rcode(Eq(x, y)) == "x == y" + assert rcode(Ne(x, y)) == "x != y" + assert rcode(Le(x, y)) == "x <= y" + assert rcode(Lt(x, y)) == "x < y" + assert rcode(Gt(x, y)) == "x > y" + assert rcode(Ge(x, y)) == "x >= y" + + +def test_rcode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + res=rcode(expr) + ref="ifelse(x < 1,x,x^2)" + assert res == ref + tau=Symbol("tau") + res=rcode(expr,tau) + ref="tau = ifelse(x < 1,x,x^2);" + assert res == ref + + expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True)) + assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))" + res = rcode(expr, assign_to='c') + assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));" + + # Check that Piecewise without a True (default) condition error + #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + #raises(ValueError, lambda: rcode(expr)) + expr = 2*Piecewise((x, x < 1), (x**2, x<2)) + assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))" + + +def test_rcode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + res = rcode(expr) + ref = "(ifelse(x != 0,sin(x)/x,1))" + assert res == ref + + +def test_rcode_Piecewise_deep(): + p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))" + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + p = rcode(expr) + ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1" + assert p == ref + + ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;" + p = rcode(expr, assign_to='c') + assert p == ref + + +def test_rcode_ITE(): + expr = ITE(x < 1, y, z) + p = rcode(expr) + ref="ifelse(x < 1,y,z)" + assert p == ref + + +def test_rcode_settings(): + raises(TypeError, lambda: rcode(sin(x), method="garbage")) + + +def test_rcode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = RCodePrinter() + p._not_r = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[i, j]' + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[i, j, k]' + + assert p._not_r == set() + +def test_rcode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = rcode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_rcode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j], assign_to=y[i]) + assert c == s + + +def test_dummy_loops(): + # the following line could also be + # [Dummy(s, integer=True) for s in 'im'] + # or [Dummy(integer=True) for s in 'im'] + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + expected = ( + 'for (i_%(icount)i in 1:m_%(mcount)i){\n' + ' y[i_%(icount)i] = x[i_%(icount)i];\n' + '}' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + code = rcode(x[i], assign_to=y[i]) + assert code == expected + + +def test_rcode_loops_add(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = x[i] + z[i];\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_addfactor(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_terms(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + + s0 = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + ) + s1 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n' + ' }\n' + ' }\n' + '}\n' + ) + s2 = ( + 'for (i in 1:m){\n' + ' for (k in 1:o){\n' + ' y[i] = a[i, k]*b[k] + y[i];\n' + ' }\n' + '}\n' + ) + s3 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = a[i, j]*b[j] + y[i];\n' + ' }\n' + '}\n' + ) + c = rcode( + b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i]) + + ref={} + ref[0] = s0 + s1 + s2 + s3[:-1] + ref[1] = s0 + s1 + s3 + s2[:-1] + ref[2] = s0 + s2 + s1 + s3[:-1] + ref[3] = s0 + s2 + s3 + s1[:-1] + ref[4] = s0 + s3 + s1 + s2[:-1] + ref[5] = s0 + s3 + s2 + s1[:-1] + + assert (c == ref[0] or + c == ref[1] or + c == ref[2] or + c == ref[3] or + c == ref[4] or + c == ref[5]) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))" + + +def test_Matrix_printing(): + # Test returning a Matrix + mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + p = rcode(mat, A) + assert p == ( + "A[0] = x*y;\n" + "A[1] = ifelse(y > 0,x + 2,y);\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + p = rcode(expr) + assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert rcode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_rcode_sgn(): + + expr = sign(x) * y + assert rcode(expr) == 'y*sign(x)' + p = rcode(expr, 'z') + assert p == 'z = y*sign(x);' + + p = rcode(sign(2 * x + x**2) * x + x**2) + assert p == "x^2 + x*sign(x^2 + 2*x)" + + expr = sign(cos(x)) + p = rcode(expr) + assert p == 'sign(cos(x))' + +def test_rcode_Assignment(): + assert rcode(Assignment(x, y + z)) == 'x = y + z;' + assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_rcode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + sol = rcode(f) + assert sol == ("for(x in seq(from=0, to=9, by=2){\n" + " y *= x;\n" + "}") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(rcode(A[0, 0]) == "A[0]") + assert(rcode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(rcode(F) == "(A - B)[0]") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_repr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..da58883b4fb027ed82db842a0a1ce5f76a49a8bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_repr.py @@ -0,0 +1,382 @@ +from __future__ import annotations +from typing import Any + +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.assumptions.ask import Q +from sympy.core.function import (Function, WildFunction) +from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.delta_functions import Heaviside +from sympy.logic.boolalg import (false, true) +from sympy.matrices.dense import (Matrix, ones) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.combinatorics import Cycle, Permutation +from sympy.core.symbol import Str +from sympy.geometry import Point, Ellipse +from sympy.printing import srepr +from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly +from sympy.polys.polyclasses import DMP +from sympy.polys.agca.extensions import FiniteExtension + +x, y = symbols('x,y') + +# eval(srepr(expr)) == expr has to succeed in the right environment. The right +# environment is the scope of "from sympy import *" for most cases. +ENV: dict[str, Any] = {"Str": Str} +exec("from sympy import *", ENV) + + +def sT(expr, string, import_stmt=None, **kwargs): + """ + sT := sreprTest + + Tests that srepr delivers the expected string and that + the condition eval(srepr(expr))==expr holds. + """ + if import_stmt is None: + ENV2 = ENV + else: + ENV2 = ENV.copy() + exec(import_stmt, ENV2) + + assert srepr(expr, **kwargs) == string + assert eval(string, ENV2) == expr + + +def test_printmethod(): + class R(Abs): + def _sympyrepr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert srepr(R(x)) == "foo(Symbol('x'))" + + +def test_Add(): + sT(x + y, "Add(Symbol('x'), Symbol('y'))") + assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))" + assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))" + assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))" + + +def test_more_than_255_args_issue_10259(): + from sympy.core.add import Add + from sympy.core.mul import Mul + for op in (Add, Mul): + expr = op(*symbols('x:256')) + assert eval(srepr(expr)) == expr + + +def test_Function(): + sT(Function("f")(x), "Function('f')(Symbol('x'))") + # test unapplied Function + sT(Function('f'), "Function('f')") + + sT(sin(x), "sin(Symbol('x'))") + sT(sin, "sin") + + +def test_Heaviside(): + sT(Heaviside(x), "Heaviside(Symbol('x'))") + sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))") + + +def test_Geometry(): + sT(Point(0, 0), "Point2D(Integer(0), Integer(0))") + sT(Ellipse(Point(0, 0), 5, 1), + "Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))") + # TODO more tests + + +def test_Singletons(): + sT(S.Catalan, 'Catalan') + sT(S.ComplexInfinity, 'zoo') + sT(S.EulerGamma, 'EulerGamma') + sT(S.Exp1, 'E') + sT(S.GoldenRatio, 'GoldenRatio') + sT(S.TribonacciConstant, 'TribonacciConstant') + sT(S.Half, 'Rational(1, 2)') + sT(S.ImaginaryUnit, 'I') + sT(S.Infinity, 'oo') + sT(S.NaN, 'nan') + sT(S.NegativeInfinity, '-oo') + sT(S.NegativeOne, 'Integer(-1)') + sT(S.One, 'Integer(1)') + sT(S.Pi, 'pi') + sT(S.Zero, 'Integer(0)') + sT(S.Complexes, 'Complexes') + sT(S.EmptySequence, 'EmptySequence') + sT(S.EmptySet, 'EmptySet') + # sT(S.IdentityFunction, 'Lambda(_x, _x)') + sT(S.Naturals, 'Naturals') + sT(S.Naturals0, 'Naturals0') + sT(S.Rationals, 'Rationals') + sT(S.Reals, 'Reals') + sT(S.UniversalSet, 'UniversalSet') + + +def test_Integer(): + sT(Integer(4), "Integer(4)") + + +def test_list(): + sT([x, Integer(4)], "[Symbol('x'), Integer(4)]") + + +def test_Matrix(): + for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]: + sT(cls([[x**+1, 1], [y, x + y]]), + "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + sT(cls(), "%s([])" % name) + + sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + +def test_empty_Matrix(): + sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])") + sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])") + sT(ones(0, 0), "MutableDenseMatrix([])") + + +def test_Rational(): + sT(Rational(1, 3), "Rational(1, 3)") + sT(Rational(-1, 3), "Rational(-1, 3)") + + +def test_Float(): + sT(Float('1.23', dps=3), "Float('1.22998', precision=13)") + sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', dps=19), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', dps=15), + "Float('0.60038617995049726', precision=53)") + + sT(Float('1.23', precision=13), "Float('1.22998', precision=13)") + sT(Float('1.23456789', precision=33), + "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', precision=66), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', precision=53), + "Float('0.60038617995049726', precision=53)") + + sT(Float('0.60038617995049726', 15), + "Float('0.60038617995049726', precision=53)") + + +def test_Symbol(): + sT(x, "Symbol('x')") + sT(y, "Symbol('y')") + sT(Symbol('x', negative=True), "Symbol('x', negative=True)") + + +def test_Symbol_two_assumptions(): + x = Symbol('x', negative=0, integer=1) + # order could vary + s1 = "Symbol('x', integer=True, negative=False)" + s2 = "Symbol('x', negative=False, integer=True)" + assert srepr(x) in (s1, s2) + assert eval(srepr(x), ENV) == x + + +def test_Symbol_no_special_commutative_treatment(): + sT(Symbol('x'), "Symbol('x')") + sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)") + sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)") + + +def test_Wild(): + sT(Wild('x', even=True), "Wild('x', even=True)") + + +def test_Dummy(): + d = Dummy('d') + sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index)) + + +def test_Dummy_assumption(): + d = Dummy('d', nonzero=True) + assert d == eval(srepr(d)) + s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index) + s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index) + assert srepr(d) in (s1, s2) + + +def test_Dummy_from_Symbol(): + # should not get the full dictionary of assumptions + n = Symbol('n', integer=True) + d = n.as_dummy() + assert srepr(d + ) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index) + + +def test_tuple(): + sT((x,), "(Symbol('x'),)") + sT((x, y), "(Symbol('x'), Symbol('y'))") + + +def test_WildFunction(): + sT(WildFunction('w'), "WildFunction('w')") + + +def test_settins(): + raises(TypeError, lambda: srepr(x, method="garbage")) + + +def test_Mul(): + sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))") + assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))" + assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))" + + +def test_AlgebraicNumber(): + a = AlgebraicNumber(sqrt(2)) + sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])") + a = AlgebraicNumber(root(-2, 3)) + sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])") + + +def test_PolyRing(): + assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)" + assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_FracField(): + assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)" + assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_PolyElement(): + R, x, y = ring("x,y", ZZ) + assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])" + + +def test_FracElement(): + F, x, y = field("x,y", ZZ) + assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])" + + +def test_FractionField(): + assert srepr(QQ.frac_field(x)) == \ + "FractionField(FracField((Symbol('x'),), QQ, lex))" + assert srepr(QQ.frac_field(x, y, order=grlex)) == \ + "FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))" + + +def test_PolynomialRingBase(): + assert srepr(ZZ.old_poly_ring(x)) == \ + "GlobalPolynomialRing(ZZ, Symbol('x'))" + assert srepr(ZZ[x].old_poly_ring(y)) == \ + "GlobalPolynomialRing(ZZ[x], Symbol('y'))" + assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \ + "GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))" + + +def test_DMP(): + p1 = DMP([1, 2], ZZ) + p2 = ZZ.old_poly_ring(x)([1, 2]) + if GROUND_TYPES != 'flint': + assert srepr(p1) == "DMP_Python([1, 2], ZZ)" + assert srepr(p2) == "DMP_Python([1, 2], ZZ)" + else: + assert srepr(p1) == "DUP_Flint([1, 2], ZZ)" + assert srepr(p2) == "DUP_Flint([1, 2], ZZ)" + + +def test_FiniteExtension(): + assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \ + "FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))" + + +def test_ExtensionElement(): + A = FiniteExtension(Poly(x**2 + 1, x)) + if GROUND_TYPES != 'flint': + ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + else: + ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + assert srepr(A.generator) == ans + +def test_BooleanAtom(): + assert srepr(true) == "true" + assert srepr(false) == "false" + + +def test_Integers(): + sT(S.Integers, "Integers") + + +def test_Naturals(): + sT(S.Naturals, "Naturals") + + +def test_Naturals0(): + sT(S.Naturals0, "Naturals0") + + +def test_Reals(): + sT(S.Reals, "Reals") + + +def test_matrix_expressions(): + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))") + sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + + +def test_Cycle(): + # FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2)) + # adds keys to the Cycle dict (GH-17661) + #import_stmt = "from sympy.combinatorics import Cycle" + #sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt) + assert srepr(Cycle(1, 2)) == "Cycle(1, 2)" + + +def test_Permutation(): + import_stmt = "from sympy.combinatorics import Permutation" + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False) + sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True) + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt) + Permutation.print_cyclic = old_print_cyclic + +def test_dict(): + from sympy.abc import x, y, z + d = {} + assert srepr(d) == "{}" + d = {x: y} + assert srepr(d) == "{Symbol('x'): Symbol('y')}" + d = {x: y, y: z} + assert srepr(d) in ( + "{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}", + "{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}", + ) + d = {x: {y: z}} + assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}" + +def test_set(): + from sympy.abc import x, y + s = set() + assert srepr(s) == "set()" + s = {x, y} + assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}") + +def test_Predicate(): + sT(Q.even, "Q.even") + +def test_AppliedPredicate(): + sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rust.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..c81d592faca0d4a31e5a9618a48d67cb19ca94d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_rust.py @@ -0,0 +1,363 @@ +from sympy.core import (S, pi, oo, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, + Eq, Ne, Le, Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sign, floor) +from sympy.logic import ITE +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix + +from sympy.printing.codeprinter import rust_code + +x, y, z = symbols('x,y,z', integer=False, real=True) +k, m, n = symbols('k,m,n', integer=True) + + +def test_Integer(): + assert rust_code(Integer(42)) == "42" + assert rust_code(Integer(-56)) == "-56" + + +def test_Relational(): + assert rust_code(Eq(x, y)) == "x == y" + assert rust_code(Ne(x, y)) == "x != y" + assert rust_code(Le(x, y)) == "x <= y" + assert rust_code(Lt(x, y)) == "x < y" + assert rust_code(Gt(x, y)) == "x > y" + assert rust_code(Ge(x, y)) == "x >= y" + + +def test_Rational(): + assert rust_code(Rational(3, 7)) == "3_f64/7.0" + assert rust_code(Rational(18, 9)) == "2" + assert rust_code(Rational(3, -7)) == "-3_f64/7.0" + assert rust_code(Rational(-3, -7)) == "3_f64/7.0" + assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" + assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" + + +def test_basic_ops(): + assert rust_code(x + y) == "x + y" + assert rust_code(x - y) == "x - y" + assert rust_code(x * y) == "x*y" + assert rust_code(x / y) == "x*y.recip()" + assert rust_code(-x) == "-x" + assert rust_code(2 * x) == "2.0*x" + assert rust_code(y + 2) == "y + 2.0" + assert rust_code(x + n) == "n as f64 + x" + +def test_printmethod(): + class fabs(Abs): + def _rust_code(self, printer): + return "%s.fabs()" % printer._print(self.args[0]) + assert rust_code(fabs(x)) == "x.fabs()" + a = MatrixSymbol("a", 1, 3) + assert rust_code(a[0,0]) == 'a[0]' + + +def test_Functions(): + assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" + assert rust_code(abs(x)) == "x.abs()" + assert rust_code(ceiling(x)) == "x.ceil()" + assert rust_code(floor(x)) == "x.floor()" + + # Automatic rewrite + assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()' + + +def test_Pow(): + assert rust_code(1/x) == "x.recip()" + assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" + assert rust_code(sqrt(x)) == "x.sqrt()" + assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" + + assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" + assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" + + assert rust_code(1/pi) == "PI.recip()" + assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" + assert rust_code(pi**-0.5) == "PI.sqrt().recip()" + + assert rust_code(x**Rational(1, 3)) == "x.cbrt()" + assert rust_code(2**x) == "x.exp2()" + assert rust_code(exp(x)) == "x.exp()" + assert rust_code(x**3) == "x.powi(3)" + assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" + assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" + + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)" + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), + (lambda base, exp: not exp.is_integer, "pow", 1)] + assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' + assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' + + +def test_constants(): + assert rust_code(pi) == "PI" + assert rust_code(oo) == "INFINITY" + assert rust_code(S.Infinity) == "INFINITY" + assert rust_code(-oo) == "NEG_INFINITY" + assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" + assert rust_code(S.NaN) == "NAN" + assert rust_code(exp(1)) == "E" + assert rust_code(S.Exp1) == "E" + + +def test_constants_other(): + assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17) + assert rust_code( + 2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17) + assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17) + + +def test_boolean(): + assert rust_code(True) == "true" + assert rust_code(S.true) == "true" + assert rust_code(False) == "false" + assert rust_code(S.false) == "false" + assert rust_code(k & m) == "k && m" + assert rust_code(k | m) == "k || m" + assert rust_code(~k) == "!k" + assert rust_code(k & m & n) == "k && m && n" + assert rust_code(k | m | n) == "k || m || n" + assert rust_code((k & m) | n) == "n || k && m" + assert rust_code((k | m) & n) == "n && (k || m)" + + +def test_Piecewise(): + expr = Piecewise((x, x < 1), (x + 2, True)) + assert rust_code(expr) == ( + "if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "}") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "};") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else { x + 2.0 };") + expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else if (x < 5.0) {\n" + " x + 1.0\n" + "} else {\n" + " x + 2.0\n" + "};") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: rust_code(expr)) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" + + +def test_sign(): + expr = sign(x) * y + assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64" + assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;" + + expr = sign(x + y) + 42 + assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42" + assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;" + + expr = sign(cos(x)) + assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })" + + +def test_reserved_words(): + + x, y = symbols("x if") + + expr = sin(y) + assert rust_code(expr) == "if_.sin()" + assert rust_code(expr, dereference=[y]) == "(*if_).sin()" + assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" + + with raises(ValueError): + rust_code(expr, error_on_reserved=True) + + +def test_ITE(): + ekpr = ITE(k < 1, m, n) + assert rust_code(ekpr) == ( + "if (k < 1) {\n" + " m\n" + "} else {\n" + " n\n" + "}") + + +def test_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + assert rust_code(x) == "x[j]" + + A = IndexedBase('A')[i, j] + assert rust_code(A) == "A[m*i + j]" + + B = IndexedBase('B')[i, j, k] + assert rust_code(B) == "B[m*o*i + o*j + k]" + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + assert rust_code(x[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i];\n" + "}") + + +def test_loops(): + m, n = symbols('m n', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i] + z[i];\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + +def test_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_loops_addfactor(): + m, n, o, p = symbols('m n o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert code == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_settings(): + raises(TypeError, lambda: rust_code(sin(x), method="garbage")) + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(g(x)) == "2*x" + + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rust_code(g(x)) == ( + "const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17)) + + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert rust_code(g(A[i]), assign_to=A[i]) == ( + "for i in 0..n {\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}") + + +def test_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "ceil", + "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], + } + assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" + assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_matrix(): + assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' + with raises(ValueError): + rust_code(Matrix([[1, 2, 3]])) + + +def test_sparse_matrix(): + # gh-15791 + with raises(NotImplementedError): + rust_code(SparseMatrix([[1, 2, 3]])) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_smtlib.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..48ff3d432d9042bf178f4e52dc46c787059937a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_smtlib.py @@ -0,0 +1,553 @@ +import contextlib +import itertools +import re +import typing +from enum import Enum +from typing import Callable + +import sympy +from sympy import Add, Implies, sqrt +from sympy.core import Mul, Pow +from sympy.core import (S, pi, symbols, Function, Rational, Integer, + Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, exp, sin, cos +from sympy.assumptions.ask import Q +from sympy.printing.smtlib import smtlib_code +from sympy.testing.pytest import raises, Failed + +x, y, z = symbols('x,y,z') + + +class _W(Enum): + DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE) + WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE) + WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE) + + +@contextlib.contextmanager +def _check_warns(expected: typing.Iterable[_W]): + warns: typing.List[str] = [] + log_warn = warns.append + yield log_warn + + errors = [] + for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)): + if not e: + errors += [f"[{i}] Received unexpected warning `{w}`."] + elif not w: + errors += [f"[{i}] Did not receive expected warning `{e.name}`."] + elif not e.value.match(w): + errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."] + + if errors: raise Failed('\n'.join(errors)) + + +def test_Integer(): + with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(Integer(67), log_warn=w) == "67" + assert smtlib_code(Integer(-1), log_warn=w) == "-1" + with _check_warns([]) as w: + assert smtlib_code(Integer(67)) == "67" + assert smtlib_code(Integer(-1)) == "-1" + + +def test_Rational(): + with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w: + assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)" + assert smtlib_code(Rational(18, 9), log_warn=w) == "2" + assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)" + assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)" + assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \ + "(* (/ 3 7) x)" + + +def test_Relational(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + +def test_AppliedBinaryRelation(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w)) + + +def test_AppliedPredicate(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w: + assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))" + assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))" + assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))" + assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))" + assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))" + assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))" + +def test_Function(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))" + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + abs(x), + symbol_table={x: int, y: bool}, + known_types={int: "INTEGER_TYPE"}, + known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"}, + log_warn=w + ) == "(declare-const x INTEGER_TYPE)\n" \ + "(ABSOLUTE_VALUE_OF x)" + + my_fun1 = Function('f1') + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], float]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Real)\n" \ + "(f1 x)" + + with _check_warns([]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], bool]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Bool)\n" \ + "(assert (f1 x))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(declare-fun f1 (Int Bool) Bool)\n" \ + "(assert (= (f1 x z) y))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w: + assert smtlib_code( + Eq(my_fun1(x, z), y), + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Real)\n" \ + "(declare-const y Real)\n" \ + "(declare-const z Real)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + +def test_Pow(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + def g(x): return 2 * x + + # if x=1, y=2, then expr=2.333... + expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b) + + with _check_warns([]) as w: + assert smtlib_code( + [ + Eq(a < 2, c), + Eq(b > a, c), + c & True, + Eq(expr, 2 + Rational(1, 3)) + ], + log_warn=w + ) == '(declare-const a Int)\n' \ + '(declare-const b Real)\n' \ + '(declare-const c Bool)\n' \ + '(assert (= (< a 2) c))\n' \ + '(assert (= (> b a) c))\n' \ + '(assert c)\n' \ + '(assert (= ' \ + '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \ + '(/ 7 3)' \ + '))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False), + log_warn=w + ) == '(declare-const b Real)\n' \ + '(declare-const c Real)\n' \ + '(* -2 c (pow (* b b) -1))' + + +def test_basic_ops(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)" + + # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w: + # todo: implement re-write, currently does '(+ x (* -1 y))' instead + # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)" + + +def test_quantifier_extensions(): + from sympy.logic.boolalg import Boolean + from sympy import Interval, Tuple, sympify + + # start For-all quantifier class example + class ForAll(Boolean): + def _smtlib(self, printer): + bound_symbol_declarations = [ + printer._s_expr(sym.name, [ + printer._known_types[printer.symbol_table[sym]], + Interval(start, end) + ]) for sym, start, end in self.limits + ] + return printer._s_expr('forall', [ + printer._s_expr('', bound_symbol_declarations), + self.function + ]) + + @property + def bound_symbols(self): + return {s for s, _, _ in self.limits} + + @property + def free_symbols(self): + bound_symbol_names = {s.name for s in self.bound_symbols} + return { + s for s in self.function.free_symbols + if s.name not in bound_symbol_names + } + + def __new__(cls, *args): + limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))] + function = [sympify(a) for a in args if isinstance(a, Boolean)] + assert len(limits) + len(function) == len(args) + assert len(function) == 1 + function = function[0] + + if isinstance(function, ForAll): return ForAll.__new__( + ForAll, *(limits + function.limits), function.function + ) + inst = Boolean.__new__(cls) + inst._args = tuple(limits + [function]) + inst.limits = limits + inst.function = function + return inst + + # end For-All Quantifier class example + + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + ForAll((x, -42, +21), Eq(f(x), f(x))), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(assert (forall ( (x Real [-42, 21])) true))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w: + assert smtlib_code( + ForAll( + (x, -42, +21), (y, -100, 3), + Implies(Eq(x, y), Eq(f(x), f(y))) + ), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(declare-fun f (Real) Real)\n' \ + '(assert (' \ + 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \ + '(=> (= x y) (= (f x) (f y)))' \ + '))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + with _check_warns([]) as w: + assert smtlib_code( + ForAll( + (a, 2, 100), ForAll( + (b, 2, 100), + Implies(a < b, sqrt(a) < b) | c + )), + log_warn=w + ) == '(declare-const c Bool)\n' \ + '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \ + '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \ + '))' + + +def test_mix_number_mult_symbols(): + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + 1 / pi, + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + [ + Eq(pi, 3.14, evaluate=False), + 1 / pi, + ], + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(assert (= MY_PI 3.14))\n' \ + '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p', S.GoldenRatio: 'g', + S.Exp1: 'e' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={S.Exp1: 'e'}, + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)' + + +def test_boolean(): + with _check_warns([]) as w: + assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (and x y))' + assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (or x y))' + assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \ + '(assert (not x))' + assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Bool)\n' \ + '(assert (and x y z))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Real)\n' \ + '(assert (or (> z 3) (and x (not y))))' + + f = Function('f') + g = Function('g') + h = Function('h') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + [Gt(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Real)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (> (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Eq(g(f(x)), z), + Eq(h(g(f(x))), x)], + symbol_table={ + f: Callable[[float], int], + g: Callable[[int], bool], + h: Callable[[bool], float] + }, + log_warn=w + ) == '(declare-const x Real)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Real) Int)\n' \ + '(declare-fun g (Int) Bool)\n' \ + '(declare-fun h (Bool) Real)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (= (g (f x)) z))\n' \ + '(assert (= (h (g (f x))) x))' + + +# todo: make smtlib_code support arrays +# def test_containers(): +# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ +# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" +# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" +# assert julia_code([1]) == "Any[1]" +# assert julia_code((1,)) == "(1,)" +# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" +# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))" +# # scalar, matrix, empty matrix and empty list +# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + +def test_smtlib_piecewise(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x, x < 1), + (x ** 2, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) x (pow x 2))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x ** 2, x < 1), + (x ** 3, x < 2), + (x ** 4, x < 3), + (x ** 5, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) (pow x 2) ' \ + '(ite (< x 2) (pow x 3) ' \ + '(ite (< x 3) (pow x 4) ' \ + '(pow x 5))))' + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(AssertionError, lambda: smtlib_code(expr, log_warn=w)) + + +def test_smtlib_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))' + + +# todo: make smtlib_code support arrays / matrices ? +# def test_smtlib_matrix_assign_to(): +# A = Matrix([[1, 2, 3]]) +# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]" +# A = Matrix([[1, 2], [3, 4]]) +# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]" + +# def test_julia_matrix_1x1(): +# A = Matrix([[3]]) +# B = MatrixSymbol('B', 1, 1) +# C = MatrixSymbol('C', 1, 2) +# assert julia_code(A, assign_to=B) == "B = [3]" +# raises(ValueError, lambda: julia_code(A, assign_to=C)) + +# def test_julia_matrix_elements(): +# A = Matrix([[x, 2, x * y]]) +# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" +# A = MatrixSymbol('AA', 1, 3) +# assert julia_code(A) == "AA" +# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \ +# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" +# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + +def test_smtlib_boolean(): + with _check_warns([]) as w: + assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true' + assert smtlib_code(True, log_warn=w) == '(assert true)' + assert smtlib_code(S.true, log_warn=w) == '(assert true)' + assert smtlib_code(S.false, log_warn=w) == '(assert false)' + assert smtlib_code(False, log_warn=w) == '(assert false)' + assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false' + + +def test_not_supported(): + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w)) + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w)) + + +def test_Float(): + assert smtlib_code(0.0) == "0.0" + assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))' + assert smtlib_code(5.3) == "5.3" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_str.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_str.py new file mode 100644 index 0000000000000000000000000000000000000000..675212964b03bf9a9806088225c28d7f70971ca7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_str.py @@ -0,0 +1,1206 @@ +from sympy import MatAdd +from sympy.algebras.quaternion import Quaternion +from sympy.assumptions.ask import Q +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.partitions import Partition +from sympy.concrete.summations import (Sum, summation) +from sympy.core.add import Add +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import UnevaluatedExpr, Expr +from sympy.core.function import (Derivative, Function, Lambda, Subs, WildFunction) +from sympy.core.mul import Mul +from sympy.core import (Catalan, EulerGamma, GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi, zoo) +from sympy.core.parameters import _exp_is_pow +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Rel, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (factorial, factorial2, subfactorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.zeta_functions import zeta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Equivalent, false, true, Xor) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions import Identity +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices import SparseMatrix +from sympy.polys.polytools import factor +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.sets.sets import (Complement, FiniteSet, Interval, SymmetricDifference) +from sympy.stats import (Covariance, Expectation, Probability, Variance) +from sympy.stats.rv import RandomSymbol +from sympy.external import import_module +from sympy.physics.control.lti import TransferFunction, Series, Parallel, \ + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.units import second, joule +from sympy.polys import (Poly, rootof, RootSum, groebner, ring, field, ZZ, QQ, + ZZ_I, QQ_I, lex, grlex) +from sympy.geometry import Point, Circle, Polygon, Ellipse, Triangle +from sympy.tensor import NDimArray +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.printing import sstr, sstrrepr, StrPrinter +from sympy.physics.quantum.trace import Tr + +x, y, z, w, t = symbols('x,y,z,w,t') +d = Dummy('d') + + +def test_printmethod(): + class R(Abs): + def _sympystr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert sstr(R(x)) == "foo(x)" + + class R(Abs): + def _sympystr(self, printer): + return "foo" + assert sstr(R(x)) == "foo" + + +def test_Abs(): + assert str(Abs(x)) == "Abs(x)" + assert str(Abs(Rational(1, 6))) == "1/6" + assert str(Abs(Rational(-1, 6))) == "1/6" + + +def test_Add(): + assert str(x + y) == "x + y" + assert str(x + 1) == "x + 1" + assert str(x + x**2) == "x**2 + x" + assert str(Add(0, 1, evaluate=False)) == "0 + 1" + assert str(Add(0, 0, 1, evaluate=False)) == "0 + 0 + 1" + assert str(1.0*x) == "1.0*x" + assert str(5 + x + y + x*y + x**2 + y**2) == "x**2 + x*y + x + y**2 + y + 5" + assert str(1 + x + x**2/2 + x**3/3) == "x**3/3 + x**2/2 + x + 1" + assert str(2*x - 7*x**2 + 2 + 3*y) == "-7*x**2 + 2*x + 3*y + 2" + assert str(x - y) == "x - y" + assert str(2 - x) == "2 - x" + assert str(x - 2) == "x - 2" + assert str(x - y - z - w) == "-w + x - y - z" + assert str(x - z*y**2*z*w) == "-w*y**2*z**2 + x" + assert str(x - 1*y*x*y) == "-x*y**2 + x" + assert str(sin(x).series(x, 0, 15)) == "x - x**3/6 + x**5/120 - x**7/5040 + x**9/362880 - x**11/39916800 + x**13/6227020800 + O(x**15)" + assert str(Add(Add(-w, x, evaluate=False), Add(-y, z, evaluate=False), evaluate=False)) == "(-w + x) + (-y + z)" + assert str(Add(Add(-x, -y, evaluate=False), -z, evaluate=False)) == "-z + (-x - y)" + assert str(Add(Add(Add(-x, -y, evaluate=False), -z, evaluate=False), -t, evaluate=False)) == "-t + (-z + (-x - y))" + + +def test_Catalan(): + assert str(Catalan) == "Catalan" + + +def test_ComplexInfinity(): + assert str(zoo) == "zoo" + + +def test_Derivative(): + assert str(Derivative(x, y)) == "Derivative(x, y)" + assert str(Derivative(x**2, x, evaluate=False)) == "Derivative(x**2, x)" + assert str(Derivative( + x**2/y, x, y, evaluate=False)) == "Derivative(x**2/y, x, y)" + + +def test_dict(): + assert str({1: 1 + x}) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str({1: x**2, 2: y*x}) in ("{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr({1: x**2, 2: y*x}) == "{1: x**2, 2: x*y}" + + +def test_Dict(): + assert str(Dict({1: 1 + x})) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str(Dict({1: x**2, 2: y*x})) in ( + "{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr(Dict({1: x**2, 2: y*x})) == "{1: x**2, 2: x*y}" + + +def test_Dummy(): + assert str(d) == "_d" + assert str(d + x) == "_d + x" + + +def test_EulerGamma(): + assert str(EulerGamma) == "EulerGamma" + + +def test_Exp(): + assert str(E) == "E" + with _exp_is_pow(True): + assert str(exp(x)) == "E**x" + + +def test_factorial(): + n = Symbol('n', integer=True) + assert str(factorial(-2)) == "zoo" + assert str(factorial(0)) == "1" + assert str(factorial(7)) == "5040" + assert str(factorial(n)) == "factorial(n)" + assert str(factorial(2*n)) == "factorial(2*n)" + assert str(factorial(factorial(n))) == 'factorial(factorial(n))' + assert str(factorial(factorial2(n))) == 'factorial(factorial2(n))' + assert str(factorial2(factorial(n))) == 'factorial2(factorial(n))' + assert str(factorial2(factorial2(n))) == 'factorial2(factorial2(n))' + assert str(subfactorial(3)) == "2" + assert str(subfactorial(n)) == "subfactorial(n)" + assert str(subfactorial(2*n)) == "subfactorial(2*n)" + + +def test_Function(): + f = Function('f') + fx = f(x) + w = WildFunction('w') + assert str(f) == "f" + assert str(fx) == "f(x)" + assert str(w) == "w_" + + +def test_Geometry(): + assert sstr(Point(0, 0)) == 'Point2D(0, 0)' + assert sstr(Circle(Point(0, 0), 3)) == 'Circle(Point2D(0, 0), 3)' + assert sstr(Ellipse(Point(1, 2), 3, 4)) == 'Ellipse(Point2D(1, 2), 3, 4)' + assert sstr(Triangle(Point(1, 1), Point(7, 8), Point(0, -1))) == \ + 'Triangle(Point2D(1, 1), Point2D(7, 8), Point2D(0, -1))' + assert sstr(Polygon(Point(5, 6), Point(-2, -3), Point(0, 0), Point(4, 7))) == \ + 'Polygon(Point2D(5, 6), Point2D(-2, -3), Point2D(0, 0), Point2D(4, 7))' + assert sstr(Triangle(Point(0, 0), Point(1, 0), Point(0, 1)), sympy_integers=True) == \ + 'Triangle(Point2D(S(0), S(0)), Point2D(S(1), S(0)), Point2D(S(0), S(1)))' + assert sstr(Ellipse(Point(1, 2), 3, 4), sympy_integers=True) == \ + 'Ellipse(Point2D(S(1), S(2)), S(3), S(4))' + + +def test_GoldenRatio(): + assert str(GoldenRatio) == "GoldenRatio" + + +def test_Heaviside(): + assert str(Heaviside(x)) == str(Heaviside(x, S.Half)) == "Heaviside(x)" + assert str(Heaviside(x, 1)) == "Heaviside(x, 1)" + + +def test_TribonacciConstant(): + assert str(TribonacciConstant) == "TribonacciConstant" + + +def test_ImaginaryUnit(): + assert str(I) == "I" + + +def test_Infinity(): + assert str(oo) == "oo" + assert str(oo*I) == "oo*I" + + +def test_Integer(): + assert str(Integer(-1)) == "-1" + assert str(Integer(1)) == "1" + assert str(Integer(-3)) == "-3" + assert str(Integer(0)) == "0" + assert str(Integer(25)) == "25" + + +def test_Integral(): + assert str(Integral(sin(x), y)) == "Integral(sin(x), y)" + assert str(Integral(sin(x), (y, 0, 1))) == "Integral(sin(x), (y, 0, 1))" + + +def test_Interval(): + n = (S.NegativeInfinity, 1, 2, S.Infinity) + for i in range(len(n)): + for j in range(i + 1, len(n)): + for l in (True, False): + for r in (True, False): + ival = Interval(n[i], n[j], l, r) + assert S(str(ival)) == ival + + +def test_AccumBounds(): + a = Symbol('a', real=True) + assert str(AccumBounds(0, a)) == "AccumBounds(0, a)" + assert str(AccumBounds(0, 1)) == "AccumBounds(0, 1)" + + +def test_Lambda(): + assert str(Lambda(d, d**2)) == "Lambda(_d, _d**2)" + # issue 2908 + assert str(Lambda((), 1)) == "Lambda((), 1)" + assert str(Lambda((), x)) == "Lambda((), x)" + assert str(Lambda((x, y), x+y)) == "Lambda((x, y), x + y)" + assert str(Lambda(((x, y),), x+y)) == "Lambda(((x, y),), x + y)" + + +def test_Limit(): + assert str(Limit(sin(x)/x, x, y)) == "Limit(sin(x)/x, x, y, dir='+')" + assert str(Limit(1/x, x, 0)) == "Limit(1/x, x, 0, dir='+')" + assert str( + Limit(sin(x)/x, x, y, dir="-")) == "Limit(sin(x)/x, x, y, dir='-')" + + +def test_list(): + assert str([x]) == sstr([x]) == "[x]" + assert str([x**2, x*y + 1]) == sstr([x**2, x*y + 1]) == "[x**2, x*y + 1]" + assert str([x**2, [y + x]]) == sstr([x**2, [y + x]]) == "[x**2, [x + y]]" + + +def test_Matrix_str(): + M = Matrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + M = Matrix([[1]]) + assert str(M) == sstr(M) == "Matrix([[1]])" + M = Matrix([[1, 2]]) + assert str(M) == sstr(M) == "Matrix([[1, 2]])" + M = Matrix() + assert str(M) == sstr(M) == "Matrix(0, 0, [])" + M = Matrix(0, 1, lambda i, j: 0) + assert str(M) == sstr(M) == "Matrix(0, 1, [])" + + +def test_Mul(): + assert str(x/y) == "x/y" + assert str(y/x) == "y/x" + assert str(x/y/z) == "x/(y*z)" + assert str((x + 1)/(y + 2)) == "(x + 1)/(y + 2)" + assert str(2*x/3) == '2*x/3' + assert str(-2*x/3) == '-2*x/3' + assert str(-1.0*x) == '-1.0*x' + assert str(1.0*x) == '1.0*x' + assert str(Mul(0, 1, evaluate=False)) == '0*1' + assert str(Mul(1, 0, evaluate=False)) == '1*0' + assert str(Mul(1, 1, evaluate=False)) == '1*1' + assert str(Mul(1, 1, 1, evaluate=False)) == '1*1*1' + assert str(Mul(1, 2, evaluate=False)) == '1*2' + assert str(Mul(1, S.Half, evaluate=False)) == '1*(1/2)' + assert str(Mul(1, 1, S.Half, evaluate=False)) == '1*1*(1/2)' + assert str(Mul(1, 1, 2, 3, x, evaluate=False)) == '1*1*2*3*x' + assert str(Mul(1, -1, evaluate=False)) == '1*(-1)' + assert str(Mul(-1, 1, evaluate=False)) == '-1*1' + assert str(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == '4*3*2*1*0*y*x' + assert str(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == '4*3*2*(z + 1)*0*y*x' + assert str(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == '(2/3)*(5/7)' + # For issue 14160 + assert str(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + # issue 21537 + assert str(Mul(x, Pow(1/y, -1, evaluate=False), evaluate=False)) == 'x/(1/y)' + + # Issue 24108 + from sympy.core.parameters import evaluate + with evaluate(False): + assert str(Mul(Pow(Integer(2), Integer(-1)), Add(Integer(-1), Mul(Integer(-1), Integer(1))))) == "(-1 - 1*1)/2" + + class CustomClass1(Expr): + is_commutative = True + + class CustomClass2(Expr): + is_commutative = True + cc1 = CustomClass1() + cc2 = CustomClass2() + assert str(Rational(2)*cc1) == '2*CustomClass1()' + assert str(cc1*Rational(2)) == '2*CustomClass1()' + assert str(cc1*Float("1.5")) == '1.5*CustomClass1()' + assert str(cc2*Rational(2)) == '2*CustomClass2()' + assert str(cc2*Rational(2)*cc1) == '2*CustomClass1()*CustomClass2()' + assert str(cc1*Rational(2)*cc2) == '2*CustomClass1()*CustomClass2()' + + +def test_NaN(): + assert str(nan) == "nan" + + +def test_NegativeInfinity(): + assert str(-oo) == "-oo" + +def test_Order(): + assert str(O(x)) == "O(x)" + assert str(O(x**2)) == "O(x**2)" + assert str(O(x*y)) == "O(x*y, x, y)" + assert str(O(x, x)) == "O(x)" + assert str(O(x, (x, 0))) == "O(x)" + assert str(O(x, (x, oo))) == "O(x, (x, oo))" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, (x, oo), (y, oo))) == "O(x, (x, oo), (y, oo))" + + +def test_Permutation_Cycle(): + from sympy.combinatorics import Permutation, Cycle + + # general principle: economically, canonically show all moved elements + # and the size of the permutation. + + for p, s in [ + (Cycle(), + '()'), + (Cycle(2), + '(2)'), + (Cycle(2, 1), + '(1 2)'), + (Cycle(1, 2)(5)(6, 7)(10), + '(1 2)(6 7)(10)'), + (Cycle(3, 4)(1, 2)(3, 4), + '(1 2)(4)'), + ]: + assert sstr(p) == s + + for p, s in [ + (Permutation([]), + 'Permutation([])'), + (Permutation([], size=1), + 'Permutation([0])'), + (Permutation([], size=2), + 'Permutation([0, 1])'), + (Permutation([], size=10), + 'Permutation([], size=10)'), + (Permutation([1, 0, 2]), + 'Permutation([1, 0, 2])'), + (Permutation([1, 0, 2, 3, 4, 5]), + 'Permutation([1, 0], size=6)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + 'Permutation([1, 0], size=10)'), + ]: + assert sstr(p, perm_cyclic=False) == s + + for p, s in [ + (Permutation([]), + '()'), + (Permutation([], size=1), + '(0)'), + (Permutation([], size=2), + '(1)'), + (Permutation([], size=10), + '(9)'), + (Permutation([1, 0, 2]), + '(2)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5]), + '(5)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + '(9)(0 1)'), + (Permutation([0, 1, 3, 2, 4, 5], size=10), + '(9)(2 3)'), + ]: + assert sstr(p) == s + + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert sstr(Permutation([1, 0, 2])) == 'Permutation([1, 0, 2])' + Permutation.print_cyclic = old_print_cyclic + +def test_Pi(): + assert str(pi) == "pi" + + +def test_Poly(): + assert str(Poly(0, x)) == "Poly(0, x, domain='ZZ')" + assert str(Poly(1, x)) == "Poly(1, x, domain='ZZ')" + assert str(Poly(x, x)) == "Poly(x, x, domain='ZZ')" + + assert str(Poly(2*x + 1, x)) == "Poly(2*x + 1, x, domain='ZZ')" + assert str(Poly(2*x - 1, x)) == "Poly(2*x - 1, x, domain='ZZ')" + + assert str(Poly(-1, x)) == "Poly(-1, x, domain='ZZ')" + assert str(Poly(-x, x)) == "Poly(-x, x, domain='ZZ')" + + assert str(Poly(-2*x + 1, x)) == "Poly(-2*x + 1, x, domain='ZZ')" + assert str(Poly(-2*x - 1, x)) == "Poly(-2*x - 1, x, domain='ZZ')" + + assert str(Poly(x - 1, x)) == "Poly(x - 1, x, domain='ZZ')" + assert str(Poly(2*x + x**5, x)) == "Poly(x**5 + 2*x, x, domain='ZZ')" + + assert str(Poly(3**(2*x), 3**x)) == "Poly((3**x)**2, 3**x, domain='ZZ')" + assert str(Poly((x**2)**x)) == "Poly(((x**2)**x), (x**2)**x, domain='ZZ')" + + assert str(Poly((x + y)**3, (x + y), expand=False) + ) == "Poly((x + y)**3, x + y, domain='ZZ')" + assert str(Poly((x - 1)**2, (x - 1), expand=False) + ) == "Poly((x - 1)**2, x - 1, domain='ZZ')" + + assert str( + Poly(x**2 + 1 + y, x)) == "Poly(x**2 + y + 1, x, domain='ZZ[y]')" + assert str( + Poly(x**2 - 1 + y, x)) == "Poly(x**2 + y - 1, x, domain='ZZ[y]')" + + assert str(Poly(x**2 + I*x, x)) == "Poly(x**2 + I*x, x, domain='ZZ_I')" + assert str(Poly(x**2 - I*x, x)) == "Poly(x**2 - I*x, x, domain='ZZ_I')" + + assert str(Poly(-x*y*z + x*y - 1, x, y, z) + ) == "Poly(-x*y*z + x*y - 1, x, y, z, domain='ZZ')" + assert str(Poly(-w*x**21*y**7*z + (1 + w)*z**3 - 2*x*z + 1, x, y, z)) == \ + "Poly(-w*x**21*y**7*z - 2*x*z + (w + 1)*z**3 + 1, x, y, z, domain='ZZ[w]')" + + assert str(Poly(x**2 + 1, x, modulus=2)) == "Poly(x**2 + 1, x, modulus=2)" + assert str(Poly(2*x**2 + 3*x + 4, x, modulus=17)) == "Poly(2*x**2 + 3*x + 4, x, modulus=17)" + + +def test_PolyRing(): + assert str(ring("x", ZZ, lex)[0]) == "Polynomial ring in x over ZZ with lex order" + assert str(ring("x,y", QQ, grlex)[0]) == "Polynomial ring in x, y over QQ with grlex order" + assert str(ring("x,y,z", ZZ["t"], lex)[0]) == "Polynomial ring in x, y, z over ZZ[t] with lex order" + + +def test_FracField(): + assert str(field("x", ZZ, lex)[0]) == "Rational function field in x over ZZ with lex order" + assert str(field("x,y", QQ, grlex)[0]) == "Rational function field in x, y over QQ with grlex order" + assert str(field("x,y,z", ZZ["t"], lex)[0]) == "Rational function field in x, y, z over ZZ[t] with lex order" + + +def test_PolyElement(): + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + Rx_zzi, xz = ring("x", ZZ_I) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + assert str(x**2) == "x**2" + + assert str((u**2 + 3*u*v + 1)*x**2*y + u + 1) == "(u**2 + 3*u*v + 1)*x**2*y + u + 1" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1" + assert str((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == "-(u**2 - 3*u*v + 1)*x**2*y - (u + 1)*x - 1" + + assert str(-(v**2 + v + 1)*x + 3*u*v + 1) == "-(v**2 + v + 1)*x + 3*u*v + 1" + assert str(-(v**2 + v + 1)*x - 3*u*v + 1) == "-(v**2 + v + 1)*x - 3*u*v + 1" + + assert str((1+I)*xz + 2) == "(1 + 1*I)*x + (2 + 0*I)" + + +def test_FracElement(): + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + Rx_zzi, xz = field("x", QQ_I) + i = QQ_I(0, 1) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + + assert str(x/3) == "x/3" + assert str(x/z) == "x/z" + assert str(x*y/z) == "x*y/z" + assert str(x/(z*t)) == "x/(z*t)" + assert str(x*y/(z*t)) == "x*y/(z*t)" + + assert str((x - 1)/y) == "(x - 1)/y" + assert str((x + 1)/y) == "(x + 1)/y" + assert str((-x - 1)/y) == "(-x - 1)/y" + assert str((x + 1)/(y*z)) == "(x + 1)/(y*z)" + assert str(-y/(x + 1)) == "-y/(x + 1)" + assert str(y*z/(x + 1)) == "y*z/(x + 1)" + + assert str(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - 1)" + assert str(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - u*v*t - 1)" + + assert str((1+i)/xz) == "(1 + 1*I)/x" + assert str(((1+i)*xz - i)/xz) == "((1 + 1*I)*x + (0 + -1*I))/x" + + +def test_GaussianInteger(): + assert str(ZZ_I(1, 0)) == "1" + assert str(ZZ_I(-1, 0)) == "-1" + assert str(ZZ_I(0, 1)) == "I" + assert str(ZZ_I(0, -1)) == "-I" + assert str(ZZ_I(0, 2)) == "2*I" + assert str(ZZ_I(0, -2)) == "-2*I" + assert str(ZZ_I(1, 1)) == "1 + I" + assert str(ZZ_I(-1, -1)) == "-1 - I" + assert str(ZZ_I(-1, -2)) == "-1 - 2*I" + + +def test_GaussianRational(): + assert str(QQ_I(1, 0)) == "1" + assert str(QQ_I(QQ(2, 3), 0)) == "2/3" + assert str(QQ_I(0, QQ(2, 3))) == "2*I/3" + assert str(QQ_I(QQ(1, 2), QQ(-2, 3))) == "1/2 - 2*I/3" + + +def test_Pow(): + assert str(x**-1) == "1/x" + assert str(x**-2) == "x**(-2)" + assert str(x**2) == "x**2" + assert str((x + y)**-1) == "1/(x + y)" + assert str((x + y)**-2) == "(x + y)**(-2)" + assert str((x + y)**2) == "(x + y)**2" + assert str((x + y)**(1 + x)) == "(x + y)**(x + 1)" + assert str(x**Rational(1, 3)) == "x**(1/3)" + assert str(1/x**Rational(1, 3)) == "x**(-1/3)" + assert str(sqrt(sqrt(x))) == "x**(1/4)" + # not the same as x**-1 + assert str(x**-1.0) == 'x**(-1.0)' + # see issue #2860 + assert str(Pow(S(2), -1.0, evaluate=False)) == '2**(-1.0)' + + +def test_sqrt(): + assert str(sqrt(x)) == "sqrt(x)" + assert str(sqrt(x**2)) == "sqrt(x**2)" + assert str(1/sqrt(x)) == "1/sqrt(x)" + assert str(1/sqrt(x**2)) == "1/sqrt(x**2)" + assert str(y/sqrt(x)) == "y/sqrt(x)" + assert str(x**0.5) == "x**0.5" + assert str(1/x**0.5) == "x**(-0.5)" + + +def test_Rational(): + n1 = Rational(1, 4) + n2 = Rational(1, 3) + n3 = Rational(2, 4) + n4 = Rational(2, -4) + n5 = Rational(0) + n7 = Rational(3) + n8 = Rational(-3) + assert str(n1*n2) == "1/12" + assert str(n1*n2) == "1/12" + assert str(n3) == "1/2" + assert str(n1*n3) == "1/8" + assert str(n1 + n3) == "3/4" + assert str(n1 + n2) == "7/12" + assert str(n1 + n4) == "-1/4" + assert str(n4*n4) == "1/4" + assert str(n4 + n2) == "-1/6" + assert str(n4 + n5) == "-1/2" + assert str(n4*n5) == "0" + assert str(n3 + n4) == "0" + assert str(n1**n7) == "1/64" + assert str(n2**n7) == "1/27" + assert str(n2**n8) == "27" + assert str(n7**n8) == "1/27" + assert str(Rational("-25")) == "-25" + assert str(Rational("1.25")) == "5/4" + assert str(Rational("-2.6e-2")) == "-13/500" + assert str(S("25/7")) == "25/7" + assert str(S("-123/569")) == "-123/569" + assert str(S("0.1[23]", rational=1)) == "61/495" + assert str(S("5.1[666]", rational=1)) == "31/6" + assert str(S("-5.1[666]", rational=1)) == "-31/6" + assert str(S("0.[9]", rational=1)) == "1" + assert str(S("-0.[9]", rational=1)) == "-1" + + assert str(sqrt(Rational(1, 4))) == "1/2" + assert str(sqrt(Rational(1, 36))) == "1/6" + + assert str((123**25) ** Rational(1, 25)) == "123" + assert str((123**25 + 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "122" + + assert str(sqrt(Rational(81, 36))**3) == "27/8" + assert str(1/sqrt(Rational(81, 36))**3) == "8/27" + + assert str(sqrt(-4)) == str(2*I) + assert str(2**Rational(1, 10**10)) == "2**(1/10000000000)" + + assert sstr(Rational(2, 3), sympy_integers=True) == "S(2)/3" + x = Symbol("x") + assert sstr(x**Rational(2, 3), sympy_integers=True) == "x**(S(2)/3)" + assert sstr(Eq(x, Rational(2, 3)), sympy_integers=True) == "Eq(x, S(2)/3)" + assert sstr(Limit(x, x, Rational(7, 2)), sympy_integers=True) == \ + "Limit(x, x, S(7)/2, dir='+')" + + +def test_Float(): + # NOTE dps is the whole number of decimal digits + assert str(Float('1.23', dps=1 + 2)) == '1.23' + assert str(Float('1.23456789', dps=1 + 8)) == '1.23456789' + assert str( + Float('1.234567890123456789', dps=1 + 18)) == '1.234567890123456789' + assert str(pi.evalf(1 + 2)) == '3.14' + assert str(pi.evalf(1 + 14)) == '3.14159265358979' + assert str(pi.evalf(1 + 64)) == ('3.141592653589793238462643383279' + '5028841971693993751058209749445923') + assert str(pi.round(-1)) == '0.0' + assert str((pi**400 - (pi**400).round(1)).n(2)) == '-0.e+88' + assert sstr(Float("100"), full_prec=False, min=-2, max=2) == '1.0e+2' + assert sstr(Float("100"), full_prec=False, min=-2, max=3) == '100.0' + assert sstr(Float("0.1"), full_prec=False, min=-2, max=3) == '0.1' + assert sstr(Float("0.099"), min=-2, max=3) == '9.90000000000000e-2' + + +def test_Relational(): + assert str(Rel(x, y, "<")) == "x < y" + assert str(Rel(x + y, y, "==")) == "Eq(x + y, y)" + assert str(Rel(x, y, "!=")) == "Ne(x, y)" + assert str(Eq(x, 1) | Eq(x, 2)) == "Eq(x, 1) | Eq(x, 2)" + assert str(Ne(x, 1) & Ne(x, 2)) == "Ne(x, 1) & Ne(x, 2)" + + +def test_AppliedBinaryRelation(): + assert str(Q.eq(x, y)) == "Q.eq(x, y)" + assert str(Q.ne(x, y)) == "Q.ne(x, y)" + + +def test_CRootOf(): + assert str(rootof(x**5 + 2*x - 1, 0)) == "CRootOf(x**5 + 2*x - 1, 0)" + + +def test_RootSum(): + f = x**5 + 2*x - 1 + + assert str( + RootSum(f, Lambda(z, z), auto=False)) == "RootSum(x**5 + 2*x - 1)" + assert str(RootSum(f, Lambda( + z, z**2), auto=False)) == "RootSum(x**5 + 2*x - 1, Lambda(z, z**2))" + + +def test_GroebnerBasis(): + assert str(groebner( + [], x, y)) == "GroebnerBasis([], x, y, domain='ZZ', order='lex')" + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + + assert str(groebner(F, order='grlex')) == \ + "GroebnerBasis([x**2 - x - 3*y + 1, y**2 - 2*x + y - 1], x, y, domain='ZZ', order='grlex')" + assert str(groebner(F, order='lex')) == \ + "GroebnerBasis([2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7], x, y, domain='ZZ', order='lex')" + +def test_set(): + assert sstr(set()) == 'set()' + assert sstr(frozenset()) == 'frozenset()' + + assert sstr({1}) == '{1}' + assert sstr(frozenset([1])) == 'frozenset({1})' + assert sstr({1, 2, 3}) == '{1, 2, 3}' + assert sstr(frozenset([1, 2, 3])) == 'frozenset({1, 2, 3})' + + assert sstr( + {1, x, x**2, x**3, x**4}) == '{1, x, x**2, x**3, x**4}' + assert sstr( + frozenset([1, x, x**2, x**3, x**4])) == 'frozenset({1, x, x**2, x**3, x**4})' + + +def test_SparseMatrix(): + M = SparseMatrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + + +def test_Sum(): + assert str(summation(cos(3*z), (z, x, y))) == "Sum(cos(3*z), (z, x, y))" + assert str(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + "Sum(x*y**2, (x, -2, 2), (y, -5, 5))" + + +def test_Symbol(): + assert str(y) == "y" + assert str(x) == "x" + e = x + assert str(e) == "x" + + +def test_tuple(): + assert str((x,)) == sstr((x,)) == "(x,)" + assert str((x + y, 1 + x)) == sstr((x + y, 1 + x)) == "(x + y, x + 1)" + assert str((x + y, ( + 1 + x, x**2))) == sstr((x + y, (1 + x, x**2))) == "(x + y, (x + 1, x**2))" + + +def test_Series_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Series(tf1, tf2)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Series(tf1, tf2, tf3)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Series(-tf2, tf1)) == \ + "Series(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOSeries_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOSeries(tfm_1, tfm_2)) == \ + "MIMOSeries(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_TransferFunction_str(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert str(tf1) == "TransferFunction(x - 1, x + 1, x)" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert str(tf2) == "TransferFunction(x + 1, 2 - y, x)" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert str(tf3) == "TransferFunction(y, y**2 + 2*y + 3, y)" + + +def test_Parallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Parallel(tf1, tf2)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Parallel(tf1, tf2, tf3)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Parallel(-tf2, tf1)) == \ + "Parallel(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOParallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOParallel(tfm_1, tfm_2)) == \ + "MIMOParallel(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_Feedback_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Feedback(tf1*tf2, tf3)) == \ + "Feedback(Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), " \ + "TransferFunction(t*x**2 - t**w*x + w, t - y, y), -1)" + assert str(Feedback(tf1, TransferFunction(1, 1, y), 1)) == \ + "Feedback(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(1, 1, y), 1)" + + +def test_MIMOFeedback_str(): + tf1 = TransferFunction(x**2 - y**3, y - z, x) + tf2 = TransferFunction(y - x, z + y, x) + tfm_1 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_2 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + assert (str(MIMOFeedback(tfm_1, tfm_2)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x))," \ + " (TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), " \ + "TransferFunction(-x + y, y + z, x)), (TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), -1)") + assert (str(MIMOFeedback(tfm_1, tfm_2, 1)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)), " \ + "(TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)), "\ + "(TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), 1)") + + +def test_TransferFunctionMatrix_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(TransferFunctionMatrix([[tf1], [tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y),), (TransferFunction(x - y, x + y, y),)))" + assert str(TransferFunctionMatrix([[tf1, tf2], [tf3, tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), (TransferFunction(t*x**2 - t**w*x + w, t - y, y), TransferFunction(x - y, x + y, y))))" + + +def test_Quaternion_str_printer(): + q = Quaternion(x, y, z, t) + assert str(q) == "x + y*i + z*j + t*k" + q = Quaternion(x,y,z,x*t) + assert str(q) == "x + y*i + z*j + t*x*k" + q = Quaternion(x,y,z,x+t) + assert str(q) == "x + y*i + z*j + (t + x)*k" + + +def test_Quantity_str(): + assert sstr(second, abbrev=True) == "s" + assert sstr(joule, abbrev=True) == "J" + assert str(second) == "second" + assert str(joule) == "joule" + + +def test_wild_str(): + # Check expressions containing Wild not causing infinite recursion + w = Wild('x') + assert str(w + 1) == 'x_ + 1' + assert str(exp(2**w) + 5) == 'exp(2**x_) + 5' + assert str(3*w + 1) == '3*x_ + 1' + assert str(1/w + 1) == '1 + 1/x_' + assert str(w**2 + 1) == 'x_**2 + 1' + assert str(1/(1 - w)) == '1/(1 - x_)' + + +def test_wild_matchpy(): + from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar + + matchpy = import_module("matchpy") + + if matchpy is None: + return + + wd = WildDot('w_') + wp = WildPlus('w__') + ws = WildStar('w___') + + assert str(wd) == 'w_' + assert str(wp) == 'w__' + assert str(ws) == 'w___' + + assert str(wp/ws + 2**wd) == '2**w_ + w__/w___' + assert str(sin(wd)*cos(wp)*sqrt(ws)) == 'sqrt(w___)*sin(w_)*cos(w__)' + + +def test_zeta(): + assert str(zeta(3)) == "zeta(3)" + + +def test_issue_3101(): + e = x - y + a = str(e) + b = str(e) + assert a == b + + +def test_issue_3103(): + e = -2*sqrt(x) - y/sqrt(x)/2 + assert str(e) not in ["(-2)*x**1/2(-1/2)*x**(-1/2)*y", + "-2*x**1/2(-1/2)*x**(-1/2)*y", "-2*x**1/2-1/2*x**-1/2*w"] + assert str(e) == "-2*sqrt(x) - y/(2*sqrt(x))" + + +def test_issue_4021(): + e = Integral(x, x) + 1 + assert str(e) == 'Integral(x, x) + 1' + + +def test_sstrrepr(): + assert sstr('abc') == 'abc' + assert sstrrepr('abc') == "'abc'" + + e = ['a', 'b', 'c', x] + assert sstr(e) == "[a, b, c, x]" + assert sstrrepr(e) == "['a', 'b', 'c', x]" + + +def test_infinity(): + assert sstr(oo*I) == "oo*I" + + +def test_full_prec(): + assert sstr(S("0.3"), full_prec=True) == "0.300000000000000" + assert sstr(S("0.3"), full_prec="auto") == "0.300000000000000" + assert sstr(S("0.3"), full_prec=False) == "0.3" + assert sstr(S("0.3")*x, full_prec=True) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert sstr(S("0.3")*x, full_prec="auto") in [ + "0.3*x", + "x*0.3" + ] + assert sstr(S("0.3")*x, full_prec=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert sstr(A*B*C**-1) == "A*B*C**(-1)" + assert sstr(C**-1*A*B) == "C**(-1)*A*B" + assert sstr(A*C**-1*B) == "A*C**(-1)*B" + assert sstr(sqrt(A)) == "sqrt(A)" + assert sstr(1/sqrt(A)) == "A**(-1/2)" + + +def test_empty_printer(): + str_printer = StrPrinter() + assert str_printer.emptyPrinter("foo") == "foo" + assert str_printer.emptyPrinter(x*y) == "x*y" + assert str_printer.emptyPrinter(32) == "32" + +def test_decimal_printer(): + dec_printer = StrPrinter(settings={"dps":3}) + f = Function('f') + assert dec_printer.doprint(f(1.329294)) == "f(1.33)" + + +def test_settings(): + raises(TypeError, lambda: sstr(S(4), method="garbage")) + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert str(where(X > 0)) == "Domain: (0 < x1) & (x1 < oo)" + + D = Die('d1', 6) + assert str(where(D > 4)) == "Domain: Eq(d1, 5) | Eq(d1, 6)" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert str(pspace(Tuple(A, B)).domain) == "Domain: (0 <= a) & (0 <= b) & (a < oo) & (b < oo)" + + +def test_FiniteSet(): + assert str(FiniteSet(*range(1, 51))) == ( + '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,' + ' 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,' + ' 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}' + ) + assert str(FiniteSet(*range(1, 6))) == '{1, 2, 3, 4, 5}' + assert str(FiniteSet(*[x*y, x**2])) == '{x**2, x*y}' + assert str(FiniteSet(FiniteSet(FiniteSet(x, y), 5), FiniteSet(x,y), 5) + ) == 'FiniteSet(5, FiniteSet(5, {x, y}), {x, y})' + + +def test_Partition(): + assert str(Partition(FiniteSet(x, y), {z})) == 'Partition({z}, {x, y})' + +def test_UniversalSet(): + assert str(S.UniversalSet) == 'UniversalSet' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ[x, y] + assert sstr(F.convert(x/(x + y))) == sstr(x/(x + y)) + assert sstr(R.convert(x + y)) == sstr(x + y) + + +def test_categories(): + from sympy.categories import (Object, NamedMorphism, + IdentityMorphism, Category) + + A = Object("A") + B = Object("B") + + f = NamedMorphism(A, B, "f") + id_A = IdentityMorphism(A) + + K = Category("K") + + assert str(A) == 'Object("A")' + assert str(f) == 'NamedMorphism(Object("A"), Object("B"), "f")' + assert str(id_A) == 'IdentityMorphism(Object("A"))' + + assert str(K) == 'Category("K")' + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert str(t) == 'Tr(A*B)' + + +def test_issue_6387(): + assert str(factor(-3.0*z + 3)) == '-3.0*(1.0*z - 1.0)' + + +def test_MatMul_MatAdd(): + X, Y = MatrixSymbol("X", 2, 2), MatrixSymbol("Y", 2, 2) + assert str(2*(X + Y)) == "2*X + 2*Y" + + assert str(I*X) == "I*X" + assert str(-I*X) == "-I*X" + assert str((1 + I)*X) == '(1 + I)*X' + assert str(-(1 + I)*X) == '(-1 - I)*X' + assert str(MatAdd(MatAdd(X, Y), MatAdd(X, Y))) == '(X + Y) + (X + Y)' + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert str(MatrixSlice(X, (None, None, None), (None, None, None))) == 'X[:, :]' + assert str(X[x:x + 1, y:y + 1]) == 'X[x:x + 1, y:y + 1]' + assert str(X[x:x + 1:2, y:y + 1:2]) == 'X[x:x + 1:2, y:y + 1:2]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[x:, :y]) == 'X[x:, :y]' + assert str(X[x:y, z:w]) == 'X[x:y, z:w]' + assert str(X[x:y:t, w:t:x]) == 'X[x:y:t, w:t:x]' + assert str(X[x::y, t::w]) == 'X[x::y, t::w]' + assert str(X[:x:y, :t:w]) == 'X[:x:y, :t:w]' + assert str(X[::x, ::y]) == 'X[::x, ::y]' + assert str(MatrixSlice(X, (0, None, None), (0, None, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (None, n, None), (None, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, None), (0, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, 2), (0, n, 2))) == 'X[::2, ::2]' + assert str(X[1:2:3, 4:5:6]) == 'X[1:2:3, 4:5:6]' + assert str(X[1:3:5, 4:6:8]) == 'X[1:3:5, 4:6:8]' + assert str(X[1:10:2]) == 'X[1:10:2, :]' + assert str(Y[:5, 1:9:2]) == 'Y[:5, 1:9:2]' + assert str(Y[:5, 1:10:2]) == 'Y[:5, 1::2]' + assert str(Y[5, :5:2]) == 'Y[5:6, :5:2]' + assert str(X[0:1, 0:1]) == 'X[:1, :1]' + assert str(X[0:1:2, 0:1:2]) == 'X[:1:2, :1:2]' + assert str((Y + Z)[2:, 2:]) == '(Y + Z)[2:, 2:]' + +def test_true_false(): + assert str(true) == repr(true) == sstr(true) == "True" + assert str(false) == repr(false) == sstr(false) == "False" + +def test_Equivalent(): + assert str(Equivalent(y, x)) == "Equivalent(x, y)" + +def test_Xor(): + assert str(Xor(y, x, evaluate=False)) == "x ^ y" + +def test_Complement(): + assert str(Complement(S.Reals, S.Naturals)) == 'Complement(Reals, Naturals)' + +def test_SymmetricDifference(): + assert str(SymmetricDifference(Interval(2, 3), Interval(3, 4),evaluate=False)) == \ + 'SymmetricDifference(Interval(2, 3), Interval(3, 4))' + + +def test_UnevaluatedExpr(): + a, b = symbols("a b") + expr1 = 2*UnevaluatedExpr(a+b) + assert str(expr1) == "2*(a + b)" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(str(A[0, 0]) == "A[0, 0]") + assert(str(3 * A[0, 0]) == "3*A[0, 0]") + + F = C[0, 0].subs(C, A - B) + assert str(F) == "(A - B)[0, 0]" + + +def test_MatrixSymbol_printing(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + assert str(A - A*B - B) == "A - A*B - B" + assert str(A*B - (A+B)) == "-A + A*B - B" + assert str(A**(-1)) == "A**(-1)" + assert str(A**3) == "A**3" + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert str(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + assert str(expr) == 'Lambda(_d, sin(_d)).(X.T*X)' + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + assert str(expr) == 'Lambda(x, 1/x).(n*X)' + + +def test_Subs_printing(): + assert str(Subs(x, (x,), (1,))) == 'Subs(x, x, 1)' + assert str(Subs(x + y, (x, y), (1, 2))) == 'Subs(x + y, (x, y), (1, 2))' + + +def test_issue_15716(): + e = Integral(factorial(x), (x, -oo, oo)) + assert e.as_terms() == ([(e, ((1.0, 0.0), (1,), ()))], [e]) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert str(Identity(4)) == 'I' + assert str(ZeroMatrix(2, 2)) == '0' + assert str(OneMatrix(2, 2)) == '1' + + +def test_issue_14567(): + assert factorial(Sum(-1, (x, 0, 0))) + y # doesn't raise an error + + +def test_issue_21823(): + assert str(Partition([1, 2])) == 'Partition({1, 2})' + assert str(Partition({1, 2})) == 'Partition({1, 2})' + + +def test_issue_22689(): + assert str(Mul(Pow(x,-2, evaluate=False), Pow(3,-1,evaluate=False), evaluate=False)) == "1/(x**2*3)" + + +def test_issue_21119_21460(): + ss = lambda x: str(S(x, evaluate=False)) + assert ss('4/2') == '4/2' + assert ss('4/-2') == '4/(-2)' + assert ss('-4/2') == '-4/2' + assert ss('-4/-2') == '-4/(-2)' + assert ss('-2*3/-1') == '-2*3/(-1)' + assert ss('-2*3/-1/2') == '-2*3/(-1*2)' + assert ss('4/2/1') == '4/(2*1)' + assert ss('-2/-1/2') == '-2/(-1*2)' + assert ss('2*3*4**(-2*3)') == '2*3/4**(2*3)' + assert ss('2*3*1*4**(-2*3)') == '2*3*1/4**(2*3)' + + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == 'x' + assert sstrrepr(Str('x')) == "Str('x')" + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert str(m) == "M" + p = Patch('P', m) + assert str(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert str(rect) == "rect" + b = BaseScalarField(rect, 0) + assert str(b) == "x" + +def test_NDimArray(): + assert sstr(NDimArray(1.0), full_prec=True) == '1.00000000000000' + assert sstr(NDimArray(1.0), full_prec=False) == '1.0' + assert sstr(NDimArray([1.0, 2.0]), full_prec=True) == '[1.00000000000000, 2.00000000000000]' + assert sstr(NDimArray([1.0, 2.0]), full_prec=False) == '[1.0, 2.0]' + assert sstr(NDimArray([], (0,))) == 'ImmutableDenseNDimArray([], (0,))' + assert sstr(NDimArray([], (0, 0))) == 'ImmutableDenseNDimArray([], (0, 0))' + assert sstr(NDimArray([], (0, 1))) == 'ImmutableDenseNDimArray([], (0, 1))' + assert sstr(NDimArray([], (1, 0))) == 'ImmutableDenseNDimArray([], (1, 0))' + +def test_Predicate(): + assert sstr(Q.even) == 'Q.even' + +def test_AppliedPredicate(): + assert sstr(Q.even(x)) == 'Q.even(x)' + +def test_printing_str_array_expressions(): + assert sstr(ArraySymbol("A", (2, 3, 4))) == "A" + assert sstr(ArrayElement("A", (2, 1/(1-x), 0))) == "A[2, 1/(1 - x), 0]" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert sstr(ArrayElement(M*N, [x, 0])) == "(M*N)[x, 0]" + +def test_printing_stats(): + # issue 24132 + x = RandomSymbol("x") + y = RandomSymbol("y") + z1 = Probability(x > 0)*Identity(2) + z2 = Expectation(x)*Identity(2) + z3 = Variance(x)*Identity(2) + z4 = Covariance(x, y) * Identity(2) + + assert str(z1) == "Probability(x > 0)*I" + assert str(z2) == "Expectation(x)*I" + assert str(z3) == "Variance(x)*I" + assert str(z4) == "Covariance(x, y)*I" + assert z1.is_commutative == False + assert z2.is_commutative == False + assert z3.is_commutative == False + assert z4.is_commutative == False + assert z2._eval_is_commutative() == False + assert z3._eval_is_commutative() == False + assert z4._eval_is_commutative() == False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tableform.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tableform.py new file mode 100644 index 0000000000000000000000000000000000000000..05802dd104a12f2f53d137167ecf31d201ff8dfc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tableform.py @@ -0,0 +1,182 @@ +from sympy.core.singleton import S +from sympy.printing.tableform import TableForm +from sympy.printing.latex import latex +from sympy.abc import x +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from textwrap import dedent + + +def test_TableForm(): + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic")) + assert s == ( + ' | 1 2\n' + '-------\n' + '1 | a b\n' + '2 | c d\n' + '3 | e ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic", wipe_zeros=False)) + assert s == dedent('''\ + | 1 2 + ------- + 1 | a b + 2 | c d + 3 | e 0''') + s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]], + headings=("automatic", None))) + assert s == ( + '1 | x**2 b \n' + '2 | c x**2\n' + '3 | e f ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]], + headings=(None, "automatic"))) + assert s == dedent('''\ + 1 2 + --- + a b + c d + e f''') + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]])) + assert s == ( + ' | y1 y2\n' + '---------------\n' + 'Group A | 5 7 \n' + 'Group B | 4 2 \n' + 'Group C | 10 3 ' + ) + raises( + ValueError, + lambda: + TableForm( + [[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="middle") + ) + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="right")) + assert s == dedent('''\ + | y1 y2 + --------------- + Group A | 5 7 + Group B | 4 2 + Group C | 10 3''') + + # other alignment permutations + d = [[1, 100], [100, 1]] + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l') + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + + s = TableForm(d, headings=(('xxx', 'x'), None)) + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + + raises(ValueError, lambda: TableForm(d, alignments='clr')) + + #pad + s = str(TableForm([[None, "-", 2], [1]], pad='?')) + assert s == dedent('''\ + ? - 2 + 1 ? ?''') + + +def test_TableForm_latex(): + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l')) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3)) + assert s == ( + '\\begin{tabular}{l l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & $a$ & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + formats=['(%s)', None], headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & (a) & $x^{3}$ \\\\\n' + '2 & (c) & $\\frac{1}{4}$ \\\\\n' + '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + + def neg_in_paren(x, i, j): + if i % 2: + return ('(%s)' if x < 0 else '%s') % x + else: + pass # use default print + s = latex(TableForm([[-1, 2], [-3, 4]], + formats=[neg_in_paren]*2, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & -1 & 2 \\\\\n' + '2 & (-3) & 4 \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]])) + assert s == ( + '\\begin{tabular}{l l}\n' + '$a$ & $x^{3}$ \\\\\n' + '$c$ & $\\frac{1}{4}$ \\\\\n' + '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tensorflow.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c92cd17b13e1148ebf83f13f66854b983491fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tensorflow.py @@ -0,0 +1,493 @@ +import random +from sympy.core.function import Derivative +from sympy.core.symbol import symbols +from sympy import Piecewise +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.external import import_module +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, log +from sympy.codegen.cfunctions import isnan, isinf +from sympy.matrices import Matrix, MatrixBase, eye, randMatrix +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace +from sympy.printing.tensorflow import tensorflow_code +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import skip +from sympy.testing.pytest import XFAIL + + +tf = tensorflow = import_module("tensorflow") + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if tf is not None: + llo = [list(range(i, i+3)) for i in range(0, 9, 3)] + m3x3 = tf.constant(llo) + m3x3sympy = Matrix(llo) + + +def _compare_tensorflow_matrix(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [randMatrix(v.rows, v.cols) for v in variables] + else: + random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +# Creating a custom inverse test. +# See https://github.com/sympy/sympy/issues/18469 +def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [eye(v.rows, v.cols)*4 for v in variables] + else: + random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +def _compare_tensorflow_matrix_scalar(variables, expr): + f = lambdify(variables, expr, 'tensorflow') + random_matrices = [ + randMatrix(v.rows, v.cols).evalf() / 100 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_scalar( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).evalf().doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_relational( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).doit() + assert r == e + + +def test_tensorflow_printing(): + assert tensorflow_code(eye(3)) == \ + "tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])" + + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert tensorflow_code(expr) == \ + "tensorflow.Variable(" \ + "[[x, tensorflow.math.sin(y)]," \ + " [tensorflow.math.exp(z), -t]])" + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_math(): + if not tf: + skip("TensorFlow not installed") + + expr = Abs(x) + assert tensorflow_code(expr) == "tensorflow.math.abs(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = sign(x) + assert tensorflow_code(expr) == "tensorflow.math.sign(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = ceiling(x) + assert tensorflow_code(expr) == "tensorflow.math.ceil(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert tensorflow_code(expr) == "tensorflow.math.floor(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert tensorflow_code(expr) == "tensorflow.math.exp(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = sqrt(x) + assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert tensorflow_code(expr) == "tensorflow.math.cos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert tensorflow_code(expr) == "tensorflow.math.acos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95)) + + expr = sin(x) + assert tensorflow_code(expr) == "tensorflow.math.sin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert tensorflow_code(expr) == "tensorflow.math.asin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = tan(x) + assert tensorflow_code(expr) == "tensorflow.math.tan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan(x) + assert tensorflow_code(expr) == "tensorflow.math.atan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan2(y, x) + assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)" + _compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random()) + + expr = cosh(x) + assert tensorflow_code(expr) == "tensorflow.math.cosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acosh(x) + assert tensorflow_code(expr) == "tensorflow.math.acosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = sinh(x) + assert tensorflow_code(expr) == "tensorflow.math.sinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = asinh(x) + assert tensorflow_code(expr) == "tensorflow.math.asinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = tanh(x) + assert tensorflow_code(expr) == "tensorflow.math.tanh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = atanh(x) + assert tensorflow_code(expr) == "tensorflow.math.atanh(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.uniform(-.5, .5)) + + expr = erf(x) + assert tensorflow_code(expr) == "tensorflow.math.erf(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + expr = loggamma(x) + assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + +def test_tensorflow_complexes(): + assert tensorflow_code(re(x)) == "tensorflow.math.real(x)" + assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)" + assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)" + + +def test_tensorflow_relational(): + if not tf: + skip("TensorFlow not installed") + + expr = Eq(x, y) + assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ne(x, y) + assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ge(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Gt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Le(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Lt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less(x, y)" + _compare_tensorflow_relational((x, y), expr) + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_matrices(): + if not tf: + skip("TensorFlow not installed") + + expr = M + assert tensorflow_code(expr) == "M" + _compare_tensorflow_matrix((M,), expr) + + expr = M + N + assert tensorflow_code(expr) == "tensorflow.math.add(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M * N + assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = HadamardProduct(M, N) + assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M*N*P*Q + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(M, N), P), Q)" + _compare_tensorflow_matrix((M, N, P, Q), expr) + + expr = M**3 + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Trace(M) + assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Determinant(M) + assert tensorflow_code(expr) == "tensorflow.linalg.det(M)" + _compare_tensorflow_matrix_scalar((M,), expr) + + expr = Inverse(M) + assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)" + _compare_tensorflow_matrix_inverse((M,), expr, use_float=True) + + expr = M.T + assert tensorflow_code(expr, tensorflow_version='1.14') == \ + "tensorflow.linalg.matrix_transpose(M)" + assert tensorflow_code(expr, tensorflow_version='1.13') == \ + "tensorflow.matrix_transpose(M)" + + _compare_tensorflow_matrix((M,), expr) + + +def test_codegen_einsum(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session(graph=graph) + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'tensorflow') + + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + y = session.run(f(ma, mb)) + c = session.run(tf.matmul(ma, mb)) + assert (y == c).all() + + +def test_codegen_extra(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session() + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + mc = tf.constant([[2, 0], [1, 2]]) + md = tf.constant([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ij,kl", ma, mb)) + assert (y == c).all() + + cg = ArrayAdd(M, N) + assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(ma + mb) + assert (y == c).all() + + cg = ArrayAdd(M, N, P) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P)' + f = lambdify((M, N, P), cg, 'tensorflow') + y = session.run(f(ma, mb, mc)) + c = session.run(ma + mb + mc) + assert (y == c).all() + + cg = ArrayAdd(M, N, P, Q) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(' \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'tensorflow') + y = session.run(f(ma, mb, mc, md)) + c = session.run(ma + mb + mc + md) + assert (y == c).all() + + cg = PermuteDims(M, [1, 0]) + assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])' + f = lambdify((M,), cg, 'tensorflow') + y = session.run(f(ma)) + c = session.run(tf.transpose(ma)) + assert (y == c).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert tensorflow_code(cg) == \ + 'tensorflow.transpose(' \ + 'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0])) + assert (y == c).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ab,bc->acb", ma, mb)) + assert (y == c).all() + + +def test_MatrixElement_printing(): + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert tensorflow_code(A[0, 0]) == "A[0, 0]" + assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]" + + F = C[0, 0].subs(C, A - B) + assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]" + + +def test_tensorflow_Derivative(): + expr = Derivative(sin(x), x) + assert tensorflow_code(expr) == \ + "tensorflow.gradients(tensorflow.math.sin(x), x)[0]" + +def test_tensorflow_isnan_isinf(): + if not tf: + skip("TensorFlow not installed") + + # Test for isnan + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isnan(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_nan(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('nan'), 0.0), (float('inf'), 1.0), (float('-inf'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() + + # Test for isinf + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isinf(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_inf(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_theanocode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff40f78cb4de16149cb5e780756b7e32b574b71 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_theanocode.py @@ -0,0 +1,639 @@ +""" +Important note on tests in this module - the Theano printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the theano_code_ and +theano_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +theanologger = logging.getLogger('theano.configdefaults') +theanologger.setLevel(logging.CRITICAL) +theano = import_module('theano') +theanologger.setLevel(logging.WARNING) + + +if theano: + import numpy as np + ts = theano.scalar + tt = theano.tensor + xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.theanocode import (theano_code, dim_handling, + theano_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def theano_code_(expr, **kwargs): + """ Wrapper for theano_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_code(expr, **kwargs) + +def theano_function_(inputs, outputs, **kwargs): + """ Wrapper for theano_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Theano Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + theano.gof.FunctionGraph + """ + outs = list(map(theano_code_, exprs)) + ins = theano.gof.graph.inputs(outs) + ins, outs = theano.gof.graph.clone(ins, outs) + return theano.gof.FunctionGraph(ins, outs) + + +def theano_simplify(fgraph): + """ Simplify a Theano Computation. + + Parameters + ========== + fgraph : theano.gof.FunctionGraph + + Returns + ======= + theano.gof.FunctionGraph + """ + mode = theano.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.optimize(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Theano objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = theano.printing.debugprint(a, file='str') + bstr = theano.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'theano.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Theano + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, theano_code_(x)) + assert theq(yt, theano_code_(y)) + assert theq(zt, theano_code_(z)) + assert theq(Xt, theano_code_(X)) + assert theq(Yt, theano_code_(Y)) + assert theq(Zt, theano_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a theano variable. """ + xx = theano_code_(x) + assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable)) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = theano_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a theano variable. """ + XX = theano_code_(X) + assert isinstance(XX, tt.TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + theano_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = theano_code_(f_t) + assert isinstance(ftt, tt.TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = theano_code_(expr) + assert comp.owner.op == theano.tensor.add + +def test_trig(): + assert theq(theano_code_(sy.sin(x)), tt.sin(xt)) + assert theq(theano_code_(sy.tan(x)), tt.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = theano_code_(expr) + expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = theano_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = theano_code_(expr) + assert isinstance(expr_t.owner.op, tt.Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(theano_code_(expr).owner.op, tt.Elemwise) + + +def test_Rationals(): + assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3)) + assert theq(theano_code_(S.Half), tt.true_div(1, 2)) + +def test_Integers(): + assert theano_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert theano_code_(sy.factorial(n)) + +def test_Derivative(): + simp = lambda expr: theano_simplify(fgraph_of(expr)) + assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(theano.grad(tt.sin(xt), xt))) + + +def test_theano_function_simple(): + """ Test theano_function() with single output. """ + f = theano_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_theano_function_multi(): + """ Test theano_function() with multiple outputs. """ + f = theano_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_theano_function_numpy(): + """ Test theano_function() vs Numpy implementation. """ + f = theano_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_theano_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = theano_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_theano_function_kwargs(): + """ + Test passing additional kwargs from theano_function() to theano.function(). + """ + import numpy as np + f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_theano_function_scalar(): + """ Test the "scalar" argument to theano_function(). """ + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the theano_function attribute is set whether wrapped or not + assert isinstance(f.theano_function, theano.compile.function_module.Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.theano_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_theano_function_bad_kwarg(): + """ + Passing an unknown keyword argument to theano_function() should raise an + exception. + """ + raises(Exception, lambda : theano_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + from theano import Constant + + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = theano_code_(Y, cache=cache) + + s = ts.Scalar('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == theano_code_(X, cache=cache) + # == doesn't work in theano like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7)) + + k = sy.Symbol('k') + theano_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = theano_code_(Block) + solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)), + tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = theano_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = theano_code_(X) + assert isinstance(tX, tt.TensorVariable) + assert tX.owner.op == tt.join_ + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = theano_code_(s1, cache=cache) + + # Test hit with same instance + assert theano_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert theano_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert theano_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.theanocode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = theano_code(s) + assert theano_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = theano_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert theano_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = theano_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + theano_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = theano_code_(expr) + + # Iterate through variables in the Theano computational graph that the + # printed expression depends on + seen = set() + for v in theano.gof.graph.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, theano.gof.graph.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = theano_code_(expr) + assert result.owner.op == tt.switch + + expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = theano_code_(expr) + expected = tt.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = theano_code_(expr) + expected = tt.switch(tt.and_(xt>0,xt<2), 0, \ + tt.switch(tt.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt)) + # assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement + assert theq(theano_code_(x > y), xt > yt) + assert theq(theano_code_(x < y), xt < yt) + assert theq(theano_code_(x >= y), xt >= yt) + assert theq(theano_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + with warns_deprecated_sympy(): + xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'}) + from sympy.functions.elementary.complexes import conjugate + from theano.tensor import as_tensor_variable as atv + from theano.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj())) + assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = theano_function_([],[1+1j]) + assert(tf()==1+1j) + + +def test_Exp1(): + """ + Test that exp(1) prints without error and evaluates close to SymPy's E + """ + # sy.exp(1) should yield same instance of E as sy.E (singleton), but extra + # check added for sanity + e_a = sy.exp(1) + e_b = sy.E + + np.testing.assert_allclose(float(e_a), np.e) + np.testing.assert_allclose(float(e_b), np.e) + + e = theano_code_(e_a) + np.testing.assert_allclose(float(e_a), e.eval()) + + e = theano_code_(e_b) + np.testing.assert_allclose(float(e_b), e.eval()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_torch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce2c6cec75e03264f93b472a79eb073742e3486 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_torch.py @@ -0,0 +1,531 @@ +import random +import math + +from sympy import symbols, Derivative +from sympy.printing.pytorch import torch_code +from sympy import (eye, MatrixSymbol, Matrix) +from sympy.tensor.array import NDimArray +from sympy.tensor.array.expressions.array_expressions import ( + ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, _CodegenArrayAbstract) +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, sqrt +from sympy.testing.pytest import skip +from sympy.external import import_module +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, Trace +from sympy.matrices import randMatrix +from sympy.matrices import Identity, ZeroMatrix, OneMatrix +from sympy import conjugate, I +from sympy import Heaviside, gamma, polygamma + + + +torch = import_module("torch") + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if torch is not None: + llo = [list(range(i, i + 3)) for i in range(0, 9, 3)] + m3x3 = torch.tensor(llo, dtype=torch.float64) + m3x3sympy = Matrix(llo) + + +def _compare_torch_matrix(variables, expr): + f = lambdify(variables, expr, 'torch') + + random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables] + random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices] + r = f(*random_variables) + e = expr.subs(dict(zip(variables, random_matrices))).doit() + + if isinstance(e, _CodegenArrayAbstract): + e = e.doit() + + if hasattr(e, 'is_number') and e.is_number: + if isinstance(r, torch.Tensor) and r.dim() == 0: + r = r.item() + e = float(e) + assert abs(r - e) < 1e-6 + return + + if e.is_Matrix or isinstance(e, NDimArray): + e = torch.tensor(e.tolist(), dtype=torch.float64) + assert torch.allclose(r, e, atol=1e-6) + else: + raise TypeError(f"Cannot compare {type(r)} with {type(e)}") + + +def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + if isinstance(r, torch.Tensor): + r = r.item() + e = expr.subs(dict(zip(variables, rvs))).doit() + assert abs(r - e) < 1e-6 + + +def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + e = bool(expr.subs(dict(zip(variables, rvs))).doit()) + assert r.item() == e + + +def test_torch_math(): + if not torch: + skip("PyTorch not installed") + + expr = Abs(x) + assert torch_code(expr) == "torch.abs(x)" + f = lambdify(x, expr, 'torch') + ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64) + y_abs = f(ma) + c = torch.abs(ma) + assert torch.all(y_abs == c) + + expr = sign(x) + assert torch_code(expr) == "torch.sign(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10)) + + expr = ceiling(x) + assert torch_code(expr) == "torch.ceil(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert torch_code(expr) == "torch.floor(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert torch_code(expr) == "torch.exp(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = sqrt(x) + assert torch_code(expr) == "torch.sqrt(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert torch_code(expr) == "torch.pow(x, 4)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert torch_code(expr) == "torch.cos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert torch_code(expr) == "torch.acos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = sin(x) + assert torch_code(expr) == "torch.sin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert torch_code(expr) == "torch.asin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = tan(x) + assert torch_code(expr) == "torch.tan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5)) + + expr = atan(x) + assert torch_code(expr) == "torch.atan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = atan2(y, x) + assert torch_code(expr) == "torch.atan2(y, x)" + _compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5)) + + expr = cosh(x) + assert torch_code(expr) == "torch.cosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = acosh(x) + assert torch_code(expr) == "torch.acosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5)) + + expr = sinh(x) + assert torch_code(expr) == "torch.sinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = asinh(x) + assert torch_code(expr) == "torch.asinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = tanh(x) + assert torch_code(expr) == "torch.tanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = atanh(x) + assert torch_code(expr) == "torch.atanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9)) + + expr = erf(x) + assert torch_code(expr) == "torch.erf(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = loggamma(x) + assert torch_code(expr) == "torch.lgamma(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5)) + + +def test_torch_complexes(): + assert torch_code(re(x)) == "torch.real(x)" + assert torch_code(im(x)) == "torch.imag(x)" + assert torch_code(arg(x)) == "torch.angle(x)" + + +def test_torch_relational(): + if not torch: + skip("PyTorch not installed") + + expr = Eq(x, y) + assert torch_code(expr) == "torch.eq(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ne(x, y) + assert torch_code(expr) == "torch.ne(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ge(x, y) + assert torch_code(expr) == "torch.ge(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Gt(x, y) + assert torch_code(expr) == "torch.gt(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Le(x, y) + assert torch_code(expr) == "torch.le(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Lt(x, y) + assert torch_code(expr) == "torch.lt(x, y)" + _compare_torch_relational((x, y), expr) + + +def test_torch_matrix(): + if torch is None: + skip("PyTorch not installed") + + expr = M + assert torch_code(expr) == "M" + f = lambdify((M,), expr, "torch") + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + assert torch.allclose(f(eye_tensor), eye_tensor) + + expr = M * N + assert torch_code(expr) == "torch.matmul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = M ** 3 + assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)" + _compare_torch_matrix((M,), expr) + + expr = M * N * P * Q + assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)" + _compare_torch_matrix((M, N, P, Q), expr) + + expr = Trace(M) + assert torch_code(expr) == "torch.trace(M)" + _compare_torch_matrix((M,), expr) + + expr = Determinant(M) + assert torch_code(expr) == "torch.det(M)" + _compare_torch_matrix((M,), expr) + + expr = HadamardProduct(M, N) + assert torch_code(expr) == "torch.mul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = Inverse(M) + assert torch_code(expr) == "torch.linalg.inv(M)" + + # For inverse, use a matrix that's guaranteed to be invertible + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + f = lambdify((M,), expr, "torch") + result = f(eye_tensor) + expected = torch.linalg.inv(eye_tensor) + assert torch.allclose(result, expected) + + +def test_torch_array_operations(): + if not torch: + skip("PyTorch not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + + ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64) + mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64) + mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64) + md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64) + + cg = ArrayTensorProduct(M, N) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ij,kl", ma, mb) + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N) + assert torch_code(cg) == 'torch.add(M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = ma + mb + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P) + assert torch_code(cg) == 'torch.add(torch.add(M, N), P)' + f = lambdify((M, N, P), cg, 'torch') + y = f(ma, mb, mc) + c = ma + mb + mc + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P, Q) + assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'torch') + y = f(ma, mb, mc, md) + c = ma + mb + mc + md + assert torch.allclose(y, c) + + cg = PermuteDims(M, [1, 0]) + assert torch_code(cg) == 'M.permute(1, 0)' + f = lambdify((M,), cg, 'torch') + y = f(ma) + c = ma.T + assert torch.allclose(y, c) + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0) + assert torch.allclose(y, c) + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,bc->acb", ma, mb) + assert torch.allclose(y, c) + + +def test_torch_derivative(): + """Test derivative handling.""" + expr = Derivative(sin(x), x) + assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]' + + +def test_torch_printing_dtype(): + if not torch: + skip("PyTorch not installed") + + # matrix printing with default dtype + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert "dtype=torch.float64" in torch_code(expr) + + # explicit dtype + assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32") + + # with requires_grad + result = torch_code(expr, requires_grad=True) + assert "requires_grad=True" in result + assert "dtype=torch.float64" in result + + # both + result = torch_code(expr, requires_grad=True, dtype="torch.float32") + assert "requires_grad=True" in result + assert "dtype=torch.float32" in result + + +def test_requires_grad(): + if not torch: + skip("PyTorch not installed") + + expr = sin(x) + cos(y) + f = lambdify([x, y], expr, 'torch') + + # make sure the gradients flow + x_val = torch.tensor(1.0, requires_grad=True) + y_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val, y_val) + assert result.requires_grad + result.backward() + + # x_val.grad should be cos(x_val) which is close to cos(1.0) + assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6 + + # y_val.grad should be -sin(y_val) which is close to -sin(2.0) + assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6 + + +def test_torch_multi_variable_derivatives(): + if not torch: + skip("PyTorch not installed") + + x, y, z = symbols("x y z") + + expr = Derivative(sin(x), x) + assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]" + + expr = Derivative(sin(x), (x, 2)) + assert torch_code( + expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + + expr = Derivative(sin(x * y), x, y) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + expr = Derivative(sin(x), x, x) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + assert result == expected + + expr = Derivative(sin(x * y * z), x, (y, 2), z) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + +def test_torch_derivative_lambdify(): + if not torch: + skip("PyTorch not installed") + + x = symbols("x") + y = symbols("y") + + expr = Derivative(x ** 2, x) + f = lambdify(x, expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(4.0)) + + expr = Derivative(sin(x), (x, 2)) + f = lambdify(x, expr, 'torch') + # Second derivative of sin(x) at x=0 is 0, not -1 + x_val = torch.tensor(0.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(0.0), atol=1e-5) + + x_val = torch.tensor(math.pi / 2, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5) + + expr = Derivative(x * y ** 2, x, y) + f = lambdify((x, y), expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + y_val = torch.tensor(3.0, requires_grad=True) + result = f(x_val, y_val) + assert torch.isclose(result, torch.tensor(6.0)) + + +def test_torch_special_matrices(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + assert torch_code(expr) == "torch.eye(3)" + + n = symbols("n") + expr = Identity(n) + assert torch_code(expr) == "torch.eye(n, n)" + + expr = ZeroMatrix(2, 3) + assert torch_code(expr) == "torch.zeros((2, 3))" + + m, n = symbols("m n") + expr = ZeroMatrix(m, n) + assert torch_code(expr) == "torch.zeros((m, n))" + + expr = OneMatrix(2, 3) + assert torch_code(expr) == "torch.ones((2, 3))" + + expr = OneMatrix(m, n) + assert torch_code(expr) == "torch.ones((m, n))" + + +def test_torch_special_matrices_lambdify(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.eye(3) + assert torch.allclose(result, expected) + + expr = ZeroMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.zeros((2, 3)) + assert torch.allclose(result, expected) + + expr = OneMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.ones((2, 3)) + assert torch.allclose(result, expected) + + +def test_torch_complex_operations(): + if not torch: + skip("PyTorch not installed") + + expr = conjugate(x) + assert torch_code(expr) == "torch.conj(x)" + + # SymPy distributes conjugate over addition and applies specific rules for each term + expr = conjugate(sin(x) + I * cos(y)) + assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))" + + expr = I + assert torch_code(expr) == "1j" + + expr = 2 * I + x + assert torch_code(expr) == "x + 2*1j" + + expr = exp(I * x) + assert torch_code(expr) == "torch.exp(1j*x)" + + +def test_torch_special_functions(): + if not torch: + skip("PyTorch not installed") + + expr = Heaviside(x) + assert torch_code(expr) == "torch.heaviside(x, 1/2)" + + expr = Heaviside(x, 0) + assert torch_code(expr) == "torch.heaviside(x, 0)" + + expr = gamma(x) + assert torch_code(expr) == "torch.special.gamma(x)" + + expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway + assert torch_code(expr) == "torch.special.digamma(x)" + + expr = gamma(sin(x)) + assert torch_code(expr) == "torch.special.gamma(torch.sin(x))" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tree.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..cf116d0cac5d38f225815fcd2d4ac90cd0dd96d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tests/test_tree.py @@ -0,0 +1,196 @@ +from sympy.printing.tree import tree +from sympy.testing.pytest import XFAIL + + +# Remove this flag after making _assumptions cache deterministic. +@XFAIL +def test_print_tree_MatAdd(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = [ + 'MatAdd: A + B\n', + 'algebraic: False\n', + 'commutative: False\n', + 'complex: False\n', + 'composite: False\n', + 'even: False\n', + 'extended_negative: False\n', + 'extended_nonnegative: False\n', + 'extended_nonpositive: False\n', + 'extended_nonzero: False\n', + 'extended_positive: False\n', + 'extended_real: False\n', + 'imaginary: False\n', + 'integer: False\n', + 'irrational: False\n', + 'negative: False\n', + 'noninteger: False\n', + 'nonnegative: False\n', + 'nonpositive: False\n', + 'nonzero: False\n', + 'odd: False\n', + 'positive: False\n', + 'prime: False\n', + 'rational: False\n', + 'real: False\n', + 'transcendental: False\n', + 'zero: False\n', + '+-MatrixSymbol: A\n', + '| algebraic: False\n', + '| commutative: False\n', + '| complex: False\n', + '| composite: False\n', + '| even: False\n', + '| extended_negative: False\n', + '| extended_nonnegative: False\n', + '| extended_nonpositive: False\n', + '| extended_nonzero: False\n', + '| extended_positive: False\n', + '| extended_real: False\n', + '| imaginary: False\n', + '| integer: False\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: False\n', + '| nonpositive: False\n', + '| nonzero: False\n', + '| odd: False\n', + '| positive: False\n', + '| prime: False\n', + '| rational: False\n', + '| real: False\n', + '| transcendental: False\n', + '| zero: False\n', + '| +-Symbol: A\n', + '| | commutative: True\n', + '| +-Integer: 3\n', + '| | algebraic: True\n', + '| | commutative: True\n', + '| | complex: True\n', + '| | extended_negative: False\n', + '| | extended_nonnegative: True\n', + '| | extended_real: True\n', + '| | finite: True\n', + '| | hermitian: True\n', + '| | imaginary: False\n', + '| | infinite: False\n', + '| | integer: True\n', + '| | irrational: False\n', + '| | negative: False\n', + '| | noninteger: False\n', + '| | nonnegative: True\n', + '| | rational: True\n', + '| | real: True\n', + '| | transcendental: False\n', + '| +-Integer: 3\n', + '| algebraic: True\n', + '| commutative: True\n', + '| complex: True\n', + '| extended_negative: False\n', + '| extended_nonnegative: True\n', + '| extended_real: True\n', + '| finite: True\n', + '| hermitian: True\n', + '| imaginary: False\n', + '| infinite: False\n', + '| integer: True\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: True\n', + '| rational: True\n', + '| real: True\n', + '| transcendental: False\n', + '+-MatrixSymbol: B\n', + ' algebraic: False\n', + ' commutative: False\n', + ' complex: False\n', + ' composite: False\n', + ' even: False\n', + ' extended_negative: False\n', + ' extended_nonnegative: False\n', + ' extended_nonpositive: False\n', + ' extended_nonzero: False\n', + ' extended_positive: False\n', + ' extended_real: False\n', + ' imaginary: False\n', + ' integer: False\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: False\n', + ' nonpositive: False\n', + ' nonzero: False\n', + ' odd: False\n', + ' positive: False\n', + ' prime: False\n', + ' rational: False\n', + ' real: False\n', + ' transcendental: False\n', + ' zero: False\n', + ' +-Symbol: B\n', + ' | commutative: True\n', + ' +-Integer: 3\n', + ' | algebraic: True\n', + ' | commutative: True\n', + ' | complex: True\n', + ' | extended_negative: False\n', + ' | extended_nonnegative: True\n', + ' | extended_real: True\n', + ' | finite: True\n', + ' | hermitian: True\n', + ' | imaginary: False\n', + ' | infinite: False\n', + ' | integer: True\n', + ' | irrational: False\n', + ' | negative: False\n', + ' | noninteger: False\n', + ' | nonnegative: True\n', + ' | rational: True\n', + ' | real: True\n', + ' | transcendental: False\n', + ' +-Integer: 3\n', + ' algebraic: True\n', + ' commutative: True\n', + ' complex: True\n', + ' extended_negative: False\n', + ' extended_nonnegative: True\n', + ' extended_real: True\n', + ' finite: True\n', + ' hermitian: True\n', + ' imaginary: False\n', + ' infinite: False\n', + ' integer: True\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: True\n', + ' rational: True\n', + ' real: True\n', + ' transcendental: False\n' + ] + + assert tree(A + B) == "".join(test_str) + + +def test_print_tree_MatAdd_noassumptions(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = \ +"""MatAdd: A + B ++-MatrixSymbol: A +| +-Str: A +| +-Integer: 3 +| +-Integer: 3 ++-MatrixSymbol: B + +-Str: B + +-Integer: 3 + +-Integer: 3 +""" + + assert tree(A + B, assumptions=False) == test_str diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tree.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tree.py new file mode 100644 index 0000000000000000000000000000000000000000..82dac013419fbe93f63dcf5b90b3a529d72a32bc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/printing/tree.py @@ -0,0 +1,175 @@ +def pprint_nodes(subtrees): + """ + Prettyprints systems of nodes. + + Examples + ======== + + >>> from sympy.printing.tree import pprint_nodes + >>> print(pprint_nodes(["a", "b1\\nb2", "c"])) + +-a + +-b1 + | b2 + +-c + + """ + def indent(s, type=1): + x = s.split("\n") + r = "+-%s\n" % x[0] + for a in x[1:]: + if a == "": + continue + if type == 1: + r += "| %s\n" % a + else: + r += " %s\n" % a + return r + if not subtrees: + return "" + f = "" + for a in subtrees[:-1]: + f += indent(a) + f += indent(subtrees[-1], 2) + return f + + +def print_node(node, assumptions=True): + """ + Returns information about the "node". + + This includes class name, string representation and assumptions. + + Parameters + ========== + + assumptions : bool, optional + See the ``assumptions`` keyword in ``tree`` + """ + s = "%s: %s\n" % (node.__class__.__name__, str(node)) + + if assumptions: + d = node._assumptions + else: + d = None + + if d: + for a in sorted(d): + v = d[a] + if v is None: + continue + s += "%s: %s\n" % (a, v) + + return s + + +def tree(node, assumptions=True): + """ + Returns a tree representation of "node" as a string. + + It uses print_node() together with pprint_nodes() on node.args recursively. + + Parameters + ========== + + assumptions : bool, optional + The flag to decide whether to print out all the assumption data + (such as ``is_integer`, ``is_real``) associated with the + expression or not. + + Enabling the flag makes the result verbose, and the printed + result may not be deterministic because of the randomness used + in backtracing the assumptions. + + See Also + ======== + + print_tree + + """ + subtrees = [] + for arg in node.args: + subtrees.append(tree(arg, assumptions=assumptions)) + s = print_node(node, assumptions=assumptions) + pprint_nodes(subtrees) + return s + + +def print_tree(node, assumptions=True): + """ + Prints a tree representation of "node". + + Parameters + ========== + + assumptions : bool, optional + The flag to decide whether to print out all the assumption data + (such as ``is_integer`, ``is_real``) associated with the + expression or not. + + Enabling the flag makes the result verbose, and the printed + result may not be deterministic because of the randomness used + in backtracing the assumptions. + + Examples + ======== + + >>> from sympy.printing import print_tree + >>> from sympy import Symbol + >>> x = Symbol('x', odd=True) + >>> y = Symbol('y', even=True) + + Printing with full assumptions information: + + >>> print_tree(y**x) + Pow: y**x + +-Symbol: y + | algebraic: True + | commutative: True + | complex: True + | even: True + | extended_real: True + | finite: True + | hermitian: True + | imaginary: False + | infinite: False + | integer: True + | irrational: False + | noninteger: False + | odd: False + | rational: True + | real: True + | transcendental: False + +-Symbol: x + algebraic: True + commutative: True + complex: True + even: False + extended_nonzero: True + extended_real: True + finite: True + hermitian: True + imaginary: False + infinite: False + integer: True + irrational: False + noninteger: False + nonzero: True + odd: True + rational: True + real: True + transcendental: False + zero: False + + Hiding the assumptions: + + >>> print_tree(y**x, assumptions=False) + Pow: y**x + +-Symbol: y + +-Symbol: x + + See Also + ======== + + tree + + """ + print(tree(node, assumptions=assumptions)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32c122012a4acfdeed9b6f145b36bd3e3131d8fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/acceleration.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/acceleration.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc35b1f45888d2f736a05fafa6829df515ba22d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/acceleration.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/approximants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/approximants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58ab3710a35a53d226040aeb840b1396700a9141 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/approximants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/aseries.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/aseries.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e6859fb80f1774a3235b371cac1d2625c3bcc22 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/aseries.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/formal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/formal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac7acd92309d3025bd335825f7ec0673432602d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/formal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/fourier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/fourier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18cb639383ffe40faab3bb7b7762d0a84562d2a1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/fourier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/gruntz.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/gruntz.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dd7386e8560a5b1e13de2236d8834907efd0fe7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/gruntz.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/kauers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/kauers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..723ad5d37848bac0011ad139e8504ea1efbe76e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/kauers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limits.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fbda671dcf6ae075401e5a31d044d5323a00fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limits.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limitseq.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limitseq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f23e4ae9b75efacd636348148ff59b27675a37c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/limitseq.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/order.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/order.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e577ed3d09c9762e7c5c58759883c928f89508b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/order.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/residues.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/residues.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c841174c988fe0b6b1b0991dadf1b8a085c544a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/residues.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/sequences.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/sequences.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27205e75130aedf56c9711417d88474fd19c0fd1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/sequences.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b276750b42d9c8343951f3399c8d09d1ae05a80d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series_class.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series_class.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc0951185b1e6185f1525066a2ad578687f08b36 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/__pycache__/series_class.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cdbbbd04942f88d52f60adb211d4cf069f211a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_limit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc28328848dad4b3ea433537971f5785253afe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_limit.py @@ -0,0 +1,9 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.series.limits import limit + +x = Symbol('x') + + +def timeit_limit_1x(): + limit(1/x, x, oo) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_order.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_order.py new file mode 100644 index 0000000000000000000000000000000000000000..1c85fa173dfc2a478792de8ab816c23ba9d408ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/benchmarks/bench_order.py @@ -0,0 +1,10 @@ +from sympy.core.add import Add +from sympy.core.symbol import Symbol +from sympy.series.order import O + +x = Symbol('x') +l = [x**i for i in range(1000)] +l.append(O(x**1001)) + +def timeit_order_1x(): + Add(*l) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0bad333bcc8708a2ff51768221b0e92693749d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2ec9bcda76d70f9ba5b8da0c138313ae3104691 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa672a3f0e99169e01786e003c1669a3a4b5b95 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1513aea116d4637f60e0d7b6334830082d873e6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..989237a2553049e788840d476cfb666e6c7ffa41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9db788b89cb83e4be54fa36ab0fa7b844412c6c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6edc95120e82c052de610d5a7596a59fee08b37a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21866edb66a1b50e17321542a788d858468d060 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6756cafc563b403e826f9092ee0b212977ca046f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdaf5fd2c2c4bd058c1d48e1375f30eb5d158d5e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5aeb6c552db8e65d47650b5d50f0c196ba8adc3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_order.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_order.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ff589d0d0ece4cce097dbadaf566153758f19f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_order.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41625d872e23f20aa38351e42d402cf9bf3f6366 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a7d38547e55911dfdcbe2adc47a3b87cdc9ae9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_series.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_series.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92c7487b2e46fff225851f5154afa605c5fb707 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/__pycache__/test_series.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_approximants.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_approximants.py new file mode 100644 index 0000000000000000000000000000000000000000..9c03d2ce38add99b0dce8725b6c8d8844b31f76b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_approximants.py @@ -0,0 +1,23 @@ +from sympy.series import approximants +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.combinatorial.numbers import (fibonacci, lucas) + + +def test_approximants(): + x, t = symbols("x,t") + g = [lucas(k) for k in range(16)] + assert list(approximants(g)) == ( + [2, -4/(x - 2), (5*x - 2)/(3*x - 1), (x - 2)/(x**2 + x - 1)] ) + g = [lucas(k)+fibonacci(k+2) for k in range(16)] + assert list(approximants(g)) == ( + [3, -3/(x - 1), (3*x - 3)/(2*x - 1), -3/(x**2 + x - 1)] ) + g = [lucas(k)**2 for k in range(16)] + assert list(approximants(g)) == ( + [4, -16/(x - 4), (35*x - 4)/(9*x - 1), (37*x - 28)/(13*x**2 + 11*x - 7), + (50*x**2 + 63*x - 52)/(37*x**2 + 19*x - 13), + (-x**2 - 7*x + 4)/(x**3 - 2*x**2 - 2*x + 1)] ) + p = [sum(binomial(k,i)*x**i for i in range(k+1)) for k in range(16)] + y = approximants(p, t, simplify=True) + assert next(y) == 1 + assert next(y) == -1/(t*(x + 1) - 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_aseries.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_aseries.py new file mode 100644 index 0000000000000000000000000000000000000000..cae0ac0a43f2406dd96e45c6a31939ac6b4cdcaa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_aseries.py @@ -0,0 +1,55 @@ +from sympy.core.function import PoleError +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import O +from sympy.abc import x + +from sympy.testing.pytest import raises + +def test_simple(): + # Gruntz' theses pp. 91 to 96 + # 6.6 + e = sin(1/x + exp(-x)) - sin(1/x) + assert e.aseries(x) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + e = exp(x) * (exp(1/x + exp(-x)) - exp(1/x)) + assert e.aseries(x, n=4) == 1/(6*x**3) + 1/(2*x**2) + 1/x + 1 + O(x**(-4), (x, oo)) + + e = exp(exp(x) / (1 - 1/x)) + assert e.aseries(x) == exp(exp(x) / (1 - 1/x)) + + # The implementation of bound in aseries is incorrect currently. This test + # should be commented out when that is fixed. + # assert e.aseries(x, bound=3) == exp(exp(x) / x**2)*exp(exp(x) / x)*exp(-exp(x) + exp(x)/(1 - 1/x) - \ + # exp(x) / x - exp(x) / x**2) * exp(exp(x)) + + e = exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x)) + assert e.aseries(x, n=4) == (-1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)))*exp(-exp(x)) + + e3 = lambda x:exp(exp(exp(x))) + e = e3(x)/e3(x - 1/e3(x)) + assert e.aseries(x, n=3) == 1 + exp(2*x + 2*exp(x))*exp(-2*exp(exp(x)))/2\ + - exp(2*x + exp(x))*exp(-2*exp(exp(x)))/2 - exp(x + exp(x))*exp(-2*exp(exp(x)))/2\ + + exp(x + exp(x))*exp(-exp(exp(x))) + O(exp(-3*exp(exp(x))), (x, oo)) + + e = exp(exp(x)) * (exp(sin(1/x + 1/exp(exp(x)))) - exp(sin(1/x))) + assert e.aseries(x, n=4) == -1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)) + + n = Symbol('n', integer=True) + e = (sqrt(n)*log(n)**2*exp(sqrt(log(n))*log(log(n))**2*exp(sqrt(log(log(n)))*log(log(log(n)))**3)))/n + assert e.aseries(n) == \ + exp(exp(sqrt(log(log(n)))*log(log(log(n)))**3)*sqrt(log(n))*log(log(n))**2)*log(n)**2/sqrt(n) + + +def test_hierarchical(): + e = sin(1/x + exp(-x)) + assert e.aseries(x, n=3, hir=True) == -exp(-2*x)*sin(1/x)/2 + \ + exp(-x)*cos(1/x) + sin(1/x) + O(exp(-3*x), (x, oo)) + + e = sin(x) * cos(exp(-x)) + assert e.aseries(x, hir=True) == exp(-4*x)*sin(x)/24 - \ + exp(-2*x)*sin(x)/2 + sin(x) + O(exp(-6*x), (x, oo)) + raises(PoleError, lambda: e.aseries(x)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_demidovich.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_demidovich.py new file mode 100644 index 0000000000000000000000000000000000000000..98cafbae6f019dd3d97d306099d5780ed2f37f04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_demidovich.py @@ -0,0 +1,143 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, sin, tan) +from sympy.polys.rationaltools import together +from sympy.series.limits import limit + +# Numbers listed with the tests refer to problem numbers in the book +# "Anti-demidovich, problemas resueltos, Ed. URSS" + +x = Symbol("x") + + +def test_leadterm(): + assert (3 + 2*x**(log(3)/log(2) - 1)).leadterm(x) == (3, 0) + + +def root3(x): + return root(x, 3) + + +def root4(x): + return root(x, 4) + + +def test_Limits_simple_0(): + assert limit((2**(x + 1) + 3**(x + 1))/(2**x + 3**x), x, oo) == 3 # 175 + + +def test_Limits_simple_1(): + assert limit((x + 1)*(x + 2)*(x + 3)/x**3, x, oo) == 1 # 172 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 # 179 + assert limit((2*x - 3)*(3*x + 5)*(4*x - 6)/(3*x**3 + x - 1), x, oo) == 8 # Primjer 1 + assert limit(x/root3(x**3 + 10), x, oo) == 1 # Primjer 2 + assert limit((x + 1)**2/(x**2 + 1), x, oo) == 1 # 181 + + +def test_Limits_simple_2(): + assert limit(1000*x/(x**2 - 1), x, oo) == 0 # 182 + assert limit((x**2 - 5*x + 1)/(3*x + 7), x, oo) is oo # 183 + assert limit((2*x**2 - x + 3)/(x**3 - 8*x + 5), x, oo) == 0 # 184 + assert limit((2*x**2 - 3*x - 4)/sqrt(x**4 + 1), x, oo) == 2 # 186 + assert limit((2*x + 3)/(x + root3(x)), x, oo) == 2 # 187 + assert limit(x**2/(10 + x*sqrt(x)), x, oo) is oo # 188 + assert limit(root3(x**2 + 1)/(x + 1), x, oo) == 0 # 189 + assert limit(sqrt(x)/sqrt(x + sqrt(x + sqrt(x))), x, oo) == 1 # 190 + + +def test_Limits_simple_3a(): + a = Symbol('a') + #issue 3513 + assert together(limit((x**2 - (a + 1)*x + a)/(x**3 - a**3), x, a)) == \ + (a - 1)/(3*a**2) # 196 + + +def test_Limits_simple_3b(): + h = Symbol("h") + assert limit(((x + h)**3 - x**3)/h, h, 0) == 3*x**2 # 197 + assert limit((1/(1 - x) - 3/(1 - x**3)), x, 1) == -1 # 198 + assert limit((sqrt(1 + x) - 1)/(root3(1 + x) - 1), x, 0) == Rational(3)/2 # Primer 4 + assert limit((sqrt(x) - 1)/(x - 1), x, 1) == Rational(1)/2 # 199 + assert limit((sqrt(x) - 8)/(root3(x) - 4), x, 64) == 3 # 200 + assert limit((root3(x) - 1)/(root4(x) - 1), x, 1) == Rational(4)/3 # 201 + assert limit( + (root3(x**2) - 2*root3(x) + 1)/(x - 1)**2, x, 1) == Rational(1)/9 # 202 + + +def test_Limits_simple_4a(): + a = Symbol('a') + assert limit((sqrt(x) - sqrt(a))/(x - a), x, a) == 1/(2*sqrt(a)) # Primer 5 + assert limit((sqrt(x) - 1)/(root3(x) - 1), x, 1) == Rational(3, 2) # 205 + assert limit((sqrt(1 + x) - sqrt(1 - x))/x, x, 0) == 1 # 207 + assert limit(sqrt(x**2 - 5*x + 6) - x, x, oo) == Rational(-5, 2) # 213 + + +def test_limits_simple_4aa(): + assert limit(x*(sqrt(x**2 + 1) - x), x, oo) == Rational(1)/2 # 214 + + +def test_Limits_simple_4b(): + #issue 3511 + assert limit(x - root3(x**3 - 1), x, oo) == 0 # 215 + + +def test_Limits_simple_4c(): + assert limit(log(1 + exp(x))/x, x, -oo) == 0 # 267a + assert limit(log(1 + exp(x))/x, x, oo) == 1 # 267b + + +def test_bounded(): + assert limit(sin(x)/x, x, oo) == 0 # 216b + assert limit(x*sin(1/x), x, 0) == 0 # 227a + + +def test_f1a(): + #issue 3508: + assert limit((sin(2*x)/x)**(1 + x), x, 0) == 2 # Primer 7 + + +def test_f1a2(): + #issue 3509: + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) # Primer 9 + + +def test_f1b(): + m = Symbol("m") + n = Symbol("n") + h = Symbol("h") + a = Symbol("a") + assert limit(sin(x)/x, x, 2) == sin(2)/2 # 216a + assert limit(sin(3*x)/x, x, 0) == 3 # 217 + assert limit(sin(5*x)/sin(2*x), x, 0) == Rational(5, 2) # 218 + assert limit(sin(pi*x)/sin(3*pi*x), x, 0) == Rational(1, 3) # 219 + assert limit(x*sin(pi/x), x, oo) == pi # 220 + assert limit((1 - cos(x))/x**2, x, 0) == S.Half # 221 + assert limit(x*sin(1/x), x, oo) == 1 # 227b + assert limit((cos(m*x) - cos(n*x))/x**2, x, 0) == -m**2/2 + n**2/2 # 232 + assert limit((tan(x) - sin(x))/x**3, x, 0) == S.Half # 233 + assert limit((x - sin(2*x))/(x + sin(3*x)), x, 0) == -Rational(1, 4) # 237 + assert limit((1 - sqrt(cos(x)))/x**2, x, 0) == Rational(1, 4) # 239 + assert limit((sqrt(1 + sin(x)) - sqrt(1 - sin(x)))/x, x, 0) == 1 # 240 + + assert limit((1 + h/x)**x, x, oo) == exp(h) # Primer 9 + assert limit((sin(x) - sin(a))/(x - a), x, a) == cos(a) # 222, *176 + assert limit((cos(x) - cos(a))/(x - a), x, a) == -sin(a) # 223 + assert limit((sin(x + h) - sin(x))/h, h, 0) == cos(x) # 225 + + +def test_f2a(): + assert limit(((x + 1)/(2*x + 1))**(x**2), x, oo) == 0 # Primer 8 + + +def test_f2(): + assert limit((sqrt( + cos(x)) - root3(cos(x)))/(sin(x)**2), x, 0) == -Rational(1, 12) # *184 + + +def test_f3(): + a = Symbol('a') + #issue 3504 + assert limit(asin(a*x)/x, x, 0) == a diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_formal.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_formal.py new file mode 100644 index 0000000000000000000000000000000000000000..cac60b12534152a5783bb8f0faab2c06da6691fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_formal.py @@ -0,0 +1,618 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, asech) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin) +from sympy.functions.special.bessel import airyai +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import integrate +from sympy.series.formal import fps +from sympy.series.order import O +from sympy.series.formal import (rational_algorithm, FormalPowerSeries, + FormalPowerSeriesProduct, FormalPowerSeriesCompose, + FormalPowerSeriesInverse, simpleDE, + rational_independent, exp_re, hyper_re) +from sympy.testing.pytest import raises, XFAIL, slow + +x, y, z = symbols('x y z') +n, m, k = symbols('n m k', integer=True) +f, r = Function('f'), Function('r') + + +def test_rational_algorithm(): + f = 1 / ((x - 1)**2 * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-2**(-k - 1) + 1 - (factorial(k + 1) / factorial(k)), 0, 0) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-15*2**(-k - 1) + 4, x + 4, 0) + + f = z / (y*m - m*x - y*x + x**2) + assert rational_algorithm(f, x, k) == \ + (((-y**(-k - 1)*z) / (y - m)) + ((m**(-k - 1)*z) / (y - m)), 0, 0) + + f = x / (1 - x - x**2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((Rational(-1, 2) + sqrt(5)/2)**(-k - 1) * + (-sqrt(5)/10 + S.Half)) + + ((-sqrt(5)/2 - S.Half)**(-k - 1) * + (sqrt(5)/10 + S.Half)), 0, 0) + + f = 1 / (x**2 + 2*x + 2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((I*(-1 + I)**(-k - 1)) / 2 - (I*(-1 - I)**(-k - 1)) / 2, 0, 0) + + f = log(1 + x) + assert rational_algorithm(f, x, k) == \ + (-(-1)**(-k) / k, 0, 1) + + f = atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k)) / 2 - (I*(-I)**(-k)) / 2) / k, 0, 1) + + f = x*atan(x) - log(1 + x**2) / 2 + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k + 1)) / 2 - (I*(-I)**(-k + 1)) / 2) / + (k*(k - 1)), 0, 2) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((-(-1)**(-k) / 2 - (I*I**(-k)) / 2 + (I*(-I)**(-k)) / 2 + + S.Half) / k, 0, 1) + + assert rational_algorithm(cos(x), x, k) is None + + +def test_rational_independent(): + ri = rational_independent + assert ri([], x) == [] + assert ri([cos(x), sin(x)], x) == [cos(x), sin(x)] + assert ri([x**2, sin(x), x*sin(x), x**3], x) == \ + [x**3 + x**2, x*sin(x) + sin(x)] + assert ri([S.One, x*log(x), log(x), sin(x)/x, cos(x), sin(x), x], x) == \ + [x + 1, x*log(x) + log(x), sin(x)/x + sin(x), cos(x)] + + +def test_simpleDE(): + # Tests just the first valid DE + for DE in simpleDE(exp(x), x, f): + assert DE == (-f(x) + Derivative(f(x), x), 1) + break + for DE in simpleDE(sin(x), x, f): + assert DE == (f(x) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(log(1 + x), x, f): + assert DE == ((x + 1)*Derivative(f(x), x, 2) + Derivative(f(x), x), 2) + break + for DE in simpleDE(asin(x), x, f): + assert DE == (x*Derivative(f(x), x) + (x**2 - 1)*Derivative(f(x), x, x), + 2) + break + for DE in simpleDE(exp(x)*sin(x), x, f): + assert DE == (2*f(x) - 2*Derivative(f(x)) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(((1 + x)/(1 - x))**n, x, f): + assert DE == (2*n*f(x) + (x**2 - 1)*Derivative(f(x), x), 1) + break + for DE in simpleDE(airyai(x), x, f): + assert DE == (-x*f(x) + Derivative(f(x), x, x), 2) + break + + +def test_exp_re(): + d = -f(x) + Derivative(f(x), x) + assert exp_re(d, r, k) == -r(k) + r(k + 1) + + d = f(x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 2) + + d = f(x) + Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + r(k + 2) + + d = Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + + d = Derivative(f(x), x, 3) + Derivative(f(x), x, 4) + Derivative(f(x)) + assert exp_re(d, r, k) == r(k) + r(k + 2) + r(k + 3) + + +def test_hyper_re(): + d = f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == r(k) + (k+1)*(k+2)*r(k + 2) + + d = -x*f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == (k + 2)*(k + 3)*r(k + 3) - r(k) + + d = 2*f(x) - 2*Derivative(f(x), x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (-2*k - 2)*r(k + 1) + (k + 1)*(k + 2)*r(k + 2) + 2*r(k) + + d = 2*n*f(x) + (x**2 - 1)*Derivative(f(x), x) + assert hyper_re(d, r, k) == \ + k*r(k) + 2*n*r(k + 1) + (-k - 2)*r(k + 2) + + d = (x**10 + 4)*Derivative(f(x), x) + x*(x**10 - 1)*Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (k*(k - 1) + k)*r(k) + (4*k - (k + 9)*(k + 10) + 40)*r(k + 10) + + d = ((x**2 - 1)*Derivative(f(x), x, 3) + 3*x*Derivative(f(x), x, x) + + Derivative(f(x), x)) + assert hyper_re(d, r, k) == \ + ((k*(k - 2)*(k - 1) + 3*k*(k - 1) + k)*r(k) + + (-k*(k + 1)*(k + 2))*r(k + 2)) + + +def test_fps(): + assert fps(1) == 1 + assert fps(2, x) == 2 + assert fps(2, x, dir='+') == 2 + assert fps(2, x, dir='-') == 2 + assert fps(1/x + 1/x**2) == 1/x + 1/x**2 + assert fps(log(1 + x), hyper=False, rational=False) == log(1 + x) + + f = fps(x**2 + x + 1) + assert isinstance(f, FormalPowerSeries) + assert f.function == x**2 + x + 1 + assert f[0] == 1 + assert f[2] == x**2 + assert f.truncate(4) == x**2 + x + 1 + O(x**4) + assert f.polynomial() == x**2 + x + 1 + + f = fps(log(1 + x)) + assert isinstance(f, FormalPowerSeries) + assert f.function == log(1 + x) + assert f.subs(x, y) == f + assert f[:5] == [0, x, -x**2/2, x**3/3, -x**4/4] + assert f.as_leading_term(x) == x + assert f.polynomial(6) == x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + + k = f.ak.variables[0] + assert f.infinite == Sum((-(-1)**(-k)*x**k)/k, (k, 1, oo)) + + ft, s = f.truncate(n=None), f[:5] + for i, t in enumerate(ft): + if i == 5: + break + assert s[i] == t + + f = sin(x).fps(x) + assert isinstance(f, FormalPowerSeries) + assert f.truncate() == x - x**3/6 + x**5/120 + O(x**6) + + raises(NotImplementedError, lambda: fps(y*x)) + raises(ValueError, lambda: fps(x, dir=0)) + + +@slow +def test_fps__rational(): + assert fps(1/x) == (1/x) + assert fps((x**2 + x + 1) / x**3, dir=-1) == (x**2 + x + 1) / x**3 + + f = 1 / ((x - 1)**2 * (x - 2)) + assert fps(f, x).truncate() == \ + (Rational(-1, 2) - x*Rational(5, 4) - 17*x**2/8 - 49*x**3/16 - 129*x**4/32 - + 321*x**5/64 + O(x**6)) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert fps(f, x).truncate() == \ + (S.Half + x*Rational(5, 4) + 17*x**2/8 + 49*x**3/16 + 113*x**4/32 + + 241*x**5/64 + O(x**6)) + + f = x / (1 - x - x**2) + assert fps(f, x, full=True).truncate() == \ + x + x**2 + 2*x**3 + 3*x**4 + 5*x**5 + O(x**6) + + f = 1 / (x**2 + 2*x + 2) + assert fps(f, x, full=True).truncate() == \ + S.Half - x/2 + x**2/4 - x**4/8 + x**5/8 + O(x**6) + + f = log(1 + x) + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, dir=1).truncate() == fps(f, x, dir=-1).truncate() + assert fps(f, x, 2).truncate() == \ + (log(3) - Rational(2, 3) - (x - 2)**2/18 + (x - 2)**3/81 - + (x - 2)**4/324 + (x - 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, dir=-1).truncate() == \ + (log(3) - Rational(2, 3) - (-x + 2)**2/18 - (-x + 2)**3/81 - + (-x + 2)**4/324 - (-x + 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + + f = atan(x) + assert fps(f, x, full=True).truncate() == x - x**3/3 + x**5/5 + O(x**6) + assert fps(f, x, full=True, dir=1).truncate() == \ + fps(f, x, full=True, dir=-1).truncate() + assert fps(f, x, 2, full=True).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(x - 2)**2/25 + 11*(x - 2)**3/375 - + 6*(x - 2)**4/625 + 41*(x - 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, full=True, dir=-1).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(-x + 2)**2/25 - 11*(-x + 2)**3/375 - + 6*(-x + 2)**4/625 - 41*(-x + 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, full=True).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert fps(f, x, full=True).truncate(n=10) == 2*x**3/3 + 2*x**7/7 + O(x**10) + + +@slow +def test_fps__hyper(): + f = sin(x) + assert fps(f, x).truncate() == x - x**3/6 + x**5/120 + O(x**6) + + f = cos(x) + assert fps(f, x).truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + f = exp(x) + assert fps(f, x).truncate() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + O(x**6) + + f = atan(x) + assert fps(f, x).truncate() == x - x**3/3 + x**5/5 + O(x**6) + + f = exp(acos(x)) + assert fps(f, x).truncate() == \ + (exp(pi/2) - x*exp(pi/2) + x**2*exp(pi/2)/2 - x**3*exp(pi/2)/3 + + 5*x**4*exp(pi/2)/24 - x**5*exp(pi/2)/6 + O(x**6)) + + f = exp(acosh(x)) + assert fps(f, x).truncate() == I + x - I*x**2/2 - I*x**4/8 + O(x**6) + + f = atan(1/x) + assert fps(f, x).truncate() == pi/2 - x + x**3/3 - x**5/5 + O(x**6) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, rational=False).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log(1 + x) + assert fps(f, x, rational=False).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + f = airyai(x**2) + assert fps(f, x).truncate() == \ + (3**Rational(5, 6)*gamma(Rational(1, 3))/(6*pi) - + 3**Rational(2, 3)*x**2/(3*gamma(Rational(1, 3))) + O(x**6)) + + f = exp(x)*sin(x) + assert fps(f, x).truncate() == x + x**2 + x**3/3 - x**5/30 + O(x**6) + + f = exp(x)*sin(x)/x + assert fps(f, x).truncate() == 1 + x + x**2/3 - x**4/30 - x**5/90 + O(x**6) + + f = sin(x) * cos(x) + assert fps(f, x).truncate() == x - 2*x**3/3 + 2*x**5/15 + O(x**6) + + +def test_fps_shift(): + f = x**-5*sin(x) + assert fps(f, x).truncate() == \ + 1/x**4 - 1/(6*x**2) + Rational(1, 120) - x**2/5040 + x**4/362880 + O(x**6) + + f = x**2*atan(x) + assert fps(f, x, rational=False).truncate() == \ + x**3 - x**5/3 + O(x**6) + + f = cos(sqrt(x))*x + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/24 - x**4/720 + x**5/40320 + O(x**6) + + f = x**2*cos(sqrt(x)) + assert fps(f, x).truncate() == \ + x**2 - x**3/2 + x**4/24 - x**5/720 + O(x**6) + + +def test_fps__Add_expr(): + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = sin(x) + cos(x) - exp(x) + log(1 + x) + assert fps(f, x).truncate() == x - 3*x**2/2 - x**4/4 + x**5/5 + O(x**6) + + f = 1/x + sin(x) + assert fps(f, x).truncate() == 1/x + x - x**3/6 + x**5/120 + O(x**6) + + f = sin(x) - cos(x) + 1/(x - 1) + assert fps(f, x).truncate() == \ + -2 - x**2/2 - 7*x**3/6 - 25*x**4/24 - 119*x**5/120 + O(x**6) + + +def test_fps__asymptotic(): + f = exp(x) + assert fps(f, x, oo) == f + assert fps(f, x, -oo).truncate() == O(1/x**6, (x, oo)) + + f = erf(x) + assert fps(f, x, oo).truncate() == 1 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo).truncate() == -1 + O(1/x**6, (x, oo)) + + f = atan(x) + assert fps(f, x, oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + O(1/x**6, (x, oo)) + + f = log(1 + x) + assert fps(f, x, oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x - log(1/x) + + O(1/x**6, (x, oo))) + assert fps(f, x, -oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x + I*pi - + log(-1/x) + O(1/x**6, (x, oo))) + + +def test_fps__fractional(): + f = sin(sqrt(x)) / x + assert fps(f, x).truncate() == \ + (1/sqrt(x) - sqrt(x)/6 + x**Rational(3, 2)/120 - + x**Rational(5, 2)/5040 + x**Rational(7, 2)/362880 - + x**Rational(9, 2)/39916800 + x**Rational(11, 2)/6227020800 + O(x**6)) + + f = sin(sqrt(x)) * x + assert fps(f, x).truncate() == \ + (x**Rational(3, 2) - x**Rational(5, 2)/6 + x**Rational(7, 2)/120 - + x**Rational(9, 2)/5040 + x**Rational(11, 2)/362880 + O(x**6)) + + f = atan(sqrt(x)) / x**2 + assert fps(f, x).truncate() == \ + (x**Rational(-3, 2) - x**Rational(-1, 2)/3 + x**S.Half/5 - + x**Rational(3, 2)/7 + x**Rational(5, 2)/9 - x**Rational(7, 2)/11 + + x**Rational(9, 2)/13 - x**Rational(11, 2)/15 + O(x**6)) + + f = exp(sqrt(x)) + assert fps(f, x).truncate().expand() == \ + (1 + x/2 + x**2/24 + x**3/720 + x**4/40320 + x**5/3628800 + sqrt(x) + + x**Rational(3, 2)/6 + x**Rational(5, 2)/120 + x**Rational(7, 2)/5040 + + x**Rational(9, 2)/362880 + x**Rational(11, 2)/39916800 + O(x**6)) + + f = exp(sqrt(x))*x + assert fps(f, x).truncate().expand() == \ + (x + x**2/2 + x**3/24 + x**4/720 + x**5/40320 + x**Rational(3, 2) + + x**Rational(5, 2)/6 + x**Rational(7, 2)/120 + x**Rational(9, 2)/5040 + + x**Rational(11, 2)/362880 + O(x**6)) + + +def test_fps__logarithmic_singularity(): + f = log(1 + 1/x) + assert fps(f, x) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, rational=False) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + +@XFAIL +def test_fps__logarithmic_singularity_fail(): + f = asech(x) # Algorithms for computing limits probably needs improvements + assert fps(f, x) == log(2) - log(x) - x**2/4 - 3*x**4/64 + O(x**6) + + +def test_fps_symbolic(): + f = x**n*sin(x**2) + assert fps(f, x).truncate(8) == x**(n + 2) - x**(n + 6)/6 + O(x**(n + 8), x) + + f = x**n*log(1 + x) + fp = fps(f, x) + k = fp.ak.variables[0] + assert fp.infinite == \ + Sum((-(-1)**(-k)*x**(k + n))/k, (k, 1, oo)) + + f = (x - 2)**n*log(1 + x) + assert fps(f, x, 2).truncate() == \ + ((x - 2)**n*log(3) + (x - 2)**(n + 1)/3 - (x - 2)**(n + 2)/18 + (x - 2)**(n + 3)/81 - + (x - 2)**(n + 4)/324 + (x - 2)**(n + 5)/1215 + O((x - 2)**(n + 6), (x, 2))) + + f = x**(n - 2)*cos(x) + assert fps(f, x).truncate() == \ + (x**(n - 2) - x**n/2 + x**(n + 2)/24 + O(x**(n + 4), x)) + + f = x**(n - 2)*sin(x) + x**n*exp(x) + assert fps(f, x).truncate() == \ + (x**(n - 1) + x**(n + 1) + x**(n + 2)/2 + x**n + + x**(n + 4)/24 + x**(n + 5)/60 + O(x**(n + 6), x)) + + f = x**n*atan(x) + assert fps(f, x, oo).truncate() == \ + (-x**(n - 5)/5 + x**(n - 3)/3 + x**n*(pi/2 - 1/x) + + O((1/x)**(-n)/x**6, (x, oo))) + + f = x**(n/2)*cos(x) + assert fps(f, x).truncate() == \ + x**(n/2) - x**(n/2 + 2)/2 + x**(n/2 + 4)/24 + O(x**(n/2 + 6), x) + + f = x**(n + m)*sin(x) + assert fps(f, x).truncate() == \ + x**(m + n + 1) - x**(m + n + 3)/6 + x**(m + n + 5)/120 + O(x**(m + n + 6), x) + + +def test_fps__slow(): + f = x*exp(x)*sin(2*x) # TODO: rsolve needs improvement + assert fps(f, x).truncate() == 2*x**2 + 2*x**3 - x**4/3 - x**5 + O(x**6) + + +def test_fps__operations(): + f1, f2 = fps(sin(x)), fps(cos(x)) + + fsum = f1 + f2 + assert fsum.function == sin(x) + cos(x) + assert fsum.truncate() == \ + 1 + x - x**2/2 - x**3/6 + x**4/24 + x**5/120 + O(x**6) + + fsum = f1 + 1 + assert fsum.function == sin(x) + 1 + assert fsum.truncate() == 1 + x - x**3/6 + x**5/120 + O(x**6) + + fsum = 1 + f2 + assert fsum.function == cos(x) + 1 + assert fsum.truncate() == 2 - x**2/2 + x**4/24 + O(x**6) + + assert (f1 + x) == Add(f1, x) + + assert -f2.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + assert (f1 - f1) is S.Zero + + fsub = f1 - f2 + assert fsub.function == sin(x) - cos(x) + assert fsub.truncate() == \ + -1 + x + x**2/2 - x**3/6 - x**4/24 + x**5/120 + O(x**6) + + fsub = f1 - 1 + assert fsub.function == sin(x) - 1 + assert fsub.truncate() == -1 + x - x**3/6 + x**5/120 + O(x**6) + + fsub = 1 - f2 + assert fsub.function == -cos(x) + 1 + assert fsub.truncate() == x**2/2 - x**4/24 + O(x**6) + + raises(ValueError, lambda: f1 + fps(exp(x), dir=-1)) + raises(ValueError, lambda: f1 + fps(exp(x), x0=1)) + + fm = f1 * 3 + + assert fm.function == 3*sin(x) + assert fm.truncate() == 3*x - x**3/2 + x**5/40 + O(x**6) + + fm = 3 * f2 + + assert fm.function == 3*cos(x) + assert fm.truncate() == 3 - 3*x**2/2 + x**4/8 + O(x**6) + + assert (f1 * f2) == Mul(f1, f2) + assert (f1 * x) == Mul(f1, x) + + fd = f1.diff() + assert fd.function == cos(x) + assert fd.truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + fd = f2.diff() + assert fd.function == -sin(x) + assert fd.truncate() == -x + x**3/6 - x**5/120 + O(x**6) + + fd = f2.diff().diff() + assert fd.function == -cos(x) + assert fd.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + f3 = fps(exp(sqrt(x))) + fd = f3.diff() + assert fd.truncate().expand() == \ + (1/(2*sqrt(x)) + S.Half + x/12 + x**2/240 + x**3/10080 + x**4/725760 + + x**5/79833600 + sqrt(x)/4 + x**Rational(3, 2)/48 + x**Rational(5, 2)/1440 + + x**Rational(7, 2)/80640 + x**Rational(9, 2)/7257600 + x**Rational(11, 2)/958003200 + + O(x**6)) + + assert f1.integrate((x, 0, 1)) == -cos(1) + 1 + assert integrate(f1, (x, 0, 1)) == -cos(1) + 1 + + fi = integrate(f1, x) + assert fi.function == -cos(x) + assert fi.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + fi = f2.integrate(x) + assert fi.function == sin(x) + assert fi.truncate() == x - x**3/6 + x**5/120 + O(x**6) + +def test_fps__product(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.product(exp(x), x)) + raises(ValueError, lambda: f1.product(fps(exp(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(y)), x, 4)) + + fprod = f1.product(f2, x) + assert isinstance(fprod, FormalPowerSeriesProduct) + assert isinstance(fprod.ffps, FormalPowerSeries) + assert isinstance(fprod.gfps, FormalPowerSeries) + assert fprod.f == sin(x) + assert fprod.g == exp(x) + assert fprod.function == sin(x) * exp(x) + assert fprod._eval_terms(4) == x + x**2 + x**3/3 + assert fprod.truncate(4) == x + x**2 + x**3/3 + O(x**4) + assert fprod.polynomial(4) == x + x**2 + x**3/3 + + raises(NotImplementedError, lambda: fprod._eval_term(5)) + raises(NotImplementedError, lambda: fprod.infinite) + raises(NotImplementedError, lambda: fprod._eval_derivative(x)) + raises(NotImplementedError, lambda: fprod.integrate(x)) + + assert f1.product(f3, x)._eval_terms(4) == x - 2*x**3/3 + assert f1.product(f3, x).truncate(4) == x - 2*x**3/3 + O(x**4) + + +def test_fps__compose(): + f1, f2, f3 = fps(exp(x)), fps(sin(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.compose(sin(x), x)) + raises(ValueError, lambda: f1.compose(fps(sin(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(y)), x, 4)) + + raises(ValueError, lambda: f1.compose(f3, x)) + raises(ValueError, lambda: f2.compose(f3, x)) + + fcomp = f1.compose(f2, x) + assert isinstance(fcomp, FormalPowerSeriesCompose) + assert isinstance(fcomp.ffps, FormalPowerSeries) + assert isinstance(fcomp.gfps, FormalPowerSeries) + assert fcomp.f == exp(x) + assert fcomp.g == sin(x) + assert fcomp.function == exp(sin(x)) + assert fcomp._eval_terms(6) == 1 + x + x**2/2 - x**4/8 - x**5/15 + assert fcomp.truncate() == 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + assert fcomp.truncate(5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + raises(NotImplementedError, lambda: fcomp._eval_term(5)) + raises(NotImplementedError, lambda: fcomp.infinite) + raises(NotImplementedError, lambda: fcomp._eval_derivative(x)) + raises(NotImplementedError, lambda: fcomp.integrate(x)) + + assert f1.compose(f2, x).truncate(4) == 1 + x + x**2/2 + O(x**4) + assert f1.compose(f2, x).truncate(8) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 - x**6/240 + x**7/90 + O(x**8) + assert f1.compose(f2, x).truncate(6) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + + assert f2.compose(f2, x).truncate(4) == x - x**3/3 + O(x**4) + assert f2.compose(f2, x).truncate(8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + assert f2.compose(f2, x).truncate(6) == x - x**3/3 + x**5/10 + O(x**6) + + +def test_fps__inverse(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.inverse(x)) + + finv = f2.inverse(x) + assert isinstance(finv, FormalPowerSeriesInverse) + assert isinstance(finv.ffps, FormalPowerSeries) + raises(ValueError, lambda: finv.gfps) + + assert finv.f == exp(x) + assert finv.function == exp(-x) + assert finv._eval_terms(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + assert finv.truncate() == 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + O(x**6) + assert finv.truncate(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + O(x**5) + + raises(NotImplementedError, lambda: finv._eval_term(5)) + raises(ValueError, lambda: finv.g) + raises(NotImplementedError, lambda: finv.infinite) + raises(NotImplementedError, lambda: finv._eval_derivative(x)) + raises(NotImplementedError, lambda: finv.integrate(x)) + + assert f2.inverse(x).truncate(8) == \ + 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + x**6/720 - x**7/5040 + O(x**8) + + assert f3.inverse(x).truncate() == 1 + x**2/2 + 5*x**4/24 + O(x**6) + assert f3.inverse(x).truncate(8) == 1 + x**2/2 + 5*x**4/24 + 61*x**6/720 + O(x**8) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_fourier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..994f182088b09b038e0e1b3885fec1c27f69f2b0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_fourier.py @@ -0,0 +1,165 @@ +from sympy.core.add import Add +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin, sinc, tan) +from sympy.series.fourier import fourier_series +from sympy.series.fourier import FourierSeries +from sympy.testing.pytest import raises +from functools import lru_cache + +x, y, z = symbols('x y z') + +# Don't declare these during import because they are slow +@lru_cache() +def _get_examples(): + fo = fourier_series(x, (x, -pi, pi)) + fe = fourier_series(x**2, (-pi, pi)) + fp = fourier_series(Piecewise((0, x < 0), (pi, True)), (x, -pi, pi)) + return fo, fe, fp + + +def test_FourierSeries(): + fo, fe, fp = _get_examples() + + assert fourier_series(1, (-pi, pi)) == 1 + assert (Piecewise((0, x < 0), (pi, True)). + fourier_series((x, -pi, pi)).truncate()) == fp.truncate() + assert isinstance(fo, FourierSeries) + assert fo.function == x + assert fo.x == x + assert fo.period == (-pi, pi) + + assert fo.term(3) == 2*sin(3*x) / 3 + assert fe.term(3) == -4*cos(3*x) / 9 + assert fp.term(3) == 2*sin(3*x) / 3 + + assert fo.as_leading_term(x) == 2*sin(x) + assert fe.as_leading_term(x) == pi**2 / 3 + assert fp.as_leading_term(x) == pi / 2 + + assert fo.truncate() == 2*sin(x) - sin(2*x) + (2*sin(3*x) / 3) + assert fe.truncate() == -4*cos(x) + cos(2*x) + pi**2 / 3 + assert fp.truncate() == 2*sin(x) + (2*sin(3*x) / 3) + pi / 2 + + fot = fo.truncate(n=None) + s = [0, 2*sin(x), -sin(2*x)] + for i, t in enumerate(fot): + if i == 3: + break + assert s[i] == t + + def _check_iter(f, i): + for ind, t in enumerate(f): + assert t == f[ind] # noqa: PLR1736 + if ind == i: + break + + _check_iter(fo, 3) + _check_iter(fe, 3) + _check_iter(fp, 3) + + assert fo.subs(x, x**2) == fo + + raises(ValueError, lambda: fourier_series(x, (0, 1, 2))) + raises(ValueError, lambda: fourier_series(x, (x, 0, oo))) + raises(ValueError, lambda: fourier_series(x*y, (0, oo))) + + +def test_FourierSeries_2(): + p = Piecewise((0, x < 0), (x, True)) + f = fourier_series(p, (x, -2, 2)) + + assert f.term(3) == (2*sin(3*pi*x / 2) / (3*pi) - + 4*cos(3*pi*x / 2) / (9*pi**2)) + assert f.truncate() == (2*sin(pi*x / 2) / pi - sin(pi*x) / pi - + 4*cos(pi*x / 2) / pi**2 + S.Half) + + +def test_square_wave(): + """Test if fourier_series approximates discontinuous function correctly.""" + square_wave = Piecewise((1, x < pi), (-1, True)) + s = fourier_series(square_wave, (x, 0, 2*pi)) + + assert s.truncate(3) == 4 / pi * sin(x) + 4 / (3 * pi) * sin(3 * x) + \ + 4 / (5 * pi) * sin(5 * x) + assert s.sigma_approximation(4) == 4 / pi * sin(x) * sinc(pi / 4) + \ + 4 / (3 * pi) * sin(3 * x) * sinc(3 * pi / 4) + + +def test_sawtooth_wave(): + s = fourier_series(x, (x, 0, pi)) + assert s.truncate(4) == \ + pi/2 - sin(2*x) - sin(4*x)/2 - sin(6*x)/3 + s = fourier_series(x, (x, 0, 1)) + assert s.truncate(4) == \ + S.Half - sin(2*pi*x)/pi - sin(4*pi*x)/(2*pi) - sin(6*pi*x)/(3*pi) + + +def test_FourierSeries__operations(): + fo, fe, fp = _get_examples() + + fes = fe.scale(-1).shift(pi**2) + assert fes.truncate() == 4*cos(x) - cos(2*x) + 2*pi**2 / 3 + + assert fp.shift(-pi/2).truncate() == (2*sin(x) + (2*sin(3*x) / 3) + + (2*sin(5*x) / 5)) + + fos = fo.scale(3) + assert fos.truncate() == 6*sin(x) - 3*sin(2*x) + 2*sin(3*x) + + fx = fe.scalex(2).shiftx(1) + assert fx.truncate() == -4*cos(2*x + 2) + cos(4*x + 4) + pi**2 / 3 + + fl = fe.scalex(3).shift(-pi).scalex(2).shiftx(1).scale(4) + assert fl.truncate() == (-16*cos(6*x + 6) + 4*cos(12*x + 12) - + 4*pi + 4*pi**2 / 3) + + raises(ValueError, lambda: fo.shift(x)) + raises(ValueError, lambda: fo.shiftx(sin(x))) + raises(ValueError, lambda: fo.scale(x*y)) + raises(ValueError, lambda: fo.scalex(x**2)) + + +def test_FourierSeries__neg(): + fo, fe, fp = _get_examples() + + assert (-fo).truncate() == -2*sin(x) + sin(2*x) - (2*sin(3*x) / 3) + assert (-fe).truncate() == +4*cos(x) - cos(2*x) - pi**2 / 3 + + +def test_FourierSeries__add__sub(): + fo, fe, fp = _get_examples() + + assert fo + fo == fo.scale(2) + assert fo - fo == 0 + assert -fe - fe == fe.scale(-2) + + assert (fo + fe).truncate() == 2*sin(x) - sin(2*x) - 4*cos(x) + cos(2*x) \ + + pi**2 / 3 + assert (fo - fe).truncate() == 2*sin(x) - sin(2*x) + 4*cos(x) - cos(2*x) \ + - pi**2 / 3 + + assert isinstance(fo + 1, Add) + + raises(ValueError, lambda: fo + fourier_series(x, (x, 0, 2))) + + +def test_FourierSeries_finite(): + + assert fourier_series(sin(x)).truncate(1) == sin(x) + # assert type(fourier_series(sin(x)*log(x))).truncate() == FourierSeries + # assert type(fourier_series(sin(x**2+6))).truncate() == FourierSeries + assert fourier_series(sin(x)*log(y)*exp(z),(x,pi,-pi)).truncate() == sin(x)*log(y)*exp(z) + assert fourier_series(sin(x)**6).truncate(oo) == -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 \ + + Rational(5, 16) + assert fourier_series(sin(x) ** 6).truncate() == -15 * cos(2 * x) / 32 + 3 * cos(4 * x) / 16 \ + + Rational(5, 16) + assert fourier_series(sin(4*x+3) + cos(3*x+4)).truncate(oo) == -sin(4)*sin(3*x) + sin(4*x)*cos(3) \ + + cos(4)*cos(3*x) + sin(3)*cos(4*x) + assert fourier_series(sin(x)+cos(x)*tan(x)).truncate(oo) == 2*sin(x) + assert fourier_series(cos(pi*x), (x, -1, 1)).truncate(oo) == cos(pi*x) + assert fourier_series(cos(3*pi*x + 4) - sin(4*pi*x)*log(pi*y), (x, -1, 1)).truncate(oo) == -log(pi*y)*sin(4*pi*x)\ + - sin(4)*sin(3*pi*x) + cos(4)*cos(3*pi*x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_gruntz.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_gruntz.py new file mode 100644 index 0000000000000000000000000000000000000000..4cae15297048bc52a69a3d9ca57a7614cfcdc61c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_gruntz.py @@ -0,0 +1,490 @@ +from sympy.core import EulerGamma +from sympy.core.function import Function +from sympy.core.numbers import (E, I, Integer, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acot, atan, cos, sin) +from sympy.functions.special.error_functions import (Ei, erf) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.polys.polytools import cancel +from sympy.functions.elementary.hyperbolic import cosh, coth, sinh, tanh +from sympy.series.gruntz import compare, mrv, rewrite, mrv_leadterm, gruntz, \ + sign +from sympy.testing.pytest import XFAIL, raises, skip, slow + +""" +This test suite is testing the limit algorithm using the bottom up approach. +See the documentation in limits2.py. The algorithm itself is highly recursive +by nature, so "compare" is logically the lowest part of the algorithm, yet in +some sense it's the most complex part, because it needs to calculate a limit +to return the result. + +Nevertheless, the rest of the algorithm depends on compare working correctly. +""" + +x = Symbol('x', real=True) +m = Symbol('m', real=True) + + +runslow = False + + +def _sskip(): + if not runslow: + skip("slow") + + +@slow +def test_gruntz_evaluation(): + # Gruntz' thesis pp. 122 to 123 + # 8.1 + assert gruntz(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x, oo) == -1 + # 8.2 + assert gruntz(exp(x)*(exp(1/x + exp(-x) + exp(-x**2)) + - exp(1/x - exp(-exp(x)))), x, oo) == 1 + # 8.3 + assert gruntz(exp(exp(x - exp(-x))/(1 - 1/x)) - exp(exp(x)), x, oo) is oo + # 8.5 + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(exp(x))), x, oo) is oo + # 8.6 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(x))))), + x, oo) is oo + # 8.7 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(exp(x)))))), + x, oo) == 1 + # 8.8 + assert gruntz(exp(exp(x)) / exp(exp(x - exp(-exp(exp(x))))), x, oo) == 1 + # 8.9 + assert gruntz(log(x)**2 * exp(sqrt(log(x))*(log(log(x)))**2 + * exp(sqrt(log(log(x))) * (log(log(log(x))))**3)) / sqrt(x), + x, oo) == 0 + # 8.10 + assert gruntz((x*log(x)*(log(x*exp(x) - x**2))**2) + / (log(log(x**2 + 2*exp(exp(3*x**3*log(x)))))), x, oo) == Rational(1, 3) + # 8.11 + assert gruntz((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1)))) - exp(x))/x, + x, oo) == -exp(2) + # 8.12 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + # 8.13 + assert gruntz(x/log(x**(log(x**(log(2)/log(x))))), x, oo) is oo + # 8.14 + assert gruntz(exp(exp(2*log(x**5 + x)*log(log(x)))) + / exp(exp(10*log(x)*log(log(x)))), x, oo) is oo + # 8.15 + assert gruntz(exp(exp(Rational(5, 2)*x**Rational(-5, 7) + Rational(21, 8)*x**Rational(6, 11) + + 2*x**(-8) + Rational(54, 17)*x**Rational(49, 45)))**8 + / log(log(-log(Rational(4, 3)*x**Rational(-5, 14))))**Rational(7, 6), x, oo) is oo + # 8.16 + assert gruntz((exp(4*x*exp(-x)/(1/exp(x) + 1/exp(2*x**2/(x + 1)))) - exp(x)) + / exp(x)**4, x, oo) == 1 + # 8.17 + assert gruntz(exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1))))/exp(x), x, oo) \ + == 1 + # 8.19 + assert gruntz(log(x)*(log(log(x) + log(log(x))) - log(log(x))) + / (log(log(x) + log(log(log(x))))), x, oo) == 1 + # 8.20 + assert gruntz(exp((log(log(x + exp(log(x)*log(log(x)))))) + / (log(log(log(exp(x) + x + log(x)))))), x, oo) == E + # Another + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(x)), x, oo) is oo + + +def test_gruntz_evaluation_slow(): + _sskip() + # 8.4 + assert gruntz(exp(exp(exp(x)/(1 - 1/x))) + - exp(exp(exp(x)/(1 - 1/x - log(x)**(-log(x))))), x, oo) is -oo + # 8.18 + assert gruntz((exp(exp(-x/(1 + exp(-x))))*exp(-x/(1 + exp(-x/(1 + exp(-x))))) + *exp(exp(-x + exp(-x/(1 + exp(-x)))))) + / (exp(-x/(1 + exp(-x))))**2 - exp(x) + x, x, oo) == 2 + + +@slow +def test_gruntz_eval_special(): + # Gruntz, p. 126 + assert gruntz(exp(x)*(sin(1/x + exp(-x)) - sin(1/x + exp(-x**2))), x, oo) == 1 + assert gruntz((erf(x - exp(-exp(x))) - erf(x)) * exp(exp(x)) * exp(x**2), + x, oo) == -2/sqrt(pi) + assert gruntz(exp(exp(x)) * (exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x))), + x, oo) == 1 + assert gruntz(exp(x)*(gamma(x + exp(-x)) - gamma(x)), x, oo) is oo + assert gruntz(exp(exp(digamma(digamma(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(exp(exp(digamma(log(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(digamma(digamma(digamma(x))), x, oo) is oo + assert gruntz(loggamma(loggamma(x)), x, oo) is oo + assert gruntz(((gamma(x + 1/gamma(x)) - gamma(x))/log(x) - cos(1/x)) + * x*log(x), x, oo) == Rational(-1, 2) + assert gruntz(x * (gamma(x - 1/gamma(x)) - gamma(x) + log(x)), x, oo) \ + == S.Half + assert gruntz((gamma(x + 1/gamma(x)) - gamma(x)) / log(x), x, oo) == 1 + + +def test_gruntz_eval_special_slow(): + _sskip() + assert gruntz(gamma(x + 1)/sqrt(2*pi) + - exp(-x)*(x**(x + S.Half) + x**(x - S.Half)/12), x, oo) is oo + assert gruntz(exp(exp(exp(digamma(digamma(digamma(x))))))/x, x, oo) == 0 + + +@XFAIL +def test_grunts_eval_special_slow_sometimes_fail(): + _sskip() + # XXX This sometimes fails!!! + assert gruntz(exp(gamma(x - exp(-x))*exp(1/x)) - exp(gamma(x)), x, oo) is oo + + +def test_gruntz_Ei(): + assert gruntz((Ei(x - exp(-exp(x))) - Ei(x)) *exp(-x)*exp(exp(x))*x, x, oo) == -1 + + +@XFAIL +def test_gruntz_eval_special_fail(): + # TODO zeta function series + assert gruntz( + exp((log(2) + 1)*x) * (zeta(x + exp(-x)) - zeta(x)), x, oo) == -log(2) + + # TODO 8.35 - 8.37 (bessel, max-min) + + +def test_gruntz_hyperbolic(): + assert gruntz(cosh(x), x, oo) is oo + assert gruntz(cosh(x), x, -oo) is oo + assert gruntz(sinh(x), x, oo) is oo + assert gruntz(sinh(x), x, -oo) is -oo + assert gruntz(2*cosh(x)*exp(x), x, oo) is oo + assert gruntz(2*cosh(x)*exp(x), x, -oo) == 1 + assert gruntz(2*sinh(x)*exp(x), x, oo) is oo + assert gruntz(2*sinh(x)*exp(x), x, -oo) == -1 + assert gruntz(tanh(x), x, oo) == 1 + assert gruntz(tanh(x), x, -oo) == -1 + assert gruntz(coth(x), x, oo) == 1 + assert gruntz(coth(x), x, -oo) == -1 + + +def test_compare1(): + assert compare(2, x, x) == "<" + assert compare(x, exp(x), x) == "<" + assert compare(exp(x), exp(x**2), x) == "<" + assert compare(exp(x**2), exp(exp(x)), x) == "<" + assert compare(1, exp(exp(x)), x) == "<" + + assert compare(x, 2, x) == ">" + assert compare(exp(x), x, x) == ">" + assert compare(exp(x**2), exp(x), x) == ">" + assert compare(exp(exp(x)), exp(x**2), x) == ">" + assert compare(exp(exp(x)), 1, x) == ">" + + assert compare(2, 3, x) == "=" + assert compare(3, -5, x) == "=" + assert compare(2, -5, x) == "=" + + assert compare(x, x**2, x) == "=" + assert compare(x**2, x**3, x) == "=" + assert compare(x**3, 1/x, x) == "=" + assert compare(1/x, x**m, x) == "=" + assert compare(x**m, -x, x) == "=" + + assert compare(exp(x), exp(-x), x) == "=" + assert compare(exp(-x), exp(2*x), x) == "=" + assert compare(exp(2*x), exp(x)**2, x) == "=" + assert compare(exp(x)**2, exp(x + exp(-x)), x) == "=" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + + assert compare(exp(x**2), 1/exp(x**2), x) == "=" + + +def test_compare2(): + assert compare(exp(x), x**5, x) == ">" + assert compare(exp(x**2), exp(x)**2, x) == ">" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + assert compare(exp(x + exp(-x)), exp(x), x) == "=" + assert compare(exp(x + exp(-x)), exp(-x), x) == "=" + assert compare(exp(-x), x, x) == ">" + assert compare(x, exp(-x), x) == "<" + assert compare(exp(x + 1/x), x, x) == ">" + assert compare(exp(-exp(x)), exp(x), x) == ">" + assert compare(exp(exp(-exp(x)) + x), exp(-exp(x)), x) == "<" + + +def test_compare3(): + assert compare(exp(exp(x)), exp(x + exp(-exp(x))), x) == ">" + + +def test_sign1(): + assert sign(Rational(0), x) == 0 + assert sign(Rational(3), x) == 1 + assert sign(Rational(-5), x) == -1 + assert sign(log(x), x) == 1 + assert sign(exp(-x), x) == 1 + assert sign(exp(x), x) == 1 + assert sign(-exp(x), x) == -1 + assert sign(3 - 1/x, x) == 1 + assert sign(-3 - 1/x, x) == -1 + assert sign(sin(1/x), x) == 1 + assert sign((x**Integer(2)), x) == 1 + assert sign(x**2, x) == 1 + assert sign(x**5, x) == 1 + + +def test_sign2(): + assert sign(x, x) == 1 + assert sign(-x, x) == -1 + y = Symbol("y", positive=True) + assert sign(y, x) == 1 + assert sign(-y, x) == -1 + assert sign(y*x, x) == 1 + assert sign(-y*x, x) == -1 + + +def mmrv(a, b): + return set(mrv(a, b)[0].keys()) + + +def test_mrv1(): + assert mmrv(x, x) == {x} + assert mmrv(x + 1/x, x) == {x} + assert mmrv(x**2, x) == {x} + assert mmrv(log(x), x) == {x} + assert mmrv(exp(x), x) == {exp(x)} + assert mmrv(exp(-x), x) == {exp(-x)} + assert mmrv(exp(x**2), x) == {exp(x**2)} + assert mmrv(-exp(1/x), x) == {x} + assert mmrv(exp(x + 1/x), x) == {exp(x + 1/x)} + + +def test_mrv2a(): + assert mmrv(exp(x + exp(-exp(x))), x) == {exp(-exp(x))} + assert mmrv(exp(x + exp(-x)), x) == {exp(x + exp(-x)), exp(-x)} + assert mmrv(exp(1/x + exp(-x)), x) == {exp(-x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2b(): + assert mmrv(exp(x + exp(-x**2)), x) == {exp(-x**2)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2c(): + assert mmrv( + exp(-x + 1/x**2) - exp(x + 1/x), x) == {exp(x + 1/x), exp(1/x**2 - x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv3(): + assert mmrv(exp(x**2) + x*exp(x) + log(x)**x/x, x) == {exp(x**2)} + assert mmrv( + exp(x)*(exp(1/x + exp(-x)) - exp(1/x)), x) == {exp(x), exp(-x)} + assert mmrv(log( + x**2 + 2*exp(exp(3*x**3*log(x)))), x) == {exp(exp(3*x**3*log(x)))} + assert mmrv(log(x - log(x))/log(x), x) == {x} + assert mmrv( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == {exp(x), exp(-x)} + assert mmrv( + 1/exp(-x + exp(-x)) - exp(x), x) == {exp(x), exp(-x), exp(x - exp(-x))} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)), x) == {exp(x*exp(x))} + assert mmrv(exp(exp(log(log(x) + 1/x))), x) == {x} + + +def test_mrv4(): + ln = log + assert mmrv((ln(ln(x) + ln(ln(x))) - ln(ln(x)))/ln(ln(x) + ln(ln(ln(x))))*ln(x), + x) == {x} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)) - exp(exp(log(log(x) + 1/x))), x) == \ + {exp(x*exp(x))} + + +def mrewrite(a, b, c): + return rewrite(a[1], a[0], b, c) + + +def test_rewrite1(): + e = exp(x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x) + e = exp(x**2) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x**2) + e = exp(x + 1/x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x - 1/x) + e = 1/exp(-x + exp(-x)) - exp(x) + assert mrewrite(mrv(e, x), x, m) == ((-m*exp(m) + m)*exp(-m)/m**2, -x) + + +def test_rewrite2(): + e = exp(x)*log(log(exp(x))) + assert mmrv(e, x) == {exp(x)} + assert mrewrite(mrv(e, x), x, m) == (1/m*log(x), -x) + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_rewrite3(): + e = exp(-x + 1/x**2) - exp(x + 1/x) + #both of these are correct and should be equivalent: + assert mrewrite(mrv(e, x), x, m) in [(-1/m + m*exp( + (x**2 + x)/x**3), -x - 1/x), ((m**2 - exp((x**2 + x)/x**3))/m, x**(-2) - x)] + + +def test_mrv_leadterm1(): + assert mrv_leadterm(-exp(1/x), x) == (-1, 0) + assert mrv_leadterm(1/exp(-x + exp(-x)) - exp(x), x) == (-1, 0) + assert mrv_leadterm( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == (-exp(1/x), 0) + + +def test_mrv_leadterm2(): + #Gruntz: p51, 3.25 + assert mrv_leadterm((log(exp(x) + x) - x)/log(exp(x) + log(x))*exp(x), x) == \ + (1, 0) + + +def test_mrv_leadterm3(): + #Gruntz: p56, 3.27 + assert mmrv(exp(-x + exp(-x)*exp(-x*log(x))), x) == {exp(-x - x*log(x))} + assert mrv_leadterm(exp(-x + exp(-x)*exp(-x*log(x))), x) == (exp(-x), 0) + + +def test_limit1(): + assert gruntz(x, x, oo) is oo + assert gruntz(x, x, -oo) is -oo + assert gruntz(-x, x, oo) is -oo + assert gruntz(x**2, x, -oo) is oo + assert gruntz(-x**2, x, oo) is -oo + assert gruntz(x*log(x), x, 0, dir="+") == 0 + assert gruntz(1/x, x, oo) == 0 + assert gruntz(exp(x), x, oo) is oo + assert gruntz(-exp(x), x, oo) is -oo + assert gruntz(exp(x)/x, x, oo) is oo + assert gruntz(1/x - exp(-x), x, oo) == 0 + assert gruntz(x + 1/x, x, oo) is oo + + +def test_limit2(): + assert gruntz(x**x, x, 0, dir="+") == 1 + assert gruntz((exp(x) - 1)/x, x, 0) == 1 + assert gruntz(1 + 1/x, x, oo) == 1 + assert gruntz(-exp(1/x), x, oo) == -1 + assert gruntz(x + exp(-x), x, oo) is oo + assert gruntz(x + exp(-x**2), x, oo) is oo + assert gruntz(x + exp(-exp(x)), x, oo) is oo + assert gruntz(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_limit3(): + a = Symbol('a') + assert gruntz(x - log(1 + exp(x)), x, oo) == 0 + assert gruntz(x - log(a + exp(x)), x, oo) == 0 + assert gruntz(exp(x)/(1 + exp(x)), x, oo) == 1 + assert gruntz(exp(x)/(a + exp(x)), x, oo) == 1 + + +def test_limit4(): + #issue 3463 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + #issue 3463 + assert gruntz((3**(1/x) + 5**(1/x))**x, x, 0) == 5 + + +@XFAIL +def test_MrvTestCase_page47_ex3_21(): + h = exp(-x/(1 + exp(-x))) + expr = exp(h)*exp(-x/(1 + h))*exp(exp(-x + h))/h**2 - exp(x) + x + assert mmrv(expr, x) == {1/h, exp(-x), exp(x), exp(x - h), exp(x/(1 + h))} + + +def test_gruntz_I(): + y = Symbol("y") + assert gruntz(I*x, x, oo) == I*oo + assert gruntz(y*I*x, x, oo) == y*I*oo + assert gruntz(y*3*I*x, x, oo) == y*I*oo + assert gruntz(y*3*sin(I)*x, x, oo) == y*I*oo + + +def test_issue_4814(): + assert gruntz((x + 1)**(1/log(x + 1)), x, oo) == E + + +def test_intractable(): + assert gruntz(1/gamma(x), x, oo) == 0 + assert gruntz(1/loggamma(x), x, oo) == 0 + assert gruntz(gamma(x)/loggamma(x), x, oo) is oo + assert gruntz(exp(gamma(x))/gamma(x), x, oo) is oo + assert gruntz(gamma(x), x, 3) == 2 + assert gruntz(gamma(Rational(1, 7) + 1/x), x, oo) == gamma(Rational(1, 7)) + assert gruntz(log(x**x)/log(gamma(x)), x, oo) == 1 + assert gruntz(log(gamma(gamma(x)))/exp(x), x, oo) is oo + + +def test_aseries_trig(): + assert cancel(gruntz(1/log(atan(x)), x, oo) + - 1/(log(pi) + log(S.Half))) == 0 + assert gruntz(1/acot(x), x, -oo) is -oo + + +def test_exp_log_series(): + assert gruntz(x/log(log(x*exp(x))), x, oo) is oo + + +def test_issue_3644(): + assert gruntz(((x**7 + x + 1)/(2**x + x**2))**(-1/x), x, oo) == 2 + + +def test_issue_6843(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert gruntz(r, x, 1).simplify() == n/2 + + +def test_issue_4190(): + assert gruntz(x - gamma(1/x), x, oo) == S.EulerGamma + + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + \ + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + assert gruntz(expr.subs(c, m), n, oo) == 1 + # fail: + assert gruntz(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_4109(): + assert gruntz(1/gamma(x), x, 0) == 0 + assert gruntz(x*gamma(x), x, 0) == 1 + + +def test_issue_6682(): + assert gruntz(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_7096(): + from sympy.functions import sign + assert gruntz(x**-pi, x, 0, dir='-') == oo*sign((-1)**(-pi)) + + +def test_issue_7391_8166(): + f = Function('f') + # limit should depend on the continuity of the expression at the point passed + raises(ValueError, lambda: gruntz(f(x), x, 4)) + raises(ValueError, lambda: gruntz(x*f(x)**2/(x**2 + f(x)**4), x, 0)) + + +def test_issue_24210_25885(): + eq = exp(x)/(1+1/x)**x**2 + ans = sqrt(E) + assert gruntz(eq, x, oo) == ans + assert gruntz(1/eq, x, oo) == 1/ans diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_kauers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_kauers.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb9044b33416bc38879649b258150ba2906250c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_kauers.py @@ -0,0 +1,23 @@ +from sympy.series.kauers import finite_diff +from sympy.series.kauers import finite_diff_kauers +from sympy.abc import x, y, z, m, n, w +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.concrete.summations import Sum + + +def test_finite_diff(): + assert finite_diff(x**2 + 2*x + 1, x) == 2*x + 3 + assert finite_diff(y**3 + 2*y**2 + 3*y + 5, y) == 3*y**2 + 7*y + 6 + assert finite_diff(z**2 - 2*z + 3, z) == 2*z - 1 + assert finite_diff(w**2 + 3*w - 2, w) == 2*w + 4 + assert finite_diff(sin(x), x, pi/6) == -sin(x) + sin(x + pi/6) + assert finite_diff(cos(y), y, pi/3) == -cos(y) + cos(y + pi/3) + assert finite_diff(x**2 - 2*x + 3, x, 2) == 4*x + assert finite_diff(n**2 - 2*n + 3, n, 3) == 6*n + 3 + +def test_finite_diff_kauers(): + assert finite_diff_kauers(Sum(x**2, (x, 1, n))) == (n + 1)**2 + assert finite_diff_kauers(Sum(y, (y, 1, m))) == (m + 1) + assert finite_diff_kauers(Sum((x*y), (x, 1, m), (y, 1, n))) == (m + 1)*(n + 1) + assert finite_diff_kauers(Sum((x*y**2), (x, 1, m), (y, 1, n))) == (n + 1)**2*(m + 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limits.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limits.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3ab7683f057424f1c3215a06381d27687710dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limits.py @@ -0,0 +1,1440 @@ +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, diff) +from sympy.core import EulerGamma, GoldenRatio +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.numbers import fibonacci +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.elementary.complexes import (Abs, re, sign) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, asinh, acosh, acoth, acsch, asech, tanh, sinh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, acot, acsc, asec, asin, + atan, cos, cot, csc, sec, sin, tan) +from sympy.functions.special.bessel import (besseli, bessely, besselj, besselk) +from sympy.functions.special.error_functions import (Ei, erf, erfc, erfi, fresnelc, fresnels) +from sympy.functions.special.gamma_functions import (digamma, gamma, uppergamma) +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import (Integral, integrate) +from sympy.series.limits import (Limit, limit) +from sympy.simplify.simplify import (logcombine, simplify) +from sympy.simplify.hyperexpand import hyperexpand + +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.mul import Mul +from sympy.series.limits import heuristics +from sympy.series.order import Order +from sympy.testing.pytest import XFAIL, raises + +from sympy import elliptic_e, elliptic_k + +from sympy.abc import x, y, z, k +n = Symbol('n', integer=True, positive=True) + + +def test_basic1(): + assert limit(x, x, oo) is oo + assert limit(x, x, -oo) is -oo + assert limit(-x, x, oo) is -oo + assert limit(x**2, x, -oo) is oo + assert limit(-x**2, x, oo) is -oo + assert limit(x*log(x), x, 0, dir="+") == 0 + assert limit(1/x, x, oo) == 0 + assert limit(exp(x), x, oo) is oo + assert limit(-exp(x), x, oo) is -oo + assert limit(exp(x)/x, x, oo) is oo + assert limit(1/x - exp(-x), x, oo) == 0 + assert limit(x + 1/x, x, oo) is oo + assert limit(x - x**2, x, oo) is -oo + assert limit((1 + x)**(1 + sqrt(2)), x, 0) == 1 + assert limit((1 + x)**oo, x, 0) == Limit((x + 1)**oo, x, 0) + assert limit((1 + x)**oo, x, 0, dir='-') == Limit((x + 1)**oo, x, 0, dir='-') + assert limit((1 + x + y)**oo, x, 0, dir='-') == Limit((1 + x + y)**oo, x, 0, dir='-') + assert limit(y/x/log(x), x, 0) == -oo*sign(y) + assert limit(cos(x + y)/x, x, 0) == sign(cos(y))*oo + assert limit(gamma(1/x + 3), x, oo) == 2 + assert limit(S.NaN, x, -oo) is S.NaN + assert limit(Order(2)*x, x, S.NaN) is S.NaN + assert limit(1/(x - 1), x, 1, dir="+") is oo + assert limit(1/(x - 1), x, 1, dir="-") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="+") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="-") is oo + assert limit(1/sin(x), x, pi, dir="+") is -oo + assert limit(1/sin(x), x, pi, dir="-") is oo + assert limit(1/cos(x), x, pi/2, dir="+") is -oo + assert limit(1/cos(x), x, pi/2, dir="-") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="+") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="-") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="+") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="-") is oo + assert limit(tan(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sec(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + + # test bi-directional limits + assert limit(sin(x)/x, x, 0, dir="+-") == 1 + assert limit(x**2, x, 0, dir="+-") == 0 + assert limit(1/x**2, x, 0, dir="+-") is oo + + # test failing bi-directional limits + assert limit(1/x, x, 0, dir="+-") is zoo + # approaching 0 + # from dir="+" + assert limit(1 + 1/x, x, 0) is oo + # from dir='-' + # Add + assert limit(1 + 1/x, x, 0, dir='-') is -oo + # Pow + assert limit(x**(-2), x, 0, dir='-') is oo + assert limit(x**(-3), x, 0, dir='-') is -oo + assert limit(1/sqrt(x), x, 0, dir='-') == (-oo)*I + assert limit(x**2, x, 0, dir='-') == 0 + assert limit(sqrt(x), x, 0, dir='-') == 0 + assert limit(x**-pi, x, 0, dir='-') == -oo*(-1)**(1 - pi) + assert limit((1 + cos(x))**oo, x, 0) == Limit((cos(x) + 1)**oo, x, 0) + + # test pull request 22491 + assert limit(1/asin(x), x, 0, dir = '+') == oo + assert limit(1/asin(x), x, 0, dir = '-') == -oo + assert limit(1/sinh(x), x, 0, dir = '+') == oo + assert limit(1/sinh(x), x, 0, dir = '-') == -oo + assert limit(log(1/x) + 1/sin(x), x, 0, dir = '+') == oo + assert limit(log(1/x) + 1/x, x, 0, dir = '+') == oo + + +def test_basic2(): + assert limit(x**x, x, 0, dir="+") == 1 + assert limit((exp(x) - 1)/x, x, 0) == 1 + assert limit(1 + 1/x, x, oo) == 1 + assert limit(-exp(1/x), x, oo) == -1 + assert limit(x + exp(-x), x, oo) is oo + assert limit(x + exp(-x**2), x, oo) is oo + assert limit(x + exp(-exp(x)), x, oo) is oo + assert limit(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_basic3(): + assert limit(1/x, x, 0, dir="+") is oo + assert limit(1/x, x, 0, dir="-") is -oo + + +def test_basic4(): + assert limit(2*x + y*x, x, 0) == 0 + assert limit(2*x + y*x, x, 1) == 2 + y + assert limit(2*x**8 + y*x**(-3), x, -2) == 512 - y/8 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 + assert integrate(1/(x**3 + 1), (x, 0, oo)) == 2*pi*sqrt(3)/9 + + +def test_log(): + # https://github.com/sympy/sympy/issues/21598 + a, b, c = symbols('a b c', positive=True) + A = log(a/b) - (log(a) - log(b)) + assert A.limit(a, oo) == 0 + assert (A * c).limit(a, oo) == 0 + + tau, x = symbols('tau x', positive=True) + # The value of manualintegrate in the issue + expr = tau**2*((tau - 1)*(tau + 1)*log(x + 1)/(tau**2 + 1)**2 + 1/((tau**2\ + + 1)*(x + 1)) - (-2*tau*atan(x/tau) + (tau**2/2 - 1/2)*log(tau**2\ + + x**2))/(tau**2 + 1)**2) + assert limit(expr, x, oo) == pi*tau**3/(tau**2 + 1)**2 + + +def test_piecewise(): + # https://github.com/sympy/sympy/issues/18363 + assert limit((real_root(x - 6, 3) + 2)/(x + 2), x, -2, '+') == Rational(1, 12) + + +def test_piecewise2(): + func1 = 2*sqrt(x)*Piecewise(((4*x - 2)/Abs(sqrt(4 - 4*(2*x - 1)**2)), 4*x - 2\ + >= 0), ((2 - 4*x)/Abs(sqrt(4 - 4*(2*x - 1)**2)), True)) + func2 = Piecewise((x**2/2, x <= 0.5), (x/2 - 0.125, True)) + func3 = Piecewise(((x - 9) / 5, x < -1), ((x - 9) / 5, x > 4), (sqrt(Abs(x - 3)), True)) + assert limit(func1, x, 0) == 1 + assert limit(func2, x, 0) == 0 + assert limit(func3, x, -1) == 2 + + +def test_basic5(): + class my(Function): + @classmethod + def eval(cls, arg): + if arg is S.Infinity: + return S.NaN + assert limit(my(x), x, oo) == Limit(my(x), x, oo) + + +def test_issue_3885(): + assert limit(x*y + x*z, z, 2) == x*y + 2*x + + +def test_Limit(): + assert Limit(sin(x)/x, x, 0) != 1 + assert Limit(sin(x)/x, x, 0).doit() == 1 + assert Limit(x, x, 0, dir='+-').args == (x, x, 0, Symbol('+-')) + + +def test_floor(): + assert limit(floor(x), x, -2, "+") == -2 + assert limit(floor(x), x, -2, "-") == -3 + assert limit(floor(x), x, -1, "+") == -1 + assert limit(floor(x), x, -1, "-") == -2 + assert limit(floor(x), x, 0, "+") == 0 + assert limit(floor(x), x, 0, "-") == -1 + assert limit(floor(x), x, 1, "+") == 1 + assert limit(floor(x), x, 1, "-") == 0 + assert limit(floor(x), x, 2, "+") == 2 + assert limit(floor(x), x, 2, "-") == 1 + assert limit(floor(x), x, 248, "+") == 248 + assert limit(floor(x), x, 248, "-") == 247 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*floor(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(floor(x + 1/2) - floor(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + # test issue 9158 + assert limit(floor(atan(x)), x, oo) == 1 + assert limit(floor(atan(x)), x, -oo) == -2 + assert limit(ceiling(atan(x)), x, oo) == 2 + assert limit(ceiling(atan(x)), x, -oo) == -1 + + +def test_floor_requires_robust_assumptions(): + assert limit(floor(sin(x)), x, 0, "+") == 0 + assert limit(floor(sin(x)), x, 0, "-") == -1 + assert limit(floor(cos(x)), x, 0, "+") == 0 + assert limit(floor(cos(x)), x, 0, "-") == 0 + assert limit(floor(5 + sin(x)), x, 0, "+") == 5 + assert limit(floor(5 + sin(x)), x, 0, "-") == 4 + assert limit(floor(5 + cos(x)), x, 0, "+") == 5 + assert limit(floor(5 + cos(x)), x, 0, "-") == 5 + + +def test_ceiling(): + assert limit(ceiling(x), x, -2, "+") == -1 + assert limit(ceiling(x), x, -2, "-") == -2 + assert limit(ceiling(x), x, -1, "+") == 0 + assert limit(ceiling(x), x, -1, "-") == -1 + assert limit(ceiling(x), x, 0, "+") == 1 + assert limit(ceiling(x), x, 0, "-") == 0 + assert limit(ceiling(x), x, 1, "+") == 2 + assert limit(ceiling(x), x, 1, "-") == 1 + assert limit(ceiling(x), x, 2, "+") == 3 + assert limit(ceiling(x), x, 2, "-") == 2 + assert limit(ceiling(x), x, 248, "+") == 249 + assert limit(ceiling(x), x, 248, "-") == 248 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*ceiling(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(ceiling(x + 1/2) - ceiling(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + +def test_ceiling_requires_robust_assumptions(): + assert limit(ceiling(sin(x)), x, 0, "+") == 1 + assert limit(ceiling(sin(x)), x, 0, "-") == 0 + assert limit(ceiling(cos(x)), x, 0, "+") == 1 + assert limit(ceiling(cos(x)), x, 0, "-") == 1 + assert limit(ceiling(5 + sin(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + sin(x)), x, 0, "-") == 5 + assert limit(ceiling(5 + cos(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + cos(x)), x, 0, "-") == 6 + + +def test_frac(): + assert limit(frac(x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit(frac(x)**x, x, oo) == AccumBounds(0, oo) # wolfram gives (0, 1) + assert limit(frac(sin(x)), x, 0, "+") == 0 + assert limit(frac(sin(x)), x, 0, "-") == 1 + assert limit(frac(cos(x)), x, 0, "+-") == 1 + assert limit(frac(x**2), x, 0, "+-") == 0 + raises(ValueError, lambda: limit(frac(x), x, 0, '+-')) + assert limit(frac(-2*x + 1), x, 0, "+") == 1 + assert limit(frac(-2*x + 1), x, 0, "-") == 0 + assert limit(frac(x + S.Half), x, 0, "+-") == S(1)/2 + assert limit(frac(1/x), x, 0) == AccumBounds(0, 1) + + +def test_issue_14355(): + assert limit(floor(sin(x)/x), x, 0, '+') == 0 + assert limit(floor(sin(x)/x), x, 0, '-') == 0 + # test comment https://github.com/sympy/sympy/issues/14355#issuecomment-372121314 + assert limit(floor(-tan(x)/x), x, 0, '+') == -2 + assert limit(floor(-tan(x)/x), x, 0, '-') == -2 + + +def test_atan(): + x = Symbol("x", real=True) + assert limit(atan(x)*sin(1/x), x, 0) == 0 + assert limit(atan(x) + sqrt(x + 1) - sqrt(x), x, oo) == pi/2 + + +def test_set_signs(): + assert limit(abs(x), x, 0) == 0 + assert limit(abs(sin(x)), x, 0) == 0 + assert limit(abs(cos(x)), x, 0) == 1 + assert limit(abs(sin(x + 1)), x, 0) == sin(1) + + # https://github.com/sympy/sympy/issues/9449 + assert limit((Abs(x + y) - Abs(x - y))/(2*x), x, 0) == sign(y) + + # https://github.com/sympy/sympy/issues/12398 + assert limit(Abs(log(x)/x**3), x, oo) == 0 + assert limit(x*(Abs(log(x)/x**3)/Abs(log(x + 1)/(x + 1)**3) - 1), x, oo) == 3 + + # https://github.com/sympy/sympy/issues/18501 + assert limit(Abs(log(x - 1)**3 - 1), x, 1, '+') == oo + + # https://github.com/sympy/sympy/issues/18997 + assert limit(Abs(log(x)), x, 0) == oo + assert limit(Abs(log(Abs(x))), x, 0) == oo + + # https://github.com/sympy/sympy/issues/19026 + z = Symbol('z', positive=True) + assert limit(Abs(log(z) + 1)/log(z), z, oo) == 1 + + # https://github.com/sympy/sympy/issues/20704 + assert limit(z*(Abs(1/z + y) - Abs(y - 1/z))/2, z, 0) == 0 + + # https://github.com/sympy/sympy/issues/21606 + assert limit(cos(z)/sign(z), z, pi, '-') == -1 + + +def test_heuristic(): + x = Symbol("x", real=True) + assert heuristics(sin(1/x) + atan(x), x, 0, '+') == AccumBounds(-1, 1) + assert limit(log(2 + sqrt(atan(x))*sqrt(sin(1/x))), x, 0) == log(2) + + +def test_issue_3871(): + z = Symbol("z", positive=True) + f = -1/z*exp(-z*x) + assert limit(f, x, oo) == 0 + assert f.limit(x, oo) == 0 + + +def test_exponential(): + n = Symbol('n') + x = Symbol('x', real=True) + assert limit((1 + x/n)**n, n, oo) == exp(x) + assert limit((1 + x/(2*n))**n, n, oo) == exp(x/2) + assert limit((1 + x/(2*n + 1))**n, n, oo) == exp(x/2) + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) + assert limit(1 + (1 + 1/x)**x, x, oo) == 1 + S.Exp1 + assert limit((2 + 6*x)**x/(6*x)**x, x, oo) == exp(S('1/3')) + + +def test_exponential2(): + n = Symbol('n') + assert limit((1 + x/(n + sin(n)))**n, n, oo) == exp(x) + + +def test_doit(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + assert l.doit() is oo + + +def test_series_AccumBounds(): + assert limit(sin(k) - sin(k + 1), k, oo) == AccumBounds(-2, 2) + assert limit(cos(k) - cos(k + 1) + 1, k, oo) == AccumBounds(-1, 3) + + # not the exact bound + assert limit(sin(k) - sin(k)*cos(k), k, oo) == AccumBounds(-2, 2) + + # test for issue #9934 + lo = (-3 + cos(1))/2 + hi = (1 + cos(1))/2 + t1 = Mul(AccumBounds(lo, hi), 1/(-1 + cos(1)), evaluate=False) + assert limit(simplify(Sum(cos(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t1 + + t2 = Mul(AccumBounds(-1 + sin(1)/2, sin(1)/2 + 1), 1/(1 - cos(1))) + assert limit(simplify(Sum(sin(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t2 + + assert limit(((sin(x) + 1)/2)**x, x, oo) == AccumBounds(0, oo) # wolfram says 0 + + # https://github.com/sympy/sympy/issues/12312 + e = 2**(-x)*(sin(x) + 1)**x + assert limit(e, x, oo) == AccumBounds(0, oo) + + +def test_bessel_functions_at_infinity(): + # Pull Request 23844 implements limits for all bessel and modified bessel + # functions approaching infinity along any direction i.e. abs(z0) tends to oo + + assert limit(besselj(1, x), x, oo) == 0 + assert limit(besselj(1, x), x, -oo) == 0 + assert limit(besselj(1, x), x, I*oo) == oo*I + assert limit(besselj(1, x), x, -I*oo) == -oo*I + assert limit(bessely(1, x), x, oo) == 0 + assert limit(bessely(1, x), x, -oo) == 0 + assert limit(bessely(1, x), x, I*oo) == -oo + assert limit(bessely(1, x), x, -I*oo) == -oo + assert limit(besseli(1, x), x, oo) == oo + assert limit(besseli(1, x), x, -oo) == -oo + assert limit(besseli(1, x), x, I*oo) == 0 + assert limit(besseli(1, x), x, -I*oo) == 0 + assert limit(besselk(1, x), x, oo) == 0 + assert limit(besselk(1, x), x, -oo) == -oo*I + assert limit(besselk(1, x), x, I*oo) == 0 + assert limit(besselk(1, x), x, -I*oo) == 0 + + # test issue 14874 + assert limit(besselk(0, x), x, oo) == 0 + + +@XFAIL +def test_doit2(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + # limit() breaks on the contained Integral. + assert l.doit(deep=False) == l + + +def test_issue_2929(): + assert limit((x * exp(x))/(exp(x) - 1), x, -oo) == 0 + + +def test_issue_3792(): + assert limit((1 - cos(x))/x**2, x, S.Half) == 4 - 4*cos(S.Half) + assert limit(sin(sin(x + 1) + 1), x, 0) == sin(1 + sin(1)) + assert limit(abs(sin(x + 1) + 1), x, 0) == 1 + sin(1) + + +def test_issue_4090(): + assert limit(1/(x + 3), x, 2) == Rational(1, 5) + assert limit(1/(x + pi), x, 2) == S.One/(2 + pi) + assert limit(log(x)/(x**2 + 3), x, 2) == log(2)/7 + assert limit(log(x)/(x**2 + pi), x, 2) == log(2)/(4 + pi) + + +def test_issue_4547(): + assert limit(cot(x), x, 0, dir='+') is oo + assert limit(cot(x), x, pi/2, dir='+') == 0 + + +def test_issue_5164(): + assert limit(x**0.5, x, oo) == oo**0.5 is oo + assert limit(x**0.5, x, 16) == 4 # Should this be a float? + assert limit(x**0.5, x, 0) == 0 + assert limit(x**(-0.5), x, oo) == 0 + assert limit(x**(-0.5), x, 4) == S.Half # Should this be a float? + + +def test_issue_5383(): + func = (1.0 * 1 + 1.0 * x)**(1.0 * 1 / x) + assert limit(func, x, 0) == E + + +def test_issue_14793(): + expr = ((x + S(1)/2) * log(x) - x + log(2*pi)/2 - \ + log(factorial(x)) + S(1)/(12*x))*x**3 + assert limit(expr, x, oo) == S(1)/360 + + +def test_issue_5183(): + # using list(...) so py.test can recalculate values + tests = list(product([x, -x], + [-1, 1], + [2, 3, S.Half, Rational(2, 3)], + ['-', '+'])) + results = (oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), oo, + 0, 0, 0, 0, 0, 0, 0, 0, + oo, oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), + 0, 0, 0, 0, 0, 0, 0, 0) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + y, s, e, d = args + eq = y**(s*e) + try: + assert limit(eq, x, 0, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, d, limit(eq, x, 0, dir=d)) + else: + assert None + + +def test_issue_5184(): + assert limit(sin(x)/x, x, oo) == 0 + assert limit(atan(x), x, oo) == pi/2 + assert limit(gamma(x), x, oo) is oo + assert limit(cos(x)/x, x, oo) == 0 + assert limit(gamma(x), x, S.Half) == sqrt(pi) + + r = Symbol('r', real=True) + assert limit(r*sin(1/r), r, 0) == 0 + + +def test_issue_5229(): + assert limit((1 + y)**(1/y) - S.Exp1, y, 0) == 0 + + +def test_issue_4546(): + # using list(...) so py.test can recalculate values + tests = list(product([cot, tan], + [-pi/2, 0, pi/2, pi, pi*Rational(3, 2)], + ['-', '+'])) + results = (0, 0, -oo, oo, 0, 0, -oo, oo, 0, 0, + oo, -oo, 0, 0, oo, -oo, 0, 0, oo, -oo) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + f, l, d = args + eq = f(x) + try: + assert limit(eq, x, l, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, l, d, limit(eq, x, l, dir=d)) + else: + assert None + + +def test_issue_3934(): + assert limit((1 + x**log(3))**(1/x), x, 0) == 1 + assert limit((5**(1/x) + 3**(1/x))**x, x, 0) == 5 + + +def test_issue_5955(): + assert limit((x**16)/(1 + x**16), x, oo) == 1 + assert limit((x**100)/(1 + x**100), x, oo) == 1 + assert limit((x**1885)/(1 + x**1885), x, oo) == 1 + assert limit((x**1000/((x + 1)**1000 + exp(-x))), x, oo) == 1 + + +def test_newissue(): + assert limit(exp(1/sin(x))/exp(cot(x)), x, 0) == 1 + + +def test_extended_real_line(): + assert limit(x - oo, x, oo) == Limit(x - oo, x, oo) + assert limit(1/(x + sin(x)) - oo, x, 0) == Limit(1/(x + sin(x)) - oo, x, 0) + assert limit(oo/x, x, oo) == Limit(oo/x, x, oo) + assert limit(x - oo + 1/x, x, oo) == Limit(x - oo + 1/x, x, oo) + + +@XFAIL +def test_order_oo(): + x = Symbol('x', positive=True) + assert Order(x)*oo != Order(1, x) + assert limit(oo/(x**2 - 4), x, oo) is oo + + +def test_issue_5436(): + raises(NotImplementedError, lambda: limit(exp(x*y), x, oo)) + raises(NotImplementedError, lambda: limit(exp(-x*y), x, oo)) + + +def test_Limit_dir(): + raises(TypeError, lambda: Limit(x, x, 0, dir=0)) + raises(ValueError, lambda: Limit(x, x, 0, dir='0')) + + +def test_polynomial(): + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, oo) == 1 + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, -oo) == 1 + assert limit(x ** Rational(77, 3) / (1 + x ** Rational(77, 3)), x, oo) == 1 + assert limit(x ** 101.1 / (1 + x ** 101.1), x, oo) == 1 + + +def test_rational(): + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, oo) == (z - 1)/(y*z) + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, -oo) == (z - 1)/(y*z) + + +def test_issue_5740(): + assert limit(log(x)*z - log(2*x)*y, x, 0) == oo*sign(y - z) + + +def test_issue_6366(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert limit(r, x, 1).cancel() == n/2 + + +def test_factorial(): + f = factorial(x) + assert limit(f, x, oo) is oo + assert limit(x/f, x, oo) == 0 + # see Stirling's approximation: + # https://en.wikipedia.org/wiki/Stirling's_approximation + assert limit(f/(sqrt(2*pi*x)*(x/E)**x), x, oo) == 1 + assert limit(f, x, -oo) == gamma(-oo) + + +def test_issue_6560(): + e = (5*x**3/4 - x*Rational(3, 4) + (y*(3*x**2/2 - S.Half) + + 35*x**4/8 - 15*x**2/4 + Rational(3, 8))/(2*(y + 1))) + assert limit(e, y, oo) == 5*x**3/4 + 3*x**2/4 - 3*x/4 - Rational(1, 4) + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + raises(NotImplementedError, lambda: limit(expr, n, oo)) + assert limit(expr.subs(c, m), n, oo) == 1 + assert limit(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_7088(): + a = Symbol('a') + assert limit(sqrt(x/(x + a)), x, oo) == 1 + + +def test_branch_cuts(): + assert limit(asin(I*x + 2), x, 0) == pi - asin(2) + assert limit(asin(I*x + 2), x, 0, '-') == asin(2) + assert limit(asin(I*x - 2), x, 0) == -asin(2) + assert limit(asin(I*x - 2), x, 0, '-') == -pi + asin(2) + assert limit(acos(I*x + 2), x, 0) == -acos(2) + assert limit(acos(I*x + 2), x, 0, '-') == acos(2) + assert limit(acos(I*x - 2), x, 0) == acos(-2) + assert limit(acos(I*x - 2), x, 0, '-') == 2*pi - acos(-2) + assert limit(atan(x + 2*I), x, 0) == I*atanh(2) + assert limit(atan(x + 2*I), x, 0, '-') == -pi + I*atanh(2) + assert limit(atan(x - 2*I), x, 0) == pi - I*atanh(2) + assert limit(atan(x - 2*I), x, 0, '-') == -I*atanh(2) + assert limit(atan(1/x), x, 0) == pi/2 + assert limit(atan(1/x), x, 0, '-') == -pi/2 + assert limit(atan(x), x, oo) == pi/2 + assert limit(atan(x), x, -oo) == -pi/2 + assert limit(acot(x + S(1)/2*I), x, 0) == pi - I*acoth(S(1)/2) + assert limit(acot(x + S(1)/2*I), x, 0, '-') == -I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0) == I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0, '-') == -pi + I*acoth(S(1)/2) + assert limit(acot(x), x, 0) == pi/2 + assert limit(acot(x), x, 0, '-') == -pi/2 + assert limit(asec(I*x + S(1)/2), x, 0) == asec(S(1)/2) + assert limit(asec(I*x + S(1)/2), x, 0, '-') == -asec(S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0) == 2*pi - asec(-S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0, '-') == asec(-S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0) == acsc(S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0, '-') == pi - acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0) == -pi + acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0, '-') == -acsc(S(1)/2) + + assert limit(log(I*x - 1), x, 0) == I*pi + assert limit(log(I*x - 1), x, 0, '-') == -I*pi + assert limit(log(-I*x - 1), x, 0) == -I*pi + assert limit(log(-I*x - 1), x, 0, '-') == I*pi + + assert limit(sqrt(I*x - 1), x, 0) == I + assert limit(sqrt(I*x - 1), x, 0, '-') == -I + assert limit(sqrt(-I*x - 1), x, 0) == -I + assert limit(sqrt(-I*x - 1), x, 0, '-') == I + + assert limit(cbrt(I*x - 1), x, 0) == (-1)**(S(1)/3) + assert limit(cbrt(I*x - 1), x, 0, '-') == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0) == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0, '-') == (-1)**(S(1)/3) + + +def test_issue_6364(): + a = Symbol('a') + e = z/(1 - sqrt(1 + z)*sin(a)**2 - sqrt(1 - z)*cos(a)**2) + assert limit(e, z, 0) == 1/(cos(a)**2 - S.Half) + + +def test_issue_6682(): + assert limit(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_4099(): + a = Symbol('a') + assert limit(a/x, x, 0) == oo*sign(a) + assert limit(-a/x, x, 0) == -oo*sign(a) + assert limit(-a*x, x, oo) == -oo*sign(a) + assert limit(a*x, x, oo) == oo*sign(a) + + +def test_issue_4503(): + dx = Symbol('dx') + assert limit((sqrt(1 + exp(x + dx)) - sqrt(1 + exp(x)))/dx, dx, 0) == \ + exp(x)/(2*sqrt(exp(x) + 1)) + + +def test_issue_6052(): + G = meijerg((), (), (1,), (0,), -x) + g = hyperexpand(G) + assert limit(g, x, 0, '+-') == 0 + assert limit(g, x, oo) == -oo + + +def test_issue_7224(): + expr = sqrt(x)*besseli(1,sqrt(8*x)) + assert limit(x*diff(expr, x, x)/expr, x, 0) == 2 + assert limit(x*diff(expr, x, x)/expr, x, 1).evalf() == 2.0 + + +def test_issue_7391_8166(): + f = Function('f') + # limit should depend on the continuity of the expression at the point passed + assert limit(f(x), x, 4) == Limit(f(x), x, 4, dir='+') + assert limit(x*f(x)**2/(x**2 + f(x)**4), x, 0) == Limit(x*f(x)**2/(x**2 + f(x)**4), x, 0, dir='+') + + +def test_issue_8208(): + assert limit(n**(Rational(1, 1e9) - 1), n, oo) == 0 + + +def test_issue_8229(): + assert limit((x**Rational(1, 4) - 2)/(sqrt(x) - 4)**Rational(2, 3), x, 16) == 0 + + +def test_issue_8433(): + d, t = symbols('d t', positive=True) + assert limit(erf(1 - t/d), t, oo) == -1 + + +def test_issue_8481(): + k = Symbol('k', integer=True, nonnegative=True) + lamda = Symbol('lamda', positive=True) + assert limit(lamda**k * exp(-lamda) / factorial(k), k, oo) == 0 + + +def test_issue_8462(): + assert limit(binomial(n, n/2), n, oo) == oo + assert limit(binomial(n, n/2) * 3 ** (-n), n, oo) == 0 + + +def test_issue_8634(): + n = Symbol('n', integer=True, positive=True) + x = Symbol('x') + assert limit(x**n, x, -oo) == oo*sign((-1)**n) + + +def test_issue_8635_18176(): + x = Symbol('x', real=True) + k = Symbol('k', positive=True) + assert limit(x**n - x**(n - 0), x, oo) == 0 + assert limit(x**n - x**(n - 5), x, oo) == oo + assert limit(x**n - x**(n - 2.5), x, oo) == oo + assert limit(x**n - x**(n - k - 1), x, oo) == oo + x = Symbol('x', positive=True) + assert limit(x**n - x**(n - 1), x, oo) == oo + assert limit(x**n - x**(n + 2), x, oo) == -oo + + +def test_issue_8730(): + assert limit(subfactorial(x), x, oo) is oo + + +def test_issue_9252(): + n = Symbol('n', integer=True) + c = Symbol('c', positive=True) + assert limit((log(n))**(n/log(n)) / (1 + c)**n, n, oo) == 0 + # limit should depend on the value of c + raises(NotImplementedError, lambda: limit((log(n))**(n/log(n)) / c**n, n, oo)) + + +def test_issue_9558(): + assert limit(sin(x)**15, x, 0, '-') == 0 + + +def test_issue_10801(): + # make sure limits work with binomial + assert limit(16**k / (k * binomial(2*k, k)**2), k, oo) == pi + + +def test_issue_10976(): + s, x = symbols('s x', real=True) + assert limit(erf(s*x)/erf(s), s, 0) == x + + +def test_issue_9041(): + assert limit(factorial(n) / ((n/exp(1))**n * sqrt(2*pi*n)), n, oo) == 1 + + +def test_issue_9205(): + x, y, a = symbols('x, y, a') + assert Limit(x, x, a).free_symbols == {a} + assert Limit(x, x, a, '-').free_symbols == {a} + assert Limit(x + y, x + y, a).free_symbols == {a} + assert Limit(-x**2 + y, x**2, a).free_symbols == {y, a} + + +def test_issue_9471(): + assert limit(((27**(log(n,3)))/n**3),n,oo) == 1 + assert limit(((27**(log(n,3)+1))/n**3),n,oo) == 27 + + +def test_issue_10382(): + assert limit(fibonacci(n + 1)/fibonacci(n), n, oo) == GoldenRatio + + +def test_issue_11496(): + assert limit(erfc(log(1/x)), x, oo) == 2 + + +def test_issue_11879(): + assert simplify(limit(((x+y)**n-x**n)/y, y, 0)) == n*x**(n-1) + + +def test_limit_with_Float(): + k = symbols("k") + assert limit(1.0 ** k, k, oo) == 1 + assert limit(0.3*1.0**k, k, oo) == Rational(3, 10) + + +def test_issue_10610(): + assert limit(3**x*3**(-x - 1)*(x + 1)**2/x**2, x, oo) == Rational(1, 3) + + +def test_issue_10868(): + assert limit(log(x) + asech(x), x, 0, '+') == log(2) + assert limit(log(x) + asech(x), x, 0, '-') == log(2) + 2*I*pi + raises(ValueError, lambda: limit(log(x) + asech(x), x, 0, '+-')) + assert limit(log(x) + asech(x), x, oo) == oo + assert limit(log(x) + acsch(x), x, 0, '+') == log(2) + assert limit(log(x) + acsch(x), x, 0, '-') == -oo + raises(ValueError, lambda: limit(log(x) + acsch(x), x, 0, '+-')) + assert limit(log(x) + acsch(x), x, oo) == oo + + +def test_issue_6599(): + assert limit((n + cos(n))/n, n, oo) == 1 + + +def test_issue_12555(): + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, -oo) == 2 + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, oo) is oo + + +def test_issue_12769(): + r, z, x = symbols('r z x', real=True) + a, b, s0, K, F0, s, T = symbols('a b s0 K F0 s T', positive=True, real=True) + fx = (F0**b*K**b*r*s0 - sqrt((F0**2*K**(2*b)*a**2*(b - 1) + \ + F0**(2*b)*K**2*a**2*(b - 1) + F0**(2*b)*K**(2*b)*s0**2*(b - 1)*(b**2 - 2*b + 1) - \ + 2*F0**(2*b)*K**(b + 1)*a*r*s0*(b**2 - 2*b + 1) + \ + 2*F0**(b + 1)*K**(2*b)*a*r*s0*(b**2 - 2*b + 1) - \ + 2*F0**(b + 1)*K**(b + 1)*a**2*(b - 1))/((b - 1)*(b**2 - 2*b + 1))))*(b*r - b - r + 1) + + assert fx.subs(K, F0).factor(deep=True) == limit(fx, K, F0).factor(deep=True) + + +def test_issue_13332(): + assert limit(sqrt(30)*5**(-5*x - 1)*(46656*x)**x*(5*x + 2)**(5*x + 5*S.Half) * + (6*x + 2)**(-6*x - 5*S.Half), x, oo) == Rational(25, 36) + + +def test_issue_12564(): + assert limit(x**2 + x*sin(x) + cos(x), x, -oo) is oo + assert limit(x**2 + x*sin(x) + cos(x), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, oo) is oo + assert limit(((x + sin(x))**2).expand(), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, -oo) is oo + assert limit(((x + sin(x))**2).expand(), x, -oo) is oo + + +def test_issue_14456(): + raises(NotImplementedError, lambda: Limit(exp(x), x, zoo).doit()) + raises(NotImplementedError, lambda: Limit(x**2/(x+1), x, zoo).doit()) + + +def test_issue_14411(): + assert limit(3*sec(4*pi*x - x/3), x, 3*pi/(24*pi - 2)) is -oo + + +def test_issue_13382(): + assert limit(x*(((x + 1)**2 + 1)/(x**2 + 1) - 1), x, oo) == 2 + + +def test_issue_13403(): + assert limit(x*(-1 + (x + log(x + 1) + 1)/(x + log(x))), x, oo) == 1 + + +def test_issue_13416(): + assert limit((-x**3*log(x)**3 + (x - 1)*(x + 1)**2*log(x + 1)**3)/(x**2*log(x)**3), x, oo) == 1 + + +def test_issue_13462(): + assert limit(n**2*(2*n*(-(1 - 1/(2*n))**x + 1) - x - (-x**2/4 + x/4)/n), n, oo) == x**3/24 - x**2/8 + x/12 + + +def test_issue_13750(): + a = Symbol('a') + assert limit(erf(a - x), x, oo) == -1 + assert limit(erf(sqrt(x) - x), x, oo) == -1 + + +def test_issue_14276(): + assert isinstance(limit(sin(x)**log(x), x, oo), Limit) + assert isinstance(limit(sin(x)**cos(x), x, oo), Limit) + assert isinstance(limit(sin(log(cos(x))), x, oo), Limit) + assert limit((1 + 1/(x**2 + cos(x)))**(x**2 + x), x, oo) == E + + +def test_issue_14514(): + assert limit((1/(log(x)**log(x)))**(1/x), x, oo) == 1 + + +def test_issues_14525(): + assert limit(sin(x)**2 - cos(x) + tan(x)*csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sin(x)**2 - cos(x) + sin(x)*cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cos(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(sin(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(cos(x)**2 - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(tan(x)**2 + sin(x)**2 - cos(x), x, oo) == AccumBounds(-S.One, S.Infinity) + + +def test_issue_14574(): + assert limit(sqrt(x)*cos(x - x**2) / (x + 1), x, oo) == 0 + + +def test_issue_10102(): + assert limit(fresnels(x), x, oo) == S.Half + assert limit(3 + fresnels(x), x, oo) == 3 + S.Half + assert limit(5*fresnels(x), x, oo) == Rational(5, 2) + assert limit(fresnelc(x), x, oo) == S.Half + assert limit(fresnels(x), x, -oo) == Rational(-1, 2) + assert limit(4*fresnelc(x), x, -oo) == -2 + + +def test_issue_14377(): + raises(NotImplementedError, lambda: limit(exp(I*x)*sin(pi*x), x, oo)) + + +def test_issue_15146(): + e = (x/2) * (-2*x**3 - 2*(x**3 - 1) * x**2 * digamma(x**3 + 1) + \ + 2*(x**3 - 1) * x**2 * digamma(x**3 + x + 1) + x + 3) + assert limit(e, x, oo) == S(1)/3 + + +def test_issue_15202(): + e = (2**x*(2 + 2**(-x)*(-2*2**x + x + 2))/(x + 1))**(x + 1) + assert limit(e, x, oo) == exp(1) + + e = (log(x, 2)**7 + 10*x*factorial(x) + 5**x) / (factorial(x + 1) + 3*factorial(x) + 10**x) + assert limit(e, x, oo) == 10 + + +def test_issue_15282(): + assert limit((x**2000 - (x + 1)**2000) / x**1999, x, oo) == -2000 + + +def test_issue_15984(): + assert limit((-x + log(exp(x) + 1))/x, x, oo, dir='-') == 0 + + +def test_issue_13571(): + assert limit(uppergamma(x, 1) / gamma(x), x, oo) == 1 + + +def test_issue_13575(): + assert limit(acos(erfi(x)), x, 1) == acos(erfi(S.One)) + + +def test_issue_17325(): + assert Limit(sin(x)/x, x, 0, dir="+-").doit() == 1 + assert Limit(x**2, x, 0, dir="+-").doit() == 0 + assert Limit(1/x**2, x, 0, dir="+-").doit() is oo + assert Limit(1/x, x, 0, dir="+-").doit() is zoo + + +def test_issue_10978(): + assert LambertW(x).limit(x, 0) == 0 + + +def test_issue_14313_comment(): + assert limit(floor(n/2), n, oo) is oo + + +def test_issue_15323(): + d = ((1 - 1/x)**x).diff(x) + assert limit(d, x, 1, dir='+') == 1 + + +def test_issue_12571(): + assert limit(-LambertW(-log(x))/log(x), x, 1) == 1 + + +def test_issue_14590(): + assert limit((x**3*((x + 1)/x)**x)/((x + 1)*(x + 2)*(x + 3)), x, oo) == exp(1) + + +def test_issue_14393(): + a, b = symbols('a b') + assert limit((x**b - y**b)/(x**a - y**a), x, y) == b*y**(-a + b)/a + + +def test_issue_14556(): + assert limit(factorial(n + 1)**(1/(n + 1)) - factorial(n)**(1/n), n, oo) == exp(-1) + + +def test_issue_14811(): + assert limit(((1 + ((S(2)/3)**(x + 1)))**(2**x))/(2**((S(4)/3)**(x - 1))), x, oo) == oo + + +def test_issue_16222(): + assert limit(exp(x), x, 1000000000) == exp(1000000000) + + +def test_issue_16714(): + assert limit(((x**(x + 1) + (x + 1)**x) / x**(x + 1))**x, x, oo) == exp(exp(1)) + + +def test_issue_16722(): + z = symbols('z', positive=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + z = symbols('z', positive=True, integer=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + + +def test_issue_17431(): + assert limit(((n + 1) + 1) / (((n + 1) + 2) * factorial(n + 1)) * + (n + 2) * factorial(n) / (n + 1), n, oo) == 0 + assert limit((n + 2)**2*factorial(n)/((n + 1)*(n + 3)*factorial(n + 1)) + , n, oo) == 0 + assert limit((n + 1) * factorial(n) / (n * factorial(n + 1)), n, oo) == 0 + + +def test_issue_17671(): + assert limit(Ei(-log(x)) - log(log(x))/x, x, 1) == EulerGamma + + +def test_issue_17751(): + a, b, c, x = symbols('a b c x', positive=True) + assert limit((a + 1)*x - sqrt((a + 1)**2*x**2 + b*x + c), x, oo) == -b/(2*a + 2) + + +def test_issue_17792(): + assert limit(factorial(n)/sqrt(n)*(exp(1)/n)**n, n, oo) == sqrt(2)*sqrt(pi) + + +def test_issue_18118(): + assert limit(sign(sin(x)), x, 0, "-") == -1 + assert limit(sign(sin(x)), x, 0, "+") == 1 + + +def test_issue_18306(): + assert limit(sin(sqrt(x))/sqrt(sin(x)), x, 0, '+') == 1 + + +def test_issue_18378(): + assert limit(log(exp(3*x) + x)/log(exp(x) + x**100), x, oo) == 3 + + +def test_issue_18399(): + assert limit((1 - S(1)/2*x)**(3*x), x, oo) is zoo + assert limit((-x)**x, x, oo) is zoo + + +def test_issue_18442(): + assert limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') == Limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') + + +def test_issue_18452(): + assert limit(abs(log(x))**x, x, 0) == 1 + assert limit(abs(log(x))**x, x, 0, "-") == 1 + + +def test_issue_18473(): + assert limit(sin(x)**(1/x), x, oo) == Limit(sin(x)**(1/x), x, oo, dir='-') + assert limit(cos(x)**(1/x), x, oo) == Limit(cos(x)**(1/x), x, oo, dir='-') + assert limit(tan(x)**(1/x), x, oo) == Limit(tan(x)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 2)**(1/x), x, oo) == 1 + assert limit((sin(x) + 10)**(1/x), x, oo) == 1 + assert limit((cos(x) - 2)**(1/x), x, oo) == Limit((cos(x) - 2)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 1)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit((tan(x)**2)**(2/x) , x, oo) == AccumBounds(0, oo) + assert limit((sin(x)**2)**(1/x), x, oo) == AccumBounds(0, 1) + # Tests for issue #23751 + assert limit((cos(x) + 1)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((sin(x)**2)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((tan(x)**2)**(2/x) , x, -oo) == AccumBounds(0, oo) + + +def test_issue_18482(): + assert limit((2*exp(3*x)/(exp(2*x) + 1))**(1/x), x, oo) == exp(1) + + +def test_issue_18508(): + assert limit(sin(x)/sqrt(1-cos(x)), x, 0) == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='+') == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='-') == -sqrt(2) + + +def test_issue_18521(): + raises(NotImplementedError, lambda: limit(exp((2 - n) * x), x, oo)) + + +def test_issue_18969(): + a, b = symbols('a b', positive=True) + assert limit(LambertW(a), a, b) == LambertW(b) + assert limit(exp(LambertW(a)), a, b) == exp(LambertW(b)) + + +def test_issue_18992(): + assert limit(n/(factorial(n)**(1/n)), n, oo) == exp(1) + + +def test_issue_19067(): + x = Symbol('x') + assert limit(gamma(x)/(gamma(x - 1)*gamma(x + 2)), x, 0) == -1 + + +def test_issue_19586(): + assert limit(x**(2**x*3**(-x)), x, oo) == 1 + + +def test_issue_13715(): + n = Symbol('n') + p = Symbol('p', zero=True) + assert limit(n + p, n, 0) == 0 + + +def test_issue_15055(): + assert limit(n**3*((-n - 1)*sin(1/n) + (n + 2)*sin(1/(n + 1)))/(-n + 1), n, oo) == 1 + + +def test_issue_16708(): + m, vi = symbols('m vi', positive=True) + B, ti, d = symbols('B ti d') + assert limit((B*ti*vi - sqrt(m)*sqrt(-2*B*d*vi + m*(vi)**2) + m*vi)/(B*vi), B, 0) == (d + ti*vi)/vi + + +def test_issue_19154(): + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , oo) == 2*sqrt(3)*pi/3 + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , -oo) == -2*sqrt(3)*pi/3 + + +def test_issue_19453(): + beta = Symbol("beta", positive=True) + h = Symbol("h", positive=True) + m = Symbol("m", positive=True) + w = Symbol("omega", positive=True) + g = Symbol("g", positive=True) + + e = exp(1) + q = 3*h**2*beta*g*e**(0.5*h*beta*w) + p = m**2*w**2 + s = e**(h*beta*w) - 1 + Z = -q/(4*p*s) - q/(2*p*s**2) - q*(e**(h*beta*w) + 1)/(2*p*s**3)\ + + e**(0.5*h*beta*w)/s + E = -diff(log(Z), beta) + + assert limit(E - 0.5*h*w, beta, oo) == 0 + assert limit(E.simplify() - 0.5*h*w, beta, oo) == 0 + + +def test_issue_19739(): + assert limit((-S(1)/4)**x, x, oo) == 0 + + +def test_issue_19766(): + assert limit(2**(-x)*sqrt(4**(x + 1) + 1), x, oo) == 2 + + +def test_issue_19770(): + m = Symbol('m') + # the result is not 0 for non-real m + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', real=True) + # can be improved to give the correct result 0 + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', nonzero=True) + assert limit(cos(m*x), x, oo) == AccumBounds(-1, 1) + assert limit(cos(m*x)/x, x, oo) == 0 + + +def test_issue_7535(): + assert limit(tan(x)/sin(tan(x)), x, pi/2) == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') + assert limit(sin(tan(x)),x,pi/2) == AccumBounds(-1, 1) + assert -oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert -oo*(1/sin(oo)) == AccumBounds(-oo, oo) + + +def test_issue_20365(): + assert limit(((x + 1)**(1/x) - E)/x, x, 0) == -E/2 + + +def test_issue_21031(): + assert limit(((1 + x)**(1/x) - (1 + 2*x)**(1/(2*x)))/asin(x), x, 0) == E/2 + + +def test_issue_21038(): + assert limit(sin(pi*x)/(3*x - 12), x, 4) == pi/3 + + +def test_issue_20578(): + expr = abs(x) * sin(1/x) + assert limit(expr,x,0,'+') == 0 + assert limit(expr,x,0,'-') == 0 + assert limit(expr,x,0,'+-') == 0 + + +def test_issue_21227(): + f = log(x) + + assert f.nseries(x, logx=y) == y + assert f.nseries(x, logx=-x) == -x + + f = log(-log(x)) + + assert f.nseries(x, logx=y) == log(-y) + assert f.nseries(x, logx=-x) == log(x) + + f = log(log(x)) + + assert f.nseries(x, logx=y) == log(y) + assert f.nseries(x, logx=-x) == log(-x) + assert f.nseries(x, logx=x) == log(x) + + f = log(log(log(1/x))) + + assert f.nseries(x, logx=y) == log(log(-y)) + assert f.nseries(x, logx=-y) == log(log(y)) + assert f.nseries(x, logx=x) == log(log(-x)) + assert f.nseries(x, logx=-x) == log(log(x)) + + +def test_issue_21415(): + exp = (x-1)*cos(1/(x-1)) + assert exp.limit(x,1) == 0 + assert exp.expand().limit(x,1) == 0 + + +def test_issue_21530(): + assert limit(sinh(n + 1)/sinh(n), n, oo) == E + + +def test_issue_21550(): + r = (sqrt(5) - 1)/2 + assert limit((x - r)/(x**2 + x - 1), x, r) == sqrt(5)/5 + + +def test_issue_21661(): + out = limit((x**(x + 1) * (log(x) + 1) + 1) / x, x, 11) + assert out == S(3138428376722)/11 + 285311670611*log(11) + + +def test_issue_21701(): + assert limit((besselj(z, x)/x**z).subs(z, 7), x, 0) == S(1)/645120 + + +def test_issue_21721(): + a = Symbol('a', real=True) + I = integrate(1/(pi*(1 + (x - a)**2)), x) + assert I.limit(x, oo) == S.Half + + +def test_issue_21756(): + term = (1 - exp(-2*I*pi*z))/(1 - exp(-2*I*pi*z/5)) + assert term.limit(z, 0) == 5 + assert re(term).limit(z, 0) == 5 + + +def test_issue_21785(): + a = Symbol('a') + assert sqrt((-a**2 + x**2)/(1 - x**2)).limit(a, 1, '-') == I + + +def test_issue_22181(): + assert limit((-1)**x * 2**(-x), x, oo) == 0 + + +def test_issue_22220(): + e1 = sqrt(30)*atan(sqrt(30)*tan(x/2)/6)/30 + e2 = sqrt(30)*I*(-log(sqrt(2)*tan(x/2) - 2*sqrt(15)*I/5) + + +log(sqrt(2)*tan(x/2) + 2*sqrt(15)*I/5))/60 + + assert limit(e1, x, -pi) == -sqrt(30)*pi/60 + assert limit(e2, x, -pi) == -sqrt(30)*pi/30 + + assert limit(e1, x, -pi, '-') == sqrt(30)*pi/60 + assert limit(e2, x, -pi, '-') == 0 + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-972727694 + expr = log(x - I) - log(-x - I) + expr2 = logcombine(expr, force=True) + assert limit(expr, x, oo) == limit(expr2, x, oo) == I*pi + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-1077618340 + expr = expr = (-log(tan(x/2) - I) +log(tan(x/2) + I)) + assert limit(expr, x, pi, '+') == 2*I*pi + assert limit(expr, x, pi, '-') == 0 + + +def test_issue_22334(): + k, n = symbols('k, n', positive=True) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)), n, oo) == 1/(k + 1) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)).expand(), n, oo) == 1/(k + 1) + assert limit((n+1)**k/(n*(-n**k + (n + 1)**k) + (n + 1)**k), n, oo) == 1/(k + 1) + + +def test_issue_22836_limit(): + assert limit(2**(1/x)/factorial(1/(x)), x, 0) == S.Zero + + +def test_sympyissue_22986(): + assert limit(acosh(1 + 1/x)*sqrt(x), x, oo) == sqrt(2) + + +def test_issue_23231(): + f = (2**x - 2**(-x))/(2**x + 2**(-x)) + assert limit(f, x, -oo) == -1 + + +def test_issue_23596(): + assert integrate(((1 + x)/x**2)*exp(-1/x), (x, 0, oo)) == oo + + +def test_issue_23752(): + expr1 = sqrt(-I*x**2 + x - 3) + expr2 = sqrt(-I*x**2 + I*x - 3) + assert limit(expr1, x, 0, '+') == -sqrt(3)*I + assert limit(expr1, x, 0, '-') == -sqrt(3)*I + assert limit(expr2, x, 0, '+') == sqrt(3)*I + assert limit(expr2, x, 0, '-') == -sqrt(3)*I + + +def test_issue_24276(): + fx = log(tan(pi/2*tanh(x))).diff(x) + assert fx.limit(x, oo) == 2 + assert fx.simplify().limit(x, oo) == 2 + assert fx.rewrite(sin).limit(x, oo) == 2 + assert fx.rewrite(sin).simplify().limit(x, oo) == 2 + +def test_issue_25230(): + a = Symbol('a', real = True) + b = Symbol('b', positive = True) + c = Symbol('c', negative = True) + n = Symbol('n', integer = True) + raises(NotImplementedError, lambda: limit(Mod(x, a), x, a)) + assert limit(Mod(x, b), x, n*b, '+') == 0 + assert limit(Mod(x, b), x, n*b, '-') == b + assert limit(Mod(x, c), x, n*c, '+') == c + assert limit(Mod(x, c), x, n*c, '-') == 0 + + +def test_issue_25582(): + + assert limit(asin(exp(x)), x, oo, '-') == -oo*I + assert limit(acos(exp(x)), x, oo, '-') == oo*I + assert limit(atan(exp(x)), x, oo, '-') == pi/2 + assert limit(acot(exp(x)), x, oo, '-') == 0 + assert limit(asec(exp(x)), x, oo, '-') == pi/2 + assert limit(acsc(exp(x)), x, oo, '-') == 0 + + +def test_issue_25847(): + #atan + assert limit(atan(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(atan(exp(1/x)), x, 0, '+') == pi/2 + assert limit(atan(exp(1/x)), x, 0, '-') == 0 + + #asin + assert limit(asin(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(asin(exp(1/x)), x, 0, '+') == -oo*I + assert limit(asin(exp(1/x)), x, 0, '-') == 0 + + #acos + assert limit(acos(sin(x)/x), x, 0, '+-') == 0 + assert limit(acos(exp(1/x)), x, 0, '+') == oo*I + assert limit(acos(exp(1/x)), x, 0, '-') == pi/2 + + #acot + assert limit(acot(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(acot(exp(1/x)), x, 0, '+') == 0 + assert limit(acot(exp(1/x)), x, 0, '-') == pi/2 + + #asec + assert limit(asec(sin(x)/x), x, 0, '+-') == 0 + assert limit(asec(exp(1/x)), x, 0, '+') == pi/2 + assert limit(asec(exp(1/x)), x, 0, '-') == oo*I + + #acsc + assert limit(acsc(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(acsc(exp(1/x)), x, 0, '+') == 0 + assert limit(acsc(exp(1/x)), x, 0, '-') == -oo*I + + #atanh + assert limit(atanh(sin(x)/x), x, 0, '+-') == oo + assert limit(atanh(exp(1/x)), x, 0, '+') == -I*pi/2 + assert limit(atanh(exp(1/x)), x, 0, '-') == 0 + + #asinh + assert limit(asinh(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(asinh(exp(1/x)), x, 0, '+') == oo + assert limit(asinh(exp(1/x)), x, 0, '-') == 0 + + #acosh + assert limit(acosh(sin(x)/x), x, 0, '+-') == 0 + assert limit(acosh(exp(1/x)), x, 0, '+') == oo + assert limit(acosh(exp(1/x)), x, 0, '-') == I*pi/2 + + #acoth + assert limit(acoth(sin(x)/x), x, 0, '+-') == oo + assert limit(acoth(exp(1/x)), x, 0, '+') == 0 + assert limit(acoth(exp(1/x)), x, 0, '-') == -I*pi/2 + + #asech + assert limit(asech(sin(x)/x), x, 0, '+-') == 0 + assert limit(asech(exp(1/x)), x, 0, '+') == I*pi/2 + assert limit(asech(exp(1/x)), x, 0, '-') == oo + + #acsch + assert limit(acsch(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(acsch(exp(1/x)), x, 0, '+') == 0 + assert limit(acsch(exp(1/x)), x, 0, '-') == oo + + +def test_issue_26040(): + assert limit(besseli(0, x + 1)/besseli(0, x), x, oo) == S.Exp1 + + +def test_issue_26250(): + e = elliptic_e(4*x/(x**2 + 2*x + 1)) + k = elliptic_k(4*x/(x**2 + 2*x + 1)) + e1 = ((1-3*x**2)*e**2/2 - (x**2-2*x+1)*e*k/2) + e2 = pi**2*(x**8 - 2*x**7 - x**6 + 4*x**5 - x**4 - 2*x**3 + x**2) + assert limit(e1/e2, x, 0) == -S(1)/8 + + +def test_issue_26513(): + assert limit(abs((-x/(x+1))**x), x ,oo) == exp(-1) + assert limit((x/(x + 1))**x, x, oo) == exp(-1) + raises (NotImplementedError, lambda: limit((-x/(x+1))**x, x, oo)) + + +def test_issue_26916(): + assert limit(Ei(x)*exp(-x), x, +oo) == 0 + assert limit(Ei(x)*exp(-x), x, -oo) == 0 + + +def test_issue_22982_15323(): + assert limit((log(E + 1/x) - 1)**(1 - sqrt(E + 1/x)), x, oo) == oo + assert limit((1 - 1/x)**x*(log(1 - 1/x) + 1/(x*(1 - 1/x))), x, 1, dir='+') == 1 + assert limit((log(E + 1/x) )**(1 - sqrt(E + 1/x)), x, oo) == 1 + assert limit((log(E + 1/x) - 1)**(- sqrt(E + 1/x)), x, oo) == oo + + +def test_issue_26991(): + assert limit(x/((x - 6)*sinh(tanh(0.03*x)) + tanh(x) - 0.5), x, oo) == 1/sinh(1) + +def test_issue_27278(): + expr = (1/(x*log((x + 3)/x)))**x*((x + 1)*log((x + 4)/(x + 1)))**(x + 1)/3 + assert limit(expr, x, oo) == 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limitseq.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limitseq.py new file mode 100644 index 0000000000000000000000000000000000000000..362bb0397feb0ec63929920855c81279eca0bd6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_limitseq.py @@ -0,0 +1,177 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.combinatorial.numbers import (fibonacci, harmonic) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.series.limitseq import limit_seq +from sympy.series.limitseq import difference_delta as dd +from sympy.testing.pytest import raises, XFAIL +from sympy.calculus.accumulationbounds import AccumulationBounds + +n, m, k = symbols('n m k', integer=True) + + +def test_difference_delta(): + e = n*(n + 1) + e2 = e * k + + assert dd(e) == 2*n + 2 + assert dd(e2, n, 2) == k*(4*n + 6) + + raises(ValueError, lambda: dd(e2)) + raises(ValueError, lambda: dd(e2, n, oo)) + + +def test_difference_delta__Sum(): + e = Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1/(n + 1) + assert dd(e, n, 5) == Add(*[1/(i + n + 1) for i in range(5)]) + + e = Sum(1/k, (k, 1, 3*n)) + assert dd(e, n) == Add(*[1/(i + 3*n + 1) for i in range(3)]) + + e = n * Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + Sum(1/k, (k, 1, n)) + + e = Sum(1/k, (k, 1, n), (m, 1, n)) + assert dd(e, n) == harmonic(n) + + +def test_difference_delta__Add(): + e = n + n*(n + 1) + assert dd(e, n) == 2*n + 3 + assert dd(e, n, 2) == 4*n + 8 + + e = n + Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + 1/(n + 1) + assert dd(e, n, 5) == 5 + Add(*[1/(i + n + 1) for i in range(5)]) + + +def test_difference_delta__Pow(): + e = 4**n + assert dd(e, n) == 3*4**n + assert dd(e, n, 2) == 15*4**n + + e = 4**(2*n) + assert dd(e, n) == 15*4**(2*n) + assert dd(e, n, 2) == 255*4**(2*n) + + e = n**4 + assert dd(e, n) == (n + 1)**4 - n**4 + + e = n**n + assert dd(e, n) == (n + 1)**(n + 1) - n**n + + +def test_limit_seq(): + e = binomial(2*n, n) / Sum(binomial(2*k, k), (k, 1, n)) + assert limit_seq(e) == S(3) / 4 + assert limit_seq(e, m) == e + + e = (5*n**3 + 3*n**2 + 4) / (3*n**3 + 4*n - 5) + assert limit_seq(e, n) == S(5) / 3 + + e = (harmonic(n) * Sum(harmonic(k), (k, 1, n))) / (n * harmonic(2*n)**2) + assert limit_seq(e, n) == 1 + + e = Sum(k**2 * Sum(2**m/m, (m, 1, k)), (k, 1, n)) / (2**n*n) + assert limit_seq(e, n) == 4 + + e = (Sum(binomial(3*k, k) * binomial(5*k, k), (k, 1, n)) / + (binomial(3*n, n) * binomial(5*n, n))) + assert limit_seq(e, n) == S(84375) / 83351 + + e = Sum(harmonic(k)**2/k, (k, 1, 2*n)) / harmonic(n)**3 + assert limit_seq(e, n) == S.One / 3 + + raises(ValueError, lambda: limit_seq(e * m)) + + +def test_alternating_sign(): + assert limit_seq((-1)**n/n**2, n) == 0 + assert limit_seq((-2)**(n+1)/(n + 3**n), n) == 0 + assert limit_seq((2*n + (-1)**n)/(n + 1), n) == 2 + assert limit_seq(sin(pi*n), n) == 0 + assert limit_seq(cos(2*pi*n), n) == 1 + assert limit_seq((S.NegativeOne/5)**n, n) == 0 + assert limit_seq((Rational(-1, 5))**n, n) == 0 + assert limit_seq((I/3)**n, n) == 0 + assert limit_seq(sqrt(n)*(I/2)**n, n) == 0 + assert limit_seq(n**7*(I/3)**n, n) == 0 + assert limit_seq(n/(n + 1) + (I/2)**n, n) == 1 + + +def test_accum_bounds(): + assert limit_seq((-1)**n, n) == AccumulationBounds(-1, 1) + assert limit_seq(cos(pi*n), n) == AccumulationBounds(-1, 1) + assert limit_seq(sin(pi*n/2)**2, n) == AccumulationBounds(0, 1) + assert limit_seq(2*(-3)**n/(n + 3**n), n) == AccumulationBounds(-2, 2) + assert limit_seq(3*n/(n + 1) + 2*(-1)**n, n) == AccumulationBounds(1, 5) + + +def test_limitseq_sum(): + from sympy.abc import x, y, z + assert limit_seq(Sum(1/x, (x, 1, y)) - log(y), y) == S.EulerGamma + assert limit_seq(Sum(1/x, (x, 1, y)) - 1/y, y) is S.Infinity + assert (limit_seq(binomial(2*x, x) / Sum(binomial(2*y, y), (y, 1, x)), x) == + S(3) / 4) + assert (limit_seq(Sum(y**2 * Sum(2**z/z, (z, 1, y)), (y, 1, x)) / + (2**x*x), x) == 4) + + +def test_issue_9308(): + assert limit_seq(subfactorial(n)/factorial(n), n) == exp(-1) + + +def test_issue_10382(): + n = Symbol('n', integer=True) + assert limit_seq(fibonacci(n+1)/fibonacci(n), n).together() == S.GoldenRatio + + +def test_issue_11672(): + assert limit_seq(Rational(-1, 2)**n, n) == 0 + + +def test_issue_14196(): + k, n = symbols('k, n', positive=True) + m = Symbol('m') + assert limit_seq(Sum(m**k, (m, 1, n)).doit()/(n**(k + 1)), n) == 1/(k + 1) + + +def test_issue_16735(): + assert limit_seq(5**n/factorial(n), n) == 0 + + +def test_issue_19868(): + assert limit_seq(1/gamma(n + S.One/2), n) == 0 + + +@XFAIL +def test_limit_seq_fail(): + # improve Summation algorithm or add ad-hoc criteria + e = (harmonic(n)**3 * Sum(1/harmonic(k), (k, 1, n)) / + (n * Sum(harmonic(k)/k, (k, 1, n)))) + assert limit_seq(e, n) == 2 + + # No unique dominant term + e = (Sum(2**k * binomial(2*k, k) / k**2, (k, 1, n)) / + (Sum(2**k/k*2, (k, 1, n)) * Sum(binomial(2*k, k), (k, 1, n)))) + assert limit_seq(e, n) == S(3) / 7 + + # Simplifications of summations needs to be improved. + e = n**3*Sum(2**k/k**2, (k, 1, n))**2 / (2**n * Sum(2**k/k, (k, 1, n))) + assert limit_seq(e, n) == 2 + + e = (harmonic(n) * Sum(2**k/k, (k, 1, n)) / + (n * Sum(2**k*harmonic(k)/k**2, (k, 1, n)))) + assert limit_seq(e, n) == 1 + + e = (Sum(2**k*factorial(k) / k**2, (k, 1, 2*n)) / + (Sum(4**k/k**2, (k, 1, n)) * Sum(factorial(k), (k, 1, 2*n)))) + assert limit_seq(e, n) == S(3) / 16 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_lseries.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_lseries.py new file mode 100644 index 0000000000000000000000000000000000000000..42d327bf60c76eebdc4570d631efef4bc84b58e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_lseries.py @@ -0,0 +1,65 @@ +from sympy.core.numbers import E +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import Order +from sympy.abc import x, y + + +def test_sin(): + e = sin(x).lseries(x) + assert next(e) == x + assert next(e) == -x**3/6 + assert next(e) == x**5/120 + + +def test_cos(): + e = cos(x).lseries(x) + assert next(e) == 1 + assert next(e) == -x**2/2 + assert next(e) == x**4/24 + + +def test_exp(): + e = exp(x).lseries(x) + assert next(e) == 1 + assert next(e) == x + assert next(e) == x**2/2 + assert next(e) == x**3/6 + + +def test_exp2(): + e = exp(cos(x)).lseries(x) + assert next(e) == E + assert next(e) == -E*x**2/2 + assert next(e) == E*x**4/6 + assert next(e) == -31*E*x**6/720 + + +def test_simple(): + assert list(x.lseries()) == [x] + assert list(S.One.lseries(x)) == [1] + assert not next((x/(x + y)).lseries(y)).has(Order) + + +def test_issue_5183(): + s = (x + 1/x).lseries() + assert list(s) == [1/x, x] + assert next((x + x**2).lseries()) == x + assert next(((1 + x)**7).lseries(x)) == 1 + assert next((sin(x + y)).series(x, n=3).lseries(y)) == x + # it would be nice if all terms were grouped, but in the + # following case that would mean that all the terms would have + # to be known since, for example, every term has a constant in it. + s = ((1 + x)**7).series(x, 1, n=None) + assert [next(s) for i in range(2)] == [128, -448 + 448*x] + + +def test_issue_6999(): + s = tanh(x).lseries(x, 1) + assert next(s) == tanh(1) + assert next(s) == x - (x - 1)*tanh(1)**2 - 1 + assert next(s) == -(x - 1)**2*tanh(1) + (x - 1)**2*tanh(1)**3 + assert next(s) == -(x - 1)**3*tanh(1)**4 - (x - 1)**3/3 + \ + 4*(x - 1)**3*tanh(1)**2/3 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_nseries.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_nseries.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f20add82d3e858e2ce145fc9fcd4a6548a48cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_nseries.py @@ -0,0 +1,557 @@ +from sympy.calculus.util import AccumBounds +from sympy.core.function import (Derivative, PoleError) +from sympy.core.numbers import (E, I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, acoth, asinh, atanh, cosh, coth, sinh, tanh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, cot, sin, tan) +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.abc import x, y, z + +from sympy.testing.pytest import raises, XFAIL + + +def test_simple_1(): + assert x.nseries(x, n=5) == x + assert y.nseries(x, n=5) == y + assert (1/(x*y)).nseries(y, n=5) == 1/(x*y) + assert Rational(3, 4).nseries(x, n=5) == Rational(3, 4) + assert x.nseries() == x + + +def test_mul_0(): + assert (x*log(x)).nseries(x, n=5) == x*log(x) + + +def test_mul_1(): + assert (x*log(2 + x)).nseries(x, n=5) == x*log(2) + x**2/2 - x**3/8 + \ + x**4/24 + O(x**5) + assert (x*log(1 + x)).nseries( + x, n=5) == x**2 - x**3/2 + x**4/3 + O(x**5) + + +def test_pow_0(): + assert (x**2).nseries(x, n=5) == x**2 + assert (1/x).nseries(x, n=5) == 1/x + assert (1/x**2).nseries(x, n=5) == 1/x**2 + assert (x**Rational(2, 3)).nseries(x, n=5) == (x**Rational(2, 3)) + assert (sqrt(x)**3).nseries(x, n=5) == (sqrt(x)**3) + + +def test_pow_1(): + assert ((1 + x)**2).nseries(x, n=5) == x**2 + 2*x + 1 + + # https://github.com/sympy/sympy/issues/21075 + assert ((sqrt(x) + 1)**2).nseries(x) == 2*sqrt(x) + x + 1 + assert ((sqrt(x) + cbrt(x))**2).nseries(x) == 2*x**Rational(5, 6)\ + + x**Rational(2, 3) + x + + +def test_geometric_1(): + assert (1/(1 - x)).nseries(x, n=5) == 1 + x + x**2 + x**3 + x**4 + O(x**5) + assert (x/(1 - x)).nseries(x, n=6) == x + x**2 + x**3 + x**4 + x**5 + O(x**6) + assert (x**3/(1 - x)).nseries(x, n=8) == x**3 + x**4 + x**5 + x**6 + \ + x**7 + O(x**8) + + +def test_sqrt_1(): + assert sqrt(1 + x).nseries(x, n=5) == 1 + x/2 - x**2/8 + x**3/16 - 5*x**4/128 + O(x**5) + + +def test_exp_1(): + assert exp(x).nseries(x, n=5) == 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + assert exp(x).nseries(x, n=12) == 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + \ + x**6/720 + x**7/5040 + x**8/40320 + x**9/362880 + x**10/3628800 + \ + x**11/39916800 + O(x**12) + assert exp(1/x).nseries(x, n=5) == exp(1/x) + assert exp(1/(1 + x)).nseries(x, n=4) == \ + (E*(1 - x - 13*x**3/6 + 3*x**2/2)).expand() + O(x**4) + assert exp(2 + x).nseries(x, n=5) == \ + (exp(2)*(1 + x + x**2/2 + x**3/6 + x**4/24)).expand() + O(x**5) + + +def test_exp_sqrt_1(): + assert exp(1 + sqrt(x)).nseries(x, n=3) == \ + (exp(1)*(1 + sqrt(x) + x/2 + sqrt(x)*x/6)).expand() + O(sqrt(x)**3) + + +def test_power_x_x1(): + assert (exp(x*log(x))).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_power_x_x2(): + assert (x**x).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_log_singular1(): + assert log(1 + 1/x).nseries(x, n=5) == x - log(x) - x**2/2 + x**3/3 - \ + x**4/4 + O(x**5) + + +def test_log_power1(): + e = 1 / (1/x + x ** (log(3)/log(2))) + assert e.nseries(x, n=5) == -x**(log(3)/log(2) + 2) + x + O(x**5) + + +def test_log_series(): + l = Symbol('l') + e = 1/(1 - log(x)) + assert e.nseries(x, n=5, logx=l) == 1/(1 - l) + + +def test_log2(): + e = log(-1/x) + assert e.nseries(x, n=5) == -log(x) + log(-1) + + +def test_log3(): + l = Symbol('l') + e = 1/log(-1/x) + assert e.nseries(x, n=4, logx=l) == 1/(-l + log(-1)) + + +def test_series1(): + e = sin(x) + assert e.nseries(x, 0, 0) != 0 + assert e.nseries(x, 0, 0) == O(1, x) + assert e.nseries(x, 0, 1) == O(x, x) + assert e.nseries(x, 0, 2) == x + O(x**2, x) + assert e.nseries(x, 0, 3) == x + O(x**3, x) + assert e.nseries(x, 0, 4) == x - x**3/6 + O(x**4, x) + + e = (exp(x) - 1)/x + assert e.nseries(x, 0, 3) == 1 + x/2 + x**2/6 + O(x**3) + + assert x.nseries(x, 0, 2) == x + + +@XFAIL +def test_series1_failing(): + assert x.nseries(x, 0, 0) == O(1, x) + assert x.nseries(x, 0, 1) == O(x, x) + + +def test_seriesbug1(): + assert (1/x).nseries(x, 0, 3) == 1/x + assert (x + 1/x).nseries(x, 0, 3) == x + 1/x + + +def test_series2x(): + assert ((x + 1)**(-2)).nseries(x, 0, 4) == 1 - 2*x + 3*x**2 - 4*x**3 + O(x**4, x) + assert ((x + 1)**(-1)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert ((x + 1)**0).nseries(x, 0, 3) == 1 + assert ((x + 1)**1).nseries(x, 0, 3) == 1 + x + assert ((x + 1)**2).nseries(x, 0, 3) == x**2 + 2*x + 1 + assert ((x + 1)**3).nseries(x, 0, 3) == 1 + 3*x + 3*x**2 + O(x**3) + + assert (1/(1 + x)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert (x + 3/(1 + 2*x)).nseries(x, 0, 4) == 3 - 5*x + 12*x**2 - 24*x**3 + O(x**4, x) + + assert ((1/x + 1)**3).nseries(x, 0, 3) == 1 + 3/x + 3/x**2 + x**(-3) + assert (1/(1 + 1/x)).nseries(x, 0, 4) == x - x**2 + x**3 - O(x**4, x) + assert (1/(1 + 1/x**2)).nseries(x, 0, 6) == x**2 - x**4 + O(x**6, x) + + +def test_bug2(): # 1/log(0)*log(0) problem + w = Symbol("w") + e = (w**(-1) + w**( + -log(3)*log(2)**(-1)))**(-1)*(3*w**(-log(3)*log(2)**(-1)) + 2*w**(-1)) + e = e.expand() + assert e.nseries(w, 0, 4).subs(w, 0) == 3 + + +def test_exp(): + e = (1 + x)**(1/x) + assert e.nseries(x, n=3) == exp(1) - x*exp(1)/2 + 11*exp(1)*x**2/24 + O(x**3) + + +def test_exp2(): + w = Symbol("w") + e = w**(1 - log(x)/(log(2) + log(x))) + logw = Symbol("logw") + assert e.nseries( + w, 0, 1, logx=logw) == exp(logw*log(2)/(log(x) + log(2))) + + +def test_bug3(): + e = (2/x + 3/x**2)/(1/x + 1/x**2) + assert e.nseries(x, n=3) == 3 - x + x**2 + O(x**3) + + +def test_generalexponent(): + p = 2 + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 - x + x**2 + O(x**3) + p = S.Half + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 2) == 2 - x + sqrt(x) + x**(S(3)/2) + O(x**2) + + e = 1 + sqrt(x) + assert e.nseries(x, 0, 4) == 1 + sqrt(x) + +# more complicated example + + +def test_genexp_x(): + e = 1/(1 + sqrt(x)) + assert e.nseries(x, 0, 2) == \ + 1 + x - sqrt(x) - sqrt(x)**3 + O(x**2, x) + +# more complicated example + + +def test_genexp_x2(): + p = Rational(3, 2) + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 + x + x**2 - sqrt(x) - x**(S(3)/2) - x**(S(5)/2) + O(x**3) + + +def test_seriesbug2(): + w = Symbol("w") + #simple case (1): + e = ((2*w)/w)**(1 + w) + assert e.nseries(w, 0, 1) == 2 + O(w, w) + assert e.nseries(w, 0, 1).subs(w, 0) == 2 + + +def test_seriesbug2b(): + w = Symbol("w") + #test sin + e = sin(2*w)/w + assert e.nseries(w, 0, 3) == 2 - 4*w**2/3 + O(w**3) + + +def test_seriesbug2d(): + w = Symbol("w", real=True) + e = log(sin(2*w)/w) + assert e.series(w, n=5) == log(2) - 2*w**2/3 - 4*w**4/45 + O(w**5) + + +def test_seriesbug2c(): + w = Symbol("w", real=True) + #more complicated case, but sin(x)~x, so the result is the same as in (1) + e = (sin(2*w)/w)**(1 + w) + assert e.series(w, 0, 1) == 2 + O(w) + assert e.series(w, 0, 3) == 2 + 2*w*log(2) + \ + w**2*(Rational(-4, 3) + log(2)**2) + O(w**3) + assert e.series(w, 0, 2).subs(w, 0) == 2 + + +def test_expbug4(): + x = Symbol("x", real=True) + assert (log( + sin(2*x)/x)*(1 + x)).series(x, 0, 2) == log(2) + x*log(2) + O(x**2, x) + assert exp( + log(sin(2*x)/x)*(1 + x)).series(x, 0, 2) == 2 + 2*x*log(2) + O(x**2) + + assert exp(log(2) + O(x)).nseries(x, 0, 2) == 2 + O(x) + assert ((2 + O(x))**(1 + x)).nseries(x, 0, 2) == 2 + O(x) + + +def test_logbug4(): + assert log(2 + O(x)).nseries(x, 0, 2) == log(2) + O(x, x) + + +def test_expbug5(): + assert exp(log(1 + x)/x).nseries(x, n=3) == exp(1) + -exp(1)*x/2 + 11*exp(1)*x**2/24 + O(x**3) + + assert exp(O(x)).nseries(x, 0, 2) == 1 + O(x) + + +def test_sinsinbug(): + assert sin(sin(x)).nseries(x, 0, 8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + + +def test_issue_3258(): + a = x/(exp(x) - 1) + assert a.nseries(x, 0, 5) == 1 - x/2 - x**4/720 + x**2/12 + O(x**5) + + +def test_issue_3204(): + x = Symbol("x", nonnegative=True) + f = sin(x**3)**Rational(1, 3) + assert f.nseries(x, 0, 17) == x - x**7/18 - x**13/3240 + O(x**17) + + +def test_issue_3224(): + f = sqrt(1 - sqrt(y)) + assert f.nseries(y, 0, 2) == 1 - sqrt(y)/2 - y/8 - sqrt(y)**3/16 + O(y**2) + + +def test_issue_3463(): + w, i = symbols('w,i') + r = log(5)/log(3) + p = w**(-1 + r) + e = 1/x*(-log(w**(1 + r)) + log(w + w**r)) + e_ser = -r*log(w)/x + p/x - p**2/(2*x) + O(w) + assert e.nseries(w, n=1) == e_ser + + +def test_sin(): + assert sin(8*x).nseries(x, n=4) == 8*x - 256*x**3/3 + O(x**4) + assert sin(x + y).nseries(x, n=1) == sin(y) + O(x) + assert sin(x + y).nseries(x, n=2) == sin(y) + cos(y)*x + O(x**2) + assert sin(x + y).nseries(x, n=5) == sin(y) + cos(y)*x - sin(y)*x**2/2 - \ + cos(y)*x**3/6 + sin(y)*x**4/24 + O(x**5) + + +def test_issue_3515(): + e = sin(8*x)/x + assert e.nseries(x, n=6) == 8 - 256*x**2/3 + 4096*x**4/15 + O(x**6) + + +def test_issue_3505(): + e = sin(x)**(-4)*(sqrt(cos(x))*sin(x)**2 - + cos(x)**Rational(1, 3)*sin(x)**2) + assert e.nseries(x, n=9) == Rational(-1, 12) - 7*x**2/288 - \ + 43*x**4/10368 - 1123*x**6/2488320 + 377*x**8/29859840 + O(x**9) + + +def test_issue_3501(): + a = Symbol("a") + e = x**(-2)*(x*sin(a + x) - x*sin(a)) + assert e.nseries(x, n=6) == cos(a) - sin(a)*x/2 - cos(a)*x**2/6 + \ + x**3*sin(a)/24 + x**4*cos(a)/120 - x**5*sin(a)/720 + O(x**6) + e = x**(-2)*(x*cos(a + x) - x*cos(a)) + assert e.nseries(x, n=6) == -sin(a) - cos(a)*x/2 + sin(a)*x**2/6 + \ + cos(a)*x**3/24 - x**4*sin(a)/120 - x**5*cos(a)/720 + O(x**6) + + +def test_issue_3502(): + e = sin(5*x)/sin(2*x) + assert e.nseries(x, n=2) == Rational(5, 2) + O(x**2) + assert e.nseries(x, n=6) == \ + Rational(5, 2) - 35*x**2/4 + 329*x**4/48 + O(x**6) + + +def test_issue_3503(): + e = sin(2 + x)/(2 + x) + assert e.nseries(x, n=2) == sin(2)/2 + x*cos(2)/2 - x*sin(2)/4 + O(x**2) + + +def test_issue_3506(): + e = (x + sin(3*x))**(-2)*(x*(x + sin(3*x)) - (x + sin(3*x))*sin(2*x)) + assert e.nseries(x, n=7) == \ + Rational(-1, 4) + 5*x**2/96 + 91*x**4/768 + 11117*x**6/129024 + O(x**7) + + +def test_issue_3508(): + x = Symbol("x", real=True) + assert log(sin(x)).series(x, n=5) == log(x) - x**2/6 - x**4/180 + O(x**5) + e = -log(x) + x*(-log(x) + log(sin(2*x))) + log(sin(2*x)) + assert e.series(x, n=5) == \ + log(2) + log(2)*x - 2*x**2/3 - 2*x**3/3 - 4*x**4/45 + O(x**5) + + +def test_issue_3507(): + e = x**(-4)*(x**2 - x**2*sqrt(cos(x))) + assert e.nseries(x, n=9) == \ + Rational(1, 4) + x**2/96 + 19*x**4/5760 + 559*x**6/645120 + 29161*x**8/116121600 + O(x**9) + + +def test_issue_3639(): + assert sin(cos(x)).nseries(x, n=5) == \ + sin(1) - x**2*cos(1)/2 - x**4*sin(1)/8 + x**4*cos(1)/24 + O(x**5) + + +def test_hyperbolic(): + assert sinh(x).nseries(x, n=6) == x + x**3/6 + x**5/120 + O(x**6) + assert cosh(x).nseries(x, n=5) == 1 + x**2/2 + x**4/24 + O(x**5) + assert tanh(x).nseries(x, n=6) == x - x**3/3 + 2*x**5/15 + O(x**6) + assert coth(x).nseries(x, n=6) == \ + 1/x - x**3/45 + x/3 + 2*x**5/945 + O(x**6) + assert asinh(x).nseries(x, n=6) == x - x**3/6 + 3*x**5/40 + O(x**6) + assert acosh(x).nseries(x, n=6) == \ + pi*I/2 - I*x - 3*I*x**5/40 - I*x**3/6 + O(x**6) + assert atanh(x).nseries(x, n=6) == x + x**3/3 + x**5/5 + O(x**6) + assert acoth(x).nseries(x, n=6) == -I*pi/2 + x + x**3/3 + x**5/5 + O(x**6) + + +def test_series2(): + w = Symbol("w", real=True) + x = Symbol("x", real=True) + e = w**(-2)*(w*exp(1/x - w) - w*exp(1/x)) + assert e.nseries(w, n=4) == -exp(1/x) + w*exp(1/x)/2 - w**2*exp(1/x)/6 + w**3*exp(1/x)/24 + O(w**4) + + +def test_series3(): + w = Symbol("w", real=True) + e = w**(-6)*(w**3*tan(w) - w**3*sin(w)) + assert e.nseries(w, n=8) == Integer(1)/2 + w**2/8 + 13*w**4/240 + 529*w**6/24192 + O(w**8) + + +def test_bug4(): + w = Symbol("w") + e = x/(w**4 + x**2*w**4 + 2*x*w**4)*w**4 + assert e.nseries(w, n=2).removeO().expand() in [x/(1 + 2*x + x**2), + 1/(1 + x/2 + 1/x/2)/2, 1/x/(1 + 2/x + x**(-2))] + + +def test_bug5(): + w = Symbol("w") + l = Symbol('l') + e = (-log(w) + log(1 + w*log(x)))**(-2)*w**(-2)*((-log(w) + + log(1 + x*w))*(-log(w) + log(1 + w*log(x)))*w - x*(-log(w) + + log(1 + w*log(x)))*w) + assert e.nseries(w, n=0, logx=l) == x/w/l + 1/w + O(1, w) + assert e.nseries(w, n=1, logx=l) == x/w/l + 1/w - x/l + 1/l*log(x) \ + + x*log(x)/l**2 + O(w) + + +def test_issue_4115(): + assert (sin(x)/(1 - cos(x))).nseries(x, n=1) == 2/x + O(x) + assert (sin(x)**2/(1 - cos(x))).nseries(x, n=1) == 2 + O(x) + + +def test_pole(): + raises(PoleError, lambda: sin(1/x).series(x, 0, 5)) + raises(PoleError, lambda: sin(1 + 1/x).series(x, 0, 5)) + raises(PoleError, lambda: (x*sin(1/x)).series(x, 0, 5)) + + +def test_expsinbug(): + assert exp(sin(x)).series(x, 0, 0) == O(1, x) + assert exp(sin(x)).series(x, 0, 1) == 1 + O(x) + assert exp(sin(x)).series(x, 0, 2) == 1 + x + O(x**2) + assert exp(sin(x)).series(x, 0, 3) == 1 + x + x**2/2 + O(x**3) + assert exp(sin(x)).series(x, 0, 4) == 1 + x + x**2/2 + O(x**4) + assert exp(sin(x)).series(x, 0, 5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + +def test_floor(): + x = Symbol('x') + assert floor(x).series(x) == 0 + assert floor(-x).series(x) == -1 + assert floor(sin(x)).series(x) == 0 + assert floor(sin(-x)).series(x) == -1 + assert floor(x**3).series(x) == 0 + assert floor(-x**3).series(x) == -1 + assert floor(cos(x)).series(x) == 0 + assert floor(cos(-x)).series(x) == 0 + assert floor(5 + sin(x)).series(x) == 5 + assert floor(5 + sin(-x)).series(x) == 4 + + assert floor(x).series(x, 2) == 2 + assert floor(-x).series(x, 2) == -3 + + x = Symbol('x', negative=True) + assert floor(x + 1.5).series(x) == 1 + + +def test_frac(): + assert frac(x).series(x, cdir=1) == x + assert frac(x).series(x, cdir=-1) == 1 + x + assert frac(2*x + 1).series(x, cdir=1) == 2*x + assert frac(2*x + 1).series(x, cdir=-1) == 1 + 2*x + assert frac(x**2).series(x, cdir=1) == x**2 + assert frac(x**2).series(x, cdir=-1) == x**2 + assert frac(sin(x) + 5).series(x, cdir=1) == x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + 5).series(x, cdir=-1) == 1 + x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + S.Half).series(x) == S.Half + x - x**3/6 + x**5/120 + O(x**6) + assert frac(x**8).series(x, cdir=1) == O(x**6) + assert frac(1/x).series(x) == AccumBounds(0, 1) + O(x**6) + + +def test_ceiling(): + assert ceiling(x).series(x) == 1 + assert ceiling(-x).series(x) == 0 + assert ceiling(sin(x)).series(x) == 1 + assert ceiling(sin(-x)).series(x) == 0 + assert ceiling(1 - cos(x)).series(x) == 1 + assert ceiling(1 - cos(-x)).series(x) == 1 + assert ceiling(x).series(x, 2) == 3 + assert ceiling(-x).series(x, 2) == -2 + + +def test_abs(): + a = Symbol('a') + assert abs(x).nseries(x, n=4) == x + assert abs(-x).nseries(x, n=4) == x + assert abs(x + 1).nseries(x, n=4) == x + 1 + assert abs(sin(x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(sin(-x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(x - a).nseries(x, 1) == -a*sign(1 - a) + (x - 1)*sign(1 - a) + sign(1 - a) + + +def test_dir(): + assert abs(x).series(x, 0, dir="+") == x + assert abs(x).series(x, 0, dir="-") == -x + assert floor(x + 2).series(x, 0, dir='+') == 2 + assert floor(x + 2).series(x, 0, dir='-') == 1 + assert floor(x + 2.2).series(x, 0, dir='-') == 2 + assert ceiling(x + 2.2).series(x, 0, dir='-') == 3 + assert sin(x + y).series(x, 0, dir='-') == sin(x + y).series(x, 0, dir='+') + + +def test_cdir(): + assert abs(x).series(x, 0, cdir=1) == x + assert abs(x).series(x, 0, cdir=-1) == -x + assert floor(x + 2).series(x, 0, cdir=1) == 2 + assert floor(x + 2).series(x, 0, cdir=-1) == 1 + assert floor(x + 2.2).series(x, 0, cdir=1) == 2 + assert ceiling(x + 2.2).series(x, 0, cdir=-1) == 3 + assert sin(x + y).series(x, 0, cdir=-1) == sin(x + y).series(x, 0, cdir=1) + + +def test_issue_3504(): + a = Symbol("a") + e = asin(a*x)/x + assert e.series(x, 4, n=2).removeO() == \ + (x - 4)*(a/(4*sqrt(-16*a**2 + 1)) - asin(4*a)/16) + asin(4*a)/4 + + +def test_issue_4441(): + a, b = symbols('a,b') + f = 1/(1 + a*x) + assert f.series(x, 0, 5) == 1 - a*x + a**2*x**2 - a**3*x**3 + \ + a**4*x**4 + O(x**5) + f = 1/(1 + (a + b)*x) + assert f.series(x, 0, 3) == 1 + x*(-a - b)\ + + x**2*(a + b)**2 + O(x**3) + + +def test_issue_4329(): + assert tan(x).series(x, pi/2, n=3).removeO() == \ + -pi/6 + x/3 - 1/(x - pi/2) + assert cot(x).series(x, pi, n=3).removeO() == \ + -x/3 + pi/3 + 1/(x - pi) + assert limit(tan(x)**tan(2*x), x, pi/4) == exp(-1) + + +def test_issue_5183(): + assert abs(x + x**2).series(n=1) == O(x) + assert abs(x + x**2).series(n=2) == x + O(x**2) + assert ((1 + x)**2).series(x, n=6) == x**2 + 2*x + 1 + assert (1 + 1/x).series() == 1 + 1/x + assert Derivative(exp(x).series(), x).doit() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + + +def test_issue_5654(): + a = Symbol('a') + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=0) == \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(1, (x, I*a)) + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=1) == 3/(16*a**4) \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(-I*a + x, (x, I*a)) + + +def test_issue_5925(): + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + s1, s2 = sx.subs(x, x + y), sxy + assert (s1 - s2).expand().removeO().simplify() == 0 + + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + assert sxy.subs({x:1, y:2}) == sx.subs(x, 3) + + +def test_exp_2(): + assert exp(x**3).nseries(x, 0, 14) == 1 + x**3 + x**6/2 + x**9/6 + x**12/24 + O(x**14) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_order.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_order.py new file mode 100644 index 0000000000000000000000000000000000000000..50fcb861ee2a76c730baae6d26cc1e7a00347176 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_order.py @@ -0,0 +1,503 @@ +from sympy.core.add import Add +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational, nan, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.series.order import O, Order +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises +from sympy.abc import w, x, y, z +from sympy.testing.pytest import XFAIL + + +def test_caching_bug(): + #needs to be a first test, so that all caches are clean + #cache it + O(w) + #and test that this won't raise an exception + O(w**(-1/x/log(3)*log(5)), w) + + +def test_free_symbols(): + assert Order(1).free_symbols == set() + assert Order(x).free_symbols == {x} + assert Order(1, x).free_symbols == {x} + assert Order(x*y).free_symbols == {x, y} + assert Order(x, x, y).free_symbols == {x, y} + + +def test_simple_1(): + o = Rational(0) + assert Order(2*x) == Order(x) + assert Order(x)*3 == Order(x) + assert -28*Order(x) == Order(x) + assert Order(Order(x)) == Order(x) + assert Order(Order(x), y) == Order(Order(x), x, y) + assert Order(-23) == Order(1) + assert Order(exp(x)) == Order(1, x) + assert Order(exp(1/x)).expr == exp(1/x) + assert Order(x*exp(1/x)).expr == x*exp(1/x) + assert Order(x**(o/3)).expr == x**(o/3) + assert Order(x**(o*Rational(5, 3))).expr == x**(o*Rational(5, 3)) + assert Order(x**2 + x + y, x) == O(1, x) + assert Order(x**2 + x + y, y) == O(1, y) + raises(ValueError, lambda: Order(exp(x), x, x)) + raises(TypeError, lambda: Order(x, 2 - x)) + + +def test_simple_2(): + assert Order(2*x)*x == Order(x**2) + assert Order(2*x)/x == Order(1, x) + assert Order(2*x)*x*exp(1/x) == Order(x**2*exp(1/x)) + assert (Order(2*x)*x*exp(1/x)/log(x)**3).expr == x**2*exp(1/x)*log(x)**-3 + + +def test_simple_3(): + assert Order(x) + x == Order(x) + assert Order(x) + 2 == 2 + Order(x) + assert Order(x) + x**2 == Order(x) + assert Order(x) + 1/x == 1/x + Order(x) + assert Order(1/x) + 1/x**2 == 1/x**2 + Order(1/x) + assert Order(x) + exp(1/x) == Order(x) + exp(1/x) + + +def test_simple_4(): + assert Order(x)**2 == Order(x**2) + + +def test_simple_5(): + assert Order(x) + Order(x**2) == Order(x) + assert Order(x) + Order(x**-2) == Order(x**-2) + assert Order(x) + Order(1/x) == Order(1/x) + + +def test_simple_6(): + assert Order(x) - Order(x) == Order(x) + assert Order(x) + Order(1) == Order(1) + assert Order(x) + Order(x**2) == Order(x) + assert Order(1/x) + Order(1) == Order(1/x) + assert Order(x) + Order(exp(1/x)) == Order(exp(1/x)) + assert Order(x**3) + Order(exp(2/x)) == Order(exp(2/x)) + assert Order(x**-3) + Order(exp(2/x)) == Order(exp(2/x)) + + +def test_simple_7(): + assert 1 + O(1) == O(1) + assert 2 + O(1) == O(1) + assert x + O(1) == O(1) + assert 1/x + O(1) == 1/x + O(1) + + +def test_simple_8(): + assert O(sqrt(-x)) == O(sqrt(x)) + assert O(x**2*sqrt(x)) == O(x**Rational(5, 2)) + assert O(x**3*sqrt(-(-x)**3)) == O(x**Rational(9, 2)) + assert O(x**Rational(3, 2)*sqrt((-x)**3)) == O(x**3) + assert O(x*(-2*x)**(I/2)) == O(x*(-x)**(I/2)) + + +def test_as_expr_variables(): + assert Order(x).as_expr_variables(None) == (x, ((x, 0),)) + assert Order(x).as_expr_variables(((x, 0),)) == (x, ((x, 0),)) + assert Order(y).as_expr_variables(((x, 0),)) == (y, ((x, 0), (y, 0))) + assert Order(y).as_expr_variables(((x, 0), (y, 0))) == (y, ((x, 0), (y, 0))) + + +def test_contains_0(): + assert Order(1, x).contains(Order(1, x)) + assert Order(1, x).contains(Order(1)) + assert Order(1).contains(Order(1, x)) is False + + +def test_contains_1(): + assert Order(x).contains(Order(x)) + assert Order(x).contains(Order(x**2)) + assert not Order(x**2).contains(Order(x)) + assert not Order(x).contains(Order(1/x)) + assert not Order(1/x).contains(Order(exp(1/x))) + assert not Order(x).contains(Order(exp(1/x))) + assert Order(1/x).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(1/x)) + assert Order(exp(1/x)).contains(Order(exp(1/x))) + assert Order(exp(2/x)).contains(Order(exp(1/x))) + assert not Order(exp(1/x)).contains(Order(exp(2/x))) + + +def test_contains_2(): + assert Order(x).contains(Order(y)) is None + assert Order(x).contains(Order(y*x)) + assert Order(y*x).contains(Order(x)) + assert Order(y).contains(Order(x*y)) + assert Order(x).contains(Order(y**2*x)) + + +def test_contains_3(): + assert Order(x*y**2).contains(Order(x**2*y)) is None + assert Order(x**2*y).contains(Order(x*y**2)) is None + + +def test_contains_4(): + assert Order(sin(1/x**2)).contains(Order(cos(1/x**2))) is True + assert Order(cos(1/x**2)).contains(Order(sin(1/x**2))) is True + + +def test_contains(): + assert Order(1, x) not in Order(1) + assert Order(1) in Order(1, x) + raises(TypeError, lambda: Order(x*y**2) in Order(x**2*y)) + + +def test_add_1(): + assert Order(x + x) == Order(x) + assert Order(3*x - 2*x**2) == Order(x) + assert Order(1 + x) == Order(1, x) + assert Order(1 + 1/x) == Order(1/x) + # TODO : A better output for Order(log(x) + 1/log(x)) + # could be Order(log(x)). Currently Order for expressions + # where all arguments would involve a log term would fall + # in this category and outputs for these should be improved. + assert Order(log(x) + 1/log(x)) == Order((log(x)**2 + 1)/log(x)) + assert Order(exp(1/x) + x) == Order(exp(1/x)) + assert Order(exp(1/x) + 1/x**20) == Order(exp(1/x)) + + +def test_ln_args(): + assert O(log(x)) + O(log(2*x)) == O(log(x)) + assert O(log(x)) + O(log(x**3)) == O(log(x)) + assert O(log(x*y)) + O(log(x) + log(y)) == O(log(x) + log(y), x, y) + + +def test_multivar_0(): + assert Order(x*y).expr == x*y + assert Order(x*y**2).expr == x*y**2 + assert Order(x*y, x).expr == x + assert Order(x*y**2, y).expr == y**2 + assert Order(x*y*z).expr == x*y*z + assert Order(x/y).expr == x/y + assert Order(x*exp(1/y)).expr == x*exp(1/y) + assert Order(exp(x)*exp(1/y)).expr == exp(x)*exp(1/y) + + +def test_multivar_0a(): + assert Order(exp(1/x)*exp(1/y)).expr == exp(1/x)*exp(1/y) + + +def test_multivar_1(): + assert Order(x + y).expr == x + y + assert Order(x + 2*y).expr == x + y + assert (Order(x + y) + x).expr == (x + y) + assert (Order(x + y) + x**2) == Order(x + y) + assert (Order(x + y) + 1/x) == 1/x + Order(x + y) + assert Order(x**2 + y*x).expr == x**2 + y*x + + +def test_multivar_2(): + assert Order(x**2*y + y**2*x, x, y).expr == x**2*y + y**2*x + + +def test_multivar_mul_1(): + assert Order(x + y)*x == Order(x**2 + y*x, x, y) + + +def test_multivar_3(): + assert (Order(x) + Order(y)).args in [ + (Order(x), Order(y)), + (Order(y), Order(x))] + assert Order(x) + Order(y) + Order(x + y) == Order(x + y) + assert (Order(x**2*y) + Order(y**2*x)).args in [ + (Order(x*y**2), Order(y*x**2)), + (Order(y*x**2), Order(x*y**2))] + assert (Order(x**2*y) + Order(y*x)) == Order(x*y) + + +def test_issue_3468(): + y = Symbol('y', negative=True) + z = Symbol('z', complex=True) + + # check that Order does not modify assumptions about symbols + Order(x) + Order(y) + Order(z) + + assert x.is_positive is None + assert y.is_positive is False + assert z.is_positive is None + + +def test_leading_order(): + assert (x + 1 + 1/x**5).extract_leading_order(x) == ((1/x**5, O(1/x**5)),) + assert (1 + 1/x).extract_leading_order(x) == ((1/x, O(1/x)),) + assert (1 + x).extract_leading_order(x) == ((1, O(1, x)),) + assert (1 + x**2).extract_leading_order(x) == ((1, O(1, x)),) + assert (2 + x**2).extract_leading_order(x) == ((2, O(1, x)),) + assert (x + x**2).extract_leading_order(x) == ((x, O(x)),) + + +def test_leading_order2(): + assert set((2 + pi + x**2).extract_leading_order(x)) == {(pi, O(1, x)), + (S(2), O(1, x))} + assert set((2*x + pi*x + x**2).extract_leading_order(x)) == {(2*x, O(x)), + (x*pi, O(x))} + + +def test_order_leadterm(): + assert O(x**2)._eval_as_leading_term(x, None, 1) == O(x**2) + + +def test_order_symbols(): + e = x*y*sin(x)*Integral(x, (x, 1, 2)) + assert O(e) == O(x**2*y, x, y) + assert O(e, x) == O(x**2) + + +def test_nan(): + assert O(nan) is nan + assert not O(x).contains(nan) + + +def test_O1(): + assert O(1, x) * x == O(x) + assert O(1, y) * x == O(1, y) + + +def test_getn(): + # other lines are tested incidentally by the suite + assert O(x).getn() == 1 + assert O(x/log(x)).getn() == 1 + assert O(x**2/log(x)**2).getn() == 2 + assert O(x*log(x)).getn() == 1 + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_diff(): + assert O(x**2).diff(x) == O(x) + + +def test_getO(): + assert (x).getO() is None + assert (x).removeO() == x + assert (O(x)).getO() == O(x) + assert (O(x)).removeO() == 0 + assert (z + O(x) + O(y)).getO() == O(x) + O(y) + assert (z + O(x) + O(y)).removeO() == z + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_leading_term(): + from sympy.functions.special.gamma_functions import digamma + assert O(1/digamma(1/x)) == O(1/log(x)) + + +def test_eval(): + assert Order(x).subs(Order(x), 1) == 1 + assert Order(x).subs(x, y) == Order(y) + assert Order(x).subs(y, x) == Order(x) + assert Order(x).subs(x, x + y) == Order(x + y, (x, -y)) + assert (O(1)**x).is_Pow + + +def test_issue_4279(): + a, b = symbols('a b') + assert O(a, a, b) + O(1, a, b) == O(1, a, b) + assert O(b, a, b) + O(1, a, b) == O(1, a, b) + assert O(a + b, a, b) + O(1, a, b) == O(1, a, b) + assert O(1, a, b) + O(a, a, b) == O(1, a, b) + assert O(1, a, b) + O(b, a, b) == O(1, a, b) + assert O(1, a, b) + O(a + b, a, b) == O(1, a, b) + + +def test_issue_4855(): + assert 1/O(1) != O(1) + assert 1/O(x) != O(1/x) + assert 1/O(x, (x, oo)) != O(1/x, (x, oo)) + + f = Function('f') + assert 1/O(f(x)) != O(1/x) + + +def test_order_conjugate_transpose(): + x = Symbol('x', real=True) + y = Symbol('y', imaginary=True) + assert conjugate(Order(x)) == Order(conjugate(x)) + assert conjugate(Order(y)) == Order(conjugate(y)) + assert conjugate(Order(x**2)) == Order(conjugate(x)**2) + assert conjugate(Order(y**2)) == Order(conjugate(y)**2) + assert transpose(Order(x)) == Order(transpose(x)) + assert transpose(Order(y)) == Order(transpose(y)) + assert transpose(Order(x**2)) == Order(transpose(x)**2) + assert transpose(Order(y**2)) == Order(transpose(y)**2) + + +def test_order_noncommutative(): + A = Symbol('A', commutative=False) + assert Order(A + A*x, x) == Order(1, x) + assert (A + A*x)*Order(x) == Order(x) + assert (A*x)*Order(x) == Order(x**2, x) + assert expand((1 + Order(x))*A*A*x) == A*A*x + Order(x**2, x) + assert expand((A*A + Order(x))*x) == A*A*x + Order(x**2, x) + assert expand((A + Order(x))*A*x) == A*A*x + Order(x**2, x) + + +def test_issue_6753(): + assert (1 + x**2)**10000*O(x) == O(x) + + +def test_order_at_infinity(): + assert Order(1 + x, (x, oo)) == Order(x, (x, oo)) + assert Order(3*x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo))*3 == Order(x, (x, oo)) + assert -28*Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (y, oo)) == Order(x, (x, oo), (y, oo)) + assert Order(3, (x, oo)) == Order(1, (x, oo)) + assert Order(x**2 + x + y, (x, oo)) == O(x**2, (x, oo)) + assert Order(x**2 + x + y, (y, oo)) == O(y, (y, oo)) + + assert Order(2*x, (x, oo))*x == Order(x**2, (x, oo)) + assert Order(2*x, (x, oo))/x == Order(1, (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x) == Order(x**2*exp(1/x), (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x)/log(x)**3 == Order(x**2*exp(1/x)*log(x)**-3, (x, oo)) + + assert Order(x, (x, oo)) + 1/x == 1/x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + 1 == 1 + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x == x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x**2 == x**2 + Order(x, (x, oo)) + assert Order(1/x, (x, oo)) + 1/x**2 == 1/x**2 + Order(1/x, (x, oo)) == Order(1/x, (x, oo)) + assert Order(x, (x, oo)) + exp(1/x) == exp(1/x) + Order(x, (x, oo)) + + assert Order(x, (x, oo))**2 == Order(x**2, (x, oo)) + + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(x, (x, oo)) + Order(x**-2, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1/x, (x, oo)) == Order(x, (x, oo)) + + assert Order(x, (x, oo)) - Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(1/x, (x, oo)) + Order(1, (x, oo)) == Order(1, (x, oo)) + assert Order(x, (x, oo)) + Order(exp(1/x), (x, oo)) == Order(x, (x, oo)) + assert Order(x**3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(x**3, (x, oo)) + assert Order(x**-3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(exp(2/x), (x, oo)) + + # issue 7207 + assert Order(exp(x), (x, oo)).expr == Order(2*exp(x), (x, oo)).expr == exp(x) + assert Order(y**x, (x, oo)).expr == Order(2*y**x, (x, oo)).expr == exp(x*log(y)) + + # issue 19545 + assert Order(1/x - 3/(3*x + 2), (x, oo)).expr == x**(-2) + +def test_mixing_order_at_zero_and_infinity(): + assert (Order(x, (x, 0)) + Order(x, (x, oo))).is_Add + assert Order(x, (x, 0)) + Order(x, (x, oo)) == Order(x, (x, oo)) + Order(x, (x, 0)) + assert Order(Order(x, (x, oo))) == Order(x, (x, oo)) + + # not supported (yet) + raises(NotImplementedError, lambda: Order(x, (x, 0))*Order(x, (x, oo))) + raises(NotImplementedError, lambda: Order(x, (x, oo))*Order(x, (x, 0))) + raises(NotImplementedError, lambda: Order(Order(x, (x, oo)), y)) + raises(NotImplementedError, lambda: Order(Order(x), (x, oo))) + + +def test_order_at_some_point(): + assert Order(x, (x, 1)) == Order(1, (x, 1)) + assert Order(2*x - 2, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(-x + 1, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(x - 1, (x, 1))**2 == Order((x - 1)**2, (x, 1)) + assert Order(x - 2, (x, 2)) - O(x - 2, (x, 2)) == Order(x - 2, (x, 2)) + + +def test_order_subs_limits(): + # issue 3333 + assert (1 + Order(x)).subs(x, 1/x) == 1 + Order(1/x, (x, oo)) + assert (1 + Order(x)).limit(x, 0) == 1 + # issue 5769 + assert ((x + Order(x**2))/x).limit(x, 0) == 1 + + assert Order(x**2).subs(x, y - 1) == Order((y - 1)**2, (y, 1)) + assert Order(10*x**2, (x, 2)).subs(x, y - 1) == Order(1, (y, 3)) + + #issue 19120 + assert O(x).subs(x, O(x)) == O(x) + assert O(x**2).subs(x, x + O(x)) == O(x**2) + assert O(x, (x, oo)).subs(x, O(x, (x, oo))) == O(x, (x, oo)) + assert O(x**2, (x, oo)).subs(x, x + O(x, (x, oo))) == O(x**2, (x, oo)) + assert (x + O(x**2)).subs(x, x + O(x**2)) == x + O(x**2) + assert (x**2 + O(x**2) + 1/x**2).subs(x, x + O(x**2)) == (x + O(x**2))**(-2) + O(x**2) + assert (x**2 + O(x**2) + 1).subs(x, x + O(x**2)) == 1 + O(x**2) + assert O(x, (x, oo)).subs(x, x + O(x**2, (x, oo))) == O(x**2, (x, oo)) + assert sin(x).series(n=8).subs(x,sin(x).series(n=8)).expand() == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + assert cos(x).series(n=8).subs(x,sin(x).series(n=8)).expand() == 1 - x**2/2 + 5*x**4/24 - 37*x**6/720 + O(x**8) + assert O(x).subs(x, O(1/x, (x, oo))) == O(1/x, (x, oo)) + +@XFAIL +def test_order_failing_due_to_solveset(): + assert O(x**3).subs(x, exp(-x**2)) == O(exp(-3*x**2), (x, -oo)) + raises(NotImplementedError, lambda: O(x).subs(x, O(1/x))) # mixing of order at different points + + +def test_issue_9351(): + assert exp(x).series(x, 10, 1) == exp(10) + Order(x - 10, (x, 10)) + + +def test_issue_9192(): + assert O(1)*O(1) == O(1) + assert O(1)**O(1) == O(1) + + +def test_issue_9910(): + assert O(x*log(x) + sin(x), (x, oo)) == O(x*log(x), (x, oo)) + + +def test_performance_of_adding_order(): + l = [x**i for i in range(1000)] + l.append(O(x**1001)) + assert Add(*l).subs(x,1) == O(1) + +def test_issue_14622(): + assert (x**(-4) + x**(-3) + x**(-1) + O(x**(-6), (x, oo))).as_numer_denom() == ( + x**4 + x**5 + x**7 + O(x**2, (x, oo)), x**8) + assert (x**3 + O(x**2, (x, oo))).is_Add + assert O(x**2, (x, oo)).contains(x**3) is False + assert O(x, (x, oo)).contains(O(x, (x, 0))) is None + assert O(x, (x, 0)).contains(O(x, (x, oo))) is None + raises(NotImplementedError, lambda: O(x**3).contains(x**w)) + + +def test_issue_15539(): + assert O(1/x**2 + 1/x**4, (x, -oo)) == O(1/x**2, (x, -oo)) + assert O(1/x**4 + exp(x), (x, -oo)) == O(1/x**4, (x, -oo)) + assert O(1/x**4 + exp(-x), (x, -oo)) == O(exp(-x), (x, -oo)) + assert O(1/x, (x, oo)).subs(x, -x) == O(-1/x, (x, -oo)) + +def test_issue_18606(): + assert unchanged(Order, 0) + + +def test_issue_22165(): + assert O(log(x)).contains(2) + + +def test_issue_23231(): + # This test checks Order for expressions having + # arguments containing variables in exponents/powers. + assert O(x**x + 2**x, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + 1/x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(2**x + 3**x , (x, oo)) == O(exp(x*log(3)), (x, oo)) + + +def test_issue_9917(): + assert O(x*sin(x) + 1, (x, oo)) == O(x, (x, oo)) + + +def test_issue_22836(): + assert O(2**x + factorial(x), (x, oo)) == O(factorial(x), (x, oo)) + assert O(2**x + factorial(x) + x**x, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x + factorial(x), (x, oo)) == O(factorial(x), (x, oo)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_residues.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_residues.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7d075a56500d008e3c8b46c1fda5db890fd76a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_residues.py @@ -0,0 +1,101 @@ +from sympy.core.function import Function +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cot, sin, tan) +from sympy.series.residues import residue +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, z, a, s, k + + +def test_basic1(): + assert residue(1/x, x, 0) == 1 + assert residue(-2/x, x, 0) == -2 + assert residue(81/x, x, 0) == 81 + assert residue(1/x**2, x, 0) == 0 + assert residue(0, x, 0) == 0 + assert residue(5, x, 0) == 0 + assert residue(x, x, 0) == 0 + assert residue(x**2, x, 0) == 0 + + +def test_basic2(): + assert residue(1/x, x, 1) == 0 + assert residue(-2/x, x, 1) == 0 + assert residue(81/x, x, -1) == 0 + assert residue(1/x**2, x, 1) == 0 + assert residue(0, x, 1) == 0 + assert residue(5, x, 1) == 0 + assert residue(x, x, 1) == 0 + assert residue(x**2, x, 5) == 0 + + +def test_f(): + f = Function("f") + assert residue(f(x)/x**5, x, 0) == f(x).diff(x, 4).subs(x, 0)/24 + + +def test_functions(): + assert residue(1/sin(x), x, 0) == 1 + assert residue(2/sin(x), x, 0) == 2 + assert residue(1/sin(x)**2, x, 0) == 0 + assert residue(1/sin(x)**5, x, 0) == Rational(3, 8) + + +def test_expressions(): + assert residue(1/(x + 1), x, 0) == 0 + assert residue(1/(x + 1), x, -1) == 1 + assert residue(1/(x**2 + 1), x, -1) == 0 + assert residue(1/(x**2 + 1), x, I) == -I/2 + assert residue(1/(x**2 + 1), x, -I) == I/2 + assert residue(1/(x**4 + 1), x, 0) == 0 + assert residue(1/(x**4 + 1), x, exp(I*pi/4)).equals(-(Rational(1, 4) + I/4)/sqrt(2)) + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/4/a**3 + + +@XFAIL +def test_expressions_failing(): + n = Symbol('n', integer=True, positive=True) + assert residue(exp(z)/(z - pi*I/4*a)**n, z, I*pi*a) == \ + exp(I*pi*a/4)/factorial(n - 1) + + +def test_NotImplemented(): + raises(NotImplementedError, lambda: residue(exp(1/z), z, 0)) + + +def test_bug(): + assert residue(2**(z)*(s + z)*(1 - s - z)/z**2, z, 0) == \ + 1 + s*log(2) - s**2*log(2) - 2*s + + +def test_issue_5654(): + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/(4*a**3) + assert residue(1/s*1/(z - exp(s)), s, 0) == 1/(z - 1) + assert residue((1 + k)/s*1/(z - exp(s)), s, 0) == k/(z - 1) + 1/(z - 1) + + +def test_issue_6499(): + assert residue(1/(exp(z) - 1), z, 0) == 1 + + +def test_issue_14037(): + assert residue(sin(x**50)/x**51, x, 0) == 1 + + +def test_issue_21176(): + f = x**2*cot(pi*x)/(x**4 + 1) + assert residue(f, x, -sqrt(2)/2 - sqrt(2)*I/2).cancel().together(deep=True)\ + == sqrt(2)*(1 - I)/(8*tan(sqrt(2)*pi*(1 + I)/2)) + + +def test_issue_21177(): + r = -sqrt(3)*tanh(sqrt(3)*pi/2)/3 + a = residue(cot(pi*x)/((x - 1)*(x - 2) + 1), x, S(3)/2 - sqrt(3)*I/2) + b = residue(cot(pi*x)/(x**2 - 3*x + 3), x, S(3)/2 - sqrt(3)*I/2) + assert a == r + assert (b - a).cancel() == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_sequences.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_sequences.py new file mode 100644 index 0000000000000000000000000000000000000000..61e276ad67982f0a9877de3548d70238976d28a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_sequences.py @@ -0,0 +1,312 @@ +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.numbers import oo, Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.functions.combinatorial.numbers import tribonacci, fibonacci +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.series import EmptySequence +from sympy.series.sequences import (SeqMul, SeqAdd, SeqPer, SeqFormula, + sequence) +from sympy.sets.sets import Interval +from sympy.tensor.indexed import Indexed, Idx +from sympy.series.sequences import SeqExpr, SeqExprOp, RecursiveSeq +from sympy.testing.pytest import raises, slow + +x, y, z = symbols('x y z') +n, m = symbols('n m') + + +def test_EmptySequence(): + assert S.EmptySequence is EmptySequence + + assert S.EmptySequence.interval is S.EmptySet + assert S.EmptySequence.length is S.Zero + + assert list(S.EmptySequence) == [] + + +def test_SeqExpr(): + #SeqExpr is a baseclass and does not take care of + #ensuring all arguments are Basics hence the use of + #Tuple(...) here. + s = SeqExpr(Tuple(1, n, y), Tuple(x, 0, 10)) + + assert isinstance(s, SeqExpr) + assert s.gen == (1, n, y) + assert s.interval == Interval(0, 10) + assert s.start == 0 + assert s.stop == 10 + assert s.length == 11 + assert s.variables == (x,) + + assert SeqExpr(Tuple(1, 2, 3), Tuple(x, 0, oo)).length is oo + + +def test_SeqPer(): + s = SeqPer((1, n, 3), (x, 0, 5)) + + assert isinstance(s, SeqPer) + assert s.periodical == Tuple(1, n, 3) + assert s.period == 3 + assert s.coeff(3) == 1 + assert s.free_symbols == {n} + + assert list(s) == [1, n, 3, 1, n, 3] + assert s[:] == [1, n, 3, 1, n, 3] + assert SeqPer((1, n, 3), (x, -oo, 0))[0:6] == [1, n, 3, 1, n, 3] + + raises(ValueError, lambda: SeqPer((1, 2, 3), (0, 1, 2))) + raises(ValueError, lambda: SeqPer((1, 2, 3), (x, -oo, oo))) + raises(ValueError, lambda: SeqPer(n**2, (0, oo))) + + assert SeqPer((n, n**2, n**3), (m, 0, oo))[:6] == \ + [n, n**2, n**3, n, n**2, n**3] + assert SeqPer((n, n**2, n**3), (n, 0, oo))[:6] == [0, 1, 8, 3, 16, 125] + assert SeqPer((n, m), (n, 0, oo))[:6] == [0, m, 2, m, 4, m] + + +def test_SeqFormula(): + s = SeqFormula(n**2, (n, 0, 5)) + + assert isinstance(s, SeqFormula) + assert s.formula == n**2 + assert s.coeff(3) == 9 + + assert list(s) == [i**2 for i in range(6)] + assert s[:] == [i**2 for i in range(6)] + assert SeqFormula(n**2, (n, -oo, 0))[0:6] == [i**2 for i in range(6)] + + assert SeqFormula(n**2, (0, oo)) == SeqFormula(n**2, (n, 0, oo)) + + assert SeqFormula(n**2, (0, m)).subs(m, x) == SeqFormula(n**2, (0, x)) + assert SeqFormula(m*n**2, (n, 0, oo)).subs(m, x) == \ + SeqFormula(x*n**2, (n, 0, oo)) + + raises(ValueError, lambda: SeqFormula(n**2, (0, 1, 2))) + raises(ValueError, lambda: SeqFormula(n**2, (n, -oo, oo))) + raises(ValueError, lambda: SeqFormula(m*n**2, (0, oo))) + + seq = SeqFormula(x*(y**2 + z), (z, 1, 100)) + assert seq.expand() == SeqFormula(x*y**2 + x*z, (z, 1, 100)) + seq = SeqFormula(sin(x*(y**2 + z)),(z, 1, 100)) + assert seq.expand(trig=True) == SeqFormula(sin(x*y**2)*cos(x*z) + sin(x*z)*cos(x*y**2), (z, 1, 100)) + assert seq.expand() == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(trig=False) == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + seq = SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + assert seq.expand() == SeqFormula(exp(x*y**2)*exp(x*z), (z, 1, 100)) + assert seq.expand(power_exp=False) == SeqFormula(exp(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(mul=False, power_exp=False) == SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + +def test_sequence(): + form = SeqFormula(n**2, (n, 0, 5)) + per = SeqPer((1, 2, 3), (n, 0, 5)) + inter = SeqFormula(n**2) + + assert sequence(n**2, (n, 0, 5)) == form + assert sequence((1, 2, 3), (n, 0, 5)) == per + assert sequence(n**2) == inter + + +def test_SeqExprOp(): + form = SeqFormula(n**2, (n, 0, 10)) + per = SeqPer((1, 2, 3), (m, 5, 10)) + + s = SeqExprOp(form, per) + assert s.gen == (n**2, (1, 2, 3)) + assert s.interval == Interval(5, 10) + assert s.start == 5 + assert s.stop == 10 + assert s.length == 6 + assert s.variables == (n, m) + + +def test_SeqAdd(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqAdd() == S.EmptySequence + assert SeqAdd(S.EmptySequence) == S.EmptySequence + assert SeqAdd(per) == per + assert SeqAdd(per, S.EmptySequence) == per + assert SeqAdd(per_bou, form_bou) == S.EmptySequence + + s = SeqAdd(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [2, 6, 10, 18, 26] + assert list(s) == [2, 6, 10, 18, 26] + + assert isinstance(SeqAdd(per, per_bou, evaluate=False), SeqAdd) + + s1 = SeqAdd(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((2, 4, 4, 3, 3, 5), (n, 1, 5)) + s2 = SeqAdd(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(2*n**2, (6, 10)) + + assert SeqAdd(form, form_bou, per) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(form, SeqAdd(form_bou, per)) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(per, SeqAdd(form, form_bou), evaluate=False) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + + assert SeqAdd(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (m, 0, oo))) == \ + SeqPer((2, 4), (n, 0, oo)) + + +def test_SeqMul(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (n, 6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqMul() == S.EmptySequence + assert SeqMul(S.EmptySequence) == S.EmptySequence + assert SeqMul(per) == per + assert SeqMul(per, S.EmptySequence) == S.EmptySequence + assert SeqMul(per_bou, form_bou) == S.EmptySequence + + s = SeqMul(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [1, 8, 9, 32, 25] + assert list(s) == [1, 8, 9, 32, 25] + + assert isinstance(SeqMul(per, per_bou, evaluate=False), SeqMul) + + s1 = SeqMul(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((1, 4, 3, 2, 2, 6), (n, 1, 5)) + s2 = SeqMul(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(n**4, (6, 10)) + + assert SeqMul(form, form_bou, per) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(form, SeqMul(form_bou, per)) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(per, SeqMul(form, form_bou2, + evaluate=False), evaluate=False) == \ + SeqMul(form, per, form_bou2, evaluate=False) + + assert SeqMul(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) == \ + SeqPer((1, 4), (n, 0, oo)) + + +def test_add(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per + (SeqPer((2, 3))) == SeqPer((3, 5), (n, 0, oo)) + assert form + SeqFormula(n**3) == SeqFormula(n**2 + n**3) + + assert per + form == SeqAdd(per, form) + + raises(TypeError, lambda: per + n) + raises(TypeError, lambda: n + per) + + +def test_sub(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per - (SeqPer((2, 3))) == SeqPer((-1, -1), (n, 0, oo)) + assert form - (SeqFormula(n**3)) == SeqFormula(n**2 - n**3) + + assert per - form == SeqAdd(per, -form) + + raises(TypeError, lambda: per - n) + raises(TypeError, lambda: n - per) + + +def test_mul__coeff_mul(): + assert SeqPer((1, 2), (n, 0, oo)).coeff_mul(2) == SeqPer((2, 4), (n, 0, oo)) + assert SeqFormula(n**2).coeff_mul(2) == SeqFormula(2*n**2) + assert S.EmptySequence.coeff_mul(100) == S.EmptySequence + + assert SeqPer((1, 2), (n, 0, oo)) * (SeqPer((2, 3))) == \ + SeqPer((2, 6), (n, 0, oo)) + assert SeqFormula(n**2) * SeqFormula(n**3) == SeqFormula(n**5) + + assert S.EmptySequence * SeqFormula(n**2) == S.EmptySequence + assert SeqFormula(n**2) * S.EmptySequence == S.EmptySequence + + raises(TypeError, lambda: sequence(n**2) * n) + raises(TypeError, lambda: n * sequence(n**2)) + + +def test_neg(): + assert -SeqPer((1, -2), (n, 0, oo)) == SeqPer((-1, 2), (n, 0, oo)) + assert -SeqFormula(n**2) == SeqFormula(-n**2) + + +def test_operations(): + per = SeqPer((1, 2), (n, 0, oo)) + per2 = SeqPer((2, 4), (n, 0, oo)) + form = SeqFormula(n**2) + form2 = SeqFormula(n**3) + + assert per + form + form2 == SeqAdd(per, form, form2) + assert per + form - form2 == SeqAdd(per, form, -form2) + assert per + form - S.EmptySequence == SeqAdd(per, form) + assert per + per2 + form == SeqAdd(SeqPer((3, 6), (n, 0, oo)), form) + assert S.EmptySequence - per == -per + assert form + form == SeqFormula(2*n**2) + + assert per * form * form2 == SeqMul(per, form, form2) + assert form * form == SeqFormula(n**4) + assert form * -form == SeqFormula(-n**4) + + assert form * (per + form2) == SeqMul(form, SeqAdd(per, form2)) + assert form * (per + per) == SeqMul(form, per2) + + assert form.coeff_mul(m) == SeqFormula(m*n**2, (n, 0, oo)) + assert per.coeff_mul(m) == SeqPer((m, 2*m), (n, 0, oo)) + + +def test_Idx_limits(): + i = symbols('i', cls=Idx) + r = Indexed('r', i) + + assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)] + assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2] + + +@slow +def test_find_linear_recurrence(): + assert sequence((0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55), \ + (n, 0, 10)).find_linear_recurrence(11) == [1, 1] + assert sequence((1, 2, 4, 7, 28, 128, 582, 2745, 13021, 61699, 292521, \ + 1387138), (n, 0, 11)).find_linear_recurrence(12) == [5, -2, 6, -11] + assert sequence(x*n**3+y*n, (n, 0, oo)).find_linear_recurrence(10) \ + == [4, -6, 4, -1] + assert sequence(x**n, (n,0,20)).find_linear_recurrence(21) == [x] + assert sequence((1,2,3)).find_linear_recurrence(10, 5) == [0, 0, 1] + assert sequence(((1 + sqrt(5))/2)**n + \ + (-(1 + sqrt(5))/2)**(-n)).find_linear_recurrence(10) == [1, 1] + assert sequence(x*((1 + sqrt(5))/2)**n + y*(-(1 + sqrt(5))/2)**(-n), \ + (n,0,oo)).find_linear_recurrence(10) == [1, 1] + assert sequence((1,2,3,4,6),(n, 0, 4)).find_linear_recurrence(5) == [] + assert sequence((2,3,4,5,6,79),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([], None) + assert sequence((2,3,4,5,8,30),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([Rational(19, 2), -20, Rational(27, 2)], (-31*x**2 + 32*x - 4)/(27*x**3 - 40*x**2 + 19*x -2)) + assert sequence(fibonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1], -x/(x**2 + x - 1)) + assert sequence(tribonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1, 1], -x/(x**3 + x**2 + x - 1)) + +def test_RecursiveSeq(): + y = Function('y') + n = Symbol('n') + fib = RecursiveSeq(y(n - 1) + y(n - 2), y(n), n, [0, 1]) + assert fib.coeff(3) == 2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_series.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f3c122b98c14a58c6d5c6636cbb53e1e66a75d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/series/tests/test_series.py @@ -0,0 +1,421 @@ +from sympy.core.evalf import N +from sympy.core.function import (Derivative, Function, PoleError, Subs) +from sympy.core.numbers import (E, Float, Rational, oo, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral, integrate +from sympy.series.order import O +from sympy.series.series import series +from sympy.abc import x, y, n, k +from sympy.testing.pytest import raises +from sympy.core import EulerGamma + + +def test_sin(): + e1 = sin(x).series(x, 0) + e2 = series(sin(x), x, 0) + assert e1 == e2 + + +def test_cos(): + e1 = cos(x).series(x, 0) + e2 = series(cos(x), x, 0) + assert e1 == e2 + + +def test_exp(): + e1 = exp(x).series(x, 0) + e2 = series(exp(x), x, 0) + assert e1 == e2 + + +def test_exp2(): + e1 = exp(cos(x)).series(x, 0) + e2 = series(exp(cos(x)), x, 0) + assert e1 == e2 + + +def test_issue_5223(): + assert series(1, x) == 1 + assert next(S.Zero.lseries(x)) == 0 + assert cos(x).series() == cos(x).series(x) + raises(ValueError, lambda: cos(x + y).series()) + raises(ValueError, lambda: x.series(dir="")) + + assert (cos(x).series(x, 1) - + cos(x + 1).series(x).subs(x, x - 1)).removeO() == 0 + e = cos(x).series(x, 1, n=None) + assert [next(e) for i in range(2)] == [cos(1), -((x - 1)*sin(1))] + e = cos(x).series(x, 1, n=None, dir='-') + assert [next(e) for i in range(2)] == [cos(1), (1 - x)*sin(1)] + # the following test is exact so no need for x -> x - 1 replacement + assert abs(x).series(x, 1, dir='-') == x + assert exp(x).series(x, 1, dir='-', n=3).removeO() == \ + E - E*(-x + 1) + E*(-x + 1)**2/2 + + D = Derivative + assert D(x**2 + x**3*y**2, x, 2, y, 1).series(x).doit() == 12*x*y + assert next(D(cos(x), x).lseries()) == D(1, x) + assert D( + exp(x), x).series(n=3) == D(1, x) + D(x, x) + D(x**2/2, x) + D(x**3/6, x) + O(x**3) + + assert Integral(x, (x, 1, 3), (y, 1, x)).series(x) == -4 + 4*x + + assert (1 + x + O(x**2)).getn() == 2 + assert (1 + x).getn() is None + + raises(PoleError, lambda: ((1/sin(x))**oo).series()) + logx = Symbol('logx') + assert ((sin(x))**y).nseries(x, n=1, logx=logx) == \ + exp(y*logx) + O(x*exp(y*logx), x) + + assert sin(1/x).series(x, oo, n=5) == 1/x - 1/(6*x**3) + O(x**(-5), (x, oo)) + assert abs(x).series(x, oo, n=5, dir='+') == x + assert abs(x).series(x, -oo, n=5, dir='-') == -x + assert abs(-x).series(x, oo, n=5, dir='+') == x + assert abs(-x).series(x, -oo, n=5, dir='-') == -x + + assert exp(x*log(x)).series(n=3) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + O(x**3*log(x)**3) + # XXX is this right? If not, fix "ngot > n" handling in expr. + p = Symbol('p', positive=True) + assert exp(sqrt(p)**3*log(p)).series(n=3) == \ + 1 + p**S('3/2')*log(p) + O(p**3*log(p)**3) + + assert exp(sin(x)*log(x)).series(n=2) == 1 + x*log(x) + O(x**2*log(x)**2) + + +def test_issue_6350(): + expr = integrate(exp(k*(y**3 - 3*y)), (y, 0, oo), conds='none') + assert expr.series(k, 0, 3) == -(-1)**(S(2)/3)*sqrt(3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(6*pi*k**(S(1)/3)) - \ + sqrt(3)*k*gamma(-S(2)/3)*gamma(-S(1)/3)/(6*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(1)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(6*pi) - \ + (-1)**(S(2)/3)*sqrt(3)*k**(S(5)/3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(4*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(7)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(8*pi) + O(k**3) + + +def test_issue_11313(): + assert Integral(cos(x), x).series(x) == sin(x).series(x) + assert Derivative(sin(x), x).series(x, n=3).doit() == cos(x).series(x, n=3) + + assert Derivative(x**3, x).as_leading_term(x) == 3*x**2 + assert Derivative(x**3, y).as_leading_term(x) == 0 + assert Derivative(sin(x), x).as_leading_term(x) == 1 + assert Derivative(cos(x), x).as_leading_term(x) == -x + + # This result is equivalent to zero, zero is not return because + # `Expr.series` doesn't currently detect an `x` in its `free_symbol`s. + assert Derivative(1, x).as_leading_term(x) == Derivative(1, x) + + assert Derivative(exp(x), x).series(x).doit() == exp(x).series(x) + assert 1 + Integral(exp(x), x).series(x) == exp(x).series(x) + + assert Derivative(log(x), x).series(x).doit() == (1/x).series(x) + assert Integral(log(x), x).series(x) == Integral(log(x), x).doit().series(x).removeO() + + +def test_series_of_Subs(): + from sympy.abc import z + + subs1 = Subs(sin(x), x, y) + subs2 = Subs(sin(x) * cos(z), x, y) + subs3 = Subs(sin(x * z), (x, z), (y, x)) + + assert subs1.series(x) == subs1 + subs1_series = (Subs(x, x, y) + Subs(-x**3/6, x, y) + + Subs(x**5/120, x, y) + O(y**6)) + assert subs1.series() == subs1_series + assert subs1.series(y) == subs1_series + assert subs1.series(z) == subs1 + assert subs2.series(z) == (Subs(z**4*sin(x)/24, x, y) + + Subs(-z**2*sin(x)/2, x, y) + Subs(sin(x), x, y) + O(z**6)) + assert subs3.series(x).doit() == subs3.doit().series(x) + assert subs3.series(z).doit() == sin(x*y) + + raises(ValueError, lambda: Subs(x + 2*y, y, z).series()) + assert Subs(x + y, y, z).series(x).doit() == x + z + + +def test_issue_3978(): + f = Function('f') + assert f(x).series(x, 0, 3, dir='-') == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x).series(x, 0, 3) == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x**2).series(x, 0, 3) == \ + f(0) + x**2*Subs(Derivative(f(x), x), x, 0) + O(x**3) + assert f(x**2+1).series(x, 0, 3) == \ + f(1) + x**2*Subs(Derivative(f(x), x), x, 1) + O(x**3) + + class TestF(Function): + pass + + assert TestF(x).series(x, 0, 3) == TestF(0) + \ + x*Subs(Derivative(TestF(x), x), x, 0) + \ + x**2*Subs(Derivative(TestF(x), x, x), x, 0)/2 + O(x**3) + +from sympy.series.acceleration import richardson, shanks +from sympy.concrete.summations import Sum +from sympy.core.numbers import Integer + + +def test_acceleration(): + e = (1 + 1/n)**n + assert round(richardson(e, n, 10, 20).evalf(), 10) == round(E.evalf(), 10) + + A = Sum(Integer(-1)**(k + 1) / k, (k, 1, n)) + assert round(shanks(A, n, 25).evalf(), 4) == round(log(2).evalf(), 4) + assert round(shanks(A, n, 25, 5).evalf(), 10) == round(log(2).evalf(), 10) + + +def test_issue_5852(): + assert series(1/cos(x/log(x)), x, 0) == 1 + x**2/(2*log(x)**2) + \ + 5*x**4/(24*log(x)**4) + O(x**6) + + +def test_issue_4583(): + assert cos(1 + x + x**2).series(x, 0, 5) == cos(1) - x*sin(1) + \ + x**2*(-sin(1) - cos(1)/2) + x**3*(-cos(1) + sin(1)/6) + \ + x**4*(-11*cos(1)/24 + sin(1)/2) + O(x**5) + + +def test_issue_6318(): + eq = (1/x)**Rational(2, 3) + assert (eq + 1).as_leading_term(x) == eq + + +def test_x_is_base_detection(): + eq = (x**2)**Rational(2, 3) + assert eq.series() == x**Rational(4, 3) + + +def test_issue_7203(): + assert series(cos(x), x, pi, 3) == \ + -1 + (x - pi)**2/2 + O((x - pi)**3, (x, pi)) + + +def test_exp_product_positive_factors(): + a, b = symbols('a, b', positive=True) + x = a * b + assert series(exp(x), x, n=8) == 1 + a*b + a**2*b**2/2 + \ + a**3*b**3/6 + a**4*b**4/24 + a**5*b**5/120 + a**6*b**6/720 + \ + a**7*b**7/5040 + O(a**8*b**8, a, b) + + +def test_issue_8805(): + assert series(1, n=8) == 1 + + +def test_issue_9173(): + p0,p1,p2,p3,b0,b1,b2=symbols('p0 p1 p2 p3 b0 b1 b2') + Q=(p0+(p1+(p2+p3/y)/y)/y)/(1+((p3/(b0*y)+(b0*p2-b1*p3)/b0**2)/y+\ + (b0**2*p1-b0*b1*p2-p3*(b0*b2-b1**2))/b0**3)/y) + + series = Q.series(y,n=3) + + assert series == y*(b0*p2/p3+b0*(-p2/p3+b1/b0))+y**2*(b0*p1/p3+b0*p2*\ + (-p2/p3+b1/b0)/p3+b0*(-p1/p3+(p2/p3-b1/b0)**2+b1*p2/(b0*p3)+\ + b2/b0-b1**2/b0**2))+b0+O(y**3) + assert series.simplify() == b2*y**2 + b1*y + b0 + O(y**3) + + +def test_issue_9549(): + y = (x**2 + x + 1) / (x**3 + x**2) + assert series(y, x, oo) == x**(-5) - 1/x**4 + x**(-3) + 1/x + O(x**(-6), (x, oo)) + + +def test_issue_10761(): + assert series(1/(x**-2 + x**-3), x, 0) == x**3 - x**4 + x**5 + O(x**6) + + +def test_issue_12578(): + y = (1 - 1/(x/2 - 1/(2*x))**4)**(S(1)/8) + assert y.series(x, 0, n=17) == 1 - 2*x**4 - 8*x**6 - 34*x**8 - 152*x**10 - 714*x**12 - \ + 3472*x**14 - 17318*x**16 + O(x**17) + + +def test_issue_12791(): + beta = symbols('beta', positive=True) + theta, varphi = symbols('theta varphi', real=True) + + expr = (-beta**2*varphi*sin(theta) + beta**2*cos(theta) + \ + beta*varphi*sin(theta) - beta*cos(theta) - beta + 1)/(beta*cos(theta) - 1)**2 + + sol = (0.5/(0.5*cos(theta) - 1.0)**2 - 0.25*cos(theta)/(0.5*cos(theta) - 1.0)**2 + + (beta - 0.5)*(-0.25*varphi*sin(2*theta) - 1.5*cos(theta) + + 0.25*cos(2*theta) + 1.25)/((0.5*cos(theta) - 1.0)**2*(0.5*cos(theta) - 1.0)) + + 0.25*varphi*sin(theta)/(0.5*cos(theta) - 1.0)**2 + + O((beta - S.Half)**2, (beta, S.Half))) + + assert expr.series(beta, 0.5, 2).trigsimp() == sol + + +def test_issue_14384(): + x, a = symbols('x a') + assert series(x**a, x) == x**a + assert series(x**(-2*a), x) == x**(-2*a) + assert series(exp(a*log(x)), x) == exp(a*log(x)) + raises(PoleError, lambda: series(x**I, x)) + raises(PoleError, lambda: series(x**(I + 1), x)) + raises(PoleError, lambda: series(exp(I*log(x)), x)) + + +def test_issue_14885(): + assert series(x**Rational(-3, 2)*exp(x), x, 0) == (x**Rational(-3, 2) + 1/sqrt(x) + + sqrt(x)/2 + x**Rational(3, 2)/6 + x**Rational(5, 2)/24 + x**Rational(7, 2)/120 + + x**Rational(9, 2)/720 + x**Rational(11, 2)/5040 + O(x**6)) + + +def test_issue_15539(): + assert series(atan(x), x, -oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + + O(x**(-6), (x, -oo))) + assert series(atan(x), x, oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + + O(x**(-6), (x, oo))) + + +def test_issue_7259(): + assert series(LambertW(x), x) == x - x**2 + 3*x**3/2 - 8*x**4/3 + 125*x**5/24 + O(x**6) + assert series(LambertW(x**2), x, n=8) == x**2 - x**4 + 3*x**6/2 + O(x**8) + assert series(LambertW(sin(x)), x, n=4) == x - x**2 + 4*x**3/3 + O(x**4) + +def test_issue_11884(): + assert cos(x).series(x, 1, n=1) == cos(1) + O(x - 1, (x, 1)) + + +def test_issue_18008(): + y = x*(1 + x*(1 - x))/((1 + x*(1 - x)) - (1 - x)*(1 - x)) + assert y.series(x, oo, n=4) == -9/(32*x**3) - 3/(16*x**2) - 1/(8*x) + S(1)/4 + x/2 + \ + O(x**(-4), (x, oo)) + + +def test_issue_18842(): + f = log(x/(1 - x)) + assert f.series(x, 0.491, n=1).removeO().nsimplify() == \ + -S(180019443780011)/5000000000000000 + + +def test_issue_19534(): + dt = symbols('dt', real=True) + expr = 16*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0)/45 + \ + 49*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.051640768506639183825*dt + \ + dt*(1/2 - sqrt(21)/14) + 1.0)/180 + 49*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 1.1854881643947648988*dt + dt*(sqrt(21)/14 + 1/2) + 1.0)/180 + \ + dt*(0.66666666666666666667*dt*(2.0*dt + 1.0) + \ + 6.0173399699313066769*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 4.1117044797036320069*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) - \ + 7.0189140975801991157*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) + \ + 0.94010945196161777522*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 0.35816132904077632692*dt + 1.0) + 5.5065024887242400038*dt + 1.0)/20 + dt/20 + 1 + + assert N(expr.series(dt, 0, 8), 20) == ( + - Float('0.00092592592592592596126289', precision=70) * dt**7 + + Float('0.0027777777777777783174695', precision=70) * dt**6 + + Float('0.016666666666666656027029', precision=70) * dt**5 + + Float('0.083333333333333300951828', precision=70) * dt**4 + + Float('0.33333333333333337034077', precision=70) * dt**3 + + Float('1.0', precision=70) * dt**2 + + Float('1.0', precision=70) * dt + + Float('1.0', precision=70) + ) + + +def test_issue_11407(): + a, b, c, x = symbols('a b c x') + assert series(sqrt(a + b + c*x), x, 0, 1) == sqrt(a + b) + O(x) + assert series(sqrt(a + b + c + c*x), x, 0, 1) == sqrt(a + b + c) + O(x) + + +def test_issue_14037(): + assert (sin(x**50)/x**51).series(x, n=0) == 1/x + O(1, x) + + +def test_issue_20551(): + expr = (exp(x)/x).series(x, n=None) + terms = [ next(expr) for i in range(3) ] + assert terms == [1/x, 1, x/2] + + +def test_issue_20697(): + p_0, p_1, p_2, p_3, b_0, b_1, b_2 = symbols('p_0 p_1 p_2 p_3 b_0 b_1 b_2') + Q = (p_0 + (p_1 + (p_2 + p_3/y)/y)/y)/(1 + ((p_3/(b_0*y) + (b_0*p_2\ + - b_1*p_3)/b_0**2)/y + (b_0**2*p_1 - b_0*b_1*p_2 - p_3*(b_0*b_2\ + - b_1**2))/b_0**3)/y) + assert Q.series(y, n=3).ratsimp() == b_2*y**2 + b_1*y + b_0 + O(y**3) + + +def test_issue_21245(): + fi = (1 + sqrt(5))/2 + assert (1/(1 - x - x**2)).series(x, 1/fi, 1).factor() == \ + (-37*sqrt(5) - 83 + 13*sqrt(5)*x + 29*x + O((x - 2/(1 + sqrt(5)))**2, (x\ + , 2/(1 + sqrt(5)))))/((2*sqrt(5) + 5)**2*(x + sqrt(5)*x - 2)) + + + +def test_issue_21938(): + expr = sin(1/x + exp(-x)) - sin(1/x) + assert expr.series(x, oo) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + +def test_issue_23432(): + expr = 1/sqrt(1 - x**2) + result = expr.series(x, 0.5) + assert result.is_Add and len(result.args) == 7 + + +def test_issue_23727(): + res = series(sqrt(1 - x**2), x, 0.1) + assert res.is_Add == True + + +def test_issue_24266(): + #type1: exp(f(x)) + assert (exp(-I*pi*(2*x+1))).series(x, 0, 3) == -1 + 2*I*pi*x + 2*pi**2*x**2 + O(x**3) + assert (exp(-I*pi*(2*x+1))*gamma(1+x)).series(x, 0, 3) == -1 + x*(EulerGamma + 2*I*pi) + \ + x**2*(-EulerGamma**2/2 + 23*pi**2/12 - 2*EulerGamma*I*pi) + O(x**3) + + #type2: c**f(x) + assert ((2*I)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(pi**2/2 - I*pi*log(2)) + \ + x*(pi**2*exp(pi**2/2 - I*pi*log(2)) - 2*I*pi*exp(pi**2/2 - I*pi*log(2))*log(2)) + O(x**2) + assert ((2)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(-I*pi*log(2)) - 2*I*pi*x*exp(-I*pi*log(2))*log(2) + O(x**2) + + #type3: f(y)**g(x) + assert ((y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(y)) + 2*I*pi*x*exp(I*pi*log(y))*log(y) + O(x**2) + assert ((I*y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(I*y)) + 2*I*pi*x*exp(I*pi*log(I*y))*log(I*y) + O(x**2) + + +def test_issue_26856(): + raises(ValueError, lambda: (2**x).series(x, oo, -1)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b909c0b5ef03b1e1e76dfbf4288f61860575da7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__init__.py @@ -0,0 +1,36 @@ +from .sets import (Set, Interval, Union, FiniteSet, ProductSet, + Intersection, imageset, Complement, SymmetricDifference, + DisjointUnion) + +from .fancysets import ImageSet, Range, ComplexRegion +from .contains import Contains +from .conditionset import ConditionSet +from .ordinals import Ordinal, OmegaPower, ord0 +from .powerset import PowerSet +from ..core.singleton import S +from .handlers.comparison import _eval_is_eq # noqa:F401 +Complexes = S.Complexes +EmptySet = S.EmptySet +Integers = S.Integers +Naturals = S.Naturals +Naturals0 = S.Naturals0 +Rationals = S.Rationals +Reals = S.Reals +UniversalSet = S.UniversalSet + +__all__ = [ + 'Set', 'Interval', 'Union', 'EmptySet', 'FiniteSet', 'ProductSet', + 'Intersection', 'imageset', 'Complement', 'SymmetricDifference', 'DisjointUnion', + + 'ImageSet', 'Range', 'ComplexRegion', 'Reals', + + 'Contains', + + 'ConditionSet', + + 'Ordinal', 'OmegaPower', 'ord0', + + 'PowerSet', + + 'Reals', 'Naturals', 'Naturals0', 'UniversalSet', 'Integers', 'Rationals', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad49e3df72b1020d8bf00e590845a5235a56c54d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/conditionset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/conditionset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7bc534c0de86df3948ec9151e213bd6f5e5b79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/conditionset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/contains.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/contains.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88fc24a03a792b66db7bac0f7d05f47f529ad93f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/contains.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/fancysets.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/fancysets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f211ad23ee363d869c218fb5adc3adea61ab40 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/fancysets.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/ordinals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/ordinals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d055dbc99b7ff5d2e7106001c4fea459464998f4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/ordinals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/powerset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/powerset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc85f4075459fc6b1a68c19ad2cd1c0d449f5f6a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/powerset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/setexpr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/setexpr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06bd4cdabf0376ee42cc7e1f40bfc714c66fa0eb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/__pycache__/setexpr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/conditionset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..e847e60ce97d7e9922ce907042ace941838b0ab1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/conditionset.py @@ -0,0 +1,246 @@ +from sympy.core.singleton import S +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Lambda, BadSignatureError +from sympy.core.logic import fuzzy_bool +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify +from sympy.logic.boolalg import And, as_Boolean +from sympy.utilities.iterables import sift, flatten, has_dups +from sympy.utilities.exceptions import sympy_deprecation_warning +from .contains import Contains +from .sets import Set, Union, FiniteSet, SetKind + + +adummy = Dummy('conditionset') + + +class ConditionSet(Set): + r""" + Set of elements which satisfies a given condition. + + .. math:: \{x \mid \textrm{condition}(x) = \texttt{True}, x \in S\} + + Examples + ======== + + >>> from sympy import Symbol, S, ConditionSet, pi, Eq, sin, Interval + >>> from sympy.abc import x, y, z + + >>> sin_sols = ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)) + >>> 2*pi in sin_sols + True + >>> pi/2 in sin_sols + False + >>> 3*pi in sin_sols + False + >>> 5 in ConditionSet(x, x**2 > 4, S.Reals) + True + + If the value is not in the base set, the result is false: + + >>> 5 in ConditionSet(x, x**2 > 4, Interval(2, 4)) + False + + Notes + ===== + + Symbols with assumptions should be avoided or else the + condition may evaluate without consideration of the set: + + >>> n = Symbol('n', negative=True) + >>> cond = (n > 0); cond + False + >>> ConditionSet(n, cond, S.Integers) + EmptySet + + Only free symbols can be changed by using `subs`: + + >>> c = ConditionSet(x, x < 1, {x, z}) + >>> c.subs(x, y) + ConditionSet(x, x < 1, {y, z}) + + To check if ``pi`` is in ``c`` use: + + >>> pi in c + False + + If no base set is specified, the universal set is implied: + + >>> ConditionSet(x, x < 1).base_set + UniversalSet + + Only symbols or symbol-like expressions can be used: + + >>> ConditionSet(x + 1, x + 1 < 1, S.Integers) + Traceback (most recent call last): + ... + ValueError: non-symbol dummy not recognized in condition + + When the base set is a ConditionSet, the symbols will be + unified if possible with preference for the outermost symbols: + + >>> ConditionSet(x, x < y, ConditionSet(z, z + y < 2, S.Integers)) + ConditionSet(x, (x < y) & (x + y < 2), Integers) + + """ + def __new__(cls, sym, condition, base_set=S.UniversalSet): + sym = _sympify(sym) + flat = flatten([sym]) + if has_dups(flat): + raise BadSignatureError("Duplicate symbols detected") + base_set = _sympify(base_set) + if not isinstance(base_set, Set): + raise TypeError( + 'base set should be a Set object, not %s' % base_set) + condition = _sympify(condition) + + if isinstance(condition, FiniteSet): + condition_orig = condition + temp = (Eq(lhs, 0) for lhs in condition) + condition = And(*temp) + sympy_deprecation_warning( + f""" +Using a set for the condition in ConditionSet is deprecated. Use a boolean +instead. + +In this case, replace + + {condition_orig} + +with + + {condition} +""", + deprecated_since_version='1.5', + active_deprecations_target="deprecated-conditionset-set", + ) + + condition = as_Boolean(condition) + + if condition is S.true: + return base_set + + if condition is S.false: + return S.EmptySet + + if base_set is S.EmptySet: + return S.EmptySet + + # no simple answers, so now check syms + for i in flat: + if not getattr(i, '_diff_wrt', False): + raise ValueError('`%s` is not symbol-like' % i) + + if base_set.contains(sym) is S.false: + raise TypeError('sym `%s` is not in base_set `%s`' % (sym, base_set)) + + know = None + if isinstance(base_set, FiniteSet): + sifted = sift( + base_set, lambda _: fuzzy_bool(condition.subs(sym, _))) + if sifted[None]: + know = FiniteSet(*sifted[True]) + base_set = FiniteSet(*sifted[None]) + else: + return FiniteSet(*sifted[True]) + + if isinstance(base_set, cls): + s, c, b = base_set.args + def sig(s): + return cls(s, Eq(adummy, 0)).as_dummy().sym + sa, sb = map(sig, (sym, s)) + if sa != sb: + raise BadSignatureError('sym does not match sym of base set') + reps = dict(zip(flatten([sym]), flatten([s]))) + if s == sym: + condition = And(condition, c) + base_set = b + elif not c.free_symbols & sym.free_symbols: + reps = {v: k for k, v in reps.items()} + condition = And(condition, c.xreplace(reps)) + base_set = b + elif not condition.free_symbols & s.free_symbols: + sym = sym.xreplace(reps) + condition = And(condition.xreplace(reps), c) + base_set = b + + # flatten ConditionSet(Contains(ConditionSet())) expressions + if isinstance(condition, Contains) and (sym == condition.args[0]): + if isinstance(condition.args[1], Set): + return condition.args[1].intersect(base_set) + + rv = Basic.__new__(cls, sym, condition, base_set) + return rv if know is None else Union(know, rv) + + sym = property(lambda self: self.args[0]) + condition = property(lambda self: self.args[1]) + base_set = property(lambda self: self.args[2]) + + @property + def free_symbols(self): + cond_syms = self.condition.free_symbols - self.sym.free_symbols + return cond_syms | self.base_set.free_symbols + + @property + def bound_symbols(self): + return flatten([self.sym]) + + def _contains(self, other): + def ok_sig(a, b): + tuples = [isinstance(i, Tuple) for i in (a, b)] + c = tuples.count(True) + if c == 1: + return False + if c == 0: + return True + return len(a) == len(b) and all( + ok_sig(i, j) for i, j in zip(a, b)) + if not ok_sig(self.sym, other): + return S.false + + # try doing base_cond first and return + # False immediately if it is False + base_cond = Contains(other, self.base_set) + if base_cond is S.false: + return S.false + + # Substitute other into condition. This could raise e.g. for + # ConditionSet(x, 1/x >= 0, Reals).contains(0) + lamda = Lambda((self.sym,), self.condition) + try: + lambda_cond = lamda(other) + except TypeError: + return None + else: + return And(base_cond, lambda_cond) + + def as_relational(self, other): + f = Lambda(self.sym, self.condition) + if isinstance(self.sym, Tuple): + f = f(*other) + else: + f = f(other) + return And(f, self.base_set.contains(other)) + + def _eval_subs(self, old, new): + sym, cond, base = self.args + dsym = sym.subs(old, adummy) + insym = dsym.has(adummy) + # prioritize changing a symbol in the base + newbase = base.subs(old, new) + if newbase != base: + if not insym: + cond = cond.subs(old, new) + return self.func(sym, cond, newbase) + if insym: + pass # no change of bound symbols via subs + elif getattr(new, '_diff_wrt', False): + cond = cond.subs(old, new) + else: + pass # let error about the symbol raise from __new__ + return self.func(sym, cond, base) + + def _kind(self): + return SetKind(self.sym.kind) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/contains.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/contains.py new file mode 100644 index 0000000000000000000000000000000000000000..403d4875279d718724a898efa5cba41bc7bed6ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/contains.py @@ -0,0 +1,63 @@ +from sympy.core import S +from sympy.core.sympify import sympify +from sympy.core.relational import Eq, Ne +from sympy.core.parameters import global_parameters +from sympy.logic.boolalg import Boolean +from sympy.utilities.misc import func_name +from .sets import Set + + +class Contains(Boolean): + """ + Asserts that x is an element of the set S. + + Examples + ======== + + >>> from sympy import Symbol, Integer, S, Contains + >>> Contains(Integer(2), S.Integers) + True + >>> Contains(Integer(-2), S.Naturals) + False + >>> i = Symbol('i', integer=True) + >>> Contains(i, S.Naturals) + Contains(i, Naturals) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Element_%28mathematics%29 + """ + def __new__(cls, x, s, evaluate=None): + x = sympify(x) + s = sympify(s) + + if evaluate is None: + evaluate = global_parameters.evaluate + + if not isinstance(s, Set): + raise TypeError('expecting Set, not %s' % func_name(s)) + + if evaluate: + # _contains can return symbolic booleans that would be returned by + # s.contains(x) but here for Contains(x, s) we only evaluate to + # true, false or return the unevaluated Contains. + result = s._contains(x) + + if isinstance(result, Boolean): + if result in (S.true, S.false): + return result + elif result is not None: + raise TypeError("_contains() should return Boolean or None") + + return super().__new__(cls, x, s) + + @property + def binary_symbols(self): + return set().union(*[i.binary_symbols + for i in self.args[1].args + if i.is_Boolean or i.is_Symbol or + isinstance(i, (Eq, Ne))]) + + def as_set(self): + return self.args[1] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/fancysets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e24a2a864222d16ba1a697558b5211127fb2ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/fancysets.py @@ -0,0 +1,1523 @@ +from functools import reduce +from itertools import product + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import fuzzy_not, fuzzy_or, fuzzy_and +from sympy.core.mod import Mod +from sympy.core.intfunc import igcd +from sympy.core.numbers import oo, Rational, Integer +from sympy.core.relational import Eq, is_eq +from sympy.core.kind import NumberKind +from sympy.core.singleton import Singleton, S +from sympy.core.symbol import Dummy, symbols, Symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.logic.boolalg import And, Or +from .sets import tfn, Set, Interval, Union, FiniteSet, ProductSet, SetKind +from sympy.utilities.misc import filldedent + + +class Rationals(Set, metaclass=Singleton): + """ + Represents the rational numbers. This set is also available as + the singleton ``S.Rationals``. + + Examples + ======== + + >>> from sympy import S + >>> S.Half in S.Rationals + True + >>> iterable = iter(S.Rationals) + >>> [next(iterable) for i in range(12)] + [0, 1, -1, 1/2, 2, -1/2, -2, 1/3, 3, -1/3, -3, 2/3] + """ + + is_iterable = True + _inf = S.NegativeInfinity + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_rational] + + def __iter__(self): + yield S.Zero + yield S.One + yield S.NegativeOne + d = 2 + while True: + for n in range(d): + if igcd(n, d) == 1: + yield Rational(n, d) + yield Rational(d, n) + yield Rational(-n, d) + yield Rational(-d, n) + d += 1 + + @property + def _boundary(self): + return S.Reals + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals(Set, metaclass=Singleton): + """ + Represents the natural numbers (or counting numbers) which are all + positive integers starting from 1. This set is also available as + the singleton ``S.Naturals``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Naturals) + >>> next(iterable) + 1 + >>> next(iterable) + 2 + >>> next(iterable) + 3 + >>> pprint(S.Naturals.intersect(Interval(0, 10))) + {1, 2, ..., 10} + + See Also + ======== + + Naturals0 : non-negative integers (i.e. includes 0, too) + Integers : also includes negative integers + """ + + is_iterable = True + _inf: Integer = S.One + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_positive and other.is_integer: + return S.true + elif other.is_integer is False or other.is_positive is False: + return S.false + + def _eval_is_subset(self, other): + return Range(1, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(1, oo).is_superset(other) + + def __iter__(self): + i = self._inf + while True: + yield i + i = i + 1 + + @property + def _boundary(self): + return self + + def as_relational(self, x): + return And(Eq(floor(x), x), x >= self.inf, x < oo) + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals0(Naturals): + """Represents the whole numbers which are all the non-negative integers, + inclusive of zero. + + See Also + ======== + + Naturals : positive integers; does not include 0 + Integers : also includes the negative integers + """ + _inf = S.Zero + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_integer and other.is_nonnegative: + return S.true + elif other.is_integer is False or other.is_nonnegative is False: + return S.false + + def _eval_is_subset(self, other): + return Range(oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(oo).is_superset(other) + + +class Integers(Set, metaclass=Singleton): + """ + Represents all integers: positive, negative and zero. This set is also + available as the singleton ``S.Integers``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Integers) + >>> next(iterable) + 0 + >>> next(iterable) + 1 + >>> next(iterable) + -1 + >>> next(iterable) + 2 + + >>> pprint(S.Integers.intersect(Interval(-4, 4))) + {-4, -3, ..., 4} + + See Also + ======== + + Naturals0 : non-negative integers + Integers : positive and negative integers and zero + """ + + is_iterable = True + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_integer] + + def __iter__(self): + yield S.Zero + i = S.One + while True: + yield i + yield -i + i = i + 1 + + @property + def _inf(self): + return S.NegativeInfinity + + @property + def _sup(self): + return S.Infinity + + @property + def _boundary(self): + return self + + def _kind(self): + return SetKind(NumberKind) + + def as_relational(self, x): + return And(Eq(floor(x), x), -oo < x, x < oo) + + def _eval_is_subset(self, other): + return Range(-oo, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(-oo, oo).is_superset(other) + + +class Reals(Interval, metaclass=Singleton): + """ + Represents all real numbers + from negative infinity to positive infinity, + including all integer, rational and irrational numbers. + This set is also available as the singleton ``S.Reals``. + + + Examples + ======== + + >>> from sympy import S, Rational, pi, I + >>> 5 in S.Reals + True + >>> Rational(-1, 2) in S.Reals + True + >>> pi in S.Reals + True + >>> 3*I in S.Reals + False + >>> S.Reals.contains(pi) + True + + + See Also + ======== + + ComplexRegion + """ + @property + def start(self): + return S.NegativeInfinity + + @property + def end(self): + return S.Infinity + + @property + def left_open(self): + return True + + @property + def right_open(self): + return True + + def __eq__(self, other): + return other == Interval(S.NegativeInfinity, S.Infinity) + + def __hash__(self): + return hash(Interval(S.NegativeInfinity, S.Infinity)) + + +class ImageSet(Set): + """ + Image of a set under a mathematical function. The transformation + must be given as a Lambda function which has as many arguments + as the elements of the set upon which it operates, e.g. 1 argument + when acting on the set of integers or 2 arguments when acting on + a complex region. + + This function is not normally called directly, but is called + from ``imageset``. + + + Examples + ======== + + >>> from sympy import Symbol, S, pi, Dummy, Lambda + >>> from sympy import FiniteSet, ImageSet, Interval + + >>> x = Symbol('x') + >>> N = S.Naturals + >>> squares = ImageSet(Lambda(x, x**2), N) # {x**2 for x in N} + >>> 4 in squares + True + >>> 5 in squares + False + + >>> FiniteSet(0, 1, 2, 3, 4, 5, 6, 7, 9, 10).intersect(squares) + {1, 4, 9} + + >>> square_iterable = iter(squares) + >>> for i in range(4): + ... next(square_iterable) + 1 + 4 + 9 + 16 + + If you want to get value for `x` = 2, 1/2 etc. (Please check whether the + `x` value is in ``base_set`` or not before passing it as args) + + >>> squares.lamda(2) + 4 + >>> squares.lamda(S(1)/2) + 1/4 + + >>> n = Dummy('n') + >>> solutions = ImageSet(Lambda(n, n*pi), S.Integers) # solutions of sin(x) = 0 + >>> dom = Interval(-1, 1) + >>> dom.intersect(solutions) + {0} + + See Also + ======== + + sympy.sets.sets.imageset + """ + def __new__(cls, flambda, *sets): + if not isinstance(flambda, Lambda): + raise ValueError('First argument must be a Lambda') + + signature = flambda.signature + + if len(signature) != len(sets): + raise ValueError('Incompatible signature') + + sets = [_sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Set arguments to ImageSet should of type Set") + + if not all(cls._check_sig(sg, st) for sg, st in zip(signature, sets)): + raise ValueError("Signature %s does not match sets %s" % (signature, sets)) + + if flambda is S.IdentityFunction and len(sets) == 1: + return sets[0] + + if not set(flambda.variables) & flambda.expr.free_symbols: + is_empty = fuzzy_or(s.is_empty for s in sets) + if is_empty == True: + return S.EmptySet + elif is_empty == False: + return FiniteSet(flambda.expr) + + return Basic.__new__(cls, flambda, *sets) + + lamda = property(lambda self: self.args[0]) + base_sets = property(lambda self: self.args[1:]) + + @property + def base_set(self): + # XXX: Maybe deprecate this? It is poorly defined in handling + # the multivariate case... + sets = self.base_sets + if len(sets) == 1: + return sets[0] + else: + return ProductSet(*sets).flatten() + + @property + def base_pset(self): + return ProductSet(*self.base_sets) + + @classmethod + def _check_sig(cls, sig_i, set_i): + if sig_i.is_symbol: + return True + elif isinstance(set_i, ProductSet): + sets = set_i.sets + if len(sig_i) != len(sets): + return False + # Recurse through the signature for nested tuples: + return all(cls._check_sig(ts, ps) for ts, ps in zip(sig_i, sets)) + else: + # XXX: Need a better way of checking whether a set is a set of + # Tuples or not. For example a FiniteSet can contain Tuples + # but so can an ImageSet or a ConditionSet. Others like + # Integers, Reals etc can not contain Tuples. We could just + # list the possibilities here... Current code for e.g. + # _contains probably only works for ProductSet. + return True # Give the benefit of the doubt + + def __iter__(self): + already_seen = set() + for i in self.base_pset: + val = self.lamda(*i) + if val in already_seen: + continue + else: + already_seen.add(val) + yield val + + def _is_multivariate(self): + return len(self.lamda.variables) > 1 + + def _contains(self, other): + from sympy.solvers.solveset import _solveset_multi + + def get_symsetmap(signature, base_sets): + '''Attempt to get a map of symbols to base_sets''' + queue = list(zip(signature, base_sets)) + symsetmap = {} + for sig, base_set in queue: + if sig.is_symbol: + symsetmap[sig] = base_set + elif base_set.is_ProductSet: + sets = base_set.sets + if len(sig) != len(sets): + raise ValueError("Incompatible signature") + # Recurse + queue.extend(zip(sig, sets)) + else: + # If we get here then we have something like sig = (x, y) and + # base_set = {(1, 2), (3, 4)}. For now we give up. + return None + + return symsetmap + + def get_equations(expr, candidate): + '''Find the equations relating symbols in expr and candidate.''' + queue = [(expr, candidate)] + for e, c in queue: + if not isinstance(e, Tuple): + yield Eq(e, c) + elif not isinstance(c, Tuple) or len(e) != len(c): + yield False + return + else: + queue.extend(zip(e, c)) + + # Get the basic objects together: + other = _sympify(other) + expr = self.lamda.expr + sig = self.lamda.signature + variables = self.lamda.variables + base_sets = self.base_sets + + # Use dummy symbols for ImageSet parameters so they don't match + # anything in other + rep = {v: Dummy(v.name) for v in variables} + variables = [v.subs(rep) for v in variables] + sig = sig.subs(rep) + expr = expr.subs(rep) + + # Map the parts of other to those in the Lambda expr + equations = [] + for eq in get_equations(expr, other): + # Unsatisfiable equation? + if eq is False: + return S.false + equations.append(eq) + + # Map the symbols in the signature to the corresponding domains + symsetmap = get_symsetmap(sig, base_sets) + if symsetmap is None: + # Can't factor the base sets to a ProductSet + return None + + # Which of the variables in the Lambda signature need to be solved for? + symss = (eq.free_symbols for eq in equations) + variables = set(variables) & reduce(set.union, symss, set()) + + # Use internal multivariate solveset + variables = tuple(variables) + base_sets = [symsetmap[v] for v in variables] + solnset = _solveset_multi(equations, variables, base_sets) + if solnset is None: + return None + return tfn[fuzzy_not(solnset.is_empty)] + + @property + def is_iterable(self): + return all(s.is_iterable for s in self.base_sets) + + def doit(self, **hints): + from sympy.sets.setexpr import SetExpr + f = self.lamda + sig = f.signature + if len(sig) == 1 and sig[0].is_symbol and isinstance(f.expr, Expr): + base_set = self.base_sets[0] + return SetExpr(base_set)._eval_func(f).set + if all(s.is_FiniteSet for s in self.base_sets): + return FiniteSet(*(f(*a) for a in product(*self.base_sets))) + return self + + def _kind(self): + return SetKind(self.lamda.expr.kind) + + +class Range(Set): + """ + Represents a range of integers. Can be called as ``Range(stop)``, + ``Range(start, stop)``, or ``Range(start, stop, step)``; when ``step`` is + not given it defaults to 1. + + ``Range(stop)`` is the same as ``Range(0, stop, 1)`` and the stop value + (just as for Python ranges) is not included in the Range values. + + >>> from sympy import Range + >>> list(Range(3)) + [0, 1, 2] + + The step can also be negative: + + >>> list(Range(10, 0, -2)) + [10, 8, 6, 4, 2] + + The stop value is made canonical so equivalent ranges always + have the same args: + + >>> Range(0, 10, 3) + Range(0, 12, 3) + + Infinite ranges are allowed. ``oo`` and ``-oo`` are never included in the + set (``Range`` is always a subset of ``Integers``). If the starting point + is infinite, then the final value is ``stop - step``. To iterate such a + range, it needs to be reversed: + + >>> from sympy import oo + >>> r = Range(-oo, 1) + >>> r[-1] + 0 + >>> next(iter(r)) + Traceback (most recent call last): + ... + TypeError: Cannot iterate over Range with infinite start + >>> next(iter(r.reversed)) + 0 + + Although ``Range`` is a :class:`Set` (and supports the normal set + operations) it maintains the order of the elements and can + be used in contexts where ``range`` would be used. + + >>> from sympy import Interval + >>> Range(0, 10, 2).intersect(Interval(3, 7)) + Range(4, 8, 2) + >>> list(_) + [4, 6] + + Although slicing of a Range will always return a Range -- possibly + empty -- an empty set will be returned from any intersection that + is empty: + + >>> Range(3)[:0] + Range(0, 0, 1) + >>> Range(3).intersect(Interval(4, oo)) + EmptySet + >>> Range(3).intersect(Range(4, oo)) + EmptySet + + Range will accept symbolic arguments but has very limited support + for doing anything other than displaying the Range: + + >>> from sympy import Symbol, pprint + >>> from sympy.abc import i, j, k + >>> Range(i, j, k).start + i + >>> Range(i, j, k).inf + Traceback (most recent call last): + ... + ValueError: invalid method for symbolic range + + Better success will be had when using integer symbols: + + >>> n = Symbol('n', integer=True) + >>> r = Range(n, n + 20, 3) + >>> r.inf + n + >>> pprint(r) + {n, n + 3, ..., n + 18} + """ + + def __new__(cls, *args): + if len(args) == 1: + if isinstance(args[0], range): + raise TypeError( + 'use sympify(%s) to convert range to Range' % args[0]) + + # expand range + slc = slice(*args) + + if slc.step == 0: + raise ValueError("step cannot be 0") + + start, stop, step = slc.start or 0, slc.stop, slc.step or 1 + try: + ok = [] + for w in (start, stop, step): + w = sympify(w) + if w in [S.NegativeInfinity, S.Infinity] or ( + w.has(Symbol) and w.is_integer != False): + ok.append(w) + elif not w.is_Integer: + if w.is_infinite: + raise ValueError('infinite symbols not allowed') + raise ValueError + else: + ok.append(w) + except ValueError: + raise ValueError(filldedent(''' + Finite arguments to Range must be integers; `imageset` can define + other cases, e.g. use `imageset(i, i/10, Range(3))` to give + [0, 1/10, 1/5].''')) + start, stop, step = ok + + null = False + if any(i.has(Symbol) for i in (start, stop, step)): + dif = stop - start + n = dif/step + if n.is_Rational: + if dif == 0: + null = True + else: # (x, x + 5, 2) or (x, 3*x, x) + n = floor(n) + end = start + n*step + if dif.is_Rational: # (x, x + 5, 2) + if (end - stop).is_negative: + end += step + else: # (x, 3*x, x) + if (end/stop - 1).is_negative: + end += step + elif n.is_extended_negative: + null = True + else: + end = stop # other methods like sup and reversed must fail + elif start.is_infinite: + span = step*(stop - start) + if span is S.NaN or span <= 0: + null = True + elif step.is_Integer and stop.is_infinite and abs(step) != 1: + raise ValueError(filldedent(''' + Step size must be %s in this case.''' % (1 if step > 0 else -1))) + else: + end = stop + else: + oostep = step.is_infinite + if oostep: + step = S.One if step > 0 else S.NegativeOne + n = ceiling((stop - start)/step) + if n <= 0: + null = True + elif oostep: + step = S.One # make it canonical + end = start + step + else: + end = start + n*step + if null: + start = end = S.Zero + step = S.One + return Basic.__new__(cls, start, end, step) + + start = property(lambda self: self.args[0]) + stop = property(lambda self: self.args[1]) + step = property(lambda self: self.args[2]) + + @property + def reversed(self): + """Return an equivalent Range in the opposite order. + + Examples + ======== + + >>> from sympy import Range + >>> Range(10).reversed + Range(9, -1, -1) + """ + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + raise ValueError('invalid method for symbolic range') + if self.start == self.stop: + return self + return self.func( + self.stop - self.step, self.start - self.step, -self.step) + + def _kind(self): + return SetKind(NumberKind) + + def _contains(self, other): + if self.start == self.stop: + return S.false + if other.is_infinite: + return S.false + if not other.is_integer: + return tfn[other.is_integer] + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + return + else: + n = self.size + if self.start.is_finite: + ref = self.start + elif self.stop.is_finite: + ref = self.stop + else: # both infinite; step is +/- 1 (enforced by __new__) + return S.true + if n == 1: + return Eq(other, self[0]) + res = (ref - other) % self.step + if res == S.Zero: + if self.has(Symbol): + d = Dummy('i') + return self.as_relational(d).subs(d, other) + return And(other >= self.inf, other <= self.sup) + elif res.is_Integer: # off sequence + return S.false + else: # symbolic/unsimplified residue modulo step + return None + + def __iter__(self): + n = self.size # validate + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + raise TypeError("Cannot iterate over symbolic Range") + if self.start in [S.NegativeInfinity, S.Infinity]: + raise TypeError("Cannot iterate over Range with infinite start") + elif self.start != self.stop: + i = self.start + if n.is_infinite: + while True: + yield i + i += self.step + else: + for _ in range(n): + yield i + i += self.step + + @property + def is_iterable(self): + # Check that size can be determined, used by __iter__ + dif = self.stop - self.start + n = dif/self.step + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + return False + if self.start in [S.NegativeInfinity, S.Infinity]: + return False + if not (n.is_extended_nonnegative and all(i.is_integer for i in self.args)): + return False + return True + + def __len__(self): + rv = self.size + if rv is S.Infinity: + raise ValueError('Use .size to get the length of an infinite Range') + return int(rv) + + @property + def size(self): + if self.start == self.stop: + return S.Zero + dif = self.stop - self.start + n = dif/self.step + if n.is_infinite: + return S.Infinity + if n.is_extended_nonnegative and all(i.is_integer for i in self.args): + return abs(floor(n)) + raise ValueError('Invalid method for symbolic Range') + + @property + def is_finite_set(self): + if self.start.is_integer and self.stop.is_integer: + return True + return self.size.is_finite + + @property + def is_empty(self): + try: + return self.size.is_zero + except ValueError: + return None + + def __bool__(self): + # this only distinguishes between definite null range + # and non-null/unknown null; getting True doesn't mean + # that it actually is not null + b = is_eq(self.start, self.stop) + if b is None: + raise ValueError('cannot tell if Range is null or not') + return not bool(b) + + def __getitem__(self, i): + ooslice = "cannot slice from the end with an infinite value" + zerostep = "slice step cannot be zero" + infinite = "slicing not possible on range with infinite start" + # if we had to take every other element in the following + # oo, ..., 6, 4, 2, 0 + # we might get oo, ..., 4, 0 or oo, ..., 6, 2 + ambiguous = "cannot unambiguously re-stride from the end " + \ + "with an infinite value" + if isinstance(i, slice): + if self.size.is_finite: # validates, too + if self.start == self.stop: + return Range(0) + start, stop, step = i.indices(self.size) + n = ceiling((stop - start)/step) + if n <= 0: + return Range(0) + canonical_stop = start + n*step + end = canonical_stop - step + ss = step*self.step + return Range(self[start], self[end] + ss, ss) + else: # infinite Range + start = i.start + stop = i.stop + if i.step == 0: + raise ValueError(zerostep) + step = i.step or 1 + ss = step*self.step + #--------------------- + # handle infinite Range + # i.e. Range(-oo, oo) or Range(oo, -oo, -1) + # -------------------- + if self.start.is_infinite and self.stop.is_infinite: + raise ValueError(infinite) + #--------------------- + # handle infinite on right + # e.g. Range(0, oo) or Range(0, -oo, -1) + # -------------------- + if self.stop.is_infinite: + # start and stop are not interdependent -- + # they only depend on step --so we use the + # equivalent reversed values + return self.reversed[ + stop if stop is None else -stop + 1: + start if start is None else -start: + step].reversed + #--------------------- + # handle infinite on the left + # e.g. Range(oo, 0, -1) or Range(-oo, 0) + # -------------------- + # consider combinations of + # start/stop {== None, < 0, == 0, > 0} and + # step {< 0, > 0} + if start is None: + if stop is None: + if step < 0: + return Range(self[-1], self.start, ss) + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step < 0: + return Range(self[-1], self[stop], ss) + else: # > 0 + return Range(self.start, self[stop], ss) + elif stop == 0: + if step > 0: + return Range(0) + else: # < 0 + raise ValueError(ooslice) + elif stop == 1: + if step > 0: + raise ValueError(ooslice) # infinite singleton + else: # < 0 + raise ValueError(ooslice) + else: # > 1 + raise ValueError(ooslice) + elif start < 0: + if stop is None: + if step < 0: + return Range(self[start], self.start, ss) + else: # > 0 + return Range(self[start], self.stop, ss) + elif stop < 0: + return Range(self[start], self[stop], ss) + elif stop == 0: + if step < 0: + raise ValueError(ooslice) + else: # > 0 + return Range(0) + elif stop > 0: + raise ValueError(ooslice) + elif start == 0: + if stop is None: + if step < 0: + raise ValueError(ooslice) # infinite singleton + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step > 1: + raise ValueError(ambiguous) + elif step == 1: + return Range(self.start, self[stop], ss) + else: # < 0 + return Range(0) + else: # >= 0 + raise ValueError(ooslice) + elif start > 0: + raise ValueError(ooslice) + else: + if self.start == self.stop: + raise IndexError('Range index out of range') + if not (all(i.is_integer or i.is_infinite + for i in self.args) and ((self.stop - self.start)/ + self.step).is_extended_positive): + raise ValueError('Invalid method for symbolic Range') + if i == 0: + if self.start.is_infinite: + raise ValueError(ooslice) + return self.start + if i == -1: + if self.stop.is_infinite: + raise ValueError(ooslice) + return self.stop - self.step + n = self.size # must be known for any other index + rv = (self.stop if i < 0 else self.start) + i*self.step + if rv.is_infinite: + raise ValueError(ooslice) + val = (rv - self.start)/self.step + rel = fuzzy_or([val.is_infinite, + fuzzy_and([val.is_nonnegative, (n-val).is_nonnegative])]) + if rel: + return rv + if rel is None: + raise ValueError('Invalid method for symbolic Range') + raise IndexError("Range index out of range") + + @property + def _inf(self): + if not self: + return S.EmptySet.inf + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.start + elif self.step.is_negative and dif.is_negative: + return self.stop - self.step + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.start + else: + return self.stop - self.step + + @property + def _sup(self): + if not self: + return S.EmptySet.sup + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.stop - self.step + elif self.step.is_negative and dif.is_negative: + return self.start + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.stop - self.step + else: + return self.start + + @property + def _boundary(self): + return self + + def as_relational(self, x): + """Rewrite a Range in terms of equalities and logic operators. """ + if self.start.is_infinite: + assert not self.stop.is_infinite # by instantiation + a = self.reversed.start + else: + a = self.start + step = self.step + in_seq = Eq(Mod(x - a, step), 0) + ints = And(Eq(Mod(a, 1), 0), Eq(Mod(step, 1), 0)) + n = (self.stop - self.start)/self.step + if n == 0: + return S.EmptySet.as_relational(x) + if n == 1: + return And(Eq(x, a), ints) + try: + a, b = self.inf, self.sup + except ValueError: + a = None + if a is not None: + range_cond = And( + x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b) + else: + a, b = self.start, self.stop - self.step + range_cond = Or( + And(self.step >= 1, x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b), + And(self.step <= -1, x < a if a.is_infinite else x <= a, + x > b if b.is_infinite else x >= b)) + return And(in_seq, ints, range_cond) + + +_sympy_converter[range] = lambda r: Range(r.start, r.stop, r.step) + +def normalize_theta_set(theta): + r""" + Normalize a Real Set `theta` in the interval `[0, 2\pi)`. It returns + a normalized value of theta in the Set. For Interval, a maximum of + one cycle $[0, 2\pi]$, is returned i.e. for theta equal to $[0, 10\pi]$, + returned normalized value would be $[0, 2\pi)$. As of now intervals + with end points as non-multiples of ``pi`` is not supported. + + Raises + ====== + + NotImplementedError + The algorithms for Normalizing theta Set are not yet + implemented. + ValueError + The input is not valid, i.e. the input is not a real set. + RuntimeError + It is a bug, please report to the github issue tracker. + + Examples + ======== + + >>> from sympy.sets.fancysets import normalize_theta_set + >>> from sympy import Interval, FiniteSet, pi + >>> normalize_theta_set(Interval(9*pi/2, 5*pi)) + Interval(pi/2, pi) + >>> normalize_theta_set(Interval(-3*pi/2, pi/2)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-pi/2, pi/2)) + Union(Interval(0, pi/2), Interval.Ropen(3*pi/2, 2*pi)) + >>> normalize_theta_set(Interval(-4*pi, 3*pi)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-3*pi/2, -pi/2)) + Interval(pi/2, 3*pi/2) + >>> normalize_theta_set(FiniteSet(0, pi, 3*pi)) + {0, pi} + + """ + from sympy.functions.elementary.trigonometric import _pi_coeff + + if theta.is_Interval: + interval_len = theta.measure + # one complete circle + if interval_len >= 2*S.Pi: + if interval_len == 2*S.Pi and theta.left_open and theta.right_open: + k = _pi_coeff(theta.start) + return Union(Interval(0, k*S.Pi, False, True), + Interval(k*S.Pi, 2*S.Pi, True, True)) + return Interval(0, 2*S.Pi, False, True) + + k_start, k_end = _pi_coeff(theta.start), _pi_coeff(theta.end) + + if k_start is None or k_end is None: + raise NotImplementedError("Normalizing theta without pi as coefficient is " + "not yet implemented") + new_start = k_start*S.Pi + new_end = k_end*S.Pi + + if new_start > new_end: + return Union(Interval(S.Zero, new_end, False, theta.right_open), + Interval(new_start, 2*S.Pi, theta.left_open, True)) + else: + return Interval(new_start, new_end, theta.left_open, theta.right_open) + + elif theta.is_FiniteSet: + new_theta = [] + for element in theta: + k = _pi_coeff(element) + if k is None: + raise NotImplementedError('Normalizing theta without pi as ' + 'coefficient, is not Implemented.') + else: + new_theta.append(k*S.Pi) + return FiniteSet(*new_theta) + + elif theta.is_Union: + return Union(*[normalize_theta_set(interval) for interval in theta.args]) + + elif theta.is_subset(S.Reals): + raise NotImplementedError("Normalizing theta when, it is of type %s is not " + "implemented" % type(theta)) + else: + raise ValueError(" %s is not a real set" % (theta)) + + +class ComplexRegion(Set): + r""" + Represents the Set of all Complex Numbers. It can represent a + region of Complex Plane in both the standard forms Polar and + Rectangular coordinates. + + * Polar Form + Input is in the form of the ProductSet or Union of ProductSets + of the intervals of ``r`` and ``theta``, and use the flag ``polar=True``. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + * Rectangular Form + Input is in the form of the ProductSet or Union of ProductSets + of interval of x and y, the real and imaginary parts of the Complex numbers in a plane. + Default input type is in rectangular form. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, S, I, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 6) + >>> c1 = ComplexRegion(a*b) # Rectangular Form + >>> c1 + CartesianComplexRegion(ProductSet(Interval(2, 3), Interval(4, 6))) + + * c1 represents the rectangular region in complex plane + surrounded by the coordinates (2, 4), (3, 4), (3, 6) and + (2, 6), of the four vertices. + + >>> c = Interval(1, 8) + >>> c2 = ComplexRegion(Union(a*b, b*c)) + >>> c2 + CartesianComplexRegion(Union(ProductSet(Interval(2, 3), Interval(4, 6)), ProductSet(Interval(4, 6), Interval(1, 8)))) + + * c2 represents the Union of two rectangular regions in complex + plane. One of them surrounded by the coordinates of c1 and + other surrounded by the coordinates (4, 1), (6, 1), (6, 8) and + (4, 8). + + >>> 2.5 + 4.5*I in c1 + True + >>> 2.5 + 6.5*I in c1 + False + + >>> r = Interval(0, 1) + >>> theta = Interval(0, 2*S.Pi) + >>> c2 = ComplexRegion(r*theta, polar=True) # Polar Form + >>> c2 # unit Disk + PolarComplexRegion(ProductSet(Interval(0, 1), Interval.Ropen(0, 2*pi))) + + * c2 represents the region in complex plane inside the + Unit Disk centered at the origin. + + >>> 0.5 + 0.5*I in c2 + True + >>> 1 + 2*I in c2 + False + + >>> unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + >>> upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + >>> intersection = unit_disk.intersect(upper_half_unit_disk) + >>> intersection + PolarComplexRegion(ProductSet(Interval(0, 1), Interval(0, pi))) + >>> intersection == upper_half_unit_disk + True + + See Also + ======== + + CartesianComplexRegion + PolarComplexRegion + Complexes + + """ + is_ComplexRegion = True + + def __new__(cls, sets, polar=False): + if polar is False: + return CartesianComplexRegion(sets) + elif polar is True: + return PolarComplexRegion(sets) + else: + raise ValueError("polar should be either True or False") + + @property + def sets(self): + """ + Return raw input sets to the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.sets + ProductSet(Interval(2, 3), Interval(4, 5)) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.sets + Union(ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + return self.args[0] + + @property + def psets(self): + """ + Return a tuple of sets (ProductSets) input of the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.psets + (ProductSet(Interval(2, 3), Interval(4, 5)),) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.psets + (ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + if self.sets.is_ProductSet: + psets = () + psets = psets + (self.sets, ) + else: + psets = self.sets.args + return psets + + @property + def a_interval(self): + """ + Return the union of intervals of `x` when, self is in + rectangular form, or the union of intervals of `r` when + self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.a_interval + Interval(2, 3) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.a_interval + Union(Interval(2, 3), Interval(4, 5)) + + """ + a_interval = [] + for element in self.psets: + a_interval.append(element.args[0]) + + a_interval = Union(*a_interval) + return a_interval + + @property + def b_interval(self): + """ + Return the union of intervals of `y` when, self is in + rectangular form, or the union of intervals of `theta` + when self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.b_interval + Interval(4, 5) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.b_interval + Interval(1, 7) + + """ + b_interval = [] + for element in self.psets: + b_interval.append(element.args[1]) + + b_interval = Union(*b_interval) + return b_interval + + @property + def _measure(self): + """ + The measure of self.sets. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, S + >>> a, b = Interval(2, 5), Interval(4, 8) + >>> c = Interval(0, 2*S.Pi) + >>> c1 = ComplexRegion(a*b) + >>> c1.measure + 12 + >>> c2 = ComplexRegion(a*c, polar=True) + >>> c2.measure + 6*pi + + """ + return self.sets._measure + + def _kind(self): + return self.args[0].kind + + @classmethod + def from_real(cls, sets): + """ + Converts given subset of real numbers to a complex region. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion + >>> unit = Interval(0,1) + >>> ComplexRegion.from_real(unit) + CartesianComplexRegion(ProductSet(Interval(0, 1), {0})) + + """ + if not sets.is_subset(S.Reals): + raise ValueError("sets must be a subset of the real line") + + return CartesianComplexRegion(sets * FiniteSet(0)) + + def _contains(self, other): + from sympy.functions import arg, Abs + + isTuple = isinstance(other, Tuple) + if isTuple and len(other) != 2: + raise ValueError('expecting Tuple of length 2') + + # If the other is not an Expression, and neither a Tuple + if not isinstance(other, (Expr, Tuple)): + return S.false + + # self in rectangular form + if not self.polar: + re, im = other if isTuple else other.as_real_imag() + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(re), + pset.args[1]._contains(im)]) + for pset in self.psets)] + + # self in polar form + elif self.polar: + if other.is_zero: + # ignore undefined complex argument + return tfn[fuzzy_or(pset.args[0]._contains(S.Zero) + for pset in self.psets)] + if isTuple: + r, theta = other + else: + r, theta = Abs(other), arg(other) + if theta.is_real and theta.is_number: + # angles in psets are normalized to [0, 2pi) + theta %= 2*S.Pi + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(r), + pset.args[1]._contains(theta)]) + for pset in self.psets)] + + +class CartesianComplexRegion(ComplexRegion): + r""" + Set representing a square region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, I, Interval + >>> region = ComplexRegion(Interval(1, 3) * Interval(4, 6)) + >>> 2 + 5*I in region + True + >>> 5*I in region + False + + See also + ======== + + ComplexRegion + PolarComplexRegion + Complexes + """ + + polar = False + variables = symbols('x, y', cls=Dummy) + + def __new__(cls, sets): + + if sets == S.Reals*S.Reals: + return S.Complexes + + if all(_a.is_FiniteSet for _a in sets.args) and (len(sets.args) == 2): + + # ** ProductSet of FiniteSets in the Complex Plane. ** + # For Cases like ComplexRegion({2, 4}*{3}), It + # would return {2 + 3*I, 4 + 3*I} + + # FIXME: This should probably be handled with something like: + # return ImageSet(Lambda((x, y), x+I*y), sets).rewrite(FiniteSet) + complex_num = [] + for x in sets.args[0]: + for y in sets.args[1]: + complex_num.append(x + S.ImaginaryUnit*y) + return FiniteSet(*complex_num) + else: + return Set.__new__(cls, sets) + + @property + def expr(self): + x, y = self.variables + return x + S.ImaginaryUnit*y + + +class PolarComplexRegion(ComplexRegion): + r""" + Set representing a polar region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, oo, pi, I + >>> rset = Interval(0, oo) + >>> thetaset = Interval(0, pi) + >>> upper_half_plane = ComplexRegion(rset * thetaset, polar=True) + >>> 1 + I in upper_half_plane + True + >>> 1 - I in upper_half_plane + False + + See also + ======== + + ComplexRegion + CartesianComplexRegion + Complexes + + """ + + polar = True + variables = symbols('r, theta', cls=Dummy) + + def __new__(cls, sets): + + new_sets = [] + # sets is Union of ProductSets + if not sets.is_ProductSet: + for k in sets.args: + new_sets.append(k) + # sets is ProductSets + else: + new_sets.append(sets) + # Normalize input theta + for k, v in enumerate(new_sets): + new_sets[k] = ProductSet(v.args[0], + normalize_theta_set(v.args[1])) + sets = Union(*new_sets) + return Set.__new__(cls, sets) + + @property + def expr(self): + r, theta = self.variables + return r*(cos(theta) + S.ImaginaryUnit*sin(theta)) + + +class Complexes(CartesianComplexRegion, metaclass=Singleton): + """ + The :class:`Set` of all complex numbers + + Examples + ======== + + >>> from sympy import S, I + >>> S.Complexes + Complexes + >>> 1 + I in S.Complexes + True + + See also + ======== + + Reals + ComplexRegion + + """ + + is_empty = False + is_finite_set = False + + # Override property from superclass since Complexes has no args + @property + def sets(self): + return ProductSet(S.Reals, S.Reals) + + def __new__(cls): + return Set.__new__(cls) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7700570b7bc3535e8ff522bd652d18cdd8103053 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/add.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/add.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdd65f4ca444e0e6a910349a95f97d4e17cffeee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/add.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..487b473444b6dbc69bcb30fbdbb1b026c63ab6e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96ca5dcdf8a16008f293e5ac76e96a42c4c0abd2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ac5260c2017c530c1408206e983b61b270dde22 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f33dee69b7257fb485ea6be4ac5e1f47d4a6d37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6338d5fa4f4a499d056f96acdadef08a59e46110 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/power.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/power.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..824d4fdf58e8b02e060aa87a8b30a7b6c4371b08 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/power.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/union.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/union.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f13f8908e5e29b3c9609703c285f65f18b367fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/__pycache__/union.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/add.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/add.py new file mode 100644 index 0000000000000000000000000000000000000000..8c07b25ed19d21febffd6b23a92b34b787179f44 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/add.py @@ -0,0 +1,79 @@ +from sympy.core.numbers import oo, Infinity, NegativeInfinity +from sympy.core.singleton import S +from sympy.core import Basic, Expr +from sympy.multipledispatch import Dispatcher +from sympy.sets import Interval, FiniteSet + + + +# XXX: The functions in this module are clearly not tested and are broken in a +# number of ways. + +_set_add = Dispatcher('_set_add') +_set_sub = Dispatcher('_set_sub') + + +@_set_add.register(Basic, Basic) +def _(x, y): + return None + + +@_set_add.register(Expr, Expr) +def _(x, y): + return x+y + + +@_set_add.register(Interval, Interval) +def _(x, y): + """ + Additions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start + y.start, x.end + y.end, + x.left_open or y.left_open, x.right_open or y.right_open) + + +@_set_add.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet({S.Infinity}) + +@_set_add.register(Interval, NegativeInfinity) +def _(x, y): + if x.end is S.Infinity: + return Interval(-oo, oo) + return FiniteSet({S.NegativeInfinity}) + + +@_set_sub.register(Basic, Basic) +def _(x, y): + return None + + +@_set_sub.register(Expr, Expr) +def _(x, y): + return x-y + + +@_set_sub.register(Interval, Interval) +def _(x, y): + """ + Subtractions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start - y.end, x.end - y.start, + x.left_open or y.right_open, x.right_open or y.left_open) + + +@_set_sub.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) + +@_set_sub.register(Interval, NegativeInfinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/comparison.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..b64d1a2a22e15d09f6f10fb4fef730163d468d45 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/comparison.py @@ -0,0 +1,53 @@ +from sympy.core.relational import Eq, is_eq +from sympy.core.basic import Basic +from sympy.core.logic import fuzzy_and, fuzzy_bool +from sympy.logic.boolalg import And +from sympy.multipledispatch import dispatch +from sympy.sets.sets import tfn, ProductSet, Interval, FiniteSet, Set + + +@dispatch(Interval, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(FiniteSet, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Interval, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return And(Eq(lhs.left, rhs.left), + Eq(lhs.right, rhs.right), + lhs.left_open == rhs.left_open, + lhs.right_open == rhs.right_open) + +@dispatch(FiniteSet, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + def all_in_both(): + s_set = set(lhs.args) + o_set = set(rhs.args) + yield fuzzy_and(lhs._contains(e) for e in o_set - s_set) + yield fuzzy_and(rhs._contains(e) for e in s_set - o_set) + + return tfn[fuzzy_and(all_in_both())] + + +@dispatch(ProductSet, ProductSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + if len(lhs.sets) != len(rhs.sets): + return False + + eqs = (is_eq(x, y) for x, y in zip(lhs.sets, rhs.sets)) + return tfn[fuzzy_and(map(fuzzy_bool, eqs))] + + +@dispatch(Set, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Set, Set) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return tfn[fuzzy_and(a.is_subset(b) for a, b in [(lhs, rhs), (rhs, lhs)])] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2529dbfd458451d7d09e91c717b170df77b1d9fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/functions.py @@ -0,0 +1,262 @@ +from sympy.core.singleton import S +from sympy.sets.sets import Set +from sympy.calculus.singularities import singularities +from sympy.core import Expr, Add +from sympy.core.function import Lambda, FunctionClass, diff, expand_mul +from sympy.core.numbers import Float, oo +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import true +from sympy.multipledispatch import Dispatcher +from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet, + Intersection, Range, Complement) +from sympy.sets.sets import EmptySet, is_function_invertible_in_set +from sympy.sets.fancysets import Integers, Naturals, Reals +from sympy.functions.elementary.exponential import match_real_imag + + +_x, _y = symbols("x y") + +FunctionUnion = (FunctionClass, Lambda) + +_set_function = Dispatcher('_set_function') + + +@_set_function.register(FunctionClass, Set) +def _(f, x): + return None + +@_set_function.register(FunctionUnion, FiniteSet) +def _(f, x): + return FiniteSet(*map(f, x)) + +@_set_function.register(Lambda, Interval) +def _(f, x): + from sympy.solvers.solveset import solveset + from sympy.series import limit + # TODO: handle functions with infinitely many solutions (eg, sin, tan) + # TODO: handle multivariate functions + + expr = f.expr + if len(expr.free_symbols) > 1 or len(f.variables) != 1: + return + var = f.variables[0] + if not var.is_real: + if expr.subs(var, Dummy(real=True)).is_real is False: + return + + if expr.is_Piecewise: + result = S.EmptySet + domain_set = x + for (p_expr, p_cond) in expr.args: + if p_cond is true: + intrvl = domain_set + else: + intrvl = p_cond.as_set() + intrvl = Intersection(domain_set, intrvl) + + if p_expr.is_Number: + image = FiniteSet(p_expr) + else: + image = imageset(Lambda(var, p_expr), intrvl) + result = Union(result, image) + + # remove the part which has been `imaged` + domain_set = Complement(domain_set, intrvl) + if domain_set is S.EmptySet: + break + return result + + if not x.start.is_comparable or not x.end.is_comparable: + return + + try: + from sympy.polys.polyutils import _nsort + sing = list(singularities(expr, var, x)) + if len(sing) > 1: + sing = _nsort(sing) + except NotImplementedError: + return + + if x.left_open: + _start = limit(expr, var, x.start, dir="+") + elif x.start not in sing: + _start = f(x.start) + if x.right_open: + _end = limit(expr, var, x.end, dir="-") + elif x.end not in sing: + _end = f(x.end) + + if len(sing) == 0: + soln_expr = solveset(diff(expr, var), var) + if not (isinstance(soln_expr, FiniteSet) + or soln_expr is S.EmptySet): + return + solns = list(soln_expr) + + extr = [_start, _end] + [f(i) for i in solns + if i.is_real and i in x] + start, end = Min(*extr), Max(*extr) + + left_open, right_open = False, False + if _start <= _end: + # the minimum or maximum value can occur simultaneously + # on both the edge of the interval and in some interior + # point + if start == _start and start not in solns: + left_open = x.left_open + if end == _end and end not in solns: + right_open = x.right_open + else: + if start == _end and start not in solns: + left_open = x.right_open + if end == _start and end not in solns: + right_open = x.left_open + + return Interval(start, end, left_open, right_open) + else: + return imageset(f, Interval(x.start, sing[0], + x.left_open, True)) + \ + Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True)) + for i in range(0, len(sing) - 1)]) + \ + imageset(f, Interval(sing[-1], x.end, True, x.right_open)) + +@_set_function.register(FunctionClass, Interval) +def _(f, x): + if f == exp: + return Interval(exp(x.start), exp(x.end), x.left_open, x.right_open) + elif f == log: + return Interval(log(x.start), log(x.end), x.left_open, x.right_open) + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Union) +def _(f, x): + return Union(*(imageset(f, arg) for arg in x.args)) + +@_set_function.register(FunctionUnion, Intersection) +def _(f, x): + # If the function is invertible, intersect the maps of the sets. + if is_function_invertible_in_set(f, x): + return Intersection(*(imageset(f, arg) for arg in x.args)) + else: + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, EmptySet) +def _(f, x): + return x + +@_set_function.register(FunctionUnion, Set) +def _(f, x): + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Range) +def _(f, self): + if not self: + return S.EmptySet + if not isinstance(f.expr, Expr): + return + if self.size == 1: + return FiniteSet(f(self[0])) + if f is S.IdentityFunction: + return self + + x = f.variables[0] + expr = f.expr + # handle f that is linear in f's variable + if x not in expr.free_symbols or x in expr.diff(x).free_symbols: + return + if self.start.is_finite: + F = f(self.step*x + self.start) # for i in range(len(self)) + else: + F = f(-self.step*x + self[-1]) + F = expand_mul(F) + if F != expr: + return imageset(x, F, Range(self.size)) + +@_set_function.register(FunctionUnion, Integers) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + n = f.variables[0] + if expr == abs(n): + return S.Naturals0 + + # f(x) + c and f(-x) + c cover the same integers + # so choose the form that has the fewest negatives + c = f(0) + fx = f(n) - c + f_x = f(-n) - c + neg_count = lambda e: sum(_.could_extract_minus_sign() + for _ in Add.make_args(e)) + if neg_count(f_x) < neg_count(fx): + expr = f_x + c + + a = Wild('a', exclude=[n]) + b = Wild('b', exclude=[n]) + match = expr.match(a*n + b) + if match and match[a] and ( + not match[a].atoms(Float) and + not match[b].atoms(Float)): + # canonical shift + a, b = match[a], match[b] + if a in [1, -1]: + # drop integer addends in b + nonint = [] + for bi in Add.make_args(b): + if not bi.is_integer: + nonint.append(bi) + b = Add(*nonint) + if b.is_number and a.is_real: + # avoid Mod for complex numbers, #11391 + br, bi = match_real_imag(b) + if br and br.is_comparable and a.is_comparable: + br %= a + b = br + S.ImaginaryUnit*bi + elif b.is_number and a.is_imaginary: + br, bi = match_real_imag(b) + ai = a/S.ImaginaryUnit + if bi and bi.is_comparable and ai.is_comparable: + bi %= ai + b = br + S.ImaginaryUnit*bi + expr = a*n + b + + if expr != f.expr: + return ImageSet(Lambda(n, expr), S.Integers) + + +@_set_function.register(FunctionUnion, Naturals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + x = f.variables[0] + if not expr.free_symbols - {x}: + if expr == abs(x): + if self is S.Naturals: + return self + return S.Naturals0 + step = expr.coeff(x) + c = expr.subs(x, 0) + if c.is_Integer and step.is_Integer and expr == step*x + c: + if self is S.Naturals: + c += step + if step > 0: + if step == 1: + if c == 0: + return S.Naturals0 + elif c == 1: + return S.Naturals + return Range(c, oo, step) + return Range(c, -oo, step) + + +@_set_function.register(FunctionUnion, Reals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + return _set_function(f, Interval(-oo, oo)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/intersection.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/intersection.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb9309ef3e9d2722ab1bfe664f1d1644f17da5d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/intersection.py @@ -0,0 +1,533 @@ +from sympy.core.basic import _aresame +from sympy.core.function import Lambda, expand_complex +from sympy.core.mul import Mul +from sympy.core.numbers import ilcm, Float +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor, ceiling +from sympy.sets.fancysets import ComplexRegion +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.multipledispatch import Dispatcher +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import (Integers, Naturals, Reals, Range, + ImageSet, Rationals) +from sympy.sets.sets import EmptySet, UniversalSet, imageset, ProductSet +from sympy.simplify.radsimp import numer + + +intersection_sets = Dispatcher('intersection_sets') + + +@intersection_sets.register(ConditionSet, ConditionSet) +def _(a, b): + return None + +@intersection_sets.register(ConditionSet, Set) +def _(a, b): + return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b)) + +@intersection_sets.register(Naturals, Integers) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Naturals) +def _(a, b): + return a if a is S.Naturals else b + +@intersection_sets.register(Interval, Naturals) +def _(a, b): + return intersection_sets(b, a) + +@intersection_sets.register(ComplexRegion, Set) +def _(self, other): + if other.is_ComplexRegion: + # self in rectangular form + if (not self.polar) and (not other.polar): + return ComplexRegion(Intersection(self.sets, other.sets)) + + # self in polar form + elif self.polar and other.polar: + r1, theta1 = self.a_interval, self.b_interval + r2, theta2 = other.a_interval, other.b_interval + new_r_interval = Intersection(r1, r2) + new_theta_interval = Intersection(theta1, theta2) + + # 0 and 2*Pi means the same + if ((2*S.Pi in theta1 and S.Zero in theta2) or + (2*S.Pi in theta2 and S.Zero in theta1)): + new_theta_interval = Union(new_theta_interval, + FiniteSet(0)) + return ComplexRegion(new_r_interval*new_theta_interval, + polar=True) + + + if other.is_subset(S.Reals): + new_interval = [] + x = symbols("x", cls=Dummy, real=True) + + # self in rectangular form + if not self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + + # self in polar form + elif self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + if S.Pi in element.args[1]: + new_interval.append(ImageSet(Lambda(x, -x), element.args[0])) + if S.Zero in element.args[0]: + new_interval.append(FiniteSet(0)) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + +@intersection_sets.register(Integers, Reals) +def _(a, b): + return a + +@intersection_sets.register(Range, Interval) +def _(a, b): + # Check that there are no symbolic arguments + if not all(i.is_number for i in a.args + b.args[:2]): + return + + # In case of null Range, return an EmptySet. + if a.size == 0: + return S.EmptySet + + # trim down to self's size, and represent + # as a Range with step 1. + start = ceiling(max(b.inf, a.inf)) + if start not in b: + start += 1 + end = floor(min(b.sup, a.sup)) + if end not in b: + end -= 1 + return intersection_sets(a, Range(start, end + 1)) + +@intersection_sets.register(Range, Naturals) +def _(a, b): + return intersection_sets(a, Interval(b.inf, S.Infinity)) + +@intersection_sets.register(Range, Range) +def _(a, b): + # Check that there are no symbolic range arguments + if not all(all(v.is_number for v in r.args) for r in [a, b]): + return None + + # non-overlap quick exits + if not b: + return S.EmptySet + if not a: + return S.EmptySet + if b.sup < a.inf: + return S.EmptySet + if b.inf > a.sup: + return S.EmptySet + + # work with finite end at the start + r1 = a + if r1.start.is_infinite: + r1 = r1.reversed + r2 = b + if r2.start.is_infinite: + r2 = r2.reversed + + # If both ends are infinite then it means that one Range is just the set + # of all integers (the step must be 1). + if r1.start.is_infinite: + return b + if r2.start.is_infinite: + return a + + from sympy.solvers.diophantine.diophantine import diop_linear + + # this equation represents the values of the Range; + # it's a linear equation + eq = lambda r, i: r.start + i*r.step + + # we want to know when the two equations might + # have integer solutions so we use the diophantine + # solver + va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b'))) + + # check for no solution + no_solution = va is None and vb is None + if no_solution: + return S.EmptySet + + # there is a solution + # ------------------- + + # find the coincident point, c + a0 = va.as_coeff_Add()[0] + c = eq(r1, a0) + + # find the first point, if possible, in each range + # since c may not be that point + def _first_finite_point(r1, c): + if c == r1.start: + return c + # st is the signed step we need to take to + # get from c to r1.start + st = sign(r1.start - c)*step + # use Range to calculate the first point: + # we want to get as close as possible to + # r1.start; the Range will not be null since + # it will at least contain c + s1 = Range(c, r1.start + st, st)[-1] + if s1 == r1.start: + pass + else: + # if we didn't hit r1.start then, if the + # sign of st didn't match the sign of r1.step + # we are off by one and s1 is not in r1 + if sign(r1.step) != sign(st): + s1 -= st + if s1 not in r1: + return + return s1 + + # calculate the step size of the new Range + step = abs(ilcm(r1.step, r2.step)) + s1 = _first_finite_point(r1, c) + if s1 is None: + return S.EmptySet + s2 = _first_finite_point(r2, c) + if s2 is None: + return S.EmptySet + + # replace the corresponding start or stop in + # the original Ranges with these points; the + # result must have at least one point since + # we know that s1 and s2 are in the Ranges + def _updated_range(r, first): + st = sign(r.step)*step + if r.start.is_finite: + rv = Range(first, r.stop, st) + else: + rv = Range(r.start, first + st, st) + return rv + r1 = _updated_range(a, s1) + r2 = _updated_range(b, s2) + + # work with them both in the increasing direction + if sign(r1.step) < 0: + r1 = r1.reversed + if sign(r2.step) < 0: + r2 = r2.reversed + + # return clipped Range with positive step; it + # can't be empty at this point + start = max(r1.start, r2.start) + stop = min(r1.stop, r2.stop) + return Range(start, stop, step) + + +@intersection_sets.register(Range, Integers) +def _(a, b): + return a + + +@intersection_sets.register(Range, Rationals) +def _(a, b): + return a + + +@intersection_sets.register(ImageSet, Set) +def _(self, other): + from sympy.solvers.diophantine import diophantine + + # Only handle the straight-forward univariate case + if (len(self.lamda.variables) > 1 + or self.lamda.signature != self.lamda.variables): + return None + base_set = self.base_sets[0] + + # Intersection between ImageSets with Integers as base set + # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the + # diophantine equations f(n)=g(m). + # If the solutions for n are {h(t) : t in Integers} then we return + # {f(h(t)) : t in integers}. + # If the solutions for n are {n_1, n_2, ..., n_k} then we return + # {f(n_i) : 1 <= i <= k}. + if base_set is S.Integers: + gm = None + if isinstance(other, ImageSet) and other.base_sets == (S.Integers,): + gm = other.lamda.expr + var = other.lamda.variables[0] + # Symbol of second ImageSet lambda must be distinct from first + m = Dummy('m') + gm = gm.subs(var, m) + elif other is S.Integers: + m = gm = Dummy('m') + if gm is not None: + fn = self.lamda.expr + n = self.lamda.variables[0] + try: + solns = list(diophantine(fn - gm, syms=(n, m), permute=True)) + except (TypeError, NotImplementedError): + # TypeError if equation not polynomial with rational coeff. + # NotImplementedError if correct format but no solver. + return + # 3 cases are possible for solns: + # - empty set, + # - one or more parametric (infinite) solutions, + # - a finite number of (non-parametric) solution couples. + # Among those, there is one type of solution set that is + # not helpful here: multiple parametric solutions. + if len(solns) == 0: + return S.EmptySet + elif any(s.free_symbols for tupl in solns for s in tupl): + if len(solns) == 1: + soln, solm = solns[0] + (t,) = soln.free_symbols + expr = fn.subs(n, soln.subs(t, n)).expand() + return imageset(Lambda(n, expr), S.Integers) + else: + return + else: + return FiniteSet(*(fn.subs(n, s[0]) for s in solns)) + + if other == S.Reals: + from sympy.solvers.solvers import denoms, solve_linear + + def _solution_union(exprs, sym): + # return a union of linear solutions to i in expr; + # if i cannot be solved, use a ConditionSet for solution + sols = [] + for i in exprs: + x, xis = solve_linear(i, 0, [sym]) + if x == sym: + sols.append(FiniteSet(xis)) + else: + sols.append(ConditionSet(sym, Eq(i, 0))) + return Union(*sols) + + f = self.lamda.expr + n = self.lamda.variables[0] + + n_ = Dummy(n.name, real=True) + f_ = f.subs(n, n_) + + re, im = f_.as_real_imag() + im = expand_complex(im) + + re = re.subs(n_, n) + im = im.subs(n_, n) + ifree = im.free_symbols + lam = Lambda(n, re) + if im.is_zero: + # allow re-evaluation + # of self in this case to make + # the result canonical + pass + elif im.is_zero is False: + return S.EmptySet + elif ifree != {n}: + return None + else: + # univarite imaginary part in same variable; + # use numer instead of as_numer_denom to keep + # this as fast as possible while still handling + # simple cases + base_set &= _solution_union( + Mul.make_args(numer(im)), n) + # exclude values that make denominators 0 + base_set -= _solution_union(denoms(f), n) + return imageset(lam, base_set) + + elif isinstance(other, Interval): + from sympy.solvers.solveset import (invert_real, invert_complex, + solveset) + + f = self.lamda.expr + n = self.lamda.variables[0] + new_inf, new_sup = None, None + new_lopen, new_ropen = other.left_open, other.right_open + + if f.is_real: + inverter = invert_real + else: + inverter = invert_complex + + g1, h1 = inverter(f, other.inf, n) + g2, h2 = inverter(f, other.sup, n) + + if all(isinstance(i, FiniteSet) for i in (h1, h2)): + if g1 == n: + if len(h1) == 1: + new_inf = h1.args[0] + if g2 == n: + if len(h2) == 1: + new_sup = h2.args[0] + # TODO: Design a technique to handle multiple-inverse + # functions + + # Any of the new boundary values cannot be determined + if any(i is None for i in (new_sup, new_inf)): + return + + + range_set = S.EmptySet + + if all(i.is_real for i in (new_sup, new_inf)): + # this assumes continuity of underlying function + # however fixes the case when it is decreasing + if new_inf > new_sup: + new_inf, new_sup = new_sup, new_inf + new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen) + range_set = base_set.intersect(new_interval) + else: + if other.is_subset(S.Reals): + solutions = solveset(f, n, S.Reals) + if not isinstance(range_set, (ImageSet, ConditionSet)): + range_set = solutions.intersect(other) + else: + return + + if range_set is S.EmptySet: + return S.EmptySet + elif isinstance(range_set, Range) and range_set.size is not S.Infinity: + range_set = FiniteSet(*list(range_set)) + + if range_set is not None: + return imageset(Lambda(n, f), range_set) + return + else: + return + + +@intersection_sets.register(ProductSet, ProductSet) +def _(a, b): + if len(b.args) != len(a.args): + return S.EmptySet + return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets))) + + +@intersection_sets.register(Interval, Interval) +def _(a, b): + # handle (-oo, oo) + infty = S.NegativeInfinity, S.Infinity + if a == Interval(*infty): + l, r = a.left, a.right + if l.is_real or l in infty or r.is_real or r in infty: + return b + + # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0 + if not a._is_comparable(b): + return None + + empty = False + + if a.start <= b.end and b.start <= a.end: + # Get topology right. + if a.start < b.start: + start = b.start + left_open = b.left_open + elif a.start > b.start: + start = a.start + left_open = a.left_open + else: + start = a.start + if not _aresame(a.start, b.start): + # For example Integer(2) != Float(2) + # Prefer the Float boundary because Floats should be + # contagious in calculations. + if b.start.has(Float) and not a.start.has(Float): + start = b.start + elif a.start.has(Float) and not b.start.has(Float): + start = a.start + else: + #this is to ensure that if Eq(a.start, b.start) but + #type(a.start) != type(b.start) the order of a and b + #does not matter for the result + start = list(ordered([a,b]))[0].start + left_open = a.left_open or b.left_open + + if a.end < b.end: + end = a.end + right_open = a.right_open + elif a.end > b.end: + end = b.end + right_open = b.right_open + else: + # see above for logic with start + end = a.end + if not _aresame(a.end, b.end): + if b.end.has(Float) and not a.end.has(Float): + end = b.end + elif a.end.has(Float) and not b.end.has(Float): + end = a.end + else: + end = list(ordered([a,b]))[0].end + right_open = a.right_open or b.right_open + + if end - start == 0 and (left_open or right_open): + empty = True + else: + empty = True + + if empty: + return S.EmptySet + + return Interval(start, end, left_open, right_open) + +@intersection_sets.register(EmptySet, Set) +def _(a, b): + return S.EmptySet + +@intersection_sets.register(UniversalSet, Set) +def _(a, b): + return b + +@intersection_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements & b._elements)) + +@intersection_sets.register(FiniteSet, Set) +def _(a, b): + try: + return FiniteSet(*[el for el in a if el in b]) + except TypeError: + return None # could not evaluate `el in b` due to symbolic ranges. + +@intersection_sets.register(Set, Set) +def _(a, b): + return None + +@intersection_sets.register(Integers, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Rationals, Reals) +def _(a, b): + return a + +def _intlike_interval(a, b): + try: + if b._inf is S.NegativeInfinity and b._sup is S.Infinity: + return a + s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1) + return intersection_sets(s, b) # take out endpoints if open interval + except ValueError: + return None + +@intersection_sets.register(Integers, Interval) +def _(a, b): + return _intlike_interval(a, b) + +@intersection_sets.register(Naturals, Interval) +def _(a, b): + return _intlike_interval(a, b) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/issubset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/issubset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc23e8bf56f1743cd7f08452dd09a0acf981f5da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/issubset.py @@ -0,0 +1,144 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.logic import fuzzy_and, fuzzy_bool, fuzzy_not, fuzzy_or +from sympy.core.relational import Eq +from sympy.sets.sets import FiniteSet, Interval, Set, Union, ProductSet +from sympy.sets.fancysets import Complexes, Reals, Range, Rationals +from sympy.multipledispatch import Dispatcher + + +_inf_sets = [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, S.Complexes] + + +is_subset_sets = Dispatcher('is_subset_sets') + + +@is_subset_sets.register(Set, Set) +def _(a, b): + return None + +@is_subset_sets.register(Interval, Interval) +def _(a, b): + # This is correct but can be made more comprehensive... + if fuzzy_bool(a.start < b.start): + return False + if fuzzy_bool(a.end > b.end): + return False + if (b.left_open and not a.left_open and fuzzy_bool(Eq(a.start, b.start))): + return False + if (b.right_open and not a.right_open and fuzzy_bool(Eq(a.end, b.end))): + return False + +@is_subset_sets.register(Interval, FiniteSet) +def _(a_interval, b_fs): + # An Interval can only be a subset of a finite set if it is finite + # which can only happen if it has zero measure. + if fuzzy_not(a_interval.measure.is_zero): + return False + +@is_subset_sets.register(Interval, Union) +def _(a_interval, b_u): + if all(isinstance(s, (Interval, FiniteSet)) for s in b_u.args): + intervals = [s for s in b_u.args if isinstance(s, Interval)] + if all(fuzzy_bool(a_interval.start < s.start) for s in intervals): + return False + if all(fuzzy_bool(a_interval.end > s.end) for s in intervals): + return False + if a_interval.measure.is_nonzero: + no_overlap = lambda s1, s2: fuzzy_or([ + fuzzy_bool(s1.end <= s2.start), + fuzzy_bool(s1.start >= s2.end), + ]) + if all(no_overlap(s, a_interval) for s in intervals): + return False + +@is_subset_sets.register(Range, Range) +def _(a, b): + if a.step == b.step == 1: + return fuzzy_and([fuzzy_bool(a.start >= b.start), + fuzzy_bool(a.stop <= b.stop)]) + +@is_subset_sets.register(Range, Interval) +def _(a_range, b_interval): + if a_range.step.is_positive: + if b_interval.left_open and a_range.inf.is_finite: + cond_left = a_range.inf > b_interval.left + else: + cond_left = a_range.inf >= b_interval.left + if b_interval.right_open and a_range.sup.is_finite: + cond_right = a_range.sup < b_interval.right + else: + cond_right = a_range.sup <= b_interval.right + return fuzzy_and([cond_left, cond_right]) + +@is_subset_sets.register(Range, FiniteSet) +def _(a_range, b_finiteset): + try: + a_size = a_range.size + except ValueError: + # symbolic Range of unknown size + return None + if a_size > len(b_finiteset): + return False + elif any(arg.has(Symbol) for arg in a_range.args): + return fuzzy_and(b_finiteset.contains(x) for x in a_range) + else: + # Checking A \ B == EmptySet is more efficient than repeated naive + # membership checks on an arbitrary FiniteSet. + a_set = set(a_range) + b_remaining = len(b_finiteset) + # Symbolic expressions and numbers of unknown type (integer or not) are + # all counted as "candidates", i.e. *potentially* matching some a in + # a_range. + cnt_candidate = 0 + for b in b_finiteset: + if b.is_Integer: + a_set.discard(b) + elif fuzzy_not(b.is_integer): + pass + else: + cnt_candidate += 1 + b_remaining -= 1 + if len(a_set) > b_remaining + cnt_candidate: + return False + if len(a_set) == 0: + return True + return None + +@is_subset_sets.register(Interval, Range) +def _(a_interval, b_range): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Interval, Rationals) +def _(a_interval, b_rationals): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Range, Complexes) +def _(a, b): + return True + +@is_subset_sets.register(Complexes, Interval) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Range) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Rationals) +def _(a, b): + return False + +@is_subset_sets.register(Rationals, Reals) +def _(a, b): + return True + +@is_subset_sets.register(Rationals, Range) +def _(a, b): + return False + +@is_subset_sets.register(ProductSet, FiniteSet) +def _(a_ps, b_fs): + return fuzzy_and(b_fs.contains(x) for x in a_ps) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/mul.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/mul.py new file mode 100644 index 0000000000000000000000000000000000000000..0dedc8068b7973fd4cb6fbf2854e5fa671d188de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/mul.py @@ -0,0 +1,79 @@ +from sympy.core import Basic, Expr +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.multipledispatch import Dispatcher +from sympy.sets.setexpr import set_mul +from sympy.sets.sets import Interval, Set + + +_x, _y = symbols("x y") + + +_set_mul = Dispatcher('_set_mul') +_set_div = Dispatcher('_set_div') + + +@_set_mul.register(Basic, Basic) +def _(x, y): + return None + +@_set_mul.register(Set, Set) +def _(x, y): + return None + +@_set_mul.register(Expr, Expr) +def _(x, y): + return x*y + +@_set_mul.register(Interval, Interval) +def _(x, y): + """ + Multiplications in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan. + comvals = ( + (x.start * y.start, bool(x.left_open or y.left_open)), + (x.start * y.end, bool(x.left_open or y.right_open)), + (x.end * y.start, bool(x.right_open or y.left_open)), + (x.end * y.end, bool(x.right_open or y.right_open)), + ) + # TODO: handle symbolic intervals + minval, minopen = min(comvals) + maxval, maxopen = max(comvals) + return Interval( + minval, + maxval, + minopen, + maxopen + ) + +@_set_div.register(Basic, Basic) +def _(x, y): + return None + +@_set_div.register(Expr, Expr) +def _(x, y): + return x/y + +@_set_div.register(Set, Set) +def _(x, y): + return None + +@_set_div.register(Interval, Interval) +def _(x, y): + """ + Divisions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + if (y.start*y.end).is_negative: + return Interval(-oo, oo) + if y.start == 0: + s2 = oo + else: + s2 = 1/y.start + if y.end == 0: + s1 = -oo + else: + s1 = 1/y.end + return set_mul(x, Interval(s1, s2, y.right_open, y.left_open)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/power.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/power.py new file mode 100644 index 0000000000000000000000000000000000000000..3cad4ee49ab27770143bc121d1fbcd024bf01548 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/power.py @@ -0,0 +1,107 @@ +from sympy.core import Basic, Expr +from sympy.core.function import Lambda +from sympy.core.numbers import oo, Infinity, NegativeInfinity, Zero, Integer +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import (Max, Min) +from sympy.sets.fancysets import ImageSet +from sympy.sets.setexpr import set_div +from sympy.sets.sets import Set, Interval, FiniteSet, Union +from sympy.multipledispatch import Dispatcher + + +_x, _y = symbols("x y") + + +_set_pow = Dispatcher('_set_pow') + + +@_set_pow.register(Basic, Basic) +def _(x, y): + return None + +@_set_pow.register(Set, Set) +def _(x, y): + return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y) + +@_set_pow.register(Expr, Expr) +def _(x, y): + return x**y + +@_set_pow.register(Interval, Zero) +def _(x, z): + return FiniteSet(S.One) + +@_set_pow.register(Interval, Integer) +def _(x, exponent): + """ + Powers in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + s1 = x.start**exponent + s2 = x.end**exponent + if ((s2 > s1) if exponent > 0 else (x.end > -x.start)) == True: + left_open = x.left_open + right_open = x.right_open + # TODO: handle unevaluated condition. + sleft = s2 + else: + # TODO: `s2 > s1` could be unevaluated. + left_open = x.right_open + right_open = x.left_open + sleft = s1 + + if x.start.is_positive: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + elif x.end.is_negative: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + + # Case where x.start < 0 and x.end > 0: + if exponent.is_odd: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(-oo, s1, True, x.left_open) + return Union(Interval(-oo, s1, True, x.left_open), Interval(s2, oo, x.right_open)) + else: + return Interval(s1, s2, x.left_open, x.right_open) + elif exponent.is_even: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(s1, oo, x.left_open) + return Interval(0, oo) + else: + return Interval(S.Zero, sleft, S.Zero not in x, left_open) + +@_set_pow.register(Interval, Infinity) +def _(b, e): + # TODO: add logic for open intervals? + if b.start.is_nonnegative: + if b.end < 1: + return FiniteSet(S.Zero) + if b.start > 1: + return FiniteSet(S.Infinity) + return Interval(0, oo) + elif b.end.is_negative: + if b.start > -1: + return FiniteSet(S.Zero) + if b.end < -1: + return FiniteSet(-oo, oo) + return Interval(-oo, oo) + else: + if b.start > -1: + if b.end < 1: + return FiniteSet(S.Zero) + return Interval(0, oo) + return Interval(-oo, oo) + +@_set_pow.register(Interval, NegativeInfinity) +def _(b, e): + return _set_pow(set_div(S.One, b), oo) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/union.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/union.py new file mode 100644 index 0000000000000000000000000000000000000000..75d867b49969ae2aeea76155dbaae7e05c1a6847 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/handlers/union.py @@ -0,0 +1,147 @@ +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.sets.sets import (EmptySet, FiniteSet, Intersection, + Interval, ProductSet, Set, Union, UniversalSet) +from sympy.sets.fancysets import (ComplexRegion, Naturals, Naturals0, + Integers, Rationals, Reals) +from sympy.multipledispatch import Dispatcher + + +union_sets = Dispatcher('union_sets') + + +@union_sets.register(Naturals0, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Rationals) +def _(a, b): + return a + +@union_sets.register(Integers, Set) +def _(a, b): + intersect = Intersection(a, b) + if intersect == a: + return b + elif intersect == b: + return a + +@union_sets.register(ComplexRegion, Set) +def _(a, b): + if b.is_subset(S.Reals): + # treat a subset of reals as a complex region + b = ComplexRegion.from_real(b) + + if b.is_ComplexRegion: + # a in rectangular form + if (not a.polar) and (not b.polar): + return ComplexRegion(Union(a.sets, b.sets)) + # a in polar form + elif a.polar and b.polar: + return ComplexRegion(Union(a.sets, b.sets), polar=True) + return None + +@union_sets.register(EmptySet, Set) +def _(a, b): + return b + + +@union_sets.register(UniversalSet, Set) +def _(a, b): + return a + +@union_sets.register(ProductSet, ProductSet) +def _(a, b): + if b.is_subset(a): + return a + if len(b.sets) != len(a.sets): + return None + if len(a.sets) == 2: + a1, a2 = a.sets + b1, b2 = b.sets + if a1 == b1: + return a1 * Union(a2, b2) + if a2 == b2: + return Union(a1, b1) * a2 + return None + +@union_sets.register(ProductSet, Set) +def _(a, b): + if b.is_subset(a): + return a + return None + +@union_sets.register(Interval, Interval) +def _(a, b): + if a._is_comparable(b): + # Non-overlapping intervals + end = Min(a.end, b.end) + start = Max(a.start, b.start) + if (end < start or + (end == start and (end not in a and end not in b))): + return None + else: + start = Min(a.start, b.start) + end = Max(a.end, b.end) + + left_open = ((a.start != start or a.left_open) and + (b.start != start or b.left_open)) + right_open = ((a.end != end or a.right_open) and + (b.end != end or b.right_open)) + return Interval(start, end, left_open, right_open) + +@union_sets.register(Interval, UniversalSet) +def _(a, b): + return S.UniversalSet + +@union_sets.register(Interval, Set) +def _(a, b): + # If I have open end points and these endpoints are contained in b + # But only in case, when endpoints are finite. Because + # interval does not contain oo or -oo. + open_left_in_b_and_finite = (a.left_open and + sympify(b.contains(a.start)) is S.true and + a.start.is_finite) + open_right_in_b_and_finite = (a.right_open and + sympify(b.contains(a.end)) is S.true and + a.end.is_finite) + if open_left_in_b_and_finite or open_right_in_b_and_finite: + # Fill in my end points and return + open_left = a.left_open and a.start not in b + open_right = a.right_open and a.end not in b + new_a = Interval(a.start, a.end, open_left, open_right) + return {new_a, b} + return None + +@union_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements | b._elements)) + +@union_sets.register(FiniteSet, Set) +def _(a, b): + # If `b` set contains one of my elements, remove it from `a` + if any(b.contains(x) == True for x in a): + return { + FiniteSet(*[x for x in a if b.contains(x) != True]), b} + return None + +@union_sets.register(Set, Set) +def _(a, b): + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/ordinals.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe062354cfe58a4747998e51fa0d261e67576cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/ordinals.py @@ -0,0 +1,282 @@ +from sympy.core import Basic, Integer +import operator + + +class OmegaPower(Basic): + """ + Represents ordinal exponential and multiplication terms one of the + building blocks of the :class:`Ordinal` class. + In ``OmegaPower(a, b)``, ``a`` represents exponent and ``b`` represents multiplicity. + """ + def __new__(cls, a, b): + if isinstance(b, int): + b = Integer(b) + if not isinstance(b, Integer) or b <= 0: + raise TypeError("multiplicity must be a positive integer") + + if not isinstance(a, Ordinal): + a = Ordinal.convert(a) + + return Basic.__new__(cls, a, b) + + @property + def exp(self): + return self.args[0] + + @property + def mult(self): + return self.args[1] + + def _compare_term(self, other, op): + if self.exp == other.exp: + return op(self.mult, other.mult) + else: + return op(self.exp, other.exp) + + def __eq__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self.args == other.args + + def __hash__(self): + return Basic.__hash__(self) + + def __lt__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self._compare_term(other, operator.lt) + + +class Ordinal(Basic): + """ + Represents ordinals in Cantor normal form. + + Internally, this class is just a list of instances of OmegaPower. + + Examples + ======== + >>> from sympy import Ordinal, OmegaPower + >>> from sympy.sets.ordinals import omega + >>> w = omega + >>> w.is_limit_ordinal + True + >>> Ordinal(OmegaPower(w + 1, 1), OmegaPower(3, 2)) + w**(w + 1) + w**3*2 + >>> 3 + w + w + >>> (w + 1) * w + w**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ordinal_arithmetic + """ + def __new__(cls, *terms): + obj = super().__new__(cls, *terms) + powers = [i.exp for i in obj.args] + if not all(powers[i] >= powers[i+1] for i in range(len(powers) - 1)): + raise ValueError("powers must be in decreasing order") + return obj + + @property + def terms(self): + return self.args + + @property + def leading_term(self): + if self == ord0: + raise ValueError("ordinal zero has no leading term") + return self.terms[0] + + @property + def trailing_term(self): + if self == ord0: + raise ValueError("ordinal zero has no trailing term") + return self.terms[-1] + + @property + def is_successor_ordinal(self): + try: + return self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def is_limit_ordinal(self): + try: + return not self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def degree(self): + return self.leading_term.exp + + @classmethod + def convert(cls, integer_value): + if integer_value == 0: + return ord0 + return Ordinal(OmegaPower(0, integer_value)) + + def __eq__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return self.terms == other.terms + + def __hash__(self): + return hash(self.args) + + def __lt__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + for term_self, term_other in zip(self.terms, other.terms): + if term_self != term_other: + return term_self < term_other + return len(self.terms) < len(other.terms) + + def __le__(self, other): + return (self == other or self < other) + + def __gt__(self, other): + return not self <= other + + def __ge__(self, other): + return not self < other + + def __str__(self): + net_str = "" + plus_count = 0 + if self == ord0: + return 'ord0' + for i in self.terms: + if plus_count: + net_str += " + " + + if i.exp == ord0: + net_str += str(i.mult) + elif i.exp == 1: + net_str += 'w' + elif len(i.exp.terms) > 1 or i.exp.is_limit_ordinal: + net_str += 'w**(%s)'%i.exp + else: + net_str += 'w**%s'%i.exp + + if not i.mult == 1 and not i.exp == ord0: + net_str += '*%s'%i.mult + + plus_count += 1 + return(net_str) + + __repr__ = __str__ + + def __add__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if other == ord0: + return self + a_terms = list(self.terms) + b_terms = list(other.terms) + r = len(a_terms) - 1 + b_exp = other.degree + while r >= 0 and a_terms[r].exp < b_exp: + r -= 1 + if r < 0: + terms = b_terms + elif a_terms[r].exp == b_exp: + sum_term = OmegaPower(b_exp, a_terms[r].mult + other.leading_term.mult) + terms = a_terms[:r] + [sum_term] + b_terms[1:] + else: + terms = a_terms[:r+1] + b_terms + return Ordinal(*terms) + + def __radd__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other + self + + def __mul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if ord0 in (self, other): + return ord0 + a_exp = self.degree + a_mult = self.leading_term.mult + summation = [] + if other.is_limit_ordinal: + for arg in other.terms: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + + else: + for arg in other.terms[:-1]: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + b_mult = other.trailing_term.mult + summation.append(OmegaPower(a_exp, a_mult*b_mult)) + summation += list(self.terms[1:]) + return Ordinal(*summation) + + def __rmul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other * self + + def __pow__(self, other): + if not self == omega: + return NotImplemented + return Ordinal(OmegaPower(other, 1)) + + +class OrdinalZero(Ordinal): + """The ordinal zero. + + OrdinalZero can be imported as ``ord0``. + """ + pass + + +class OrdinalOmega(Ordinal): + """The ordinal omega which forms the base of all ordinals in cantor normal form. + + OrdinalOmega can be imported as ``omega``. + + Examples + ======== + + >>> from sympy.sets.ordinals import omega + >>> omega + omega + w*2 + """ + def __new__(cls): + return Ordinal.__new__(cls) + + @property + def terms(self): + return (OmegaPower(1, 1),) + + +ord0 = OrdinalZero() +omega = OrdinalOmega() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/powerset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb3b41b9859281480bc9517a1cad0abe7a5683f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/powerset.py @@ -0,0 +1,119 @@ +from sympy.core.decorators import _sympifyit +from sympy.core.parameters import global_parameters +from sympy.core.logic import fuzzy_bool +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from .sets import Set, FiniteSet, SetKind + + +class PowerSet(Set): + r"""A symbolic object representing a power set. + + Parameters + ========== + + arg : Set + The set to take power of. + + evaluate : bool + The flag to control evaluation. + + If the evaluation is disabled for finite sets, it can take + advantage of using subset test as a membership test. + + Notes + ===== + + Power set `\mathcal{P}(S)` is defined as a set containing all the + subsets of `S`. + + If the set `S` is a finite set, its power set would have + `2^{\left| S \right|}` elements, where `\left| S \right|` denotes + the cardinality of `S`. + + Examples + ======== + + >>> from sympy import PowerSet, S, FiniteSet + + A power set of a finite set: + + >>> PowerSet(FiniteSet(1, 2, 3)) + PowerSet({1, 2, 3}) + + A power set of an empty set: + + >>> PowerSet(S.EmptySet) + PowerSet(EmptySet) + >>> PowerSet(PowerSet(S.EmptySet)) + PowerSet(PowerSet(EmptySet)) + + A power set of an infinite set: + + >>> PowerSet(S.Reals) + PowerSet(Reals) + + Evaluating the power set of a finite set to its explicit form: + + >>> PowerSet(FiniteSet(1, 2, 3)).rewrite(FiniteSet) + FiniteSet(EmptySet, {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + .. [2] https://en.wikipedia.org/wiki/Axiom_of_power_set + """ + def __new__(cls, arg, evaluate=None): + if evaluate is None: + evaluate=global_parameters.evaluate + + arg = _sympify(arg) + + if not isinstance(arg, Set): + raise ValueError('{} must be a set.'.format(arg)) + + return super().__new__(cls, arg) + + @property + def arg(self): + return self.args[0] + + def _eval_rewrite_as_FiniteSet(self, *args, **kwargs): + arg = self.arg + if arg.is_FiniteSet: + return arg.powerset() + return None + + @_sympifyit('other', NotImplemented) + def _contains(self, other): + if not isinstance(other, Set): + return None + + return fuzzy_bool(self.arg.is_superset(other)) + + def _eval_is_subset(self, other): + if isinstance(other, PowerSet): + return self.arg.is_subset(other.arg) + + def __len__(self): + return 2 ** len(self.arg) + + def __iter__(self): + found = [S.EmptySet] + yield S.EmptySet + + for x in self.arg: + temp = [] + x = FiniteSet(x) + for y in found: + new = x + y + yield new + temp.append(new) + found.extend(temp) + + @property + def kind(self): + return SetKind(self.arg.kind) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/setexpr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..94d77d5293617a620b70a945888987ce6cc61157 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/setexpr.py @@ -0,0 +1,97 @@ +from sympy.core import Expr +from sympy.core.decorators import call_highest_priority, _sympifyit +from .fancysets import ImageSet +from .sets import set_add, set_sub, set_mul, set_div, set_pow, set_function + + +class SetExpr(Expr): + """An expression that can take on values of a set. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet + >>> from sympy.sets.setexpr import SetExpr + + >>> a = SetExpr(Interval(0, 5)) + >>> b = SetExpr(FiniteSet(1, 10)) + >>> (a + b).set + Union(Interval(1, 6), Interval(10, 15)) + >>> (2*a + b).set + Interval(1, 20) + """ + _op_priority = 11.0 + + def __new__(cls, setarg): + return Expr.__new__(cls, setarg) + + set = property(lambda self: self.args[0]) + + def _latex(self, printer): + return r"SetExpr\left({}\right)".format(printer._print(self.set)) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__radd__') + def __add__(self, other): + return _setexpr_apply_operation(set_add, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__add__') + def __radd__(self, other): + return _setexpr_apply_operation(set_add, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return _setexpr_apply_operation(set_mul, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return _setexpr_apply_operation(set_mul, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rsub__') + def __sub__(self, other): + return _setexpr_apply_operation(set_sub, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__sub__') + def __rsub__(self, other): + return _setexpr_apply_operation(set_sub, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rpow__') + def __pow__(self, other): + return _setexpr_apply_operation(set_pow, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__pow__') + def __rpow__(self, other): + return _setexpr_apply_operation(set_pow, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return _setexpr_apply_operation(set_div, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + return _setexpr_apply_operation(set_div, other, self) + + def _eval_func(self, func): + # TODO: this could be implemented straight into `imageset`: + res = set_function(func, self.set) + if res is None: + return SetExpr(ImageSet(func, self.set)) + return SetExpr(res) + + +def _setexpr_apply_operation(op, x, y): + if isinstance(x, SetExpr): + x = x.set + if isinstance(y, SetExpr): + y = y.set + out = op(x, y) + return SetExpr(out) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/sets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..3c85ce87c515cfd4520dcc6b9265fe76d8c6163f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/sets.py @@ -0,0 +1,2804 @@ +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING, overload +from functools import reduce +from collections import defaultdict +from collections.abc import Mapping, Iterable +import inspect + +from sympy.core.kind import Kind, UndefinedKind, NumberKind +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.decorators import sympify_method_args, sympify_return +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import (FuzzyBool, fuzzy_bool, fuzzy_or, fuzzy_and, + fuzzy_not) +from sympy.core.numbers import Float, Integer +from sympy.core.operations import LatticeOp +from sympy.core.parameters import global_parameters +from sympy.core.relational import Eq, Ne, is_lt +from sympy.core.singleton import Singleton, S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol, Dummy, uniquely_named_symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.logic.boolalg import And, Or, Not, Xor, true, false +from sympy.utilities.decorator import deprecated +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import (iproduct, sift, roundrobin, iterable, + subsets) +from sympy.utilities.misc import func_name, filldedent + +from mpmath import mpi, mpf + +from mpmath.libmp.libmpf import prec_to_dps + + +tfn = defaultdict(lambda: None, { + True: S.true, + S.true: S.true, + False: S.false, + S.false: S.false}) + + +@sympify_method_args +class Set(Basic, EvalfMixin): + """ + The base class for any kind of set. + + Explanation + =========== + + This is not meant to be used directly as a container of items. It does not + behave like the builtin ``set``; see :class:`FiniteSet` for that. + + Real intervals are represented by the :class:`Interval` class and unions of + sets by the :class:`Union` class. The empty set is represented by the + :class:`EmptySet` class and available as a singleton as ``S.EmptySet``. + """ + + __slots__: tuple[()] = () + + is_number = False + is_iterable = False + is_interval = False + + is_FiniteSet = False + is_Interval = False + is_ProductSet = False + is_Union = False + is_Intersection: FuzzyBool = None + is_UniversalSet: FuzzyBool = None + is_Complement: FuzzyBool = None + is_ComplexRegion = False + + is_empty: FuzzyBool = None + is_finite_set: FuzzyBool = None + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return None + + if TYPE_CHECKING: + + def __new__(cls, *args: Basic | complex) -> Set: + ... + + @overload # type: ignore + def subs(self, arg1: Mapping[Basic | complex, Set | complex], arg2: None=None) -> Set: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Set | complex]], arg2: None=None, **kwargs: Any) -> Set: ... + @overload + def subs(self, arg1: Set | complex, arg2: Set | complex) -> Set: ... + @overload + def subs(self, arg1: Mapping[Basic | complex, Basic | complex], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Basic | complex]], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Basic | complex, arg2: Basic | complex, **kwargs: Any) -> Basic: ... + + def subs(self, arg1: Mapping[Basic | complex, Basic | complex] | Basic | complex, # type: ignore + arg2: Basic | complex | None = None, **kwargs: Any) -> Basic: + ... + + def simplify(self, **kwargs) -> Set: + assert False + + def evalf(self, n: int = 15, subs: dict[Basic, Basic | float] | None = None, + maxn: int = 100, chop: bool = False, strict: bool = False, + quad: str | None = None, verbose: bool = False) -> Set: + ... + + n = evalf + + @staticmethod + def _infimum_key(expr): + """ + Return infimum (if possible) else S.Infinity. + """ + try: + infimum = expr.inf + assert infimum.is_comparable + infimum = infimum.evalf() # issue #18505 + except (NotImplementedError, + AttributeError, AssertionError, ValueError): + infimum = S.Infinity + return infimum + + def union(self, other): + """ + Returns the union of ``self`` and ``other``. + + Examples + ======== + + As a shortcut it is possible to use the ``+`` operator: + + >>> from sympy import Interval, FiniteSet + >>> Interval(0, 1).union(Interval(2, 3)) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(0, 1) + Interval(2, 3) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(1, 2, True, True) + FiniteSet(2, 3) + Union({3}, Interval.Lopen(1, 2)) + + Similarly it is possible to use the ``-`` operator for set differences: + + >>> Interval(0, 2) - Interval(0, 1) + Interval.Lopen(1, 2) + >>> Interval(1, 3) - FiniteSet(2) + Union(Interval.Ropen(1, 2), Interval.Lopen(2, 3)) + + """ + return Union(self, other) + + def intersect(self, other): + """ + Returns the intersection of 'self' and 'other'. + + Examples + ======== + + >>> from sympy import Interval + + >>> Interval(1, 3).intersect(Interval(1, 2)) + Interval(1, 2) + + >>> from sympy import imageset, Lambda, symbols, S + >>> n, m = symbols('n m') + >>> a = imageset(Lambda(n, 2*n), S.Integers) + >>> a.intersect(imageset(Lambda(m, 2*m + 1), S.Integers)) + EmptySet + + """ + return Intersection(self, other) + + def intersection(self, other): + """ + Alias for :meth:`intersect()` + """ + return self.intersect(other) + + def is_disjoint(self, other): + """ + Returns True if ``self`` and ``other`` are disjoint. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 2).is_disjoint(Interval(1, 2)) + False + >>> Interval(0, 2).is_disjoint(Interval(3, 4)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Disjoint_sets + """ + return self.intersect(other) == S.EmptySet + + def isdisjoint(self, other): + """ + Alias for :meth:`is_disjoint()` + """ + return self.is_disjoint(other) + + def complement(self, universe): + r""" + The complement of 'self' w.r.t the given universe. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(0, 1).complement(S.Reals) + Union(Interval.open(-oo, 0), Interval.open(1, oo)) + + >>> Interval(0, 1).complement(S.UniversalSet) + Complement(UniversalSet, Interval(0, 1)) + + """ + return Complement(universe, self) + + def _complement(self, other): + # this behaves as other - self + if isinstance(self, ProductSet) and isinstance(other, ProductSet): + # If self and other are disjoint then other - self == self + if len(self.sets) != len(other.sets): + return other + + # There can be other ways to represent this but this gives: + # (A x B) - (C x D) = ((A - C) x B) U (A x (B - D)) + overlaps = [] + pairs = list(zip(self.sets, other.sets)) + for n in range(len(pairs)): + sets = (o if i != n else o-s for i, (s, o) in enumerate(pairs)) + overlaps.append(ProductSet(*sets)) + return Union(*overlaps) + + elif isinstance(other, Interval): + if isinstance(self, (Interval, FiniteSet)): + return Intersection(other, self.complement(S.Reals)) + + elif isinstance(other, Union): + return Union(*(o - self for o in other.args)) + + elif isinstance(other, Complement): + return Complement(other.args[0], Union(other.args[1], self), evaluate=False) + + elif other is S.EmptySet: + return S.EmptySet + + elif isinstance(other, FiniteSet): + sifted = sift(other, lambda x: fuzzy_bool(self.contains(x))) + # ignore those that are contained in self + return Union(FiniteSet(*(sifted[False])), + Complement(FiniteSet(*(sifted[None])), self, evaluate=False) + if sifted[None] else S.EmptySet) + + def symmetric_difference(self, other): + """ + Returns symmetric difference of ``self`` and ``other``. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(1, 3).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(3, oo)) + >>> Interval(1, 10).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(10, oo)) + + >>> from sympy import S, EmptySet + >>> S.Reals.symmetric_difference(EmptySet) + Reals + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + + """ + return SymmetricDifference(self, other) + + def _symmetric_difference(self, other): + return Union(Complement(self, other), Complement(other, self)) + + @property + def inf(self): + """ + The infimum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).inf + 0 + >>> Union(Interval(0, 1), Interval(2, 3)).inf + 0 + + """ + return self._inf + + @property + def _inf(self): + raise NotImplementedError("(%s)._inf" % self) + + @property + def sup(self): + """ + The supremum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).sup + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).sup + 3 + + """ + return self._sup + + @property + def _sup(self): + raise NotImplementedError("(%s)._sup" % self) + + def contains(self, other): + """ + Returns a SymPy value indicating whether ``other`` is contained + in ``self``: ``true`` if it is, ``false`` if it is not, else + an unevaluated ``Contains`` expression (or, as in the case of + ConditionSet and a union of FiniteSet/Intervals, an expression + indicating the conditions for containment). + + Examples + ======== + + >>> from sympy import Interval, S + >>> from sympy.abc import x + + >>> Interval(0, 1).contains(0.5) + True + + As a shortcut it is possible to use the ``in`` operator, but that + will raise an error unless an affirmative true or false is not + obtained. + + >>> Interval(0, 1).contains(x) + (0 <= x) & (x <= 1) + >>> x in Interval(0, 1) + Traceback (most recent call last): + ... + TypeError: did not evaluate to a bool: None + + The result of 'in' is a bool, not a SymPy value + + >>> 1 in Interval(0, 2) + True + >>> _ is S.true + False + """ + from .contains import Contains + other = sympify(other, strict=True) + + c = self._contains(other) + if isinstance(c, Contains): + return c + if c is None: + return Contains(other, self, evaluate=False) + b = tfn[c] + if b is None: + return c + return b + + def _contains(self, other): + """Test if ``other`` is an element of the set ``self``. + + This is an internal method that is expected to be overridden by + subclasses of ``Set`` and will be called by the public + :func:`Set.contains` method or the :class:`Contains` expression. + + Parameters + ========== + + other: Sympified :class:`Basic` instance + The object whose membership in ``self`` is to be tested. + + Returns + ======= + + Symbolic :class:`Boolean` or ``None``. + + A return value of ``None`` indicates that it is unknown whether + ``other`` is contained in ``self``. Returning ``None`` from here + ensures that ``self.contains(other)`` or ``Contains(self, other)`` will + return an unevaluated :class:`Contains` expression. + + If not ``None`` then the returned value is a :class:`Boolean` that is + logically equivalent to the statement that ``other`` is an element of + ``self``. Usually this would be either ``S.true`` or ``S.false`` but + not always. + """ + raise NotImplementedError(f"{type(self).__name__}._contains") + + def is_subset(self, other): + """ + Returns True if ``self`` is a subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) + False + + """ + if not isinstance(other, Set): + raise ValueError("Unknown argument '%s'" % other) + + # Handle the trivial cases + if self == other: + return True + is_empty = self.is_empty + if is_empty is True: + return True + elif fuzzy_not(is_empty) and other.is_empty: + return False + if self.is_finite_set is False and other.is_finite_set: + return False + + # Dispatch on subclass rules + ret = self._eval_is_subset(other) + if ret is not None: + return ret + ret = other._eval_is_superset(self) + if ret is not None: + return ret + + # Use pairwise rules from multiple dispatch + from sympy.sets.handlers.issubset import is_subset_sets + ret = is_subset_sets(self, other) + if ret is not None: + return ret + + # Fall back on computing the intersection + # XXX: We shouldn't do this. A query like this should be handled + # without evaluating new Set objects. It should be the other way round + # so that the intersect method uses is_subset for evaluation. + if self.intersect(other) == self: + return True + + def _eval_is_subset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + def _eval_is_superset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + # This should be deprecated: + def issubset(self, other): + """ + Alias for :meth:`is_subset()` + """ + return self.is_subset(other) + + def is_proper_subset(self, other): + """ + Returns True if ``self`` is a proper subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_proper_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_proper_subset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_subset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def is_superset(self, other): + """ + Returns True if ``self`` is a superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_superset(Interval(0, 1)) + False + >>> Interval(0, 1).is_superset(Interval(0, 1, left_open=True)) + True + + """ + if isinstance(other, Set): + return other.is_subset(self) + else: + raise ValueError("Unknown argument '%s'" % other) + + # This should be deprecated: + def issuperset(self, other): + """ + Alias for :meth:`is_superset()` + """ + return self.is_superset(other) + + def is_proper_superset(self, other): + """ + Returns True if ``self`` is a proper superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).is_proper_superset(Interval(0, 0.5)) + True + >>> Interval(0, 1).is_proper_superset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_superset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def _eval_powerset(self): + from .powerset import PowerSet + return PowerSet(self) + + def powerset(self): + """ + Find the Power set of ``self``. + + Examples + ======== + + >>> from sympy import EmptySet, FiniteSet, Interval + + A power set of an empty set: + + >>> A = EmptySet + >>> A.powerset() + {EmptySet} + + A power set of a finite set: + + >>> A = FiniteSet(1, 2) + >>> a, b, c = FiniteSet(1), FiniteSet(2), FiniteSet(1, 2) + >>> A.powerset() == FiniteSet(a, b, c, EmptySet) + True + + A power set of an interval: + + >>> Interval(1, 2).powerset() + PowerSet(Interval(1, 2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + """ + return self._eval_powerset() + + @property + def measure(self): + """ + The (Lebesgue) measure of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).measure + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).measure + 2 + + """ + return self._measure + + @property + def kind(self): + """ + The kind of a Set + + Explanation + =========== + + Any :class:`Set` will have kind :class:`SetKind` which is + parametrised by the kind of the elements of the set. For example + most sets are sets of numbers and will have kind + ``SetKind(NumberKind)``. If elements of sets are different in kind than + their kind will ``SetKind(UndefinedKind)``. See + :class:`sympy.core.kind.Kind` for an explanation of the kind system. + + Examples + ======== + + >>> from sympy import Interval, Matrix, FiniteSet, EmptySet, ProductSet, PowerSet + + >>> FiniteSet(Matrix([1, 2])).kind + SetKind(MatrixKind(NumberKind)) + + >>> Interval(1, 2).kind + SetKind(NumberKind) + + >>> EmptySet.kind + SetKind() + + A :class:`sympy.sets.powerset.PowerSet` is a set of sets: + + >>> PowerSet({1, 2, 3}).kind + SetKind(SetKind(NumberKind)) + + A :class:`ProductSet` represents the set of tuples of elements of + other sets. Its kind is :class:`sympy.core.containers.TupleKind` + parametrised by the kinds of the elements of those sets: + + >>> p = ProductSet(FiniteSet(1, 2), FiniteSet(3, 4)) + >>> list(p) + [(1, 3), (2, 3), (1, 4), (2, 4)] + >>> p.kind + SetKind(TupleKind(NumberKind, NumberKind)) + + When all elements of the set do not have same kind, the kind + will be returned as ``SetKind(UndefinedKind)``: + + >>> FiniteSet(0, Matrix([1, 2])).kind + SetKind(UndefinedKind) + + The kind of the elements of a set are given by the ``element_kind`` + attribute of ``SetKind``: + + >>> Interval(1, 2).kind.element_kind + NumberKind + + See Also + ======== + + NumberKind + sympy.core.kind.UndefinedKind + sympy.core.containers.TupleKind + MatrixKind + sympy.matrices.expressions.sets.MatrixSet + sympy.sets.conditionset.ConditionSet + Rationals + Naturals + Integers + sympy.sets.fancysets.ImageSet + sympy.sets.fancysets.Range + sympy.sets.fancysets.ComplexRegion + sympy.sets.powerset.PowerSet + sympy.sets.sets.ProductSet + sympy.sets.sets.Interval + sympy.sets.sets.Union + sympy.sets.sets.Intersection + sympy.sets.sets.Complement + sympy.sets.sets.EmptySet + sympy.sets.sets.UniversalSet + sympy.sets.sets.FiniteSet + sympy.sets.sets.SymmetricDifference + sympy.sets.sets.DisjointUnion + """ + return self._kind() + + @property + def boundary(self): + """ + The boundary or frontier of a set. + + Explanation + =========== + + A point x is on the boundary of a set S if + + 1. x is in the closure of S. + I.e. Every neighborhood of x contains a point in S. + 2. x is not in the interior of S. + I.e. There does not exist an open set centered on x contained + entirely within S. + + There are the points on the outer rim of S. If S is open then these + points need not actually be contained within S. + + For example, the boundary of an interval is its start and end points. + This is true regardless of whether or not the interval is open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).boundary + {0, 1} + >>> Interval(0, 1, True, False).boundary + {0, 1} + """ + return self._boundary + + @property + def is_open(self): + """ + Property method to check whether a set is open. + + Explanation + =========== + + A set is open if and only if it has an empty intersection with its + boundary. In particular, a subset A of the reals is open if and only + if each one of its points is contained in an open interval that is a + subset of A. + + Examples + ======== + >>> from sympy import S + >>> S.Reals.is_open + True + >>> S.Rationals.is_open + False + """ + return Intersection(self, self.boundary).is_empty + + @property + def is_closed(self): + """ + A property method to check whether a set is closed. + + Explanation + =========== + + A set is closed if its complement is an open set. The closedness of a + subset of the reals is determined with respect to R and its standard + topology. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).is_closed + True + """ + return self.boundary.is_subset(self) + + @property + def closure(self): + """ + Property method which returns the closure of a set. + The closure is defined as the union of the set itself and its + boundary. + + Examples + ======== + >>> from sympy import S, Interval + >>> S.Reals.closure + Reals + >>> Interval(0, 1).closure + Interval(0, 1) + """ + return self + self.boundary + + @property + def interior(self): + """ + Property method which returns the interior of a set. + The interior of a set S consists all points of S that do not + belong to the boundary of S. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).interior + Interval.open(0, 1) + >>> Interval(0, 1).boundary.interior + EmptySet + """ + return self - self.boundary + + @property + def _boundary(self): + raise NotImplementedError() + + @property + def _measure(self): + raise NotImplementedError("(%s)._measure" % self) + + def _kind(self): + return SetKind(UndefinedKind) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return self.func(*[arg.evalf(n=dps) for arg in self.args]) + + @sympify_return([('other', 'Set')], NotImplemented) + def __add__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __or__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __and__(self, other): + return self.intersect(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __mul__(self, other): + return ProductSet(self, other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __xor__(self, other): + return SymmetricDifference(self, other) + + @sympify_return([('exp', Expr)], NotImplemented) + def __pow__(self, exp): + if not (exp.is_Integer and exp >= 0): + raise ValueError("%s: Exponent must be a positive Integer" % exp) + return ProductSet(*[self]*exp) + + @sympify_return([('other', 'Set')], NotImplemented) + def __sub__(self, other): + return Complement(self, other) + + def __contains__(self, other): + other = _sympify(other) + c = self._contains(other) + b = tfn[c] + if b is None: + # x in y must evaluate to T or F; to entertain a None + # result with Set use y.contains(x) + raise TypeError('did not evaluate to a bool: %r' % c) + return b + + +class ProductSet(Set): + """ + Represents a Cartesian Product of Sets. + + Explanation + =========== + + Returns a Cartesian product given several sets as either an iterable + or individual arguments. + + Can use ``*`` operator on any sets for convenient shorthand. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet, ProductSet + >>> I = Interval(0, 5); S = FiniteSet(1, 2, 3) + >>> ProductSet(I, S) + ProductSet(Interval(0, 5), {1, 2, 3}) + + >>> (2, 2) in ProductSet(I, S) + True + + >>> Interval(0, 1) * Interval(0, 1) # The unit square + ProductSet(Interval(0, 1), Interval(0, 1)) + + >>> coin = FiniteSet('H', 'T') + >>> set(coin**2) + {(H, H), (H, T), (T, H), (T, T)} + + The Cartesian product is not commutative or associative e.g.: + + >>> I*S == S*I + False + >>> (I*I)*I == I*(I*I) + False + + Notes + ===== + + - Passes most operations down to the argument sets + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cartesian_product + """ + is_ProductSet = True + + def __new__(cls, *sets, **assumptions): + if len(sets) == 1 and iterable(sets[0]) and not isinstance(sets[0], (Set, set)): + sympy_deprecation_warning( + """ +ProductSet(iterable) is deprecated. Use ProductSet(*iterable) instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-productset-iterable", + ) + sets = tuple(sets[0]) + + sets = [sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Arguments to ProductSet should be of type Set") + + # Nullary product of sets is *not* the empty set + if len(sets) == 0: + return FiniteSet(()) + + if S.EmptySet in sets: + return S.EmptySet + + return Basic.__new__(cls, *sets, **assumptions) + + @property + def sets(self): + return self.args + + def flatten(self): + def _flatten(sets): + for s in sets: + if s.is_ProductSet: + yield from _flatten(s.sets) + else: + yield s + return ProductSet(*_flatten(self.sets)) + + + + def _contains(self, element): + """ + ``in`` operator for ProductSets. + + Examples + ======== + + >>> from sympy import Interval + >>> (2, 3) in Interval(0, 5) * Interval(0, 5) + True + + >>> (10, 10) in Interval(0, 5) * Interval(0, 5) + False + + Passes operation on to constituent sets + """ + if element.is_Symbol: + return None + + if not isinstance(element, Tuple) or len(element) != len(self.sets): + return S.false + + return And(*[s.contains(e) for s, e in zip(self.sets, element)]) + + def as_relational(self, *symbols): + symbols = [_sympify(s) for s in symbols] + if len(symbols) != len(self.sets) or not all( + i.is_Symbol for i in symbols): + raise ValueError( + 'number of symbols must match the number of sets') + return And(*[s.as_relational(i) for s, i in zip(self.sets, symbols)]) + + @property + def _boundary(self): + return Union(*(ProductSet(*(b + b.boundary if i != j else b.boundary + for j, b in enumerate(self.sets))) + for i, a in enumerate(self.sets))) + + @property + def is_iterable(self): + """ + A property method which tests whether a set is iterable or not. + Returns True if set is iterable, otherwise returns False. + + Examples + ======== + + >>> from sympy import FiniteSet, Interval + >>> I = Interval(0, 1) + >>> A = FiniteSet(1, 2, 3, 4, 5) + >>> I.is_iterable + False + >>> A.is_iterable + True + + """ + return all(set.is_iterable for set in self.sets) + + def __iter__(self): + """ + A method which implements is_iterable property method. + If self.is_iterable returns True (both constituent sets are iterable), + then return the Cartesian Product. Otherwise, raise TypeError. + """ + return iproduct(*self.sets) + + @property + def is_empty(self): + return fuzzy_or(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def _measure(self): + measure = 1 + for s in self.sets: + measure *= s.measure + return measure + + def _kind(self): + return SetKind(TupleKind(*(i.kind.element_kind for i in self.args))) + + def __len__(self): + return reduce(lambda a, b: a*b, (len(s) for s in self.args)) + + def __bool__(self): + return all(self.sets) + + +class Interval(Set): + """ + Represents a real interval as a Set. + + Usage: + Returns an interval with end points ``start`` and ``end``. + + For ``left_open=True`` (default ``left_open`` is ``False``) the interval + will be open on the left. Similarly, for ``right_open=True`` the interval + will be open on the right. + + Examples + ======== + + >>> from sympy import Symbol, Interval + >>> Interval(0, 1) + Interval(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Lopen(0, 1) + Interval.Lopen(0, 1) + >>> Interval.open(0, 1) + Interval.open(0, 1) + + >>> a = Symbol('a', real=True) + >>> Interval(0, a) + Interval(0, a) + + Notes + ===== + - Only real end points are supported + - ``Interval(a, b)`` with $a > b$ will return the empty set + - Use the ``evalf()`` method to turn an Interval into an mpmath + ``mpi`` interval instance + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Interval_%28mathematics%29 + """ + is_Interval = True + + def __new__(cls, start, end, left_open=False, right_open=False): + + start = _sympify(start) + end = _sympify(end) + left_open = _sympify(left_open) + right_open = _sympify(right_open) + + if not all(isinstance(a, (type(true), type(false))) + for a in [left_open, right_open]): + raise NotImplementedError( + "left_open and right_open can have only true/false values, " + "got %s and %s" % (left_open, right_open)) + + # Only allow real intervals + if fuzzy_not(fuzzy_and(i.is_extended_real for i in (start, end, end-start))): + raise ValueError("Non-real intervals are not supported") + + # evaluate if possible + if is_lt(end, start): + return S.EmptySet + elif (end - start).is_negative: + return S.EmptySet + + if end == start and (left_open or right_open): + return S.EmptySet + if end == start and not (left_open or right_open): + if start is S.Infinity or start is S.NegativeInfinity: + return S.EmptySet + return FiniteSet(end) + + # Make sure infinite interval end points are open. + if start is S.NegativeInfinity: + left_open = true + if end is S.Infinity: + right_open = true + if start == S.Infinity or end == S.NegativeInfinity: + return S.EmptySet + + return Basic.__new__(cls, start, end, left_open, right_open) + + @property + def start(self): + """ + The left end point of the interval. + + This property takes the same value as the ``inf`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).start + 0 + + """ + return self._args[0] + + @property + def end(self): + """ + The right end point of the interval. + + This property takes the same value as the ``sup`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).end + 1 + + """ + return self._args[1] + + @property + def left_open(self): + """ + True if interval is left-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, left_open=True).left_open + True + >>> Interval(0, 1, left_open=False).left_open + False + + """ + return self._args[2] + + @property + def right_open(self): + """ + True if interval is right-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, right_open=True).right_open + True + >>> Interval(0, 1, right_open=False).right_open + False + + """ + return self._args[3] + + @classmethod + def open(cls, a, b): + """Return an interval including neither boundary.""" + return cls(a, b, True, True) + + @classmethod + def Lopen(cls, a, b): + """Return an interval not including the left boundary.""" + return cls(a, b, True, False) + + @classmethod + def Ropen(cls, a, b): + """Return an interval not including the right boundary.""" + return cls(a, b, False, True) + + @property + def _inf(self): + return self.start + + @property + def _sup(self): + return self.end + + @property + def left(self): + return self.start + + @property + def right(self): + return self.end + + @property + def is_empty(self): + if self.left_open or self.right_open: + cond = self.start >= self.end # One/both bounds open + else: + cond = self.start > self.end # Both bounds closed + return fuzzy_bool(cond) + + @property + def is_finite_set(self): + return self.measure.is_zero + + def _complement(self, other): + if other == S.Reals: + a = Interval(S.NegativeInfinity, self.start, + True, not self.left_open) + b = Interval(self.end, S.Infinity, not self.right_open, True) + return Union(a, b) + + if isinstance(other, FiniteSet): + nums = [m for m in other.args if m.is_number] + if nums == []: + return None + + return Set._complement(self, other) + + @property + def _boundary(self): + finite_points = [p for p in (self.start, self.end) + if abs(p) != S.Infinity] + return FiniteSet(*finite_points) + + def _contains(self, other): + if (not isinstance(other, Expr) or other is S.NaN + or other.is_real is False or other.has(S.ComplexInfinity)): + # if an expression has zoo it will be zoo or nan + # and neither of those is real + return false + + if self.start is S.NegativeInfinity and self.end is S.Infinity: + if other.is_real is not None: + return tfn[other.is_real] + + d = Dummy() + return self.as_relational(d).subs(d, other) + + def as_relational(self, x): + """Rewrite an interval in terms of inequalities and logic operators.""" + x = sympify(x) + if self.right_open: + right = x < self.end + else: + right = x <= self.end + if self.left_open: + left = self.start < x + else: + left = self.start <= x + return And(left, right) + + @property + def _measure(self): + return self.end - self.start + + def _kind(self): + return SetKind(NumberKind) + + def to_mpi(self, prec=53): + return mpi(mpf(self.start._eval_evalf(prec)), + mpf(self.end._eval_evalf(prec))) + + def _eval_evalf(self, prec): + return Interval(self.left._evalf(prec), self.right._evalf(prec), + left_open=self.left_open, right_open=self.right_open) + + def _is_comparable(self, other): + is_comparable = self.start.is_comparable + is_comparable &= self.end.is_comparable + is_comparable &= other.start.is_comparable + is_comparable &= other.end.is_comparable + + return is_comparable + + @property + def is_left_unbounded(self): + """Return ``True`` if the left endpoint is negative infinity. """ + return self.left is S.NegativeInfinity or self.left == Float("-inf") + + @property + def is_right_unbounded(self): + """Return ``True`` if the right endpoint is positive infinity. """ + return self.right is S.Infinity or self.right == Float("+inf") + + def _eval_Eq(self, other): + if not isinstance(other, Interval): + if isinstance(other, FiniteSet): + return false + elif isinstance(other, Set): + return None + return false + + +class Union(Set, LatticeOp): + """ + Represents a union of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Union, Interval + >>> Union(Interval(1, 2), Interval(3, 4)) + Union(Interval(1, 2), Interval(3, 4)) + + The Union constructor will always try to merge overlapping intervals, + if possible. For example: + + >>> Union(Interval(1, 2), Interval(2, 3)) + Interval(1, 3) + + See Also + ======== + + Intersection + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Union_%28set_theory%29 + """ + is_Union = True + + @property + def identity(self): + return S.EmptySet + + @property + def zero(self): + return S.UniversalSet + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + + # flatten inputs to merge intersections and iterables + args = _sympify(args) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_union(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + def _complement(self, universe): + # DeMorgan's Law + return Intersection(s.complement(universe) for s in self.args) + + @property + def _inf(self): + # We use Min so that sup is meaningful in combination with symbolic + # interval end points. + return Min(*[set.inf for set in self.args]) + + @property + def _sup(self): + # We use Max so that sup is meaningful in combination with symbolic + # end points. + return Max(*[set.sup for set in self.args]) + + @property + def is_empty(self): + return fuzzy_and(set.is_empty for set in self.args) + + @property + def is_finite_set(self): + return fuzzy_and(set.is_finite_set for set in self.args) + + @property + def _measure(self): + # Measure of a union is the sum of the measures of the sets minus + # the sum of their pairwise intersections plus the sum of their + # triple-wise intersections minus ... etc... + + # Sets is a collection of intersections and a set of elementary + # sets which made up those intersections (called "sos" for set of sets) + # An example element might of this list might be: + # ( {A,B,C}, A.intersect(B).intersect(C) ) + + # Start with just elementary sets ( ({A}, A), ({B}, B), ... ) + # Then get and subtract ( ({A,B}, (A int B), ... ) while non-zero + sets = [(FiniteSet(s), s) for s in self.args] + measure = 0 + parity = 1 + while sets: + # Add up the measure of these sets and add or subtract it to total + measure += parity * sum(inter.measure for sos, inter in sets) + + # For each intersection in sets, compute the intersection with every + # other set not already part of the intersection. + sets = ((sos + FiniteSet(newset), newset.intersect(intersection)) + for sos, intersection in sets for newset in self.args + if newset not in sos) + + # Clear out sets with no measure + sets = [(sos, inter) for sos, inter in sets if inter.measure != 0] + + # Clear out duplicates + sos_list = [] + sets_list = [] + for _set in sets: + if _set[0] in sos_list: + continue + else: + sos_list.append(_set[0]) + sets_list.append(_set) + sets = sets_list + + # Flip Parity - next time subtract/add if we added/subtracted here + parity *= -1 + return measure + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.EmptySet) + if not kinds: + return SetKind() + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind(UndefinedKind) + + @property + def _boundary(self): + def boundary_of_set(i): + """ The boundary of set i minus interior of all other sets """ + b = self.args[i].boundary + for j, a in enumerate(self.args): + if j != i: + b = b - a.interior + return b + return Union(*map(boundary_of_set, range(len(self.args)))) + + def _contains(self, other): + return Or(*[s.contains(other) for s in self.args]) + + def is_subset(self, other): + return fuzzy_and(s.is_subset(other) for s in self.args) + + def as_relational(self, symbol): + """Rewrite a Union in terms of equalities and logic operators. """ + if (len(self.args) == 2 and + all(isinstance(i, Interval) for i in self.args)): + # optimization to give 3 args as (x > 1) & (x < 5) & Ne(x, 3) + # instead of as 4, ((1 <= x) & (x < 3)) | ((x <= 5) & (3 < x)) + # XXX: This should be ideally be improved to handle any number of + # intervals and also not to assume that the intervals are in any + # particular sorted order. + a, b = self.args + if a.sup == b.inf and a.right_open and b.left_open: + mincond = symbol > a.inf if a.left_open else symbol >= a.inf + maxcond = symbol < b.sup if b.right_open else symbol <= b.sup + necond = Ne(symbol, a.sup) + return And(necond, mincond, maxcond) + return Or(*[i.as_relational(symbol) for i in self.args]) + + @property + def is_iterable(self): + return all(arg.is_iterable for arg in self.args) + + def __iter__(self): + return roundrobin(*(iter(arg) for arg in self.args)) + + +class Intersection(Set, LatticeOp): + """ + Represents an intersection of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Intersection, Interval + >>> Intersection(Interval(1, 3), Interval(2, 4)) + Interval(2, 3) + + We often use the .intersect method + + >>> Interval(1,3).intersect(Interval(2,4)) + Interval(2, 3) + + See Also + ======== + + Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Intersection_%28set_theory%29 + """ + is_Intersection = True + + @property + def identity(self): + return S.UniversalSet + + @property + def zero(self): + return S.EmptySet + + def __new__(cls, *args , evaluate=None): + if evaluate is None: + evaluate = global_parameters.evaluate + + # flatten inputs to merge intersections and iterables + args = list(ordered(set(_sympify(args)))) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_intersection(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + @property + def is_iterable(self): + return any(arg.is_iterable for arg in self.args) + + @property + def is_finite_set(self): + if fuzzy_or(arg.is_finite_set for arg in self.args): + return True + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.UniversalSet) + if not kinds: + return SetKind(UndefinedKind) + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind() + + @property + def _inf(self): + raise NotImplementedError() + + @property + def _sup(self): + raise NotImplementedError() + + def _contains(self, other): + return And(*[set.contains(other) for set in self.args]) + + def __iter__(self): + sets_sift = sift(self.args, lambda x: x.is_iterable) + + completed = False + candidates = sets_sift[True] + sets_sift[None] + + finite_candidates, others = [], [] + for candidate in candidates: + length = None + try: + length = len(candidate) + except TypeError: + others.append(candidate) + + if length is not None: + finite_candidates.append(candidate) + finite_candidates.sort(key=len) + + for s in finite_candidates + others: + other_sets = set(self.args) - {s} + other = Intersection(*other_sets, evaluate=False) + completed = True + for x in s: + try: + if x in other: + yield x + except TypeError: + completed = False + if completed: + return + + if not completed: + if not candidates: + raise TypeError("None of the constituent sets are iterable") + raise TypeError( + "The computation had not completed because of the " + "undecidable set membership is found in every candidates.") + + @staticmethod + def _handle_finite_sets(args): + '''Simplify intersection of one or more FiniteSets and other sets''' + + # First separate the FiniteSets from the others + fs_args, others = sift(args, lambda x: x.is_FiniteSet, binary=True) + + # Let the caller handle intersection of non-FiniteSets + if not fs_args: + return + + # Convert to Python sets and build the set of all elements + fs_sets = [set(fs) for fs in fs_args] + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Extract elements that are definitely in or definitely not in the + # intersection. Here we check contains for all of args. + definite = set() + for e in all_elements: + inall = fuzzy_and(s.contains(e) for s in args) + if inall is True: + definite.add(e) + if inall is not None: + for s in fs_sets: + s.discard(e) + + # At this point all elements in all of fs_sets are possibly in the + # intersection. In some cases this is because they are definitely in + # the intersection of the finite sets but it's not clear if they are + # members of others. We might have {m, n}, {m}, and Reals where we + # don't know if m or n is real. We want to remove n here but it is + # possibly in because it might be equal to m. So what we do now is + # extract the elements that are definitely in the remaining finite + # sets iteratively until we end up with {n}, {}. At that point if we + # get any empty set all remaining elements are discarded. + + fs_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Need fuzzy containment testing + fs_symsets = [FiniteSet(*s) for s in fs_sets] + + while fs_elements: + for e in fs_elements: + infs = fuzzy_and(s.contains(e) for s in fs_symsets) + if infs is True: + definite.add(e) + if infs is not None: + for n, s in enumerate(fs_sets): + # Update Python set and FiniteSet + if e in s: + s.remove(e) + fs_symsets[n] = FiniteSet(*s) + fs_elements.remove(e) + break + # If we completed the for loop without removing anything we are + # done so quit the outer while loop + else: + break + + # If any of the sets of remainder elements is empty then we discard + # all of them for the intersection. + if not all(fs_sets): + fs_sets = [set()] + + # Here we fold back the definitely included elements into each fs. + # Since they are definitely included they must have been members of + # each FiniteSet to begin with. We could instead fold these in with a + # Union at the end to get e.g. {3}|({x}&{y}) rather than {3,x}&{3,y}. + if definite: + fs_sets = [fs | definite for fs in fs_sets] + + if fs_sets == [set()]: + return S.EmptySet + + sets = [FiniteSet(*s) for s in fs_sets] + + # Any set in others is redundant if it contains all the elements that + # are in the finite sets so we don't need it in the Intersection + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + is_redundant = lambda o: all(fuzzy_bool(o.contains(e)) for e in all_elements) + others = [o for o in others if not is_redundant(o)] + + if others: + rest = Intersection(*others) + # XXX: Maybe this shortcut should be at the beginning. For large + # FiniteSets it could much more efficient to process the other + # sets first... + if rest is S.EmptySet: + return S.EmptySet + # Flatten the Intersection + if rest.is_Intersection: + sets.extend(rest.args) + else: + sets.append(rest) + + if len(sets) == 1: + return sets[0] + else: + return Intersection(*sets, evaluate=False) + + def as_relational(self, symbol): + """Rewrite an Intersection in terms of equalities and logic operators""" + return And(*[set.as_relational(symbol) for set in self.args]) + + +class Complement(Set): + r"""Represents the set difference or relative complement of a set with + another set. + + $$A - B = \{x \in A \mid x \notin B\}$$ + + + Examples + ======== + + >>> from sympy import Complement, FiniteSet + >>> Complement(FiniteSet(0, 1, 2), FiniteSet(1)) + {0, 2} + + See Also + ========= + + Intersection, Union + + References + ========== + + .. [1] https://mathworld.wolfram.com/ComplementSet.html + """ + + is_Complement = True + + def __new__(cls, a, b, evaluate=True): + a, b = map(_sympify, (a, b)) + if evaluate: + return Complement.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + """ + Simplify a :class:`Complement`. + + """ + if B == S.UniversalSet or A.is_subset(B): + return S.EmptySet + + if isinstance(B, Union): + return Intersection(*(s.complement(A) for s in B.args)) + + result = B._complement(A) + if result is not None: + return result + else: + return Complement(A, B, evaluate=False) + + def _contains(self, other): + A = self.args[0] + B = self.args[1] + return And(A.contains(other), Not(B.contains(other))) + + def as_relational(self, symbol): + """Rewrite a complement in terms of equalities and logic + operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = Not(B.as_relational(symbol)) + + return And(A_rel, B_rel) + + def _kind(self): + return self.args[0].kind + + @property + def is_iterable(self): + if self.args[0].is_iterable: + return True + + @property + def is_finite_set(self): + A, B = self.args + a_finite = A.is_finite_set + if a_finite is True: + return True + elif a_finite is False and B.is_finite_set: + return False + + def __iter__(self): + A, B = self.args + for a in A: + if a not in B: + yield a + else: + continue + + +class EmptySet(Set, metaclass=Singleton): + """ + Represents the empty set. The empty set is available as a singleton + as ``S.EmptySet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.EmptySet + EmptySet + + >>> Interval(1, 2).intersect(S.EmptySet) + EmptySet + + See Also + ======== + + UniversalSet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Empty_set + """ + is_empty = True + is_finite_set = True + is_FiniteSet = True + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return True + + @property + def _measure(self): + return 0 + + def _contains(self, other): + return false + + def as_relational(self, symbol): + return false + + def __len__(self): + return 0 + + def __iter__(self): + return iter([]) + + def _eval_powerset(self): + return FiniteSet(self) + + @property + def _boundary(self): + return self + + def _complement(self, other): + return other + + def _kind(self): + return SetKind() + + def _symmetric_difference(self, other): + return other + + +class UniversalSet(Set, metaclass=Singleton): + """ + Represents the set of all things. + The universal set is available as a singleton as ``S.UniversalSet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.UniversalSet + UniversalSet + + >>> Interval(1, 2).intersect(S.UniversalSet) + Interval(1, 2) + + See Also + ======== + + EmptySet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Universal_set + """ + + is_UniversalSet = True + is_empty = False + is_finite_set = False + + def _complement(self, other): + return S.EmptySet + + def _symmetric_difference(self, other): + return other + + @property + def _measure(self): + return S.Infinity + + def _kind(self): + return SetKind(UndefinedKind) + + def _contains(self, other): + return true + + def as_relational(self, symbol): + return true + + @property + def _boundary(self): + return S.EmptySet + + +class FiniteSet(Set): + """ + Represents a finite set of Sympy expressions. + + Examples + ======== + + >>> from sympy import FiniteSet, Symbol, Interval, Naturals0 + >>> FiniteSet(1, 2, 3, 4) + {1, 2, 3, 4} + >>> 3 in FiniteSet(1, 2, 3, 4) + True + >>> FiniteSet(1, (1, 2), Symbol('x')) + {1, x, (1, 2)} + >>> FiniteSet(Interval(1, 2), Naturals0, {1, 2}) + FiniteSet({1, 2}, Interval(1, 2), Naturals0) + >>> members = [1, 2, 3, 4] + >>> f = FiniteSet(*members) + >>> f + {1, 2, 3, 4} + >>> f - FiniteSet(2) + {1, 3, 4} + >>> f + FiniteSet(2, 5) + {1, 2, 3, 4, 5} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Finite_set + """ + is_FiniteSet = True + is_iterable = True + is_empty = False + is_finite_set = True + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + if evaluate: + args = list(map(sympify, args)) + + if len(args) == 0: + return S.EmptySet + else: + args = list(map(sympify, args)) + + # keep the form of the first canonical arg + dargs = {} + for i in reversed(list(ordered(args))): + if i.is_Symbol: + dargs[i] = i + else: + try: + dargs[i.as_dummy()] = i + except TypeError: + # e.g. i = class without args like `Interval` + dargs[i] = i + _args_set = set(dargs.values()) + args = list(ordered(_args_set, Set._infimum_key)) + obj = Basic.__new__(cls, *args) + obj._args_set = _args_set + return obj + + + def __iter__(self): + return iter(self.args) + + def _complement(self, other): + if isinstance(other, Interval): + # Splitting in sub-intervals is only done for S.Reals; + # other cases that need splitting will first pass through + # Set._complement(). + nums, syms = [], [] + for m in self.args: + if m.is_number and m.is_real: + nums.append(m) + elif m.is_real == False: + pass # drop non-reals + else: + syms.append(m) # various symbolic expressions + if other == S.Reals and nums != []: + nums.sort() + intervals = [] # Build up a list of intervals between the elements + intervals += [Interval(S.NegativeInfinity, nums[0], True, True)] + for a, b in zip(nums[:-1], nums[1:]): + intervals.append(Interval(a, b, True, True)) # both open + intervals.append(Interval(nums[-1], S.Infinity, True, True)) + if syms != []: + return Complement(Union(*intervals, evaluate=False), + FiniteSet(*syms), evaluate=False) + else: + return Union(*intervals, evaluate=False) + elif nums == []: # no splitting necessary or possible: + if syms: + return Complement(other, FiniteSet(*syms), evaluate=False) + else: + return other + + elif isinstance(other, FiniteSet): + unk = [] + for i in self: + c = sympify(other.contains(i)) + if c is not S.true and c is not S.false: + unk.append(i) + unk = FiniteSet(*unk) + if unk == self: + return + not_true = [] + for i in other: + c = sympify(self.contains(i)) + if c is not S.true: + not_true.append(i) + return Complement(FiniteSet(*not_true), unk) + + return Set._complement(self, other) + + def _contains(self, other): + """ + Tests whether an element, other, is in the set. + + Explanation + =========== + + The actual test is for mathematical equality (as opposed to + syntactical equality). In the worst case all elements of the + set must be checked. + + Examples + ======== + + >>> from sympy import FiniteSet + >>> 1 in FiniteSet(1, 2) + True + >>> 5 in FiniteSet(1, 2) + False + + """ + if other in self._args_set: + return S.true + else: + # evaluate=True is needed to override evaluate=False context; + # we need Eq to do the evaluation + return Or(*[Eq(e, other, evaluate=True) for e in self.args]) + + def _eval_is_subset(self, other): + return fuzzy_and(other._contains(e) for e in self.args) + + @property + def _boundary(self): + return self + + @property + def _inf(self): + return Min(*self) + + @property + def _sup(self): + return Max(*self) + + @property + def measure(self): + return 0 + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return SetKind(self.args[0].kind) + else: + return SetKind(UndefinedKind) + + def __len__(self): + return len(self.args) + + def as_relational(self, symbol): + """Rewrite a FiniteSet in terms of equalities and logic operators. """ + return Or(*[Eq(symbol, elem) for elem in self]) + + def compare(self, other): + return (hash(self) - hash(other)) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return FiniteSet(*[elem.evalf(n=dps) for elem in self]) + + def _eval_simplify(self, **kwargs): + from sympy.simplify import simplify + return FiniteSet(*[simplify(elem, **kwargs) for elem in self]) + + @property + def _sorted_args(self): + return self.args + + def _eval_powerset(self): + return self.func(*[self.func(*s) for s in subsets(self.args)]) + + def _eval_rewrite_as_PowerSet(self, *args, **kwargs): + """Rewriting method for a finite set to a power set.""" + from .powerset import PowerSet + + is2pow = lambda n: bool(n and not n & (n - 1)) + if not is2pow(len(self)): + return None + + fs_test = lambda arg: isinstance(arg, Set) and arg.is_FiniteSet + if not all(fs_test(arg) for arg in args): + return None + + biggest = max(args, key=len) + for arg in subsets(biggest.args): + arg_set = FiniteSet(*arg) + if arg_set not in args: + return None + return PowerSet(biggest) + + def __ge__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return other.is_subset(self) + + def __gt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_superset(other) + + def __le__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_subset(other) + + def __lt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_subset(other) + + def __eq__(self, other): + if isinstance(other, (set, frozenset)): + return self._args_set == other + return super().__eq__(other) + + __hash__ : Callable[[Basic], Any] = Basic.__hash__ + +_sympy_converter[set] = lambda x: FiniteSet(*x) +_sympy_converter[frozenset] = lambda x: FiniteSet(*x) + + +class SymmetricDifference(Set): + """Represents the set of elements which are in either of the + sets and not in their intersection. + + Examples + ======== + + >>> from sympy import SymmetricDifference, FiniteSet + >>> SymmetricDifference(FiniteSet(1, 2, 3), FiniteSet(3, 4, 5)) + {1, 2, 4, 5} + + See Also + ======== + + Complement, Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + """ + + is_SymmetricDifference = True + + def __new__(cls, a, b, evaluate=True): + if evaluate: + return SymmetricDifference.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + result = B._symmetric_difference(A) + if result is not None: + return result + else: + return SymmetricDifference(A, B, evaluate=False) + + def as_relational(self, symbol): + """Rewrite a symmetric_difference in terms of equalities and + logic operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = B.as_relational(symbol) + + return Xor(A_rel, B_rel) + + @property + def is_iterable(self): + if all(arg.is_iterable for arg in self.args): + return True + + def __iter__(self): + + args = self.args + union = roundrobin(*(iter(arg) for arg in args)) + + for item in union: + count = 0 + for s in args: + if item in s: + count += 1 + + if count % 2 == 1: + yield item + + + +class DisjointUnion(Set): + """ Represents the disjoint union (also known as the external disjoint union) + of a finite number of sets. + + Examples + ======== + + >>> from sympy import DisjointUnion, FiniteSet, Interval, Union, Symbol + >>> A = FiniteSet(1, 2, 3) + >>> B = Interval(0, 5) + >>> DisjointUnion(A, B) + DisjointUnion({1, 2, 3}, Interval(0, 5)) + >>> DisjointUnion(A, B).rewrite(Union) + Union(ProductSet({1, 2, 3}, {0}), ProductSet(Interval(0, 5), {1})) + >>> C = FiniteSet(Symbol('x'), Symbol('y'), Symbol('z')) + >>> DisjointUnion(C, C) + DisjointUnion({x, y, z}, {x, y, z}) + >>> DisjointUnion(C, C).rewrite(Union) + ProductSet({x, y, z}, {0, 1}) + + References + ========== + + https://en.wikipedia.org/wiki/Disjoint_union + """ + + def __new__(cls, *sets): + dj_collection = [] + for set_i in sets: + if isinstance(set_i, Set): + dj_collection.append(set_i) + else: + raise TypeError("Invalid input: '%s', input args \ + to DisjointUnion must be Sets" % set_i) + obj = Basic.__new__(cls, *dj_collection) + return obj + + @property + def sets(self): + return self.args + + @property + def is_empty(self): + return fuzzy_and(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def is_iterable(self): + if self.is_empty: + return False + iter_flag = True + for set_i in self.sets: + if not set_i.is_empty: + iter_flag = iter_flag and set_i.is_iterable + return iter_flag + + def _eval_rewrite_as_Union(self, *sets, **kwargs): + """ + Rewrites the disjoint union as the union of (``set`` x {``i``}) + where ``set`` is the element in ``sets`` at index = ``i`` + """ + + dj_union = S.EmptySet + index = 0 + for set_i in sets: + if isinstance(set_i, Set): + cross = ProductSet(set_i, FiniteSet(index)) + dj_union = Union(dj_union, cross) + index = index + 1 + return dj_union + + def _contains(self, element): + """ + ``in`` operator for DisjointUnion + + Examples + ======== + + >>> from sympy import Interval, DisjointUnion + >>> D = DisjointUnion(Interval(0, 1), Interval(0, 2)) + >>> (0.5, 0) in D + True + >>> (0.5, 1) in D + True + >>> (1.5, 0) in D + False + >>> (1.5, 1) in D + True + + Passes operation on to constituent sets + """ + if not isinstance(element, Tuple) or len(element) != 2: + return S.false + + if not element[1].is_Integer: + return S.false + + if element[1] >= len(self.sets) or element[1] < 0: + return S.false + + return self.sets[element[1]]._contains(element[0]) + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return self.args[0].kind + else: + return SetKind(UndefinedKind) + + def __iter__(self): + if self.is_iterable: + + iters = [] + for i, s in enumerate(self.sets): + iters.append(iproduct(s, {Integer(i)})) + + return iter(roundrobin(*iters)) + else: + raise ValueError("'%s' is not iterable." % self) + + def __len__(self): + """ + Returns the length of the disjoint union, i.e., the number of elements in the set. + + Examples + ======== + + >>> from sympy import FiniteSet, DisjointUnion, EmptySet + >>> D1 = DisjointUnion(FiniteSet(1, 2, 3, 4), EmptySet, FiniteSet(3, 4, 5)) + >>> len(D1) + 7 + >>> D2 = DisjointUnion(FiniteSet(3, 5, 7), EmptySet, FiniteSet(3, 5, 7)) + >>> len(D2) + 6 + >>> D3 = DisjointUnion(EmptySet, EmptySet) + >>> len(D3) + 0 + + Adds up the lengths of the constituent sets. + """ + + if self.is_finite_set: + size = 0 + for set in self.sets: + size += len(set) + return size + else: + raise ValueError("'%s' is not a finite set." % self) + + +def imageset(*args): + r""" + Return an image of the set under transformation ``f``. + + Explanation + =========== + + If this function cannot compute the image, it returns an + unevaluated ImageSet object. + + .. math:: + \{ f(x) \mid x \in \mathrm{self} \} + + Examples + ======== + + >>> from sympy import S, Interval, imageset, sin, Lambda + >>> from sympy.abc import x + + >>> imageset(x, 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(lambda x: 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(Lambda(x, sin(x)), Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + + >>> imageset(sin, Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + >>> imageset(lambda y: x + y, Interval(-2, 1)) + ImageSet(Lambda(y, x + y), Interval(-2, 1)) + + Expressions applied to the set of Integers are simplified + to show as few negatives as possible and linear expressions + are converted to a canonical form. If this is not desirable + then the unevaluated ImageSet should be used. + + >>> imageset(x, -2*x + 5, S.Integers) + ImageSet(Lambda(x, 2*x + 1), Integers) + + See Also + ======== + + sympy.sets.fancysets.ImageSet + + """ + from .fancysets import ImageSet + from .setexpr import set_function + + if len(args) < 2: + raise ValueError('imageset expects at least 2 args, got: %s' % len(args)) + + if isinstance(args[0], (Symbol, tuple)) and len(args) > 2: + f = Lambda(args[0], args[1]) + set_list = args[2:] + else: + f = args[0] + set_list = args[1:] + + if isinstance(f, Lambda): + pass + elif callable(f): + nargs = getattr(f, 'nargs', {}) + if nargs: + if len(nargs) != 1: + raise NotImplementedError(filldedent(''' + This function can take more than 1 arg + but the potentially complicated set input + has not been analyzed at this point to + know its dimensions. TODO + ''')) + N = nargs.args[0] + if N == 1: + s = 'x' + else: + s = [Symbol('x%i' % i) for i in range(1, N + 1)] + else: + s = inspect.signature(f).parameters + + dexpr = _sympify(f(*[Dummy() for i in s])) + var = tuple(uniquely_named_symbol( + Symbol(i), dexpr) for i in s) + f = Lambda(var, f(*var)) + else: + raise TypeError(filldedent(''' + expecting lambda, Lambda, or FunctionClass, + not \'%s\'.''' % func_name(f))) + + if any(not isinstance(s, Set) for s in set_list): + name = [func_name(s) for s in set_list] + raise ValueError( + 'arguments after mapping should be sets, not %s' % name) + + if len(set_list) == 1: + set = set_list[0] + try: + # TypeError if arg count != set dimensions + r = set_function(f, set) + if r is None: + raise TypeError + if not r: + return r + except TypeError: + r = ImageSet(f, set) + if isinstance(r, ImageSet): + f, set = r.args + + if f.variables[0] == f.expr: + return set + + if isinstance(set, ImageSet): + # XXX: Maybe this should just be: + # f2 = set.lambda + # fun = Lambda(f2.signature, f(*f2.expr)) + # return imageset(fun, *set.base_sets) + if len(set.lamda.variables) == 1 and len(f.variables) == 1: + x = set.lamda.variables[0] + y = f.variables[0] + return imageset( + Lambda(x, f.expr.subs(y, set.lamda.expr)), *set.base_sets) + + if r is not None: + return r + + return ImageSet(f, *set_list) + + +def is_function_invertible_in_set(func, setv): + """ + Checks whether function ``func`` is invertible when the domain is + restricted to set ``setv``. + """ + # Functions known to always be invertible: + if func in (exp, log): + return True + u = Dummy("u") + fdiff = func(u).diff(u) + # monotonous functions: + # TODO: check subsets (`func` in `setv`) + if (fdiff > 0) == True or (fdiff < 0) == True: + return True + # TODO: support more + return None + + +def simplify_union(args): + """ + Simplify a :class:`Union` using known rules. + + Explanation + =========== + + We first start with global rules like 'Merge all FiniteSets' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent. This process depends + on ``union_sets(a, b)`` functions. + """ + from sympy.sets.handlers.union import union_sets + + # ===== Global Rules ===== + if not args: + return S.EmptySet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # Merge all finite sets + finite_sets = [x for x in args if x.is_FiniteSet] + if len(finite_sets) > 1: + a = (x for set in finite_sets for x in set) + finite_set = FiniteSet(*a) + args = [finite_set] + [x for x in args if not x.is_FiniteSet] + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = union_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + if new_set is not None: + if not isinstance(new_set, set): + new_set = {new_set} + new_args = (args - {s, t}).union(new_set) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Union(*args, evaluate=False) + + +def simplify_intersection(args): + """ + Simplify an intersection using known rules. + + Explanation + =========== + + We first start with global rules like + 'if any empty sets return empty set' and 'distribute any unions' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent + """ + + # ===== Global Rules ===== + if not args: + return S.UniversalSet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # If any EmptySets return EmptySet + if S.EmptySet in args: + return S.EmptySet + + # Handle Finite sets + rv = Intersection._handle_finite_sets(args) + + if rv is not None: + return rv + + # If any of the sets are unions, return a Union of Intersections + for s in args: + if s.is_Union: + other_sets = set(args) - {s} + if len(other_sets) > 0: + other = Intersection(*other_sets) + return Union(*(Intersection(arg, other) for arg in s.args)) + else: + return Union(*s.args) + + for s in args: + if s.is_Complement: + args.remove(s) + other_sets = args + [s.args[0]] + return Complement(Intersection(*other_sets), s.args[1]) + + from sympy.sets.handlers.intersection import intersection_sets + + # At this stage we are guaranteed not to have any + # EmptySets, FiniteSets, or Unions in the intersection + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = intersection_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + + if new_set is not None: + new_args = (args - {s, t}).union({new_set}) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Intersection(*args, evaluate=False) + + +def _handle_finite_sets(op, x, y, commutative): + # Handle finite sets: + fs_args, other = sift([x, y], lambda x: isinstance(x, FiniteSet), binary=True) + if len(fs_args) == 2: + return FiniteSet(*[op(i, j) for i in fs_args[0] for j in fs_args[1]]) + elif len(fs_args) == 1: + sets = [_apply_operation(op, other[0], i, commutative) for i in fs_args[0]] + return Union(*sets) + else: + return None + + +def _apply_operation(op, x, y, commutative): + from .fancysets import ImageSet + d = Dummy('d') + + out = _handle_finite_sets(op, x, y, commutative) + if out is None: + out = op(x, y) + + if out is None and commutative: + out = op(y, x) + if out is None: + _x, _y = symbols("x y") + if isinstance(x, Set) and not isinstance(y, Set): + out = ImageSet(Lambda(d, op(d, y)), x).doit() + elif not isinstance(x, Set) and isinstance(y, Set): + out = ImageSet(Lambda(d, op(x, d)), y).doit() + else: + out = ImageSet(Lambda((_x, _y), op(_x, _y)), x, y) + return out + + +def set_add(x, y): + from sympy.sets.handlers.add import _set_add + return _apply_operation(_set_add, x, y, commutative=True) + + +def set_sub(x, y): + from sympy.sets.handlers.add import _set_sub + return _apply_operation(_set_sub, x, y, commutative=False) + + +def set_mul(x, y): + from sympy.sets.handlers.mul import _set_mul + return _apply_operation(_set_mul, x, y, commutative=True) + + +def set_div(x, y): + from sympy.sets.handlers.mul import _set_div + return _apply_operation(_set_div, x, y, commutative=False) + + +def set_pow(x, y): + from sympy.sets.handlers.power import _set_pow + return _apply_operation(_set_pow, x, y, commutative=False) + + +def set_function(f, x): + from sympy.sets.handlers.functions import _set_function + return _set_function(f, x) + + +class SetKind(Kind): + """ + SetKind is kind for all Sets + + Every instance of Set will have kind ``SetKind`` parametrised by the kind + of the elements of the ``Set``. The kind of the elements might be + ``NumberKind``, or ``TupleKind`` or something else. When not all elements + have the same kind then the kind of the elements will be given as + ``UndefinedKind``. + + Parameters + ========== + + element_kind: Kind (optional) + The kind of the elements of the set. In a well defined set all elements + will have the same kind. Otherwise the kind should + :class:`sympy.core.kind.UndefinedKind`. The ``element_kind`` argument is optional but + should only be omitted in the case of ``EmptySet`` whose kind is simply + ``SetKind()`` + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(1, 2).kind + SetKind(NumberKind) + >>> Interval(1,2).kind.element_kind + NumberKind + + See Also + ======== + + sympy.core.kind.NumberKind + sympy.matrices.kind.MatrixKind + sympy.core.containers.TupleKind + """ + def __new__(cls, element_kind=None): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + if not self.element_kind: + return "SetKind()" + else: + return "SetKind(%s)" % self.element_kind diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc64cfb945ad642131158de323e35f8cbf6995c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb6dafe03325e3752e9fd57831f9622aafedef76 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53f057affc7aa0c24005d40f5fa4a1d8e138f010 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..767ffe20907ca3adfe13311bd793fbd38987d9c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b54c0a9527c3de4b1695186cbdf2e3ffdd48841 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a35833513cb5e6b4ef678b11bea2e981fee696 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_conditionset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..4818246f306afd46a09a2cbea1faab858a9e7806 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_conditionset.py @@ -0,0 +1,294 @@ +from sympy.core.expr import unchanged +from sympy.sets import (ConditionSet, Intersection, FiniteSet, + EmptySet, Union, Contains, ImageSet) +from sympy.sets.sets import SetKind +from sympy.core.function import (Function, Lambda) +from sympy.core.mod import Mod +from sympy.core.kind import NumberKind +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.trigonometric import (asin, sin) +from sympy.logic.boolalg import And +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.sets.sets import Interval +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +f = Function('f') + + +def test_CondSet(): + sin_sols_principal = ConditionSet(x, Eq(sin(x), 0), + Interval(0, 2*pi, False, True)) + assert pi in sin_sols_principal + assert pi/2 not in sin_sols_principal + assert 3*pi not in sin_sols_principal + assert oo not in sin_sols_principal + assert 5 in ConditionSet(x, x**2 > 4, S.Reals) + assert 1 not in ConditionSet(x, x**2 > 4, S.Reals) + # in this case, 0 is not part of the base set so + # it can't be in any subset selected by the condition + assert 0 not in ConditionSet(x, y > 5, Interval(1, 7)) + # since 'in' requires a true/false, the following raises + # an error because the given value provides no information + # for the condition to evaluate (since the condition does + # not depend on the dummy symbol): the result is `y > 5`. + # In this case, ConditionSet is just acting like + # Piecewise((Interval(1, 7), y > 5), (S.EmptySet, True)). + raises(TypeError, lambda: 6 in ConditionSet(x, y > 5, + Interval(1, 7))) + + X = MatrixSymbol('X', 2, 2) + matrix_set = ConditionSet(X, Eq(X*Matrix([[1, 1], [1, 1]]), X)) + Y = Matrix([[0, 0], [0, 0]]) + assert matrix_set.contains(Y).doit() is S.true + Z = Matrix([[1, 2], [3, 4]]) + assert matrix_set.contains(Z).doit() is S.false + + assert isinstance(ConditionSet(x, x < 1, {x, y}).base_set, + FiniteSet) + raises(TypeError, lambda: ConditionSet(x, x + 1, {x, y})) + raises(TypeError, lambda: ConditionSet(x, x, 1)) + + I = S.Integers + U = S.UniversalSet + C = ConditionSet + assert C(x, False, I) is S.EmptySet + assert C(x, True, I) is I + assert C(x, x < 1, C(x, x < 2, I) + ) == C(x, (x < 1) & (x < 2), I) + assert C(y, y < 1, C(x, y < 2, I) + ) == C(x, (x < 1) & (y < 2), I), C(y, y < 1, C(x, y < 2, I)) + assert C(y, y < 1, C(x, x < 2, I) + ) == C(y, (y < 1) & (y < 2), I) + assert C(y, y < 1, C(x, y < x, I) + ) == C(x, (x < 1) & (y < x), I) + assert unchanged(C, y, x < 1, C(x, y < x, I)) + assert ConditionSet(x, x < 1).base_set is U + # arg checking is not done at instantiation but this + # will raise an error when containment is tested + assert ConditionSet((x,), x < 1).base_set is U + + c = ConditionSet((x, y), x < y, I**2) + assert (1, 2) in c + assert (1, pi) not in c + + raises(TypeError, lambda: C(x, x > 1, C((x, y), x > 1, I**2))) + # signature mismatch since only 3 args are accepted + raises(TypeError, lambda: C((x, y), x + y < 2, U, U)) + + +def test_CondSet_intersect(): + input_conditionset = ConditionSet(x, x**2 > 4, Interval(1, 4, False, + False)) + other_domain = Interval(0, 3, False, False) + output_conditionset = ConditionSet(x, x**2 > 4, Interval( + 1, 3, False, False)) + assert Intersection(input_conditionset, other_domain + ) == output_conditionset + + +def test_issue_9849(): + assert ConditionSet(x, Eq(x, x), S.Naturals + ) is S.Naturals + assert ConditionSet(x, Eq(Abs(sin(x)), -1), S.Naturals + ) == S.EmptySet + + +def test_simplified_FiniteSet_in_CondSet(): + assert ConditionSet(x, And(x < 1, x > -3), FiniteSet(0, 1, 2) + ) == FiniteSet(0) + assert ConditionSet(x, x < 0, FiniteSet(0, 1, 2)) == EmptySet + assert ConditionSet(x, And(x < -3), EmptySet) == EmptySet + y = Symbol('y') + assert (ConditionSet(x, And(x > 0), FiniteSet(-1, 0, 1, y)) == + Union(FiniteSet(1), ConditionSet(x, And(x > 0), FiniteSet(y)))) + assert (ConditionSet(x, Eq(Mod(x, 3), 1), FiniteSet(1, 4, 2, y)) == + Union(FiniteSet(1, 4), ConditionSet(x, Eq(Mod(x, 3), 1), + FiniteSet(y)))) + + +def test_free_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).free_symbols == {y, z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(z) + ).free_symbols == {z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, z) + ).free_symbols == {x, z} + assert ConditionSet(x, Eq(x, 0), ImageSet(Lambda(y, y**2), + S.Integers)).free_symbols == set() + + +def test_bound_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).bound_symbols == [x] + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, y) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ImageSet(Lambda(y, y**2), S.Integers) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ConditionSet(y, y > 1, S.Integers) + ).bound_symbols == [x] + + +def test_as_dummy(): + _0, _1 = symbols('_0 _1') + assert ConditionSet(x, x < 1, Interval(y, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(y, oo)) + assert ConditionSet(x, x < 1, Interval(x, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(x, oo)) + assert ConditionSet(x, x < 1, ImageSet(Lambda(y, y**2), S.Integers) + ).as_dummy() == ConditionSet( + _0, _0 < 1, ImageSet(Lambda(_0, _0**2), S.Integers)) + e = ConditionSet((x, y), x <= y, S.Reals**2) + assert e.bound_symbols == [x, y] + assert e.as_dummy() == ConditionSet((_0, _1), _0 <= _1, S.Reals**2) + assert e.as_dummy() == ConditionSet((y, x), y <= x, S.Reals**2 + ).as_dummy() + + +def test_subs_CondSet(): + s = FiniteSet(z, y) + c = ConditionSet(x, x < 2, s) + assert c.subs(x, y) == c + assert c.subs(z, y) == ConditionSet(x, x < 2, FiniteSet(y)) + assert c.xreplace({x: y}) == ConditionSet(y, y < 2, s) + + assert ConditionSet(x, x < y, s + ).subs(y, w) == ConditionSet(x, x < w, s.subs(y, w)) + # if the user uses assumptions that cause the condition + # to evaluate, that can't be helped from SymPy's end + n = Symbol('n', negative=True) + assert ConditionSet(n, 0 < n, S.Integers) is S.EmptySet + p = Symbol('p', positive=True) + assert ConditionSet(n, n < y, S.Integers + ).subs(n, x) == ConditionSet(n, n < y, S.Integers) + raises(ValueError, lambda: ConditionSet( + x + 1, x < 1, S.Integers)) + assert ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) == Interval(-5, 5), ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) + assert ConditionSet( + n, n < x, Interval(-oo, 0)).subs(x, p + ) == Interval(-oo, 0) + + assert ConditionSet(f(x), f(x) < 1, {w, z} + ).subs(f(x), y) == ConditionSet(f(x), f(x) < 1, {w, z}) + + # issue 17341 + k = Symbol('k') + img1 = ImageSet(Lambda(k, 2*k*pi + asin(y)), S.Integers) + img2 = ImageSet(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers) + assert ConditionSet(x, Contains( + y, Interval(-1,1)), img1).subs(y, S.One/3).dummy_eq(img2) + + assert (0, 1) in ConditionSet((x, y), x + y < 3, S.Integers**2) + + raises(TypeError, lambda: ConditionSet(n, n < -10, Interval(0, 10))) + + +def test_subs_CondSet_tebr(): + with warns_deprecated_sympy(): + assert ConditionSet((x, y), {x + 1, x + y}, S.Reals**2) == \ + ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + + +def test_dummy_eq(): + C = ConditionSet + I = S.Integers + c = C(x, x < 1, I) + assert c.dummy_eq(C(y, y < 1, I)) + assert c.dummy_eq(1) == False + assert c.dummy_eq(C(x, x < 1, S.Reals)) == False + + c1 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c2 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c3 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Complexes**2) + assert c1.dummy_eq(c2) + assert c1.dummy_eq(c3) is False + assert c.dummy_eq(c1) is False + assert c1.dummy_eq(c) is False + + # issue 19496 + m = Symbol('m') + n = Symbol('n') + a = Symbol('a') + d1 = ImageSet(Lambda(m, m*pi), S.Integers) + d2 = ImageSet(Lambda(n, n*pi), S.Integers) + c1 = ConditionSet(x, Ne(a, 0), d1) + c2 = ConditionSet(x, Ne(a, 0), d2) + assert c1.dummy_eq(c2) + + +def test_contains(): + assert 6 in ConditionSet(x, x > 5, Interval(1, 7)) + assert (8 in ConditionSet(x, y > 5, Interval(1, 7))) is False + # `in` should give True or False; in this case there is not + # enough information for that result + raises(TypeError, + lambda: 6 in ConditionSet(x, y > 5, Interval(1, 7))) + # here, there is enough information but the comparison is + # not defined + raises(TypeError, lambda: 0 in ConditionSet(x, 1/x >= 0, S.Reals)) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(6) == (y > 5) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(8) is S.false + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(w) == And(Contains(w, Interval(1, 7)), y > 5) + # This returns an unevaluated Contains object + # because 1/0 should not be defined for 1 and 0 in the context of + # reals. + assert ConditionSet(x, 1/x >= 0, S.Reals).contains(0) == \ + Contains(0, ConditionSet(x, 1/x >= 0, S.Reals), evaluate=False) + c = ConditionSet((x, y), x + y > 1, S.Integers**2) + assert not c.contains(1) + assert c.contains((2, 1)) + assert not c.contains((0, 1)) + c = ConditionSet((w, (x, y)), w + x + y > 1, S.Integers*S.Integers**2) + assert not c.contains(1) + assert not c.contains((1, 2)) + assert not c.contains(((1, 2), 3)) + assert not c.contains(((1, 2), (3, 4))) + assert c.contains((1, (3, 4))) + + +def test_as_relational(): + assert ConditionSet((x, y), x > 1, S.Integers**2).as_relational((x, y) + ) == (x > 1) & Contains(x, S.Integers) & Contains(y, S.Integers) + assert ConditionSet(x, x > 1, S.Integers).as_relational(x + ) == Contains(x, S.Integers) & (x > 1) + + +def test_flatten(): + """Tests whether there is basic denesting functionality""" + inner = ConditionSet(x, sin(x) + x > 0) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(y, sin(y) + y > 0) + outer = ConditionSet(x, Contains(y, inner), S.Reals) + assert outer != ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(x, sin(x) + x > 0).intersect(Interval(-1, 1)) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, Interval(-1, 1)) + + +def test_duplicate(): + from sympy.core.function import BadSignatureError + # test coverage for line 95 in conditionset.py, check for duplicates in symbols + dup = symbols('a,a') + raises(BadSignatureError, lambda: ConditionSet(dup, x < 0)) + + +def test_SetKind_ConditionSet(): + assert ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)).kind is SetKind(NumberKind) + assert ConditionSet(x, x < 0).kind is SetKind(NumberKind) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_contains.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6b98940946f98bf377aad6810f5b32eb6dd069 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_contains.py @@ -0,0 +1,52 @@ +from sympy.core.expr import unchanged +from sympy.core.numbers import oo +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.testing.pytest import raises + + +def test_contains_basic(): + raises(TypeError, lambda: Contains(S.Integers, 1)) + assert Contains(2, S.Integers) is S.true + assert Contains(-2, S.Naturals) is S.false + + i = Symbol('i', integer=True) + assert Contains(i, S.Naturals) == Contains(i, S.Naturals, evaluate=False) + + +def test_issue_6194(): + x = Symbol('x') + assert unchanged(Contains, x, Interval(0, 1)) + assert Interval(0, 1).contains(x) == (S.Zero <= x) & (x <= 1) + assert Contains(x, FiniteSet(0)) != S.false + assert Contains(x, Interval(1, 1)) != S.false + assert Contains(x, S.Integers) != S.false + + +def test_issue_10326(): + assert Contains(oo, Interval(-oo, oo)) == False + assert Contains(-oo, Interval(-oo, oo)) == False + + +def test_binary_symbols(): + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + assert Contains(x, FiniteSet(y, Eq(z, True)) + ).binary_symbols == {y, z} + + +def test_as_set(): + x = Symbol('x') + y = Symbol('y') + assert Contains(x, FiniteSet(y)).as_set() == FiniteSet(y) + assert Contains(x, S.Integers).as_set() == S.Integers + assert Contains(x, S.Reals).as_set() == S.Reals + + +def test_type_error(): + # Pass in a parameter not of type "set" + raises(TypeError, lambda: Contains(2, None)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_fancysets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..b23c2a99fce0af5bfe7c667185465ee417de19ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_fancysets.py @@ -0,0 +1,1313 @@ + +from sympy.core.expr import unchanged +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range, normalize_theta_set, + ComplexRegion) +from sympy.sets.sets import (FiniteSet, Interval, Union, imageset, + Intersection, ProductSet, SetKind) +from sympy.sets.conditionset import ConditionSet +from sympy.simplify.simplify import simplify +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import And +from sympy.matrices.dense import eye +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, y, t, z +from sympy.core.mod import Mod + +import itertools + + +def test_naturals(): + N = S.Naturals + assert 5 in N + assert -5 not in N + assert 5.5 not in N + ni = iter(N) + a, b, c, d = next(ni), next(ni), next(ni), next(ni) + assert (a, b, c, d) == (1, 2, 3, 4) + assert isinstance(a, Basic) + + assert N.intersect(Interval(-5, 5)) == Range(1, 6) + assert N.intersect(Interval(-5, 5, True, True)) == Range(1, 5) + + assert N.boundary == N + assert N.is_open == False + assert N.is_closed == True + + assert N.inf == 1 + assert N.sup is oo + assert not N.contains(oo) + for s in (S.Naturals0, S.Naturals): + assert s.intersection(S.Reals) is s + assert s.is_subset(S.Reals) + + assert N.as_relational(x) == And(Eq(floor(x), x), x >= 1, x < oo) + + +def test_naturals0(): + N = S.Naturals0 + assert 0 in N + assert -1 not in N + assert next(iter(N)) == 0 + assert not N.contains(oo) + assert N.contains(sin(x)) == Contains(sin(x), N) + + +def test_integers(): + Z = S.Integers + assert 5 in Z + assert -5 in Z + assert 5.5 not in Z + assert not Z.contains(oo) + assert not Z.contains(-oo) + + zi = iter(Z) + a, b, c, d = next(zi), next(zi), next(zi), next(zi) + assert (a, b, c, d) == (0, 1, -1, 2) + assert isinstance(a, Basic) + + assert Z.intersect(Interval(-5, 5)) == Range(-5, 6) + assert Z.intersect(Interval(-5, 5, True, True)) == Range(-4, 5) + assert Z.intersect(Interval(5, S.Infinity)) == Range(5, S.Infinity) + assert Z.intersect(Interval.Lopen(5, S.Infinity)) == Range(6, S.Infinity) + + assert Z.inf is -oo + assert Z.sup is oo + + assert Z.boundary == Z + assert Z.is_open == False + assert Z.is_closed == True + + assert Z.as_relational(x) == And(Eq(floor(x), x), -oo < x, x < oo) + + +def test_ImageSet(): + raises(ValueError, lambda: ImageSet(x, S.Integers)) + assert ImageSet(Lambda(x, 1), S.Integers) == FiniteSet(1) + assert ImageSet(Lambda(x, y), S.Integers) == {y} + assert ImageSet(Lambda(x, 1), S.EmptySet) == S.EmptySet + empty = Intersection(FiniteSet(log(2)/pi), S.Integers) + assert unchanged(ImageSet, Lambda(x, 1), empty) # issue #17471 + squares = ImageSet(Lambda(x, x**2), S.Naturals) + assert 4 in squares + assert 5 not in squares + assert FiniteSet(*range(10)).intersect(squares) == FiniteSet(1, 4, 9) + + assert 16 not in squares.intersect(Interval(0, 10)) + + si = iter(squares) + a, b, c, d = next(si), next(si), next(si), next(si) + assert (a, b, c, d) == (1, 4, 9, 16) + + harmonics = ImageSet(Lambda(x, 1/x), S.Naturals) + assert Rational(1, 5) in harmonics + assert Rational(.25) in harmonics + assert harmonics.contains(.25) == Contains( + 0.25, ImageSet(Lambda(x, 1/x), S.Naturals), evaluate=False) + assert Rational(.3) not in harmonics + assert (1, 2) not in harmonics + + assert harmonics.is_iterable + + assert imageset(x, -x, Interval(0, 1)) == Interval(-1, 0) + + assert ImageSet(Lambda(x, x**2), Interval(0, 2)).doit() == Interval(0, 4) + assert ImageSet(Lambda((x, y), 2*x), {4}, {3}).doit() == FiniteSet(8) + assert (ImageSet(Lambda((x, y), x+y), {1, 2, 3}, {10, 20, 30}).doit() == + FiniteSet(11, 12, 13, 21, 22, 23, 31, 32, 33)) + + c = Interval(1, 3) * Interval(1, 3) + assert Tuple(2, 6) in ImageSet(Lambda(((x, y),), (x, 2*y)), c) + assert Tuple(2, S.Half) in ImageSet(Lambda(((x, y),), (x, 1/y)), c) + assert Tuple(2, -2) not in ImageSet(Lambda(((x, y),), (x, y**2)), c) + assert Tuple(2, -2) in ImageSet(Lambda(((x, y),), (x, -2)), c) + c3 = ProductSet(Interval(3, 7), Interval(8, 11), Interval(5, 9)) + assert Tuple(8, 3, 9) in ImageSet(Lambda(((t, y, x),), (y, t, x)), c3) + assert Tuple(Rational(1, 8), 3, 9) in ImageSet(Lambda(((t, y, x),), (1/y, t, x)), c3) + assert 2/pi not in ImageSet(Lambda(((x, y),), 2/x), c) + assert 2/S(100) not in ImageSet(Lambda(((x, y),), 2/x), c) + assert Rational(2, 3) in ImageSet(Lambda(((x, y),), 2/x), c) + + S1 = imageset(lambda x, y: x + y, S.Integers, S.Naturals) + assert S1.base_pset == ProductSet(S.Integers, S.Naturals) + assert S1.base_sets == (S.Integers, S.Naturals) + + # Passing a set instead of a FiniteSet shouldn't raise + assert unchanged(ImageSet, Lambda(x, x**2), {1, 2, 3}) + + S2 = ImageSet(Lambda(((x, y),), x+y), {(1, 2), (3, 4)}) + assert 3 in S2.doit() + # FIXME: This doesn't yet work: + #assert 3 in S2 + assert S2._contains(3) is None + + raises(TypeError, lambda: ImageSet(Lambda(x, x**2), 1)) + + +def test_image_is_ImageSet(): + assert isinstance(imageset(x, sqrt(sin(x)), Range(5)), ImageSet) + + +def test_halfcircle(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + + assert (1, 0) in halfcircle + assert (0, -1) not in halfcircle + assert (0, 0) in halfcircle + assert halfcircle._contains((r, 0)) is None + assert not halfcircle.is_iterable + + +@XFAIL +def test_halfcircle_fail(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + assert (r, 2*pi) not in halfcircle + + +def test_ImageSet_iterator_not_injective(): + L = Lambda(x, x - x % 2) # produces 0, 2, 2, 4, 4, 6, 6, ... + evens = ImageSet(L, S.Naturals) + i = iter(evens) + # No repeats here + assert (next(i), next(i), next(i), next(i)) == (0, 2, 4, 6) + + +def test_inf_Range_len(): + raises(ValueError, lambda: len(Range(0, oo, 2))) + assert Range(0, oo, 2).size is S.Infinity + assert Range(0, -oo, -2).size is S.Infinity + assert Range(oo, 0, -2).size is S.Infinity + assert Range(-oo, 0, 2).size is S.Infinity + + +def test_Range_set(): + empty = Range(0) + + assert Range(5) == Range(0, 5) == Range(0, 5, 1) + + r = Range(10, 20, 2) + assert 12 in r + assert 8 not in r + assert 11 not in r + assert 30 not in r + + assert list(Range(0, 5)) == list(range(5)) + assert list(Range(5, 0, -1)) == list(range(5, 0, -1)) + + + assert Range(5, 15).sup == 14 + assert Range(5, 15).inf == 5 + assert Range(15, 5, -1).sup == 15 + assert Range(15, 5, -1).inf == 6 + assert Range(10, 67, 10).sup == 60 + assert Range(60, 7, -10).inf == 10 + + assert len(Range(10, 38, 10)) == 3 + + assert Range(0, 0, 5) == empty + assert Range(oo, oo, 1) == empty + assert Range(oo, 1, 1) == empty + assert Range(-oo, 1, -1) == empty + assert Range(1, oo, -1) == empty + assert Range(1, -oo, 1) == empty + assert Range(1, -4, oo) == empty + ip = symbols('ip', positive=True) + assert Range(0, ip, -1) == empty + assert Range(0, -ip, 1) == empty + assert Range(1, -4, -oo) == Range(1, 2) + assert Range(1, 4, oo) == Range(1, 2) + assert Range(-oo, oo).size == oo + assert Range(oo, -oo, -1).size == oo + raises(ValueError, lambda: Range(-oo, oo, 2)) + raises(ValueError, lambda: Range(x, pi, y)) + raises(ValueError, lambda: Range(x, y, 0)) + + assert 5 in Range(0, oo, 5) + assert -5 in Range(-oo, 0, 5) + assert oo not in Range(0, oo) + ni = symbols('ni', integer=False) + assert ni not in Range(oo) + u = symbols('u', integer=None) + assert Range(oo).contains(u) is not False + inf = symbols('inf', infinite=True) + assert inf not in Range(-oo, oo) + raises(ValueError, lambda: Range(0, oo, 2)[-1]) + raises(ValueError, lambda: Range(0, -oo, -2)[-1]) + assert Range(-oo, 1, 1)[-1] is S.Zero + assert Range(oo, 1, -1)[-1] == 2 + assert inf not in Range(oo) + assert Range(1, 10, 1)[-1] == 9 + assert all(i.is_Integer for i in Range(0, -1, 1)) + it = iter(Range(-oo, 0, 2)) + raises(TypeError, lambda: next(it)) + + assert empty.intersect(S.Integers) == empty + assert Range(-1, 10, 1).intersect(S.Complexes) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Reals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Rationals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Integers) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals) == Range(1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals0) == Range(0, 10, 1) + + # test slicing + assert Range(1, 10, 1)[5] == 6 + assert Range(1, 12, 2)[5] == 11 + assert Range(1, 10, 1)[-1] == 9 + assert Range(1, 10, 3)[-1] == 7 + raises(ValueError, lambda: Range(oo,0,-1)[1:3:0]) + raises(ValueError, lambda: Range(oo,0,-1)[:1]) + raises(ValueError, lambda: Range(1, oo)[-2]) + raises(ValueError, lambda: Range(-oo, 1)[2]) + raises(IndexError, lambda: Range(10)[-20]) + raises(IndexError, lambda: Range(10)[20]) + raises(ValueError, lambda: Range(2, -oo, -2)[2:2:0]) + assert Range(2, -oo, -2)[2:2:2] == empty + assert Range(2, -oo, -2)[:2:2] == Range(2, -2, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:2]) + assert Range(-oo, 4, 2)[::-2] == Range(2, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[::2]) + assert Range(oo, 2, -2)[::] == Range(oo, 2, -2) + assert Range(-oo, 4, 2)[:-2:-2] == Range(2, 0, -4) + assert Range(-oo, 4, 2)[:-2:2] == Range(-oo, 0, 4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:-2]) + assert Range(-oo, 4, 2)[-2::-2] == Range(0, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[-2:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[0::2]) + assert Range(oo, 2, -2)[0::] == Range(oo, 2, -2) + raises(ValueError, lambda: Range(-oo, 4, 2)[0:-2:2]) + assert Range(oo, 2, -2)[0:-2:] == Range(oo, 6, -2) + raises(ValueError, lambda: Range(oo, 2, -2)[0:2:]) + raises(ValueError, lambda: Range(-oo, 4, 2)[2::-1]) + assert Range(-oo, 4, 2)[-2::2] == Range(0, 4, 4) + assert Range(oo, 0, -2)[-10:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[0]) + raises(ValueError, lambda: Range(oo, 0, -2)[-10:10:2]) + raises(ValueError, lambda: Range(oo, 0, -2)[0::-2]) + assert Range(oo, 0, -2)[0:-4:-2] == empty + assert Range(oo, 0, -2)[:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[:1:-1]) + + # test empty Range + assert Range(x, x, y) == empty + assert empty.reversed == empty + assert 0 not in empty + assert list(empty) == [] + assert len(empty) == 0 + assert empty.size is S.Zero + assert empty.intersect(FiniteSet(0)) is S.EmptySet + assert bool(empty) is False + raises(IndexError, lambda: empty[0]) + assert empty[:0] == empty + raises(NotImplementedError, lambda: empty.inf) + raises(NotImplementedError, lambda: empty.sup) + assert empty.as_relational(x) is S.false + + AB = [None] + list(range(12)) + for R in [ + Range(1, 10), + Range(1, 10, 2), + ]: + r = list(R) + for a, b, c in itertools.product(AB, AB, [-3, -1, None, 1, 3]): + for reverse in range(2): + r = list(reversed(r)) + R = R.reversed + result = list(R[a:b:c]) + ans = r[a:b:c] + txt = ('\n%s[%s:%s:%s] = %s -> %s' % ( + R, a, b, c, result, ans)) + check = ans == result + assert check, txt + + assert Range(1, 10, 1).boundary == Range(1, 10, 1) + + for r in (Range(1, 10, 2), Range(1, oo, 2)): + rev = r.reversed + assert r.inf == rev.inf and r.sup == rev.sup + assert r.step == -rev.step + + builtin_range = range + + raises(TypeError, lambda: Range(builtin_range(1))) + assert S(builtin_range(10)) == Range(10) + assert S(builtin_range(1000000000000)) == Range(1000000000000) + + # test Range.as_relational + assert Range(1, 4).as_relational(x) == (x >= 1) & (x <= 3) & Eq(Mod(x, 1), 0) + assert Range(oo, 1, -2).as_relational(x) == (x >= 3) & (x < oo) & Eq(Mod(x + 1, -2), 0) + + +def test_Range_symbolic(): + # symbolic Range + xr = Range(x, x + 4, 5) + sr = Range(x, y, t) + i = Symbol('i', integer=True) + ip = Symbol('i', integer=True, positive=True) + ipr = Range(ip) + inr = Range(0, -ip, -1) + ir = Range(i, i + 19, 2) + ir2 = Range(i, i*8, 3*i) + i = Symbol('i', integer=True) + inf = symbols('inf', infinite=True) + raises(ValueError, lambda: Range(inf)) + raises(ValueError, lambda: Range(inf, 0, -1)) + raises(ValueError, lambda: Range(inf, inf, 1)) + raises(ValueError, lambda: Range(1, 1, inf)) + # args + assert xr.args == (x, x + 5, 5) + assert sr.args == (x, y, t) + assert ir.args == (i, i + 20, 2) + assert ir2.args == (i, 10*i, 3*i) + # reversed + raises(ValueError, lambda: xr.reversed) + raises(ValueError, lambda: sr.reversed) + assert ipr.reversed.args == (ip - 1, -1, -1) + assert inr.reversed.args == (-ip + 1, 1, 1) + assert ir.reversed.args == (i + 18, i - 2, -2) + assert ir2.reversed.args == (7*i, -2*i, -3*i) + # contains + assert inf not in sr + assert inf not in ir + assert 0 in ipr + assert 0 in inr + raises(TypeError, lambda: 1 in ipr) + raises(TypeError, lambda: -1 in inr) + assert .1 not in sr + assert .1 not in ir + assert i + 1 not in ir + assert i + 2 in ir + raises(TypeError, lambda: x in xr) # XXX is this what contains is supposed to do? + raises(TypeError, lambda: 1 in sr) # XXX is this what contains is supposed to do? + # iter + raises(ValueError, lambda: next(iter(xr))) + raises(ValueError, lambda: next(iter(sr))) + assert next(iter(ir)) == i + assert next(iter(ir2)) == i + assert sr.intersect(S.Integers) == sr + assert sr.intersect(FiniteSet(x)) == Intersection({x}, sr) + raises(ValueError, lambda: sr[:2]) + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + # len + assert len(ir) == ir.size == 10 + assert len(ir2) == ir2.size == 3 + raises(ValueError, lambda: len(xr)) + raises(ValueError, lambda: xr.size) + raises(ValueError, lambda: len(sr)) + raises(ValueError, lambda: sr.size) + # bool + assert bool(Range(0)) == False + assert bool(xr) + assert bool(ir) + assert bool(ipr) + assert bool(inr) + raises(ValueError, lambda: bool(sr)) + raises(ValueError, lambda: bool(ir2)) + # inf + raises(ValueError, lambda: xr.inf) + raises(ValueError, lambda: sr.inf) + assert ipr.inf == 0 + assert inr.inf == -ip + 1 + assert ir.inf == i + raises(ValueError, lambda: ir2.inf) + # sup + raises(ValueError, lambda: xr.sup) + raises(ValueError, lambda: sr.sup) + assert ipr.sup == ip - 1 + assert inr.sup == 0 + assert ir.inf == i + raises(ValueError, lambda: ir2.sup) + # getitem + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + raises(ValueError, lambda: sr[-1]) + raises(ValueError, lambda: sr[:2]) + assert ir[:2] == Range(i, i + 4, 2) + assert ir[0] == i + assert ir[-2] == i + 16 + assert ir[-1] == i + 18 + assert ir2[:2] == Range(i, 7*i, 3*i) + assert ir2[0] == i + assert ir2[-2] == 4*i + assert ir2[-1] == 7*i + raises(ValueError, lambda: Range(i)[-1]) + assert ipr[0] == ipr.inf == 0 + assert ipr[-1] == ipr.sup == ip - 1 + assert inr[0] == inr.sup == 0 + assert inr[-1] == inr.inf == -ip + 1 + raises(ValueError, lambda: ipr[-2]) + assert ir.inf == i + assert ir.sup == i + 18 + raises(ValueError, lambda: Range(i).inf) + # as_relational + assert ir.as_relational(x) == ((x >= i) & (x <= i + 18) & + Eq(Mod(-i + x, 2), 0)) + assert ir2.as_relational(x) == Eq( + Mod(-i + x, 3*i), 0) & (((x >= i) & (x <= 7*i) & (3*i >= 1)) | + ((x <= i) & (x >= 7*i) & (3*i <= -1))) + assert Range(i, i + 1).as_relational(x) == Eq(x, i) + assert sr.as_relational(z) == Eq( + Mod(t, 1), 0) & Eq(Mod(x, 1), 0) & Eq(Mod(-x + z, t), 0 + ) & (((z >= x) & (z <= -t + y) & (t >= 1)) | + ((z <= x) & (z >= -t + y) & (t <= -1))) + assert xr.as_relational(z) == Eq(z, x) & Eq(Mod(x, 1), 0) + # symbols can clash if user wants (but it must be integer) + assert xr.as_relational(x) == Eq(Mod(x, 1), 0) + # contains() for symbolic values (issue #18146) + e = Symbol('e', integer=True, even=True) + o = Symbol('o', integer=True, odd=True) + assert Range(5).contains(i) == And(i >= 0, i <= 4) + assert Range(1).contains(i) == Eq(i, 0) + assert Range(-oo, 5, 1).contains(i) == (i <= 4) + assert Range(-oo, oo).contains(i) == True + assert Range(0, 8, 2).contains(i) == Contains(i, Range(0, 8, 2)) + assert Range(0, 8, 2).contains(e) == And(e >= 0, e <= 6) + assert Range(0, 8, 2).contains(2*i) == And(2*i >= 0, 2*i <= 6) + assert Range(0, 8, 2).contains(o) == False + assert Range(1, 9, 2).contains(e) == False + assert Range(1, 9, 2).contains(o) == And(o >= 1, o <= 7) + assert Range(8, 0, -2).contains(o) == False + assert Range(9, 1, -2).contains(o) == And(o >= 3, o <= 9) + assert Range(-oo, 8, 2).contains(i) == Contains(i, Range(-oo, 8, 2)) + + +def test_range_range_intersection(): + for a, b, r in [ + (Range(0), Range(1), S.EmptySet), + (Range(3), Range(4, oo), S.EmptySet), + (Range(3), Range(-3, -1), S.EmptySet), + (Range(1, 3), Range(0, 3), Range(1, 3)), + (Range(1, 3), Range(1, 4), Range(1, 3)), + (Range(1, oo, 2), Range(2, oo, 2), S.EmptySet), + (Range(0, oo, 2), Range(oo), Range(0, oo, 2)), + (Range(0, oo, 2), Range(100), Range(0, 100, 2)), + (Range(2, oo, 2), Range(oo), Range(2, oo, 2)), + (Range(0, oo, 2), Range(5, 6), S.EmptySet), + (Range(2, 80, 1), Range(55, 71, 4), Range(55, 71, 4)), + (Range(0, 6, 3), Range(-oo, 5, 3), S.EmptySet), + (Range(0, oo, 2), Range(5, oo, 3), Range(8, oo, 6)), + (Range(4, 6, 2), Range(2, 16, 7), S.EmptySet),]: + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + a, b = b, a + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + + +def test_range_interval_intersection(): + p = symbols('p', positive=True) + assert isinstance(Range(3).intersect(Interval(p, p + 2)), Intersection) + assert Range(4).intersect(Interval(0, 3)) == Range(4) + assert Range(4).intersect(Interval(-oo, oo)) == Range(4) + assert Range(4).intersect(Interval(1, oo)) == Range(1, 4) + assert Range(4).intersect(Interval(1.1, oo)) == Range(2, 4) + assert Range(4).intersect(Interval(0.1, 3)) == Range(1, 4) + assert Range(4).intersect(Interval(0.1, 3.1)) == Range(1, 4) + assert Range(4).intersect(Interval.open(0, 3)) == Range(1, 3) + assert Range(4).intersect(Interval.open(0.1, 0.5)) is S.EmptySet + assert Interval(-1, 5).intersect(S.Complexes) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Reals) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Integers) == Range(-1, 6) + assert Interval(-1, 5).intersect(S.Naturals) == Range(1, 6) + assert Interval(-1, 5).intersect(S.Naturals0) == Range(0, 6) + + # Null Range intersections + assert Range(0).intersect(Interval(0.2, 0.8)) is S.EmptySet + assert Range(0).intersect(Interval(-oo, oo)) is S.EmptySet + + +def test_range_is_finite_set(): + assert Range(-100, 100).is_finite_set is True + assert Range(2, oo).is_finite_set is False + assert Range(-oo, 50).is_finite_set is False + assert Range(-oo, oo).is_finite_set is False + assert Range(oo, -oo).is_finite_set is True + assert Range(0, 0).is_finite_set is True + assert Range(oo, oo).is_finite_set is True + assert Range(-oo, -oo).is_finite_set is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + assert Range(n, n + 49).is_finite_set is True + assert Range(n, 0).is_finite_set is True + assert Range(-3, n + 7).is_finite_set is True + assert Range(n, m).is_finite_set is True + assert Range(n + m, m - n).is_finite_set is True + assert Range(n, n + m + n).is_finite_set is True + assert Range(n, oo).is_finite_set is False + assert Range(-oo, n).is_finite_set is False + assert Range(n, -oo).is_finite_set is True + assert Range(oo, n).is_finite_set is True + + +def test_Range_is_iterable(): + assert Range(-100, 100).is_iterable is True + assert Range(2, oo).is_iterable is False + assert Range(-oo, 50).is_iterable is False + assert Range(-oo, oo).is_iterable is False + assert Range(oo, -oo).is_iterable is True + assert Range(0, 0).is_iterable is True + assert Range(oo, oo).is_iterable is True + assert Range(-oo, -oo).is_iterable is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + p = Symbol('p', integer=True, positive=True) + assert Range(n, n + 49).is_iterable is True + assert Range(n, 0).is_iterable is False + assert Range(-3, n + 7).is_iterable is False + assert Range(-3, p + 7).is_iterable is False # Should work with better __iter__ + assert Range(n, m).is_iterable is False + assert Range(n + m, m - n).is_iterable is False + assert Range(n, n + m + n).is_iterable is False + assert Range(n, oo).is_iterable is False + assert Range(-oo, n).is_iterable is False + x = Symbol('x') + assert Range(x, x + 49).is_iterable is False + assert Range(x, 0).is_iterable is False + assert Range(-3, x + 7).is_iterable is False + assert Range(x, m).is_iterable is False + assert Range(x + m, m - x).is_iterable is False + assert Range(x, x + m + x).is_iterable is False + assert Range(x, oo).is_iterable is False + assert Range(-oo, x).is_iterable is False + + +def test_Integers_eval_imageset(): + ans = ImageSet(Lambda(x, 2*x + Rational(3, 7)), S.Integers) + im = imageset(Lambda(x, -2*x + Rational(3, 7)), S.Integers) + assert im == ans + im = imageset(Lambda(x, -2*x - Rational(11, 7)), S.Integers) + assert im == ans + y = Symbol('y') + L = imageset(x, 2*x + y, S.Integers) + assert y + 4 in L + a, b, c = 0.092, 0.433, 0.341 + assert a in imageset(x, a + c*x, S.Integers) + assert b in imageset(x, b + c*x, S.Integers) + + _x = symbols('x', negative=True) + eq = _x**2 - _x + 1 + assert imageset(_x, eq, S.Integers).lamda.expr == _x**2 + _x + 1 + eq = 3*_x - 1 + assert imageset(_x, eq, S.Integers).lamda.expr == 3*_x + 2 + + assert imageset(x, (x, 1/x), S.Integers) == \ + ImageSet(Lambda(x, (x, 1/x)), S.Integers) + + +def test_Range_eval_imageset(): + a, b, c = symbols('a b c') + assert imageset(x, a*(x + b) + c, Range(3)) == \ + imageset(x, a*x + a*b + c, Range(3)) + eq = (x + 1)**2 + assert imageset(x, eq, Range(3)).lamda.expr == eq + eq = a*(x + b) + c + r = Range(3, -3, -2) + imset = imageset(x, eq, r) + assert imset.lamda.expr != eq + assert list(imset) == [eq.subs(x, i).expand() for i in list(r)] + + +def test_fun(): + assert (FiniteSet(*ImageSet(Lambda(x, sin(pi*x/4)), + Range(-10, 11))) == FiniteSet(-1, -sqrt(2)/2, 0, sqrt(2)/2, 1)) + + +def test_Range_is_empty(): + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert Range(0).is_empty + assert not Range(1).is_empty + assert Range(1, 0).is_empty + assert not Range(-1, 0).is_empty + assert Range(i).is_empty is None + assert Range(n).is_empty + assert Range(p).is_empty is False + assert Range(n, 0).is_empty is False + assert Range(n, p).is_empty is False + assert Range(p, n).is_empty + assert Range(n, -1).is_empty is None + assert Range(p, n, -1).is_empty is False + + +def test_Reals(): + assert 5 in S.Reals + assert S.Pi in S.Reals + assert -sqrt(2) in S.Reals + assert (2, 5) not in S.Reals + assert sqrt(-1) not in S.Reals + assert S.Reals == Interval(-oo, oo) + assert S.Reals != Interval(0, oo) + assert S.Reals.is_subset(Interval(-oo, oo)) + assert S.Reals.intersect(Range(-oo, oo)) == Range(-oo, oo) + assert S.ComplexInfinity not in S.Reals + assert S.NaN not in S.Reals + assert x + S.ComplexInfinity not in S.Reals + + +def test_Complex(): + assert 5 in S.Complexes + assert 5 + 4*I in S.Complexes + assert S.Pi in S.Complexes + assert -sqrt(2) in S.Complexes + assert -I in S.Complexes + assert sqrt(-1) in S.Complexes + assert S.Complexes.intersect(S.Reals) == S.Reals + assert S.Complexes.union(S.Reals) == S.Complexes + assert S.Complexes == ComplexRegion(S.Reals*S.Reals) + assert (S.Complexes == ComplexRegion(Interval(1, 2)*Interval(3, 4))) == False + assert str(S.Complexes) == "Complexes" + assert repr(S.Complexes) == "Complexes" + + +def take(n, iterable): + "Return first n items of the iterable as a list" + return list(itertools.islice(iterable, n)) + + +def test_intersections(): + assert S.Integers.intersect(S.Reals) == S.Integers + assert 5 in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(S.Reals) + assert -5 not in S.Naturals.intersect(S.Reals) + assert 5.5 not in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(Interval(3, oo)) + assert -5 in S.Integers.intersect(Interval(-oo, 3)) + assert all(x.is_Integer + for x in take(10, S.Integers.intersect(Interval(3, oo)) )) + + +def test_infinitely_indexed_set_1(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == imageset(Lambda(m, m), S.Integers) + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(m, 2*m + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(n, 2*n + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(m, 2*m), S.Integers).intersect( + imageset(Lambda(n, 3*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*t), S.Integers)) + + assert imageset(x, x/2 + Rational(1, 3), S.Integers).intersect(S.Integers) is S.EmptySet + assert imageset(x, x/2 + S.Half, S.Integers).intersect(S.Integers) is S.Integers + + # https://github.com/sympy/sympy/issues/17355 + S53 = ImageSet(Lambda(n, 5*n + 3), S.Integers) + assert S53.intersect(S.Integers) == S53 + + +def test_infinitely_indexed_set_2(): + from sympy.abc import n + a = Symbol('a', integer=True) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, n + a), S.Integers) + assert imageset(Lambda(n, n + pi), S.Integers) == \ + imageset(Lambda(n, n + a + pi), S.Integers) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, -n + a), S.Integers) + assert imageset(Lambda(n, -6*n), S.Integers) == \ + ImageSet(Lambda(n, 6*n), S.Integers) + assert imageset(Lambda(n, 2*n + pi), S.Integers) == \ + ImageSet(Lambda(n, 2*n + pi - 2), S.Integers) + + +def test_imageset_intersect_real(): + from sympy.abc import n + assert imageset(Lambda(n, n + (n - 1)*(n + 1)*I), S.Integers).intersect(S.Reals) == FiniteSet(-1, 1) + im = (n - 1)*(n + S.Half) + assert imageset(Lambda(n, n + im*I), S.Integers + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n + im*(n + 1)*I), S.Naturals0 + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n/2 + im.expand()*I), S.Integers + ).intersect(S.Reals) == ImageSet(Lambda(x, x/2), ConditionSet( + n, Eq(n**2 - n/2 - S(1)/2, 0), S.Integers)) + assert imageset(Lambda(n, n/(1/n - 1) + im*(n + 1)*I), S.Integers + ).intersect(S.Reals) == FiniteSet(S.Half) + assert imageset(Lambda(n, n/(n - 6) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) == FiniteSet(-1) + assert imageset(Lambda(n, n/(n**2 - 9) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) is S.EmptySet + s = ImageSet( + Lambda(n, -I*(I*(2*pi*n - pi/4) + log(Abs(sqrt(-I))))), + S.Integers) + # s is unevaluated, but after intersection the result + # should be canonical + assert s.intersect(S.Reals) == imageset( + Lambda(n, 2*n*pi - pi/4), S.Integers) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_imageset_intersect_interval(): + from sympy.abc import n + f1 = ImageSet(Lambda(n, n*pi), S.Integers) + f2 = ImageSet(Lambda(n, 2*n), Interval(0, pi)) + f3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + # complex expressions + f4 = ImageSet(Lambda(n, n*I*pi), S.Integers) + f5 = ImageSet(Lambda(n, 2*I*n*pi + pi/2), S.Integers) + # non-linear expressions + f6 = ImageSet(Lambda(n, log(n)), S.Integers) + f7 = ImageSet(Lambda(n, n**2), S.Integers) + f8 = ImageSet(Lambda(n, Abs(n)), S.Integers) + f9 = ImageSet(Lambda(n, exp(n)), S.Naturals0) + + assert f1.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f1.intersect(Interval(0, 2*pi, False, True)) == FiniteSet(0, pi) + assert f2.intersect(Interval(1, 2)) == Interval(1, 2) + assert f3.intersect(Interval(-1, 1)) == S.EmptySet + assert f3.intersect(Interval(-5, 5)) == FiniteSet(pi*Rational(-3, 2), pi/2) + assert f4.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f4.intersect(Interval(1, 2)) == S.EmptySet + assert f5.intersect(Interval(0, 1)) == S.EmptySet + assert f6.intersect(Interval(0, 1)) == FiniteSet(S.Zero, log(2)) + assert f7.intersect(Interval(0, 10)) == Intersection(f7, Interval(0, 10)) + assert f8.intersect(Interval(0, 2)) == Intersection(f8, Interval(0, 2)) + assert f9.intersect(Interval(1, 2)) == Intersection(f9, Interval(1, 2)) + + +def test_imageset_intersect_diophantine(): + from sympy.abc import m, n + # Check that same lambda variable for both ImageSets is handled correctly + img1 = ImageSet(Lambda(n, 2*n + 1), S.Integers) + img2 = ImageSet(Lambda(n, 4*n + 1), S.Integers) + assert img1.intersect(img2) == img2 + # Empty solution set returned by diophantine: + assert ImageSet(Lambda(n, 2*n), S.Integers).intersect( + ImageSet(Lambda(n, 2*n + 1), S.Integers)) == S.EmptySet + # Check intersection with S.Integers: + assert ImageSet(Lambda(n, 9/n + 20*n/3), S.Integers).intersect( + S.Integers) == FiniteSet(-61, -23, 23, 61) + # Single solution (2, 3) for diophantine solution: + assert ImageSet(Lambda(n, (n - 2)**2), S.Integers).intersect( + ImageSet(Lambda(n, -(n - 3)**2), S.Integers)) == FiniteSet(0) + # Single parametric solution for diophantine solution: + assert ImageSet(Lambda(n, n**2 + 5), S.Integers).intersect( + ImageSet(Lambda(m, 2*m), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 4*n**2 + 4*n + 6), S.Integers)) + # 4 non-parametric solution couples for dioph. equation: + assert ImageSet(Lambda(n, n**2 - 9), S.Integers).intersect( + ImageSet(Lambda(m, -m**2), S.Integers)) == FiniteSet(-9, 0) + # Double parametric solution for diophantine solution: + assert ImageSet(Lambda(m, m**2 + 40), S.Integers).intersect( + ImageSet(Lambda(n, 41*n), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(m, m**2 + 40), S.Integers), + ImageSet(Lambda(n, 41*n), S.Integers))) + # Check that diophantine returns *all* (8) solutions (permute=True) + assert ImageSet(Lambda(n, n**4 - 2**4), S.Integers).intersect( + ImageSet(Lambda(m, -m**4 + 3**4), S.Integers)) == FiniteSet(0, 65) + assert ImageSet(Lambda(n, pi/12 + n*5*pi/12), S.Integers).intersect( + ImageSet(Lambda(n, 7*pi/12 + n*11*pi/12), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 55*pi*n/12 + 17*pi/4), S.Integers)) + # TypeError raised by diophantine (#18081) + assert ImageSet(Lambda(n, n*log(2)), S.Integers).intersection( + S.Integers).dummy_eq(Intersection(ImageSet( + Lambda(n, n*log(2)), S.Integers), S.Integers)) + # NotImplementedError raised by diophantine (no solver for cubic_thue) + assert ImageSet(Lambda(n, n**3 + 1), S.Integers).intersect( + ImageSet(Lambda(n, n**3), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(n, n**3 + 1), S.Integers), + ImageSet(Lambda(n, n**3), S.Integers))) + + +def test_infinitely_indexed_set_3(): + from sympy.abc import n, m + assert imageset(Lambda(m, 2*pi*m), S.Integers).intersect( + imageset(Lambda(n, 3*pi*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*pi*t), S.Integers)) + assert imageset(Lambda(n, 2*n + 1), S.Integers) == \ + imageset(Lambda(n, 2*n - 1), S.Integers) + assert imageset(Lambda(n, 3*n + 2), S.Integers) == \ + imageset(Lambda(n, 3*n - 1), S.Integers) + + +def test_ImageSet_simplification(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == S.Integers + assert imageset(Lambda(n, sin(n)), + imageset(Lambda(m, tan(m)), S.Integers)) == \ + imageset(Lambda(m, sin(tan(m))), S.Integers) + assert imageset(n, 1 + 2*n, S.Naturals) == Range(3, oo, 2) + assert imageset(n, 1 + 2*n, S.Naturals0) == Range(1, oo, 2) + assert imageset(n, 1 - 2*n, S.Naturals) == Range(-1, -oo, -2) + + +def test_ImageSet_contains(): + assert (2, S.Half) in imageset(x, (x, 1/x), S.Integers) + assert imageset(x, x + I*3, S.Integers).intersection(S.Reals) is S.EmptySet + i = Dummy(integer=True) + q = imageset(x, x + I*y, S.Integers).intersection(S.Reals) + assert q.subs(y, I*i).intersection(S.Integers) is S.Integers + q = imageset(x, x + I*y/x, S.Integers).intersection(S.Reals) + assert q.subs(y, 0) is S.Integers + assert q.subs(y, I*i*x).intersection(S.Integers) is S.Integers + z = cos(1)**2 + sin(1)**2 - 1 + q = imageset(x, x + I*z, S.Integers).intersection(S.Reals) + assert q is not S.EmptySet + + +def test_ComplexRegion_contains(): + r = Symbol('r', real=True) + # contains in ComplexRegion + a = Interval(2, 3) + b = Interval(4, 6) + c = Interval(7, 9) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*b, c*a)) + assert 2.5 + 4.5*I in c1 + assert 2 + 4*I in c1 + assert 3 + 4*I in c1 + assert 8 + 2.5*I in c2 + assert 2.5 + 6.1*I not in c1 + assert 4.5 + 3.2*I not in c1 + assert c1.contains(x) == Contains(x, c1, evaluate=False) + assert c1.contains(r) == False + assert c2.contains(x) == Contains(x, c2, evaluate=False) + assert c2.contains(r) == False + + r1 = Interval(0, 1) + theta1 = Interval(0, 2*S.Pi) + c3 = ComplexRegion(r1*theta1, polar=True) + assert (0.5 + I*6/10) in c3 + assert (S.Half + I*6/10) in c3 + assert (S.Half + .6*I) in c3 + assert (0.5 + .6*I) in c3 + assert I in c3 + assert 1 in c3 + assert 0 in c3 + assert 1 + I not in c3 + assert 1 - I not in c3 + assert c3.contains(x) == Contains(x, c3, evaluate=False) + assert c3.contains(r + 2*I) == Contains( + r + 2*I, c3, evaluate=False) # is in fact False + assert c3.contains(1/(1 + r**2)) == Contains( + 1/(1 + r**2), c3, evaluate=False) # is in fact True + + r2 = Interval(0, 3) + theta2 = Interval(pi, 2*pi, left_open=True) + c4 = ComplexRegion(r2*theta2, polar=True) + assert c4.contains(0) == True + assert c4.contains(2 + I) == False + assert c4.contains(-2 + I) == False + assert c4.contains(-2 - I) == True + assert c4.contains(2 - I) == True + assert c4.contains(-2) == False + assert c4.contains(2) == True + assert c4.contains(x) == Contains(x, c4, evaluate=False) + assert c4.contains(3/(1 + r**2)) == Contains( + 3/(1 + r**2), c4, evaluate=False) # is in fact True + + raises(ValueError, lambda: ComplexRegion(r1*theta1, polar=2)) + + +def test_symbolic_Range(): + n = Symbol('n') + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + raises(ValueError, lambda: Range(n, n+1)[0]) + raises(ValueError, lambda: Range(n).size) + + n = Symbol('n', integer=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n, n+1)[0] == n + raises(ValueError, lambda: Range(n).size) + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, nonnegative=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n+1)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n+1).size == n+1 + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, positive=True) + assert Range(n)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n, n+1).size == 1 + + m = Symbol('m', integer=True, positive=True) + + assert Range(n, n+m)[0] == n + assert Range(n, n+m).size == m + assert Range(n, n+1).size == 1 + assert Range(n, n+m, 2).size == floor(m/2) + + m = Symbol('m', integer=True, positive=True, even=True) + assert Range(n, n+m, 2).size == m/2 + + +def test_issue_18400(): + n = Symbol('n', integer=True) + raises(ValueError, lambda: imageset(lambda x: x*2, Range(n))) + + n = Symbol('n', integer=True, positive=True) + # No exception + assert imageset(lambda x: x*2, Range(n)) == imageset(lambda x: x*2, Range(n)) + + +def test_ComplexRegion_intersect(): + # Polar form + X_axis = ComplexRegion(Interval(0, oo)*FiniteSet(0, S.Pi), polar=True) + + unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + upper_half_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + lower_half_disk = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + right_half_disk = ComplexRegion(Interval(0, oo)*Interval(-S.Pi/2, S.Pi/2), polar=True) + first_quad_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi/2), polar=True) + + assert upper_half_disk.intersect(unit_disk) == upper_half_unit_disk + assert right_half_disk.intersect(first_quad_disk) == first_quad_disk + assert upper_half_disk.intersect(right_half_disk) == first_quad_disk + assert upper_half_disk.intersect(lower_half_disk) == X_axis + + c1 = ComplexRegion(Interval(0, 4)*Interval(0, 2*S.Pi), polar=True) + assert c1.intersect(Interval(1, 5)) == Interval(1, 4) + assert c1.intersect(Interval(4, 9)) == FiniteSet(4) + assert c1.intersect(Interval(5, 12)) is S.EmptySet + + # Rectangular form + X_axis = ComplexRegion(Interval(-oo, oo)*FiniteSet(0)) + + unit_square = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + upper_half_unit_square = ComplexRegion(Interval(-1, 1)*Interval(0, 1)) + upper_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(0, oo)) + lower_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(-oo, 0)) + right_half_plane = ComplexRegion(Interval(0, oo)*Interval(-oo, oo)) + first_quad_plane = ComplexRegion(Interval(0, oo)*Interval(0, oo)) + + assert upper_half_plane.intersect(unit_square) == upper_half_unit_square + assert right_half_plane.intersect(first_quad_plane) == first_quad_plane + assert upper_half_plane.intersect(right_half_plane) == first_quad_plane + assert upper_half_plane.intersect(lower_half_plane) == X_axis + + c1 = ComplexRegion(Interval(-5, 5)*Interval(-10, 10)) + assert c1.intersect(Interval(2, 7)) == Interval(2, 5) + assert c1.intersect(Interval(5, 7)) == FiniteSet(5) + assert c1.intersect(Interval(6, 9)) is S.EmptySet + + # unevaluated object + C1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + C2 = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + assert C1.intersect(C2) == Intersection(C1, C2, evaluate=False) + + +def test_ComplexRegion_union(): + # Polar form + c1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + c2 = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + c3 = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + c4 = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + + p1 = Union(Interval(0, 1)*Interval(0, 2*S.Pi), Interval(0, 1)*Interval(0, S.Pi)) + p2 = Union(Interval(0, oo)*Interval(0, S.Pi), Interval(0, oo)*Interval(S.Pi, 2*S.Pi)) + + assert c1.union(c2) == ComplexRegion(p1, polar=True) + assert c3.union(c4) == ComplexRegion(p2, polar=True) + + # Rectangular form + c5 = ComplexRegion(Interval(2, 5)*Interval(6, 9)) + c6 = ComplexRegion(Interval(4, 6)*Interval(10, 12)) + c7 = ComplexRegion(Interval(0, 10)*Interval(-10, 0)) + c8 = ComplexRegion(Interval(12, 16)*Interval(14, 20)) + + p3 = Union(Interval(2, 5)*Interval(6, 9), Interval(4, 6)*Interval(10, 12)) + p4 = Union(Interval(0, 10)*Interval(-10, 0), Interval(12, 16)*Interval(14, 20)) + + assert c5.union(c6) == ComplexRegion(p3) + assert c7.union(c8) == ComplexRegion(p4) + + assert c1.union(Interval(2, 4)) == Union(c1, Interval(2, 4), evaluate=False) + assert c5.union(Interval(2, 4)) == Union(c5, ComplexRegion.from_real(Interval(2, 4))) + + +def test_ComplexRegion_from_real(): + c1 = ComplexRegion(Interval(0, 1) * Interval(0, 2 * S.Pi), polar=True) + + raises(ValueError, lambda: c1.from_real(c1)) + assert c1.from_real(Interval(-1, 1)) == ComplexRegion(Interval(-1, 1) * FiniteSet(0), False) + + +def test_ComplexRegion_measure(): + a, b = Interval(2, 5), Interval(4, 8) + theta1, theta2 = Interval(0, 2*S.Pi), Interval(0, S.Pi) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*theta1, b*theta2), polar=True) + + assert c1.measure == 12 + assert c2.measure == 9*pi + + +def test_normalize_theta_set(): + # Interval + assert normalize_theta_set(Interval(pi, 2*pi)) == \ + Union(FiniteSet(0), Interval.Ropen(pi, 2*pi)) + assert normalize_theta_set(Interval(pi*Rational(9, 2), 5*pi)) == Interval(pi/2, pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), pi/2)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval.open(pi*Rational(-3, 2), pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval.open(pi*Rational(-7, 2), pi*Rational(-3, 2))) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-4*pi, 3*pi)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), -pi/2)) == Interval(pi/2, pi*Rational(3, 2)) + assert normalize_theta_set(Interval.open(0, 2*pi)) == Interval.open(0, 2*pi) + assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(4*pi, pi*Rational(9, 2))) == Interval.open(0, pi/2) + assert normalize_theta_set(Interval.Lopen(4*pi, pi*Rational(9, 2))) == Interval.Lopen(0, pi/2) + assert normalize_theta_set(Interval.Ropen(4*pi, pi*Rational(9, 2))) == Interval.Ropen(0, pi/2) + assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \ + Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi)) + + # FiniteSet + assert normalize_theta_set(FiniteSet(0, pi, 3*pi)) == FiniteSet(0, pi) + assert normalize_theta_set(FiniteSet(0, pi/2, pi, 2*pi)) == FiniteSet(0, pi/2, pi) + assert normalize_theta_set(FiniteSet(0, -pi/2, -pi, -2*pi)) == FiniteSet(0, pi, pi*Rational(3, 2)) + assert normalize_theta_set(FiniteSet(pi*Rational(-3, 2), pi/2)) == \ + FiniteSet(pi/2) + assert normalize_theta_set(FiniteSet(2*pi)) == FiniteSet(0) + + # Unions + assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \ + Union(Interval(0, pi/3), Interval(pi/2, pi)) + assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, pi*Rational(7, 3)))) == \ + Interval(0, pi) + + # ValueError for non-real sets + raises(ValueError, lambda: normalize_theta_set(S.Complexes)) + + # NotImplementedError for subset of reals + raises(NotImplementedError, lambda: normalize_theta_set(Interval(0, 1))) + + # NotImplementedError without pi as coefficient + raises(NotImplementedError, lambda: normalize_theta_set(Interval(1, 2*pi))) + raises(NotImplementedError, lambda: normalize_theta_set(Interval(2*pi, 10))) + raises(NotImplementedError, lambda: normalize_theta_set(FiniteSet(0, 3, 3*pi))) + + +def test_ComplexRegion_FiniteSet(): + x, y, z, a, b, c = symbols('x y z a b c') + + # Issue #9669 + assert ComplexRegion(FiniteSet(a, b, c)*FiniteSet(x, y, z)) == \ + FiniteSet(a + I*x, a + I*y, a + I*z, b + I*x, b + I*y, + b + I*z, c + I*x, c + I*y, c + I*z) + assert ComplexRegion(FiniteSet(2)*FiniteSet(3)) == FiniteSet(2 + 3*I) + + +def test_union_RealSubSet(): + assert (S.Complexes).union(Interval(1, 2)) == S.Complexes + assert (S.Complexes).union(S.Integers) == S.Complexes + + +def test_SetKind_fancySet(): + G = lambda *args: ImageSet(Lambda(x, x ** 2), *args) + assert G(Interval(1, 4)).kind is SetKind(NumberKind) + assert G(FiniteSet(1, 4)).kind is SetKind(NumberKind) + assert S.Rationals.kind is SetKind(NumberKind) + assert S.Naturals.kind is SetKind(NumberKind) + assert S.Integers.kind is SetKind(NumberKind) + assert Range(3).kind is SetKind(NumberKind) + a = Interval(2, 3) + b = Interval(4, 6) + c1 = ComplexRegion(a*b) + assert c1.kind is SetKind(TupleKind(NumberKind, NumberKind)) + + +def test_issue_9980(): + c1 = ComplexRegion(Interval(1, 2)*Interval(2, 3)) + c2 = ComplexRegion(Interval(1, 5)*Interval(1, 3)) + R = Union(c1, c2) + assert simplify(R) == ComplexRegion(Union(Interval(1, 2)*Interval(2, 3), \ + Interval(1, 5)*Interval(1, 3)), False) + assert c1.func(*c1.args) == c1 + assert R.func(*R.args) == R + + +def test_issue_11732(): + interval12 = Interval(1, 2) + finiteset1234 = FiniteSet(1, 2, 3, 4) + pointComplex = Tuple(1, 5) + + assert (interval12 in S.Naturals) == False + assert (interval12 in S.Naturals0) == False + assert (interval12 in S.Integers) == False + assert (interval12 in S.Complexes) == False + + assert (finiteset1234 in S.Naturals) == False + assert (finiteset1234 in S.Naturals0) == False + assert (finiteset1234 in S.Integers) == False + assert (finiteset1234 in S.Complexes) == False + + assert (pointComplex in S.Naturals) == False + assert (pointComplex in S.Naturals0) == False + assert (pointComplex in S.Integers) == False + assert (pointComplex in S.Complexes) == True + + +def test_issue_11730(): + unit = Interval(0, 1) + square = ComplexRegion(unit ** 2) + + assert Union(S.Complexes, FiniteSet(oo)) != S.Complexes + assert Union(S.Complexes, FiniteSet(eye(4))) != S.Complexes + assert Union(unit, square) == square + assert Intersection(S.Reals, square) == unit + + +def test_issue_11938(): + unit = Interval(0, 1) + ival = Interval(1, 2) + cr1 = ComplexRegion(ival * unit) + + assert Intersection(cr1, S.Reals) == ival + assert Intersection(cr1, unit) == FiniteSet(1) + + arg1 = Interval(0, S.Pi) + arg2 = FiniteSet(S.Pi) + arg3 = Interval(S.Pi / 4, 3 * S.Pi / 4) + cp1 = ComplexRegion(unit * arg1, polar=True) + cp2 = ComplexRegion(unit * arg2, polar=True) + cp3 = ComplexRegion(unit * arg3, polar=True) + + assert Intersection(cp1, S.Reals) == Interval(-1, 1) + assert Intersection(cp2, S.Reals) == Interval(-1, 0) + assert Intersection(cp3, S.Reals) == FiniteSet(0) + + +def test_issue_11914(): + a, b = Interval(0, 1), Interval(0, pi) + c, d = Interval(2, 3), Interval(pi, 3 * pi / 2) + cp1 = ComplexRegion(a * b, polar=True) + cp2 = ComplexRegion(c * d, polar=True) + + assert -3 in cp1.union(cp2) + assert -3 in cp2.union(cp1) + assert -5 not in cp1.union(cp2) + + +def test_issue_9543(): + assert ImageSet(Lambda(x, x**2), S.Naturals).is_subset(S.Reals) + + +def test_issue_16871(): + assert ImageSet(Lambda(x, x), FiniteSet(1)) == {1} + assert ImageSet(Lambda(x, x - 3), S.Integers + ).intersection(S.Integers) is S.Integers + + +@XFAIL +def test_issue_16871b(): + assert ImageSet(Lambda(x, x - 3), S.Integers).is_subset(S.Integers) + + +def test_issue_18050(): + assert imageset(Lambda(x, I*x + 1), S.Integers + ) == ImageSet(Lambda(x, I*x + 1), S.Integers) + assert imageset(Lambda(x, 3*I*x + 4 + 8*I), S.Integers + ) == ImageSet(Lambda(x, 3*I*x + 4 + 2*I), S.Integers) + # no 'Mod' for next 2 tests: + assert imageset(Lambda(x, 2*x + 3*I), S.Integers + ) == ImageSet(Lambda(x, 2*x + 3*I), S.Integers) + r = Symbol('r', positive=True) + assert imageset(Lambda(x, r*x + 10), S.Integers + ) == ImageSet(Lambda(x, r*x + 10), S.Integers) + # reduce real part: + assert imageset(Lambda(x, 3*x + 8 + 5*I), S.Integers + ) == ImageSet(Lambda(x, 3*x + 2 + 5*I), S.Integers) + + +def test_Rationals(): + assert S.Integers.is_subset(S.Rationals) + assert S.Naturals.is_subset(S.Rationals) + assert S.Naturals0.is_subset(S.Rationals) + assert S.Rationals.is_subset(S.Reals) + assert S.Rationals.inf is -oo + assert S.Rationals.sup is oo + it = iter(S.Rationals) + assert [next(it) for i in range(12)] == [ + 0, 1, -1, S.Half, 2, Rational(-1, 2), -2, + Rational(1, 3), 3, Rational(-1, 3), -3, Rational(2, 3)] + assert Basic() not in S.Rationals + assert S.Half in S.Rationals + assert S.Rationals.contains(0.5) == Contains( + 0.5, S.Rationals, evaluate=False) + assert 2 in S.Rationals + r = symbols('r', rational=True) + assert r in S.Rationals + raises(TypeError, lambda: x in S.Rationals) + # issue #18134: + assert S.Rationals.boundary == S.Reals + assert S.Rationals.closure == S.Reals + assert S.Rationals.is_open == False + assert S.Rationals.is_closed == False + + +def test_NZQRC_unions(): + # check that all trivial number set unions are simplified: + nbrsets = (S.Naturals, S.Naturals0, S.Integers, S.Rationals, + S.Reals, S.Complexes) + unions = (Union(a, b) for a in nbrsets for b in nbrsets) + assert all(u.is_Union is False for u in unions) + + +def test_imageset_intersection(): + n = Dummy() + s = ImageSet(Lambda(n, -I*(I*(2*pi*n - pi/4) + + log(Abs(sqrt(-I))))), S.Integers) + assert s.intersect(S.Reals) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_issue_17858(): + assert 1 in Range(-oo, oo) + assert 0 in Range(oo, -oo, -1) + assert oo not in Range(-oo, oo) + assert -oo not in Range(-oo, oo) + +def test_issue_17859(): + r = Range(-oo,oo) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) + r = Range(oo,-oo,-1) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_ordinals.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..973ca329586f3e904f9377c44022c266f81c805c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_ordinals.py @@ -0,0 +1,67 @@ +from sympy.sets.ordinals import Ordinal, OmegaPower, ord0, omega +from sympy.testing.pytest import raises + +def test_string_ordinals(): + assert str(omega) == 'w' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(3, 2))) == 'w**5*3 + w**3*2' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(0, 5))) == 'w**5*3 + 5' + assert str(Ordinal(OmegaPower(1, 3), OmegaPower(0, 5))) == 'w*3 + 5' + assert str(Ordinal(OmegaPower(omega + 1, 1), OmegaPower(3, 2))) == 'w**(w + 1) + w**3*2' + +def test_addition_with_integers(): + assert 3 + Ordinal(OmegaPower(5, 3)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(5, 3))+3 == Ordinal(OmegaPower(5, 3), OmegaPower(0, 3)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(0, 2))+3 == \ + Ordinal(OmegaPower(5, 3), OmegaPower(0, 5)) + + +def test_addition_with_ordinals(): + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(3, 3)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 5)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(4, 2)) + assert Ordinal(OmegaPower(omega, 2), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(omega, 2), OmegaPower(4, 2)) + +def test_comparison(): + assert Ordinal(OmegaPower(5, 3)) > Ordinal(OmegaPower(4, 3), OmegaPower(2, 1)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) < Ordinal(OmegaPower(5, 4)) + assert Ordinal(OmegaPower(5, 4)) < Ordinal(OmegaPower(5, 5), OmegaPower(4, 1)) + + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + assert not Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(omega, 3)) > Ordinal(OmegaPower(5, 3)) + +def test_multiplication_with_integers(): + w = omega + assert 3*w == w + assert w*9 == Ordinal(OmegaPower(1, 9)) + +def test_multiplication(): + w = omega + assert w*(w + 1) == w*w + w + assert (w + 1)*(w + 1) == w*w + w + 1 + assert w*1 == w + assert 1*w == w + assert w*ord0 == ord0 + assert ord0*w == ord0 + assert w**w == w * w**w + assert (w**w)*w*w == w**(w + 2) + +def test_exponentiation(): + w = omega + assert w**2 == w*w + assert w**3 == w*w*w + assert w**(w + 1) == Ordinal(OmegaPower(omega + 1, 1)) + assert (w**w)*(w**w) == w**(w*2) + +def test_comapre_not_instance(): + w = OmegaPower(omega + 1, 1) + assert(not (w == None)) + assert(not (w < 5)) + raises(TypeError, lambda: w < 6.66) + +def test_is_successort(): + w = Ordinal(OmegaPower(5, 1)) + assert not w.is_successor_ordinal diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_powerset.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3a407d565f6b9537a296af103ec0a4e137cff9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_powerset.py @@ -0,0 +1,141 @@ +from sympy.core.expr import unchanged +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Interval +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import FiniteSet +from sympy.testing.pytest import raises, XFAIL + + +def test_powerset_creation(): + assert unchanged(PowerSet, FiniteSet(1, 2)) + assert unchanged(PowerSet, S.EmptySet) + raises(ValueError, lambda: PowerSet(123)) + assert unchanged(PowerSet, S.Reals) + assert unchanged(PowerSet, S.Integers) + + +def test_powerset_rewrite_FiniteSet(): + assert PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) == \ + FiniteSet(S.EmptySet, FiniteSet(1), FiniteSet(2), FiniteSet(1, 2)) + assert PowerSet(S.EmptySet).rewrite(FiniteSet) == FiniteSet(S.EmptySet) + assert PowerSet(S.Naturals).rewrite(FiniteSet) == PowerSet(S.Naturals) + + +def test_finiteset_rewrite_powerset(): + assert FiniteSet(S.EmptySet).rewrite(PowerSet) == PowerSet(S.EmptySet) + assert FiniteSet( + S.EmptySet, FiniteSet(1), + FiniteSet(2), FiniteSet(1, 2)).rewrite(PowerSet) == \ + PowerSet(FiniteSet(1, 2)) + assert FiniteSet(1, 2, 3).rewrite(PowerSet) == FiniteSet(1, 2, 3) + + +def test_powerset__contains__(): + subset_series = [ + S.EmptySet, + FiniteSet(1, 2), + S.Naturals, + S.Naturals0, + S.Integers, + S.Rationals, + S.Reals, + S.Complexes] + + l = len(subset_series) + for i in range(l): + for j in range(l): + if i <= j: + assert subset_series[i] in \ + PowerSet(subset_series[j], evaluate=False) + else: + assert subset_series[i] not in \ + PowerSet(subset_series[j], evaluate=False) + + +@XFAIL +def test_failing_powerset__contains__(): + # XXX These are failing when evaluate=True, + # but using unevaluated PowerSet works fine. + assert FiniteSet(1, 2) not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Integers not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Integers not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Reals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Reals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + + +def test_powerset__len__(): + A = PowerSet(S.EmptySet, evaluate=False) + assert len(A) == 1 + A = PowerSet(A, evaluate=False) + assert len(A) == 2 + A = PowerSet(A, evaluate=False) + assert len(A) == 4 + A = PowerSet(A, evaluate=False) + assert len(A) == 16 + + +def test_powerset__iter__(): + a = PowerSet(FiniteSet(1, 2)).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + + a = PowerSet(S.Naturals).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + assert next(a) == FiniteSet(3) + assert next(a) == FiniteSet(1, 3) + assert next(a) == FiniteSet(2, 3) + assert next(a) == FiniteSet(1, 2, 3) + + +def test_powerset_contains(): + A = PowerSet(FiniteSet(1), evaluate=False) + assert A.contains(2) == Contains(2, A) + + x = Symbol('x') + + A = PowerSet(FiniteSet(x), evaluate=False) + assert A.contains(FiniteSet(1)) == Contains(FiniteSet(1), A) + + +def test_powerset_method(): + # EmptySet + A = FiniteSet() + pset = A.powerset() + assert len(pset) == 1 + assert pset == FiniteSet(S.EmptySet) + + # FiniteSets + A = FiniteSet(1, 2) + pset = A.powerset() + assert len(pset) == 2**len(A) + assert pset == FiniteSet(FiniteSet(), FiniteSet(1), + FiniteSet(2), A) + # Not finite sets + A = Interval(0, 1) + assert A.powerset() == PowerSet(A) + +def test_is_subset(): + # covers line 101-102 + # initialize powerset(1), which is a subset of powerset(1,2) + subset = PowerSet(FiniteSet(1)) + pset = PowerSet(FiniteSet(1, 2)) + bad_set = PowerSet(FiniteSet(2, 3)) + # assert "subset" is subset of pset == True + assert subset.is_subset(pset) + # assert "bad_set" is subset of pset == False + assert not pset.is_subset(bad_set) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_setexpr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..faab1261c8d3e86901b04d30e8bc94de31642b93 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_setexpr.py @@ -0,0 +1,317 @@ +from sympy.sets.setexpr import SetExpr +from sympy.sets import Interval, FiniteSet, Intersection, ImageSet, Union + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.sets.sets import Set + + +a, x = symbols("a, x") +_d = Dummy("d") + + +def test_setexpr(): + se = SetExpr(Interval(0, 1)) + assert isinstance(se.set, Set) + assert isinstance(se, Expr) + + +def test_scalar_funcs(): + assert SetExpr(Interval(0, 1)).set == Interval(0, 1) + a, b = Symbol('a', real=True), Symbol('b', real=True) + a, b = 1, 2 + # TODO: add support for more functions in the future: + for f in [exp, log]: + input_se = f(SetExpr(Interval(a, b))) + output = input_se.set + expected = Interval(Min(f(a), f(b)), Max(f(a), f(b))) + assert output == expected + + +def test_Add_Mul(): + assert (SetExpr(Interval(0, 1)) + 1).set == Interval(1, 2) + assert (SetExpr(Interval(0, 1))*2).set == Interval(0, 2) + + +def test_Pow(): + assert (SetExpr(Interval(0, 2))**2).set == Interval(0, 4) + + +def test_compound(): + assert (exp(SetExpr(Interval(0, 1))*2 + 1)).set == \ + Interval(exp(1), exp(3)) + + +def test_Interval_Interval(): + assert (SetExpr(Interval(1, 2)) + SetExpr(Interval(10, 20))).set == \ + Interval(11, 22) + assert (SetExpr(Interval(1, 2))*SetExpr(Interval(10, 20))).set == \ + Interval(10, 40) + + +def test_FiniteSet_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2, 3)) + SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(2, 3, 4, 5) + assert (SetExpr(FiniteSet(1, 2, 3))*SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(1, 2, 3, 4, 6) + + +def test_Interval_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2)) + SetExpr(Interval(0, 10))).set == \ + Interval(1, 12) + + +def test_Many_Sets(): + assert (SetExpr(Interval(0, 1)) + + SetExpr(Interval(2, 3)) + + SetExpr(FiniteSet(10, 11, 12))).set == Interval(12, 16) + + +def test_same_setexprs_are_not_identical(): + a = SetExpr(FiniteSet(0, 1)) + b = SetExpr(FiniteSet(0, 1)) + assert (a + b).set == FiniteSet(0, 1, 2) + + # Cannot detect the set being the same: + # assert (a + a).set == FiniteSet(0, 2) + + +def test_Interval_arithmetic(): + i12cc = SetExpr(Interval(1, 2)) + i12lo = SetExpr(Interval.Lopen(1, 2)) + i12ro = SetExpr(Interval.Ropen(1, 2)) + i12o = SetExpr(Interval.open(1, 2)) + + n23cc = SetExpr(Interval(-2, 3)) + n23lo = SetExpr(Interval.Lopen(-2, 3)) + n23ro = SetExpr(Interval.Ropen(-2, 3)) + n23o = SetExpr(Interval.open(-2, 3)) + + n3n2cc = SetExpr(Interval(-3, -2)) + + assert i12cc + i12cc == SetExpr(Interval(2, 4)) + assert i12cc - i12cc == SetExpr(Interval(-1, 1)) + assert i12cc*i12cc == SetExpr(Interval(1, 4)) + assert i12cc/i12cc == SetExpr(Interval(S.Half, 2)) + assert i12cc**2 == SetExpr(Interval(1, 4)) + assert i12cc**3 == SetExpr(Interval(1, 8)) + + assert i12lo + i12ro == SetExpr(Interval.open(2, 4)) + assert i12lo - i12ro == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12ro == SetExpr(Interval.open(1, 4)) + assert i12lo/i12ro == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12lo == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12lo == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12lo + i12cc == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12cc == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12cc == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12cc == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12o == SetExpr(Interval.open(2, 4)) + assert i12lo - i12o == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12o == SetExpr(Interval.open(1, 4)) + assert i12lo/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12lo**2 == SetExpr(Interval.Lopen(1, 4)) + assert i12lo**3 == SetExpr(Interval.Lopen(1, 8)) + + assert i12ro + i12ro == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12ro == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12ro + i12cc == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12cc == SetExpr(Interval.Ropen(-1, 1)) + assert i12ro*i12cc == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12cc == SetExpr(Interval.Ropen(S.Half, 2)) + assert i12ro + i12o == SetExpr(Interval.open(2, 4)) + assert i12ro - i12o == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12o == SetExpr(Interval.open(1, 4)) + assert i12ro/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12ro**2 == SetExpr(Interval.Ropen(1, 4)) + assert i12ro**3 == SetExpr(Interval.Ropen(1, 8)) + + assert i12o + i12lo == SetExpr(Interval.open(2, 4)) + assert i12o - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12o*i12lo == SetExpr(Interval.open(1, 4)) + assert i12o/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12ro == SetExpr(Interval.open(2, 4)) + assert i12o - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12o*i12ro == SetExpr(Interval.open(1, 4)) + assert i12o/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12cc == SetExpr(Interval.open(2, 4)) + assert i12o - i12cc == SetExpr(Interval.open(-1, 1)) + assert i12o*i12cc == SetExpr(Interval.open(1, 4)) + assert i12o/i12cc == SetExpr(Interval.open(S.Half, 2)) + assert i12o**2 == SetExpr(Interval.open(1, 4)) + assert i12o**3 == SetExpr(Interval.open(1, 8)) + + assert n23cc + n23cc == SetExpr(Interval(-4, 6)) + assert n23cc - n23cc == SetExpr(Interval(-5, 5)) + assert n23cc*n23cc == SetExpr(Interval(-6, 9)) + assert n23cc/n23cc == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23ro == SetExpr(Interval.Ropen(-4, 6)) + assert n23cc - n23ro == SetExpr(Interval.Lopen(-5, 5)) + assert n23cc*n23ro == SetExpr(Interval.Ropen(-6, 9)) + assert n23cc/n23ro == SetExpr(Interval.Lopen(-oo, oo)) + assert n23cc + n23lo == SetExpr(Interval.Lopen(-4, 6)) + assert n23cc - n23lo == SetExpr(Interval.Ropen(-5, 5)) + assert n23cc*n23lo == SetExpr(Interval(-6, 9)) + assert n23cc/n23lo == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23o == SetExpr(Interval.open(-4, 6)) + assert n23cc - n23o == SetExpr(Interval.open(-5, 5)) + assert n23cc*n23o == SetExpr(Interval.open(-6, 9)) + assert n23cc/n23o == SetExpr(Interval.open(-oo, oo)) + assert n23cc**2 == SetExpr(Interval(0, 9)) + assert n23cc**3 == SetExpr(Interval(-8, 27)) + + n32cc = SetExpr(Interval(-3, 2)) + n32lo = SetExpr(Interval.Lopen(-3, 2)) + n32ro = SetExpr(Interval.Ropen(-3, 2)) + assert n32cc*n32lo == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32cc == SetExpr(Interval(-6, 9)) + assert n32lo*n32cc == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32ro == SetExpr(Interval(-6, 9)) + assert n32lo*n32ro == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + assert i12cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + + assert n3n2cc**2 == SetExpr(Interval(4, 9)) + assert n3n2cc**3 == SetExpr(Interval(-27, -8)) + + assert n23cc + i12cc == SetExpr(Interval(-1, 5)) + assert n23cc - i12cc == SetExpr(Interval(-4, 2)) + assert n23cc*i12cc == SetExpr(Interval(-4, 6)) + assert n23cc/i12cc == SetExpr(Interval(-2, 3)) + + +def test_SetExpr_Intersection(): + x, y, z, w = symbols("x y z w") + set1 = Interval(x, y) + set2 = Interval(w, z) + inter = Intersection(set1, set2) + se = SetExpr(inter) + assert exp(se).set == Intersection( + ImageSet(Lambda(x, exp(x)), set1), + ImageSet(Lambda(x, exp(x)), set2)) + assert cos(se).set == ImageSet(Lambda(x, cos(x)), inter) + + +def test_SetExpr_Interval_div(): + # TODO: some expressions cannot be calculated due to bugs (currently + # commented): + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(-2, 1)) == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(2, 3))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(0, 4)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(-3, 0)) == SetExpr(Interval(-oo, Rational(-2, 3))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(0, 3)) == SetExpr(Interval(Rational(2, 3), oo)) + + # assert SetExpr(Interval(0, 1))/SetExpr(Interval(0, 1)) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, 0))/SetExpr(Interval(0, 1)) == SetExpr(Interval(-oo, 0)) + assert SetExpr(Interval(-1, 2))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert 1/SetExpr(Interval(-1, 2)) == SetExpr(Union(Interval(-oo, -1), Interval(S.Half, oo))) + + assert 1/SetExpr(Interval(0, 2)) == SetExpr(Interval(S.Half, oo)) + assert (-1)/SetExpr(Interval(0, 2)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert 1/SetExpr(Interval(-oo, 0)) == SetExpr(Interval.open(-oo, 0)) + assert 1/SetExpr(Interval(-1, 0)) == SetExpr(Interval(-oo, -1)) + # assert (-2)/SetExpr(Interval(-oo, 0)) == SetExpr(Interval(0, oo)) + # assert 1/SetExpr(Interval(-oo, -1)) == SetExpr(Interval(-1, 0)) + + # assert SetExpr(Interval(1, 2))/a == Mul(SetExpr(Interval(1, 2)), 1/a, evaluate=False) + + # assert SetExpr(Interval(1, 2))/0 == SetExpr(Interval(1, 2))*zoo + # assert SetExpr(Interval(1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/(-oo) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-oo, oo))/oo == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-oo, oo))/(-oo) == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/(-oo) == SetExpr(Interval(0, oo)) + + +def test_SetExpr_Interval_pow(): + assert SetExpr(Interval(0, 2))**2 == SetExpr(Interval(0, 4)) + assert SetExpr(Interval(-1, 1))**2 == SetExpr(Interval(0, 1)) + assert SetExpr(Interval(1, 2))**2 == SetExpr(Interval(1, 4)) + assert SetExpr(Interval(-1, 2))**3 == SetExpr(Interval(-1, 8)) + assert SetExpr(Interval(-1, 1))**0 == SetExpr(FiniteSet(1)) + + + assert SetExpr(Interval(1, 2))**Rational(5, 2) == SetExpr(Interval(1, 4*sqrt(2))) + #assert SetExpr(Interval(-1, 2))**Rational(1, 3) == SetExpr(Interval(-1, 2**Rational(1, 3))) + #assert SetExpr(Interval(0, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + + #assert SetExpr(Interval(-4, 2))**Rational(2, 3) == SetExpr(Interval(0, 2*2**Rational(1, 3))) + + #assert SetExpr(Interval(-1, 5))**S.Half == SetExpr(Interval(0, sqrt(5))) + #assert SetExpr(Interval(-oo, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 4)) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(1, 5))**(-2) == SetExpr(Interval(Rational(1, 25), 1)) + assert SetExpr(Interval(-1, 3))**(-2) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(0, 2))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + assert SetExpr(Interval(-1, 2))**(-3) == SetExpr(Union(Interval(-oo, -1), Interval(Rational(1, 8), oo))) + assert SetExpr(Interval(-3, -2))**(-3) == SetExpr(Interval(Rational(-1, 8), Rational(-1, 27))) + assert SetExpr(Interval(-3, -2))**(-2) == SetExpr(Interval(Rational(1, 9), Rational(1, 4))) + #assert SetExpr(Interval(0, oo))**S.Half == SetExpr(Interval(0, oo)) + #assert SetExpr(Interval(-oo, -1))**Rational(1, 3) == SetExpr(Interval(-oo, -1)) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 3)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-oo, 0))**(-2) == SetExpr(Interval.open(0, oo)) + assert SetExpr(Interval(-2, 0))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + + assert SetExpr(Interval(Rational(1, 3), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(S.Half, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(0, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(2, 3))**oo == SetExpr(FiniteSet(oo)) + assert SetExpr(Interval(1, 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(S.Half, 3))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-1, 3), Rational(-1, 4)))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(-1, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-3, -2))**oo == SetExpr(FiniteSet(-oo, oo)) + assert SetExpr(Interval(-2, -1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(Rational(-1, 2), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(Rational(-1, 2), 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-2, 3), 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(-1, 1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, S.Half))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, 2))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, S.Half))**oo == SetExpr(Interval(-oo, oo)) + + assert (SetExpr(Interval(1, 2))**x).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**x), Interval(1, 2)))) + + assert SetExpr(Interval(2, 3))**(-oo) == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, 2))**(-oo) == SetExpr(Interval(0, oo)) + assert (SetExpr(Interval(-1, 2))**(-oo)).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**(-oo)), Interval(-1, 2)))) + + +def test_SetExpr_Integers(): + assert SetExpr(S.Integers) + 1 == SetExpr(S.Integers) + assert (SetExpr(S.Integers) + I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, _d + I), S.Integers))) + assert SetExpr(S.Integers)*(-1) == SetExpr(S.Integers) + assert (SetExpr(S.Integers)*2).dummy_eq( + SetExpr(ImageSet(Lambda(_d, 2*_d), S.Integers))) + assert (SetExpr(S.Integers)*I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d), S.Integers))) + # issue #18050: + assert SetExpr(S.Integers)._eval_func(Lambda(x, I*x + 1)).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d + 1), S.Integers))) + # needs improvement: + assert (SetExpr(S.Integers)*I + 1).dummy_eq( + SetExpr(ImageSet(Lambda(x, x + 1), + ImageSet(Lambda(_d, _d*I), S.Integers)))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_sets.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..657ab19a90eb88ca48f266f7a5cf050504caed43 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/sets/tests/test_sets.py @@ -0,0 +1,1753 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.logic.boolalg import (false, true) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.dense import Matrix +from sympy.polys.rootoftools import rootof +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range) +from sympy.sets.sets import (Complement, DisjointUnion, FiniteSet, Intersection, Interval, ProductSet, Set, SymmetricDifference, Union, imageset, SetKind) +from mpmath import mpi + +from sympy.core.expr import unchanged +from sympy.core.relational import Eq, Ne, Le, Lt, LessThan +from sympy.logic import And, Or, Xor +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.utilities.iterables import cartes + +from sympy.abc import x, y, z, m, n + +EmptySet = S.EmptySet + +def test_imageset(): + ints = S.Integers + assert imageset(x, x - 1, S.Naturals) is S.Naturals0 + assert imageset(x, x + 1, S.Naturals0) is S.Naturals + assert imageset(x, abs(x), S.Naturals0) is S.Naturals0 + assert imageset(x, abs(x), S.Naturals) is S.Naturals + assert imageset(x, abs(x), S.Integers) is S.Naturals0 + # issue 16878a + r = symbols('r', real=True) + assert imageset(x, (x, x), S.Reals)._contains((1, r)) == None + assert imageset(x, (x, x), S.Reals)._contains((1, 2)) == False + assert (r, r) in imageset(x, (x, x), S.Reals) + assert 1 + I in imageset(x, x + I, S.Reals) + assert {1} not in imageset(x, (x,), S.Reals) + assert (1, 1) not in imageset(x, (x,), S.Reals) + raises(TypeError, lambda: imageset(x, ints)) + raises(ValueError, lambda: imageset(x, y, z, ints)) + raises(ValueError, lambda: imageset(Lambda(x, cos(x)), y)) + assert (1, 2) in imageset(Lambda((x, y), (x, y)), ints, ints) + raises(ValueError, lambda: imageset(Lambda(x, x), ints, ints)) + assert imageset(cos, ints) == ImageSet(Lambda(x, cos(x)), ints) + def f(x): + return cos(x) + assert imageset(f, ints) == imageset(x, cos(x), ints) + f = lambda x: cos(x) + assert imageset(f, ints) == ImageSet(Lambda(x, cos(x)), ints) + assert imageset(x, 1, ints) == FiniteSet(1) + assert imageset(x, y, ints) == {y} + assert imageset((x, y), (1, z), ints, S.Reals) == {(1, z)} + clash = Symbol('x', integer=true) + assert (str(imageset(lambda x: x + clash, Interval(-2, 1)).lamda.expr) + in ('x0 + x', 'x + x0')) + x1, x2 = symbols("x1, x2") + assert imageset(lambda x, y: + Add(x, y), Interval(1, 2), Interval(2, 3)).dummy_eq( + ImageSet(Lambda((x1, x2), x1 + x2), + Interval(1, 2), Interval(2, 3))) + + +def test_is_empty(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_empty is False + + assert S.EmptySet.is_empty is True + + +def test_is_finiteset(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_finite_set is False + + assert S.EmptySet.is_finite_set is True + + assert FiniteSet(1, 2).is_finite_set is True + assert Interval(1, 2).is_finite_set is False + assert Interval(x, y).is_finite_set is None + assert ProductSet(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert ProductSet(FiniteSet(1), Interval(1, 2)).is_finite_set is False + assert ProductSet(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Union(Interval(0, 1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert Union(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Intersection(Interval(x, y), FiniteSet(1)).is_finite_set is True + assert Intersection(Interval(x, y), Interval(1, 2)).is_finite_set is None + assert Intersection(FiniteSet(x), FiniteSet(y)).is_finite_set is True + assert Complement(FiniteSet(1), Interval(x, y)).is_finite_set is True + assert Complement(Interval(x, y), FiniteSet(1)).is_finite_set is None + assert Complement(Interval(1, 2), FiniteSet(x)).is_finite_set is False + assert DisjointUnion(Interval(-5, 3), FiniteSet(x, y)).is_finite_set is False + assert DisjointUnion(S.EmptySet, FiniteSet(x, y), S.EmptySet).is_finite_set is True + + +def test_deprecated_is_EmptySet(): + with warns_deprecated_sympy(): + S.EmptySet.is_EmptySet + + with warns_deprecated_sympy(): + FiniteSet(1).is_EmptySet + + +def test_interval_arguments(): + assert Interval(0, oo) == Interval(0, oo, False, True) + assert Interval(0, oo).right_open is true + assert Interval(-oo, 0) == Interval(-oo, 0, True, False) + assert Interval(-oo, 0).left_open is true + assert Interval(oo, -oo) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(-oo, -oo) == S.EmptySet + assert Interval(oo, x) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(x, -oo) == S.EmptySet + assert Interval(x, x) == {x} + + assert isinstance(Interval(1, 1), FiniteSet) + e = Sum(x, (x, 1, 3)) + assert isinstance(Interval(e, e), FiniteSet) + + assert Interval(1, 0) == S.EmptySet + assert Interval(1, 1).measure == 0 + + assert Interval(1, 1, False, True) == S.EmptySet + assert Interval(1, 1, True, False) == S.EmptySet + assert Interval(1, 1, True, True) == S.EmptySet + + + assert isinstance(Interval(0, Symbol('a')), Interval) + assert Interval(Symbol('a', positive=True), 0) == S.EmptySet + raises(ValueError, lambda: Interval(0, S.ImaginaryUnit)) + raises(ValueError, lambda: Interval(0, Symbol('z', extended_real=False))) + raises(ValueError, lambda: Interval(x, x + S.ImaginaryUnit)) + + raises(NotImplementedError, lambda: Interval(0, 1, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, False, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, z, And(x, y))) + + +def test_interval_symbolic_end_points(): + a = Symbol('a', real=True) + + assert Union(Interval(0, a), Interval(0, 3)).sup == Max(a, 3) + assert Union(Interval(a, 0), Interval(-3, 0)).inf == Min(-3, a) + + assert Interval(0, a).contains(1) == LessThan(1, a) + + +def test_interval_is_empty(): + x, y = symbols('x, y') + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + nn = Symbol('nn', nonnegative=True) + assert Interval(1, 2).is_empty == False + assert Interval(3, 3).is_empty == False # FiniteSet + assert Interval(r, r).is_empty == False # FiniteSet + assert Interval(r, r + nn).is_empty == False + assert Interval(x, x).is_empty == False + assert Interval(1, oo).is_empty == False + assert Interval(-oo, oo).is_empty == False + assert Interval(-oo, 1).is_empty == False + assert Interval(x, y).is_empty == None + assert Interval(r, oo).is_empty == False # real implies finite + assert Interval(n, 0).is_empty == False + assert Interval(n, 0, left_open=True).is_empty == False + assert Interval(p, 0).is_empty == True # EmptySet + assert Interval(nn, 0).is_empty == None + assert Interval(n, p).is_empty == False + assert Interval(0, p, left_open=True).is_empty == False + assert Interval(0, p, right_open=True).is_empty == False + assert Interval(0, nn, left_open=True).is_empty == None + assert Interval(0, nn, right_open=True).is_empty == None + + +def test_union(): + assert Union(Interval(1, 2), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 2), Interval(2, 3, True)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(2, 4)) == Interval(1, 4) + assert Union(Interval(1, 2), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(1, 2)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(1, 2)) == \ + Interval(1, 3, False, True) + assert Union(Interval(1, 3), Interval(1, 2, False, True)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 2, True), Interval(1, 3, True, True)) == \ + Interval(1, 3, True, True) + assert Union(Interval(1, 2, True, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 3), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(2, 3)) == \ + Interval(1, 3) + assert Union(Interval(1, 2, False, True), Interval(2, 3, True)) != \ + Interval(1, 3) + assert Union(Interval(1, 2), S.EmptySet) == Interval(1, 2) + assert Union(S.EmptySet) == S.EmptySet + + assert Union(Interval(0, 1), *[FiniteSet(1.0/n) for n in range(1, 10)]) == \ + Interval(0, 1) + # issue #18241: + x = Symbol('x') + assert Union(Interval(0, 1), FiniteSet(1, x)) == Union( + Interval(0, 1), FiniteSet(x)) + assert unchanged(Union, Interval(0, 1), FiniteSet(2, x)) + + assert Interval(1, 2).union(Interval(2, 3)) == \ + Interval(1, 2) + Interval(2, 3) + + assert Interval(1, 2).union(Interval(2, 3)) == Interval(1, 3) + + assert Union(Set()) == Set() + + assert FiniteSet(1) + FiniteSet(2) + FiniteSet(3) == FiniteSet(1, 2, 3) + assert FiniteSet('ham') + FiniteSet('eggs') == FiniteSet('ham', 'eggs') + assert FiniteSet(1, 2, 3) + S.EmptySet == FiniteSet(1, 2, 3) + + assert FiniteSet(1, 2, 3) & FiniteSet(2, 3, 4) == FiniteSet(2, 3) + assert FiniteSet(1, 2, 3) | FiniteSet(2, 3, 4) == FiniteSet(1, 2, 3, 4) + + assert FiniteSet(1, 2, 3) & S.EmptySet == S.EmptySet + assert FiniteSet(1, 2, 3) | S.EmptySet == FiniteSet(1, 2, 3) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert S.EmptySet | FiniteSet(x, FiniteSet(y, z)) == \ + FiniteSet(x, FiniteSet(y, z)) + + # Test that Intervals and FiniteSets play nicely + assert Interval(1, 3) + FiniteSet(2) == Interval(1, 3) + assert Interval(1, 3, True, True) + FiniteSet(3) == \ + Interval(1, 3, True, False) + X = Interval(1, 3) + FiniteSet(5) + Y = Interval(1, 2) + FiniteSet(3) + XandY = X.intersect(Y) + assert 2 in X and 3 in X and 3 in XandY + assert XandY.is_subset(X) and XandY.is_subset(Y) + + raises(TypeError, lambda: Union(1, 2, 3)) + + assert X.is_iterable is False + + # issue 7843 + assert Union(S.EmptySet, FiniteSet(-sqrt(-I), sqrt(-I))) == \ + FiniteSet(-sqrt(-I), sqrt(-I)) + + assert Union(S.Reals, S.Integers) == S.Reals + + +def test_union_iter(): + # Use Range because it is ordered + u = Union(Range(3), Range(5), Range(4), evaluate=False) + + # Round robin + assert list(u) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4] + + +def test_union_is_empty(): + assert (Interval(x, y) + FiniteSet(1)).is_empty == False + assert (Interval(x, y) + Interval(-x, y)).is_empty == None + + +def test_difference(): + assert Interval(1, 3) - Interval(1, 2) == Interval(2, 3, True) + assert Interval(1, 3) - Interval(2, 3) == Interval(1, 2, False, True) + assert Interval(1, 3, True) - Interval(2, 3) == Interval(1, 2, True, True) + assert Interval(1, 3, True) - Interval(2, 3, True) == \ + Interval(1, 2, True, False) + assert Interval(0, 2) - FiniteSet(1) == \ + Union(Interval(0, 1, False, True), Interval(1, 2, True, False)) + + # issue #18119 + assert S.Reals - FiniteSet(I) == S.Reals + assert S.Reals - FiniteSet(-I, I) == S.Reals + assert Interval(0, 10) - FiniteSet(-I, I) == Interval(0, 10) + assert Interval(0, 10) - FiniteSet(1, I) == Union( + Interval.Ropen(0, 1), Interval.Lopen(1, 10)) + assert S.Reals - FiniteSet(1, 2 + I, x, y**2) == Complement( + Union(Interval.open(-oo, 1), Interval.open(1, oo)), FiniteSet(x, y**2), + evaluate=False) + + assert FiniteSet(1, 2, 3) - FiniteSet(2) == FiniteSet(1, 3) + assert FiniteSet('ham', 'eggs') - FiniteSet('eggs') == FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4) - Interval(2, 10, True, False) == \ + FiniteSet(1, 2) + assert FiniteSet(1, 2, 3, 4) - S.EmptySet == FiniteSet(1, 2, 3, 4) + assert Union(Interval(0, 2), FiniteSet(2, 3, 4)) - Interval(1, 3) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert -1 in S.Reals - S.Naturals + + +def test_Complement(): + A = FiniteSet(1, 3, 4) + B = FiniteSet(3, 4) + C = Interval(1, 3) + D = Interval(1, 2) + + assert Complement(A, B, evaluate=False).is_iterable is True + assert Complement(A, C, evaluate=False).is_iterable is True + assert Complement(C, D, evaluate=False).is_iterable is None + + assert FiniteSet(*Complement(A, B, evaluate=False)) == FiniteSet(1) + assert FiniteSet(*Complement(A, C, evaluate=False)) == FiniteSet(4) + raises(TypeError, lambda: FiniteSet(*Complement(C, A, evaluate=False))) + + assert Complement(Interval(1, 3), Interval(1, 2)) == Interval(2, 3, True) + assert Complement(FiniteSet(1, 3, 4), FiniteSet(3, 4)) == FiniteSet(1) + assert Complement(Union(Interval(0, 2), FiniteSet(2, 3, 4)), + Interval(1, 3)) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert 3 not in Complement(Interval(0, 5), Interval(1, 4), evaluate=False) + assert -1 in Complement(S.Reals, S.Naturals, evaluate=False) + assert 1 not in Complement(S.Reals, S.Naturals, evaluate=False) + + assert Complement(S.Integers, S.UniversalSet) == EmptySet + assert S.UniversalSet.complement(S.Integers) == EmptySet + + assert (0 not in S.Reals.intersect(S.Integers - FiniteSet(0))) + + assert S.EmptySet - S.Integers == S.EmptySet + + assert (S.Integers - FiniteSet(0)) - FiniteSet(1) == S.Integers - FiniteSet(0, 1) + + assert S.Reals - Union(S.Naturals, FiniteSet(pi)) == \ + Intersection(S.Reals - S.Naturals, S.Reals - FiniteSet(pi)) + # issue 12712 + assert Complement(FiniteSet(x, y, 2), Interval(-10, 10)) == \ + Complement(FiniteSet(x, y), Interval(-10, 10)) + + A = FiniteSet(*symbols('a:c')) + B = FiniteSet(*symbols('d:f')) + assert unchanged(Complement, ProductSet(A, A), B) + + A2 = ProductSet(A, A) + B3 = ProductSet(B, B, B) + assert A2 - B3 == A2 + assert B3 - A2 == B3 + + +def test_set_operations_nonsets(): + '''Tests that e.g. FiniteSet(1) * 2 raises TypeError''' + ops = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + lambda a, b: a / b, + lambda a, b: a // b, + lambda a, b: a | b, + lambda a, b: a & b, + lambda a, b: a ^ b, + # FiniteSet(1) ** 2 gives a ProductSet + #lambda a, b: a ** b, + ] + Sx = FiniteSet(x) + Sy = FiniteSet(y) + sets = [ + {1}, + FiniteSet(1), + Interval(1, 2), + Union(Sx, Interval(1, 2)), + Intersection(Sx, Sy), + Complement(Sx, Sy), + ProductSet(Sx, Sy), + S.EmptySet, + ] + nums = [0, 1, 2, S(0), S(1), S(2)] + + for si in sets: + for ni in nums: + for op in ops: + raises(TypeError, lambda : op(si, ni)) + raises(TypeError, lambda : op(ni, si)) + raises(TypeError, lambda: si ** object()) + raises(TypeError, lambda: si ** {1}) + + +def test_complement(): + assert Complement({1, 2}, {1}) == {2} + assert Interval(0, 1).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, True, True)) + assert Interval(0, 1, True, False).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, True, True)) + assert Interval(0, 1, False, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, False, True)) + assert Interval(0, 1, True, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, False, True)) + + assert S.UniversalSet.complement(S.EmptySet) == S.EmptySet + assert S.UniversalSet.complement(S.Reals) == S.EmptySet + assert S.UniversalSet.complement(S.UniversalSet) == S.EmptySet + + assert S.EmptySet.complement(S.Reals) == S.Reals + + assert Union(Interval(0, 1), Interval(2, 3)).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, 2, True, True), + Interval(3, oo, True, True)) + + assert FiniteSet(0).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(0, oo, True, True)) + + assert (FiniteSet(5) + Interval(S.NegativeInfinity, + 0)).complement(S.Reals) == \ + Interval(0, 5, True, True) + Interval(5, S.Infinity, True, True) + + assert FiniteSet(1, 2, 3).complement(S.Reals) == \ + Interval(S.NegativeInfinity, 1, True, True) + \ + Interval(1, 2, True, True) + Interval(2, 3, True, True) +\ + Interval(3, S.Infinity, True, True) + + assert FiniteSet(x).complement(S.Reals) == Complement(S.Reals, FiniteSet(x)) + + assert FiniteSet(0, x).complement(S.Reals) == Complement(Interval(-oo, 0, True, True) + + Interval(0, oo, True, True) + , FiniteSet(x), evaluate=False) + + square = Interval(0, 1) * Interval(0, 1) + notsquare = square.complement(S.Reals*S.Reals) + + assert all(pt in square for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any( + pt in notsquare for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any(pt in square for pt in [(-1, 0), (1.5, .5), (10, 10)]) + assert all(pt in notsquare for pt in [(-1, 0), (1.5, .5), (10, 10)]) + + +def test_intersect1(): + assert all(S.Integers.intersection(i) is i for i in + (S.Naturals, S.Naturals0)) + assert all(i.intersection(S.Integers) is i for i in + (S.Naturals, S.Naturals0)) + s = S.Naturals0 + assert S.Naturals.intersection(s) is S.Naturals + assert s.intersection(S.Naturals) is S.Naturals + x = Symbol('x') + assert Interval(0, 2).intersect(Interval(1, 2)) == Interval(1, 2) + assert Interval(0, 2).intersect(Interval(1, 2, True)) == \ + Interval(1, 2, True) + assert Interval(0, 2, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, False) + assert Interval(0, 2, True, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, True) + assert Interval(0, 2).intersect(Union(Interval(0, 1), Interval(2, 3))) == \ + Union(Interval(0, 1), Interval(2, 2)) + + assert FiniteSet(1, 2).intersect(FiniteSet(1, 2, 3)) == FiniteSet(1, 2) + assert FiniteSet(1, 2, x).intersect(FiniteSet(x)) == FiniteSet(x) + assert FiniteSet('ham', 'eggs').intersect(FiniteSet('ham')) == \ + FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4, 5).intersect(S.EmptySet) == S.EmptySet + + assert Interval(0, 5).intersect(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersect(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(0, 2)) == \ + Union(Interval(0, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2, True, True)) == \ + S.EmptySet + assert Union(Interval(0, 1), Interval(2, 3)).intersect(S.EmptySet) == \ + S.EmptySet + assert Union(Interval(0, 5), FiniteSet('ham')).intersect(FiniteSet(2, 3, 4, 5, 6)) == \ + Intersection(FiniteSet(2, 3, 4, 5, 6), Union(FiniteSet('ham'), Interval(0, 5))) + assert Intersection(FiniteSet(1, 2, 3), Interval(2, x), Interval(3, y)) == \ + Intersection(FiniteSet(3), Interval(2, x), Interval(3, y), evaluate=False) + assert Intersection(FiniteSet(1, 2), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + assert Intersection(FiniteSet(1, 2, 4), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + # XXX: Is the real=True necessary here? + # https://github.com/sympy/sympy/issues/17532 + m, n = symbols('m, n', real=True) + assert Intersection(FiniteSet(m), FiniteSet(m, n), Interval(m, m+1)) == \ + FiniteSet(m) + + # issue 8217 + assert Intersection(FiniteSet(x), FiniteSet(y)) == \ + Intersection(FiniteSet(x), FiniteSet(y), evaluate=False) + assert FiniteSet(x).intersect(S.Reals) == \ + Intersection(S.Reals, FiniteSet(x), evaluate=False) + + # tests for the intersection alias + assert Interval(0, 5).intersection(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersection(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersection(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + + # canonical boundary selected + a = sqrt(2*sqrt(6) + 5) + b = sqrt(2) + sqrt(3) + assert Interval(a, 4).intersection(Interval(b, 5)) == Interval(b, 4) + assert Interval(1, a).intersection(Interval(0, b)) == Interval(1, b) + + +def test_intersection_interval_float(): + # intersection of Intervals with mixed Rational/Float boundaries should + # lead to Float boundaries in all cases regardless of which Interval is + # open or closed. + typs = [ + (Interval, Interval, Interval), + (Interval, Interval.open, Interval.open), + (Interval, Interval.Lopen, Interval.Lopen), + (Interval, Interval.Ropen, Interval.Ropen), + (Interval.open, Interval.open, Interval.open), + (Interval.open, Interval.Lopen, Interval.open), + (Interval.open, Interval.Ropen, Interval.open), + (Interval.Lopen, Interval.Lopen, Interval.Lopen), + (Interval.Lopen, Interval.Ropen, Interval.open), + (Interval.Ropen, Interval.Ropen, Interval.Ropen), + ] + + as_float = lambda a1, a2: a2 if isinstance(a2, float) else a1 + + for t1, t2, t3 in typs: + for t1i, t2i in [(t1, t2), (t2, t1)]: + for a1, a2, b1, b2 in cartes([2, 2.0], [2, 2.0], [3, 3.0], [3, 3.0]): + I1 = t1(a1, b1) + I2 = t2(a2, b2) + I3 = t3(as_float(a1, a2), as_float(b1, b2)) + assert I1.intersect(I2) == I3 + + +def test_intersection(): + # iterable + i = Intersection(FiniteSet(1, 2, 3), Interval(2, 5), evaluate=False) + assert i.is_iterable + assert set(i) == {S(2), S(3)} + + # challenging intervals + x = Symbol('x', real=True) + i = Intersection(Interval(0, 3), Interval(x, 6)) + assert (5 in i) is False + raises(TypeError, lambda: 2 in i) + + # Singleton special cases + assert Intersection(Interval(0, 1), S.EmptySet) == S.EmptySet + assert Intersection(Interval(-oo, oo), Interval(-oo, x)) == Interval(-oo, x) + + # Products + line = Interval(0, 5) + i = Intersection(line**2, line**3, evaluate=False) + assert (2, 2) not in i + assert (2, 2, 2) not in i + raises(TypeError, lambda: list(i)) + + a = Intersection(Intersection(S.Integers, S.Naturals, evaluate=False), S.Reals, evaluate=False) + assert a._argset == frozenset([Intersection(S.Naturals, S.Integers, evaluate=False), S.Reals]) + + assert Intersection(S.Complexes, FiniteSet(S.ComplexInfinity)) == S.EmptySet + + # issue 12178 + assert Intersection() == S.UniversalSet + + # issue 16987 + assert Intersection({1}, {1}, {x}) == Intersection({1}, {x}) + + +def test_issue_9623(): + n = Symbol('n') + + a = S.Reals + b = Interval(0, oo) + c = FiniteSet(n) + + assert Intersection(a, b, c) == Intersection(b, c) + assert Intersection(Interval(1, 2), Interval(3, 4), FiniteSet(n)) == EmptySet + + +def test_is_disjoint(): + assert Interval(0, 2).is_disjoint(Interval(1, 2)) == False + assert Interval(0, 2).is_disjoint(Interval(3, 4)) == True + + +def test_ProductSet__len__(): + A = FiniteSet(1, 2) + B = FiniteSet(1, 2, 3) + assert ProductSet(A).__len__() == 2 + assert ProductSet(A).__len__() is not S(2) + assert ProductSet(A, B).__len__() == 6 + assert ProductSet(A, B).__len__() is not S(6) + + +def test_ProductSet(): + # ProductSet is always a set of Tuples + assert ProductSet(S.Reals) == S.Reals ** 1 + assert ProductSet(S.Reals, S.Reals) == S.Reals ** 2 + assert ProductSet(S.Reals, S.Reals, S.Reals) == S.Reals ** 3 + + assert ProductSet(S.Reals) != S.Reals + assert ProductSet(S.Reals, S.Reals) == S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) != S.Reals * S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) == (S.Reals * S.Reals * S.Reals).flatten() + + assert 1 not in ProductSet(S.Reals) + assert (1,) in ProductSet(S.Reals) + + assert 1 not in ProductSet(S.Reals, S.Reals) + assert (1, 2) in ProductSet(S.Reals, S.Reals) + assert (1, I) not in ProductSet(S.Reals, S.Reals) + + assert (1, 2, 3) in ProductSet(S.Reals, S.Reals, S.Reals) + assert (1, 2, 3) in S.Reals ** 3 + assert (1, 2, 3) not in S.Reals * S.Reals * S.Reals + assert ((1, 2), 3) in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) not in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) in S.Reals * (S.Reals * S.Reals) + + assert ProductSet() == FiniteSet(()) + assert ProductSet(S.Reals, S.EmptySet) == S.EmptySet + + # See GH-17458 + + for ni in range(5): + Rn = ProductSet(*(S.Reals,) * ni) + assert (1,) * ni in Rn + assert 1 not in Rn + + assert (S.Reals * S.Reals) * S.Reals != S.Reals * (S.Reals * S.Reals) + + S1 = S.Reals + S2 = S.Integers + x1 = pi + x2 = 3 + assert x1 in S1 + assert x2 in S2 + assert (x1, x2) in S1 * S2 + S3 = S1 * S2 + x3 = (x1, x2) + assert x3 in S3 + assert (x3, x3) in S3 * S3 + assert x3 + x3 not in S3 * S3 + + raises(ValueError, lambda: S.Reals**-1) + with warns_deprecated_sympy(): + ProductSet(FiniteSet(s) for s in range(2)) + raises(TypeError, lambda: ProductSet(None)) + + S1 = FiniteSet(1, 2) + S2 = FiniteSet(3, 4) + S3 = ProductSet(S1, S2) + assert (S3.as_relational(x, y) + == And(S1.as_relational(x), S2.as_relational(y)) + == And(Or(Eq(x, 1), Eq(x, 2)), Or(Eq(y, 3), Eq(y, 4)))) + raises(ValueError, lambda: S3.as_relational(x)) + raises(ValueError, lambda: S3.as_relational(x, 1)) + raises(ValueError, lambda: ProductSet(Interval(0, 1)).as_relational(x, y)) + + Z2 = ProductSet(S.Integers, S.Integers) + assert Z2.contains((1, 2)) is S.true + assert Z2.contains((1,)) is S.false + assert Z2.contains(x) == Contains(x, Z2, evaluate=False) + assert Z2.contains(x).subs(x, 1) is S.false + assert Z2.contains((x, 1)).subs(x, 2) is S.true + assert Z2.contains((x, y)) == Contains(x, S.Integers) & Contains(y, S.Integers) + assert unchanged(Contains, (x, y), Z2) + assert Contains((1, 2), Z2) is S.true + + +def test_ProductSet_of_single_arg_is_not_arg(): + assert unchanged(ProductSet, Interval(0, 1)) + assert unchanged(ProductSet, ProductSet(Interval(0, 1))) + + +def test_ProductSet_is_empty(): + assert ProductSet(S.Integers, S.Reals).is_empty == False + assert ProductSet(Interval(x, 1), S.Reals).is_empty == None + + +def test_interval_subs(): + a = Symbol('a', real=True) + + assert Interval(0, a).subs(a, 2) == Interval(0, 2) + assert Interval(a, 0).subs(a, 2) == S.EmptySet + + +def test_interval_to_mpi(): + assert Interval(0, 1).to_mpi() == mpi(0, 1) + assert Interval(0, 1, True, False).to_mpi() == mpi(0, 1) + assert type(Interval(0, 1).to_mpi()) == type(mpi(0, 1)) + + +def test_set_evalf(): + assert Interval(S(11)/64, S.Half).evalf() == Interval( + Float('0.171875'), Float('0.5')) + assert Interval(x, S.Half, right_open=True).evalf() == Interval( + x, Float('0.5'), right_open=True) + assert Interval(-oo, S.Half).evalf() == Interval(-oo, Float('0.5')) + assert FiniteSet(2, x).evalf() == FiniteSet(Float('2.0'), x) + + +def test_measure(): + a = Symbol('a', real=True) + + assert Interval(1, 3).measure == 2 + assert Interval(0, a).measure == a + assert Interval(1, a).measure == a - 1 + + assert Union(Interval(1, 2), Interval(3, 4)).measure == 2 + assert Union(Interval(1, 2), Interval(3, 4), FiniteSet(5, 6, 7)).measure \ + == 2 + + assert FiniteSet(1, 2, oo, a, -oo, -5).measure == 0 + + assert S.EmptySet.measure == 0 + + square = Interval(0, 10) * Interval(0, 10) + offsetsquare = Interval(5, 15) * Interval(5, 15) + band = Interval(-oo, oo) * Interval(2, 4) + + assert square.measure == offsetsquare.measure == 100 + assert (square + offsetsquare).measure == 175 # there is some overlap + assert (square - offsetsquare).measure == 75 + assert (square * FiniteSet(1, 2, 3)).measure == 0 + assert (square.intersect(band)).measure == 20 + assert (square + band).measure is oo + assert (band * FiniteSet(1, 2, 3)).measure is nan + + +def test_is_subset(): + assert Interval(0, 1).is_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_subset(Interval(0, 2)) is False + assert Interval(0, 1).is_subset(FiniteSet(0, 1)) is False + + assert FiniteSet(1, 2).is_subset(FiniteSet(1, 2, 3, 4)) + assert FiniteSet(4, 5).is_subset(FiniteSet(1, 2, 3, 4)) is False + assert FiniteSet(1).is_subset(Interval(0, 2)) + assert FiniteSet(1, 2).is_subset(Interval(0, 2, True, True)) is False + assert (Interval(1, 2) + FiniteSet(3)).is_subset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) + + assert Interval(3, 4).is_subset(Union(Interval(0, 1), Interval(2, 5))) is True + assert Interval(3, 6).is_subset(Union(Interval(0, 1), Interval(2, 5))) is False + + assert FiniteSet(1, 2, 3, 4).is_subset(Interval(0, 5)) is True + assert S.EmptySet.is_subset(FiniteSet(1, 2, 3)) is True + + assert Interval(0, 1).is_subset(S.EmptySet) is False + assert S.EmptySet.is_subset(S.EmptySet) is True + + raises(ValueError, lambda: S.EmptySet.is_subset(1)) + + # tests for the issubset alias + assert FiniteSet(1, 2, 3, 4).issubset(Interval(0, 5)) is True + assert S.EmptySet.issubset(FiniteSet(1, 2, 3)) is True + + assert S.Naturals.is_subset(S.Integers) + assert S.Naturals0.is_subset(S.Integers) + + assert FiniteSet(x).is_subset(FiniteSet(y)) is None + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x)) is True + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x+1)) is False + + assert Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) is False + assert Interval(-2, 3).is_subset(Union(Interval(-oo, -2), Interval(3, oo))) is False + + n = Symbol('n', integer=True) + assert Range(-3, 4, 1).is_subset(FiniteSet(-10, 10)) is False + assert Range(S(10)**100).is_subset(FiniteSet(0, 1, 2)) is False + assert Range(6, 0, -2).is_subset(FiniteSet(2, 4, 6)) is True + assert Range(1, oo).is_subset(FiniteSet(1, 2)) is False + assert Range(-oo, 1).is_subset(FiniteSet(1)) is False + assert Range(3).is_subset(FiniteSet(0, 1, n)) is None + assert Range(n, n + 2).is_subset(FiniteSet(n, n + 1)) is True + assert Range(5).is_subset(Interval(0, 4, right_open=True)) is False + #issue 19513 + assert imageset(Lambda(n, 1/n), S.Integers).is_subset(S.Reals) is None + +def test_is_proper_subset(): + assert Interval(0, 1).is_proper_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_proper_subset(Interval(0, 2)) is False + assert S.EmptySet.is_proper_subset(FiniteSet(1, 2, 3)) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_subset(0)) + + +def test_is_superset(): + assert Interval(0, 1).is_superset(Interval(0, 2)) == False + assert Interval(0, 3).is_superset(Interval(0, 2)) + + assert FiniteSet(1, 2).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(4, 5).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(1).is_superset(Interval(0, 2)) == False + assert FiniteSet(1, 2).is_superset(Interval(0, 2, True, True)) == False + assert (Interval(1, 2) + FiniteSet(3)).is_superset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) == False + + assert Interval(3, 4).is_superset(Union(Interval(0, 1), Interval(2, 5))) == False + + assert FiniteSet(1, 2, 3, 4).is_superset(Interval(0, 5)) == False + assert S.EmptySet.is_superset(FiniteSet(1, 2, 3)) == False + + assert Interval(0, 1).is_superset(S.EmptySet) == True + assert S.EmptySet.is_superset(S.EmptySet) == True + + raises(ValueError, lambda: S.EmptySet.is_superset(1)) + + # tests for the issuperset alias + assert Interval(0, 1).issuperset(S.EmptySet) == True + assert S.EmptySet.issuperset(S.EmptySet) == True + + +def test_is_proper_superset(): + assert Interval(0, 1).is_proper_superset(Interval(0, 2)) is False + assert Interval(0, 3).is_proper_superset(Interval(0, 2)) is True + assert FiniteSet(1, 2, 3).is_proper_superset(S.EmptySet) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_superset(0)) + + +def test_contains(): + assert Interval(0, 2).contains(1) is S.true + assert Interval(0, 2).contains(3) is S.false + assert Interval(0, 2, True, False).contains(0) is S.false + assert Interval(0, 2, True, False).contains(2) is S.true + assert Interval(0, 2, False, True).contains(0) is S.true + assert Interval(0, 2, False, True).contains(2) is S.false + assert Interval(0, 2, True, True).contains(0) is S.false + assert Interval(0, 2, True, True).contains(2) is S.false + + assert (Interval(0, 2) in Interval(0, 2)) is False + + assert FiniteSet(1, 2, 3).contains(2) is S.true + assert FiniteSet(1, 2, Symbol('x')).contains(Symbol('x')) is S.true + + assert FiniteSet(y)._contains(x) == Eq(y, x, evaluate=False) + raises(TypeError, lambda: x in FiniteSet(y)) + assert FiniteSet({x, y})._contains({x}) == Eq({x, y}, {x}, evaluate=False) + assert FiniteSet({x, y}).subs(y, x)._contains({x}) is S.true + assert FiniteSet({x, y}).subs(y, x+1)._contains({x}) is S.false + + # issue 8197 + from sympy.abc import a, b + assert FiniteSet(b).contains(-a) == Eq(b, -a) + assert FiniteSet(b).contains(a) == Eq(b, a) + assert FiniteSet(a).contains(1) == Eq(a, 1) + raises(TypeError, lambda: 1 in FiniteSet(a)) + + # issue 8209 + rad1 = Pow(Pow(2, Rational(1, 3)) - 1, Rational(1, 3)) + rad2 = Pow(Rational(1, 9), Rational(1, 3)) - Pow(Rational(2, 9), Rational(1, 3)) + Pow(Rational(4, 9), Rational(1, 3)) + s1 = FiniteSet(rad1) + s2 = FiniteSet(rad2) + assert s1 - s2 == S.EmptySet + + items = [1, 2, S.Infinity, S('ham'), -1.1] + fset = FiniteSet(*items) + assert all(item in fset for item in items) + assert all(fset.contains(item) is S.true for item in items) + + assert Union(Interval(0, 1), Interval(2, 5)).contains(3) is S.true + assert Union(Interval(0, 1), Interval(2, 5)).contains(6) is S.false + assert Union(Interval(0, 1), FiniteSet(2, 5)).contains(3) is S.false + + assert S.EmptySet.contains(1) is S.false + assert FiniteSet(rootof(x**3 + x - 1, 0)).contains(S.Infinity) is S.false + + assert rootof(x**5 + x**3 + 1, 0) in S.Reals + assert not rootof(x**5 + x**3 + 1, 1) in S.Reals + + # non-bool results + assert Union(Interval(1, 2), Interval(3, 4)).contains(x) == \ + Or(And(S.One <= x, x <= 2), And(S(3) <= x, x <= 4)) + assert Intersection(Interval(1, x), Interval(2, 3)).contains(y) == \ + And(y <= 3, y <= x, S.One <= y, S(2) <= y) + + assert (S.Complexes).contains(S.ComplexInfinity) == S.false + + +def test_interval_symbolic(): + x = Symbol('x') + e = Interval(0, 1) + assert e.contains(x) == And(S.Zero <= x, x <= 1) + raises(TypeError, lambda: x in e) + e = Interval(0, 1, True, True) + assert e.contains(x) == And(S.Zero < x, x < 1) + c = Symbol('c', real=False) + assert Interval(x, x + 1).contains(c) == False + e = Symbol('e', extended_real=True) + assert Interval(-oo, oo).contains(e) == And( + S.NegativeInfinity < e, e < S.Infinity) + + +def test_union_contains(): + x = Symbol('x') + i1 = Interval(0, 1) + i2 = Interval(2, 3) + i3 = Union(i1, i2) + assert i3.as_relational(x) == Or(And(S.Zero <= x, x <= 1), And(S(2) <= x, x <= 3)) + raises(TypeError, lambda: x in i3) + e = i3.contains(x) + assert e == i3.as_relational(x) + assert e.subs(x, -0.5) is false + assert e.subs(x, 0.5) is true + assert e.subs(x, 1.5) is false + assert e.subs(x, 2.5) is true + assert e.subs(x, 3.5) is false + + U = Interval(0, 2, True, True) + Interval(10, oo) + FiniteSet(-1, 2, 5, 6) + assert all(el not in U for el in [0, 4, -oo]) + assert all(el in U for el in [2, 5, 10]) + + +def test_is_number(): + assert Interval(0, 1).is_number is False + assert Set().is_number is False + + +def test_Interval_is_left_unbounded(): + assert Interval(3, 4).is_left_unbounded is False + assert Interval(-oo, 3).is_left_unbounded is True + assert Interval(Float("-inf"), 3).is_left_unbounded is True + + +def test_Interval_is_right_unbounded(): + assert Interval(3, 4).is_right_unbounded is False + assert Interval(3, oo).is_right_unbounded is True + assert Interval(3, Float("+inf")).is_right_unbounded is True + + +def test_Interval_as_relational(): + x = Symbol('x') + + assert Interval(-1, 2, False, False).as_relational(x) == \ + And(Le(-1, x), Le(x, 2)) + assert Interval(-1, 2, True, False).as_relational(x) == \ + And(Lt(-1, x), Le(x, 2)) + assert Interval(-1, 2, False, True).as_relational(x) == \ + And(Le(-1, x), Lt(x, 2)) + assert Interval(-1, 2, True, True).as_relational(x) == \ + And(Lt(-1, x), Lt(x, 2)) + + assert Interval(-oo, 2, right_open=False).as_relational(x) == And(Lt(-oo, x), Le(x, 2)) + assert Interval(-oo, 2, right_open=True).as_relational(x) == And(Lt(-oo, x), Lt(x, 2)) + + assert Interval(-2, oo, left_open=False).as_relational(x) == And(Le(-2, x), Lt(x, oo)) + assert Interval(-2, oo, left_open=True).as_relational(x) == And(Lt(-2, x), Lt(x, oo)) + + assert Interval(-oo, oo).as_relational(x) == And(Lt(-oo, x), Lt(x, oo)) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert Interval(x, y).as_relational(x) == (x <= y) + assert Interval(y, x).as_relational(x) == (y <= x) + + +def test_Finite_as_relational(): + x = Symbol('x') + y = Symbol('y') + + assert FiniteSet(1, 2).as_relational(x) == Or(Eq(x, 1), Eq(x, 2)) + assert FiniteSet(y, -5).as_relational(x) == Or(Eq(x, y), Eq(x, -5)) + + +def test_Union_as_relational(): + x = Symbol('x') + assert (Interval(0, 1) + FiniteSet(2)).as_relational(x) == \ + Or(And(Le(0, x), Le(x, 1)), Eq(x, 2)) + assert (Interval(0, 1, True, True) + FiniteSet(1)).as_relational(x) == \ + And(Lt(0, x), Le(x, 1)) + assert Or(x < 0, x > 0).as_set().as_relational(x) == \ + And((x > -oo), (x < oo), Ne(x, 0)) + assert (Interval.Ropen(1, 3) + Interval.Lopen(3, 5) + ).as_relational(x) == And(Ne(x,3),(x>=1),(x<=5)) + + +def test_Intersection_as_relational(): + x = Symbol('x') + assert (Intersection(Interval(0, 1), FiniteSet(2), + evaluate=False).as_relational(x) + == And(And(Le(0, x), Le(x, 1)), Eq(x, 2))) + + +def test_Complement_as_relational(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == \ + And(Le(0, x), Le(x, 1), Ne(x, 2)) + + +@XFAIL +def test_Complement_as_relational_fail(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + # XXX This example fails because 0 <= x changes to x >= 0 + # during the evaluation. + assert expr.as_relational(x) == \ + (0 <= x) & (x <= 1) & Ne(x, 2) + + +def test_SymmetricDifference_as_relational(): + x = Symbol('x') + expr = SymmetricDifference(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == Xor(Eq(x, 2), Le(0, x) & Le(x, 1)) + + +def test_EmptySet(): + assert S.EmptySet.as_relational(Symbol('x')) is S.false + assert S.EmptySet.intersect(S.UniversalSet) == S.EmptySet + assert S.EmptySet.boundary == S.EmptySet + + +def test_finite_basic(): + x = Symbol('x') + A = FiniteSet(1, 2, 3) + B = FiniteSet(3, 4, 5) + AorB = Union(A, B) + AandB = A.intersect(B) + assert A.is_subset(AorB) and B.is_subset(AorB) + assert AandB.is_subset(A) + assert AandB == FiniteSet(3) + + assert A.inf == 1 and A.sup == 3 + assert AorB.inf == 1 and AorB.sup == 5 + assert FiniteSet(x, 1, 5).sup == Max(x, 5) + assert FiniteSet(x, 1, 5).inf == Min(x, 1) + + # issue 7335 + assert FiniteSet(S.EmptySet) != S.EmptySet + assert FiniteSet(FiniteSet(1, 2, 3)) != FiniteSet(1, 2, 3) + assert FiniteSet((1, 2, 3)) != FiniteSet(1, 2, 3) + + # Ensure a variety of types can exist in a FiniteSet + assert FiniteSet((1, 2), A, -5, x, 'eggs', x**2) + + assert (A > B) is False + assert (A >= B) is False + assert (A < B) is False + assert (A <= B) is False + assert AorB > A and AorB > B + assert AorB >= A and AorB >= B + assert A >= A and A <= A + assert A >= AandB and B >= AandB + assert A > AandB and B > AandB + + +def test_product_basic(): + H, T = 'H', 'T' + unit_line = Interval(0, 1) + d6 = FiniteSet(1, 2, 3, 4, 5, 6) + d4 = FiniteSet(1, 2, 3, 4) + coin = FiniteSet(H, T) + + square = unit_line * unit_line + + assert (0, 0) in square + assert 0 not in square + assert (H, T) in coin ** 2 + assert (.5, .5, .5) in (square * unit_line).flatten() + assert ((.5, .5), .5) in square * unit_line + assert (H, 3, 3) in (coin * d6 * d6).flatten() + assert ((H, 3), 3) in coin * d6 * d6 + HH, TT = sympify(H), sympify(T) + assert set(coin**2) == {(HH, HH), (HH, TT), (TT, HH), (TT, TT)} + + assert (d4*d4).is_subset(d6*d6) + + assert square.complement(Interval(-oo, oo)*Interval(-oo, oo)) == Union( + (Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))*Interval(-oo, oo), + Interval(-oo, oo)*(Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))) + + assert (Interval(-5, 5)**3).is_subset(Interval(-10, 10)**3) + assert not (Interval(-10, 10)**3).is_subset(Interval(-5, 5)**3) + assert not (Interval(-5, 5)**2).is_subset(Interval(-10, 10)**3) + + assert (Interval(.2, .5)*FiniteSet(.5)).is_subset(square) # segment in square + + assert len(coin*coin*coin) == 8 + assert len(S.EmptySet*S.EmptySet) == 0 + assert len(S.EmptySet*coin) == 0 + raises(TypeError, lambda: len(coin*Interval(0, 2))) + + +def test_real(): + x = Symbol('x', real=True) + + I = Interval(0, 5) + J = Interval(10, 20) + A = FiniteSet(1, 2, 30, x, S.Pi) + B = FiniteSet(-4, 0) + C = FiniteSet(100) + D = FiniteSet('Ham', 'Eggs') + + assert all(s.is_subset(S.Reals) for s in [I, J, A, B, C]) + assert not D.is_subset(S.Reals) + assert all((a + b).is_subset(S.Reals) for a in [I, J, A, B, C] for b in [I, J, A, B, C]) + assert not any((a + D).is_subset(S.Reals) for a in [I, J, A, B, C, D]) + + assert not (I + A + D).is_subset(S.Reals) + + +def test_supinf(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert (Interval(0, 1) + FiniteSet(2)).sup == 2 + assert (Interval(0, 1) + FiniteSet(2)).inf == 0 + assert (Interval(0, 1) + FiniteSet(x)).sup == Max(1, x) + assert (Interval(0, 1) + FiniteSet(x)).inf == Min(0, x) + assert FiniteSet(5, 1, x).sup == Max(5, x) + assert FiniteSet(5, 1, x).inf == Min(1, x) + assert FiniteSet(5, 1, x, y).sup == Max(5, x, y) + assert FiniteSet(5, 1, x, y).inf == Min(1, x, y) + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).sup == \ + S.Infinity + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).inf == \ + S.NegativeInfinity + assert FiniteSet('Ham', 'Eggs').sup == Max('Ham', 'Eggs') + + +def test_universalset(): + U = S.UniversalSet + x = Symbol('x') + assert U.as_relational(x) is S.true + assert U.union(Interval(2, 4)) == U + + assert U.intersect(Interval(2, 4)) == Interval(2, 4) + assert U.measure is S.Infinity + assert U.boundary == S.EmptySet + assert U.contains(0) is S.true + + +def test_Union_of_ProductSets_shares(): + line = Interval(0, 2) + points = FiniteSet(0, 1, 2) + assert Union(line * line, line * points) == line * line + + +def test_Interval_free_symbols(): + # issue 6211 + assert Interval(0, 1).free_symbols == set() + x = Symbol('x', real=True) + assert Interval(0, x).free_symbols == {x} + + +def test_image_interval(): + x = Symbol('x', real=True) + a = Symbol('a', real=True) + assert imageset(x, 2*x, Interval(-2, 1)) == Interval(-4, 2) + assert imageset(x, 2*x, Interval(-2, 1, True, False)) == \ + Interval(-4, 2, True, False) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1)) == Interval(0, 4) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1, True, True)) == \ + Interval(0, 4, False, True) + assert imageset(x, (x - 2)**2, Interval(1, 3)) == Interval(0, 1) + assert imageset(x, 3*x**4 - 26*x**3 + 78*x**2 - 90*x, Interval(0, 4)) == \ + Interval(-35, 0) # Multiple Maxima + assert imageset(x, x + 1/x, Interval(-oo, oo)) == Interval(-oo, -2) \ + + Interval(2, oo) # Single Infinite discontinuity + assert imageset(x, 1/x + 1/(x-1)**2, Interval(0, 2, True, False)) == \ + Interval(Rational(3, 2), oo, False) # Multiple Infinite discontinuities + + # Test for Python lambda + assert imageset(lambda x: 2*x, Interval(-2, 1)) == Interval(-4, 2) + + assert imageset(Lambda(x, a*x), Interval(0, 1)) == \ + ImageSet(Lambda(x, a*x), Interval(0, 1)) + + assert imageset(Lambda(x, sin(cos(x))), Interval(0, 1)) == \ + ImageSet(Lambda(x, sin(cos(x))), Interval(0, 1)) + + +def test_image_piecewise(): + f = Piecewise((x, x <= -1), (1/x**2, x <= 5), (x**3, True)) + f1 = Piecewise((0, x <= 1), (1, x <= 2), (2, True)) + assert imageset(x, f, Interval(-5, 5)) == Union(Interval(-5, -1), Interval(Rational(1, 25), oo)) + assert imageset(x, f1, Interval(1, 2)) == FiniteSet(0, 1) + + +@XFAIL # See: https://github.com/sympy/sympy/pull/2723#discussion_r8659826 +def test_image_Intersection(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert imageset(x, x**2, Interval(-2, 0).intersect(Interval(x, y))) == \ + Interval(0, 4).intersect(Interval(Min(x**2, y**2), Max(x**2, y**2))) + + +def test_image_FiniteSet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, FiniteSet(1, 2, 3)) == FiniteSet(2, 4, 6) + + +def test_image_Union(): + x = Symbol('x', real=True) + assert imageset(x, x**2, Interval(-2, 0) + FiniteSet(1, 2, 3)) == \ + (Interval(0, 4) + FiniteSet(9)) + + +def test_image_EmptySet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, S.EmptySet) == S.EmptySet + + +def test_issue_5724_7680(): + assert I not in S.Reals # issue 7680 + assert Interval(-oo, oo).contains(I) is S.false + + +def test_boundary(): + assert FiniteSet(1).boundary == FiniteSet(1) + assert all(Interval(0, 1, left_open, right_open).boundary == FiniteSet(0, 1) + for left_open in (true, false) for right_open in (true, false)) + + +def test_boundary_Union(): + assert (Interval(0, 1) + Interval(2, 3)).boundary == FiniteSet(0, 1, 2, 3) + assert ((Interval(0, 1, False, True) + + Interval(1, 2, True, False)).boundary == FiniteSet(0, 1, 2)) + + assert (Interval(0, 1) + FiniteSet(2)).boundary == FiniteSet(0, 1, 2) + assert Union(Interval(0, 10), Interval(5, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + assert Union(Interval(0, 10), Interval(0, 1), evaluate=False).boundary \ + == FiniteSet(0, 10) + assert Union(Interval(0, 10, True, True), + Interval(10, 15, True, True), evaluate=False).boundary \ + == FiniteSet(0, 10, 15) + + +@XFAIL +def test_union_boundary_of_joining_sets(): + """ Testing the boundary of unions is a hard problem """ + assert Union(Interval(0, 10), Interval(10, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + +def test_boundary_ProductSet(): + open_square = Interval(0, 1, True, True) ** 2 + assert open_square.boundary == (FiniteSet(0, 1) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1)) + + second_square = Interval(1, 2, True, True) * Interval(0, 1, True, True) + assert (open_square + second_square).boundary == ( + FiniteSet(0, 1) * Interval(0, 1) + + FiniteSet(1, 2) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1) + + Interval(1, 2) * FiniteSet(0, 1)) + + +def test_boundary_ProductSet_line(): + line_in_r2 = Interval(0, 1) * FiniteSet(0) + assert line_in_r2.boundary == line_in_r2 + + +def test_is_open(): + assert Interval(0, 1, False, False).is_open is False + assert Interval(0, 1, True, False).is_open is False + assert Interval(0, 1, True, True).is_open is True + assert FiniteSet(1, 2, 3).is_open is False + + +def test_is_closed(): + assert Interval(0, 1, False, False).is_closed is True + assert Interval(0, 1, True, False).is_closed is False + assert FiniteSet(1, 2, 3).is_closed is True + + +def test_closure(): + assert Interval(0, 1, False, True).closure == Interval(0, 1, False, False) + + +def test_interior(): + assert Interval(0, 1, False, True).interior == Interval(0, 1, True, True) + + +def test_issue_7841(): + raises(TypeError, lambda: x in S.Reals) + + +def test_Eq(): + assert Eq(Interval(0, 1), Interval(0, 1)) + assert Eq(Interval(0, 1), Interval(0, 2)) == False + + s1 = FiniteSet(0, 1) + s2 = FiniteSet(1, 2) + + assert Eq(s1, s1) + assert Eq(s1, s2) == False + + assert Eq(s1*s2, s1*s2) + assert Eq(s1*s2, s2*s1) == False + + assert unchanged(Eq, FiniteSet({x, y}), FiniteSet({x})) + assert Eq(FiniteSet({x, y}).subs(y, x), FiniteSet({x})) is S.true + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x) is S.true + assert Eq(FiniteSet({x, y}).subs(y, x+1), FiniteSet({x})) is S.false + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x+1) is S.false + + assert Eq(ProductSet({1}, {2}), Interval(1, 2)) is S.false + assert Eq(ProductSet({1}), ProductSet({1}, {2})) is S.false + + assert Eq(FiniteSet(()), FiniteSet(1)) is S.false + assert Eq(ProductSet(), FiniteSet(1)) is S.false + + i1 = Interval(0, 1) + i2 = Interval(x, y) + assert unchanged(Eq, ProductSet(i1, i1), ProductSet(i2, i2)) + + +def test_SymmetricDifference(): + A = FiniteSet(0, 1, 2, 3, 4, 5) + B = FiniteSet(2, 4, 6, 8, 10) + C = Interval(8, 10) + + assert SymmetricDifference(A, B, evaluate=False).is_iterable is True + assert SymmetricDifference(A, C, evaluate=False).is_iterable is None + assert FiniteSet(*SymmetricDifference(A, B, evaluate=False)) == \ + FiniteSet(0, 1, 3, 5, 6, 8, 10) + raises(TypeError, + lambda: FiniteSet(*SymmetricDifference(A, C, evaluate=False))) + + assert SymmetricDifference(FiniteSet(0, 1, 2, 3, 4, 5), \ + FiniteSet(2, 4, 6, 8, 10)) == FiniteSet(0, 1, 3, 5, 6, 8, 10) + assert SymmetricDifference(FiniteSet(2, 3, 4), FiniteSet(2, 3, 4 ,5)) \ + == FiniteSet(5) + assert FiniteSet(1, 2, 3, 4, 5) ^ FiniteSet(1, 2, 5, 6) == \ + FiniteSet(3, 4, 6) + assert Set(S(1), S(2), S(3)) ^ Set(S(2), S(3), S(4)) == Union(Set(S(1), S(2), S(3)) - Set(S(2), S(3), S(4)), \ + Set(S(2), S(3), S(4)) - Set(S(1), S(2), S(3))) + assert Interval(0, 4) ^ Interval(2, 5) == Union(Interval(0, 4) - \ + Interval(2, 5), Interval(2, 5) - Interval(0, 4)) + + +def test_issue_9536(): + from sympy.functions.elementary.exponential import log + a = Symbol('a', real=True) + assert FiniteSet(log(a)).intersect(S.Reals) == Intersection(S.Reals, FiniteSet(log(a))) + + +def test_issue_9637(): + n = Symbol('n') + a = FiniteSet(n) + b = FiniteSet(2, n) + assert Complement(S.Reals, a) == Complement(S.Reals, a, evaluate=False) + assert Complement(Interval(1, 3), a) == Complement(Interval(1, 3), a, evaluate=False) + assert Complement(Interval(1, 3), b) == \ + Complement(Union(Interval(1, 2, False, True), Interval(2, 3, True, False)), a) + assert Complement(a, S.Reals) == Complement(a, S.Reals, evaluate=False) + assert Complement(a, Interval(1, 3)) == Complement(a, Interval(1, 3), evaluate=False) + + +def test_issue_9808(): + # See https://github.com/sympy/sympy/issues/16342 + assert Complement(FiniteSet(y), FiniteSet(1)) == Complement(FiniteSet(y), FiniteSet(1), evaluate=False) + assert Complement(FiniteSet(1, 2, x), FiniteSet(x, y, 2, 3)) == \ + Complement(FiniteSet(1), FiniteSet(y), evaluate=False) + + +def test_issue_9956(): + assert Union(Interval(-oo, oo), FiniteSet(1)) == Interval(-oo, oo) + assert Interval(-oo, oo).contains(1) is S.true + + +def test_issue_Symbol_inter(): + i = Interval(0, oo) + r = S.Reals + mat = Matrix([0, 0, 0]) + assert Intersection(r, i, FiniteSet(m), FiniteSet(m, n)) == \ + Intersection(i, FiniteSet(m)) + assert Intersection(FiniteSet(1, m, n), FiniteSet(m, n, 2), i) == \ + Intersection(i, FiniteSet(m, n)) + assert Intersection(FiniteSet(m, n, x), FiniteSet(m, z), r) == \ + Intersection(Intersection({m, z}, {m, n, x}), r) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, x), r) == \ + Intersection(FiniteSet(3, m, n), FiniteSet(m, n, x), r, evaluate=False) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, 2, 3), r) == \ + Intersection(FiniteSet(3, m, n), r) + assert Intersection(r, FiniteSet(mat, 2, n), FiniteSet(0, mat, n)) == \ + Intersection(r, FiniteSet(n)) + assert Intersection(FiniteSet(sin(x), cos(x)), FiniteSet(sin(x), cos(x), 1), r) == \ + Intersection(r, FiniteSet(sin(x), cos(x))) + assert Intersection(FiniteSet(x**2, 1, sin(x)), FiniteSet(x**2, 2, sin(x)), r) == \ + Intersection(r, FiniteSet(x**2, sin(x))) + + +def test_issue_11827(): + assert S.Naturals0**4 + + +def test_issue_10113(): + f = x**2/(x**2 - 4) + assert imageset(x, f, S.Reals) == Union(Interval(-oo, 0), Interval(1, oo, True, True)) + assert imageset(x, f, Interval(-2, 2)) == Interval(-oo, 0) + assert imageset(x, f, Interval(-2, 3)) == Union(Interval(-oo, 0), Interval(Rational(9, 5), oo)) + + +def test_issue_10248(): + raises( + TypeError, lambda: list(Intersection(S.Reals, FiniteSet(x))) + ) + A = Symbol('A', real=True) + assert list(Intersection(S.Reals, FiniteSet(A))) == [A] + + +def test_issue_9447(): + a = Interval(0, 1) + Interval(2, 3) + assert Complement(S.UniversalSet, a) == Complement( + S.UniversalSet, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + assert Complement(S.Naturals, a) == Complement( + S.Naturals, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + + +def test_issue_10337(): + assert (FiniteSet(2) == 3) is False + assert (FiniteSet(2) != 3) is True + raises(TypeError, lambda: FiniteSet(2) < 3) + raises(TypeError, lambda: FiniteSet(2) <= 3) + raises(TypeError, lambda: FiniteSet(2) > 3) + raises(TypeError, lambda: FiniteSet(2) >= 3) + + +def test_issue_10326(): + bad = [ + EmptySet, + FiniteSet(1), + Interval(1, 2), + S.ComplexInfinity, + S.ImaginaryUnit, + S.Infinity, + S.NaN, + S.NegativeInfinity, + ] + interval = Interval(0, 5) + for i in bad: + assert i not in interval + + x = Symbol('x', real=True) + nr = Symbol('nr', extended_real=False) + assert x + 1 in Interval(x, x + 4) + assert nr not in Interval(x, x + 4) + assert Interval(1, 2) in FiniteSet(Interval(0, 5), Interval(1, 2)) + assert Interval(-oo, oo).contains(oo) is S.false + assert Interval(-oo, oo).contains(-oo) is S.false + + +def test_issue_2799(): + U = S.UniversalSet + a = Symbol('a', real=True) + inf_interval = Interval(a, oo) + R = S.Reals + + assert U + inf_interval == inf_interval + U + assert U + R == R + U + assert R + inf_interval == inf_interval + R + + +def test_issue_9706(): + assert Interval(-oo, 0).closure == Interval(-oo, 0, True, False) + assert Interval(0, oo).closure == Interval(0, oo, False, True) + assert Interval(-oo, oo).closure == Interval(-oo, oo) + + +def test_issue_8257(): + reals_plus_infinity = Union(Interval(-oo, oo), FiniteSet(oo)) + reals_plus_negativeinfinity = Union(Interval(-oo, oo), FiniteSet(-oo)) + assert Interval(-oo, oo) + FiniteSet(oo) == reals_plus_infinity + assert FiniteSet(oo) + Interval(-oo, oo) == reals_plus_infinity + assert Interval(-oo, oo) + FiniteSet(-oo) == reals_plus_negativeinfinity + assert FiniteSet(-oo) + Interval(-oo, oo) == reals_plus_negativeinfinity + + +def test_issue_10931(): + assert S.Integers - S.Integers == EmptySet + assert S.Integers - S.Reals == EmptySet + + +def test_issue_11174(): + soln = Intersection(Interval(-oo, oo), FiniteSet(-x), evaluate=False) + assert Intersection(FiniteSet(-x), S.Reals) == soln + + soln = Intersection(S.Reals, FiniteSet(x), evaluate=False) + assert Intersection(FiniteSet(x), S.Reals) == soln + + +def test_issue_18505(): + assert ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers).contains(0) == \ + Contains(0, ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers)) + + +def test_finite_set_intersection(): + # The following should not produce recursion errors + # Note: some of these are not completely correct. See + # https://github.com/sympy/sympy/issues/16342. + assert Intersection(FiniteSet(-oo, x), FiniteSet(x)) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(0, x)]) == FiniteSet(x) + + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(x)]) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(2, 3, x, y), FiniteSet(1, 2, x)]) == \ + Intersection._handle_finite_sets([FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)]) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, x, y)) + + assert FiniteSet(1+x-y) & FiniteSet(1) == \ + FiniteSet(1) & FiniteSet(1+x-y) == \ + Intersection(FiniteSet(1+x-y), FiniteSet(1), evaluate=False) + + assert FiniteSet(1) & FiniteSet(x) == FiniteSet(x) & FiniteSet(1) == \ + Intersection(FiniteSet(1), FiniteSet(x), evaluate=False) + + assert FiniteSet({x}) & FiniteSet({x, y}) == \ + Intersection(FiniteSet({x}), FiniteSet({x, y}), evaluate=False) + + +def test_union_intersection_constructor(): + # The actual exception does not matter here, so long as these fail + sets = [FiniteSet(1), FiniteSet(2)] + raises(Exception, lambda: Union(sets)) + raises(Exception, lambda: Intersection(sets)) + raises(Exception, lambda: Union(tuple(sets))) + raises(Exception, lambda: Intersection(tuple(sets))) + raises(Exception, lambda: Union(i for i in sets)) + raises(Exception, lambda: Intersection(i for i in sets)) + + # Python sets are treated the same as FiniteSet + # The union of a single set (of sets) is the set (of sets) itself + assert Union(set(sets)) == FiniteSet(*sets) + assert Intersection(set(sets)) == FiniteSet(*sets) + + assert Union({1}, {2}) == FiniteSet(1, 2) + assert Intersection({1, 2}, {2, 3}) == FiniteSet(2) + + +def test_Union_contains(): + assert zoo not in Union( + Interval.open(-oo, 0), Interval.open(0, oo)) + + +@XFAIL +def test_issue_16878b(): + # in intersection_sets for (ImageSet, Set) there is no code + # that handles the base_set of S.Reals like there is + # for Integers + assert imageset(x, (x, x), S.Reals).is_subset(S.Reals**2) is True + +def test_DisjointUnion(): + assert DisjointUnion(FiniteSet(1, 2, 3), FiniteSet(1, 2, 3), FiniteSet(1, 2, 3)).rewrite(Union) == (FiniteSet(1, 2, 3) * FiniteSet(0, 1, 2)) + assert DisjointUnion(Interval(1, 3), Interval(2, 4)).rewrite(Union) == Union(Interval(1, 3) * FiniteSet(0), Interval(2, 4) * FiniteSet(1)) + assert DisjointUnion(Interval(0, 5), Interval(0, 5)).rewrite(Union) == Union(Interval(0, 5) * FiniteSet(0), Interval(0, 5) * FiniteSet(1)) + assert DisjointUnion(Interval(-1, 2), S.EmptySet, S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(Interval(-1, 2)).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(S.EmptySet, Interval(-1, 2), S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(1) + assert DisjointUnion(Interval(-oo, oo)).rewrite(Union) == Interval(-oo, oo) * FiniteSet(0) + assert DisjointUnion(S.EmptySet).rewrite(Union) == S.EmptySet + assert DisjointUnion().rewrite(Union) == S.EmptySet + raises(TypeError, lambda: DisjointUnion(Symbol('n'))) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert DisjointUnion(FiniteSet(x), FiniteSet(y, z)).rewrite(Union) == (FiniteSet(x) * FiniteSet(0)) + (FiniteSet(y, z) * FiniteSet(1)) + +def test_DisjointUnion_is_empty(): + assert DisjointUnion(S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, FiniteSet(1, 2, 3)).is_empty is False + +def test_DisjointUnion_is_iterable(): + assert DisjointUnion(S.Integers, S.Naturals, S.Rationals).is_iterable is True + assert DisjointUnion(S.EmptySet, S.Reals).is_iterable is False + assert DisjointUnion(FiniteSet(1, 2, 3), S.EmptySet, FiniteSet(x, y)).is_iterable is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_iterable is False + +def test_DisjointUnion_contains(): + assert (0, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1, 2) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 0.5) not in DisjointUnion(FiniteSet(0.5)) + assert (0, 5) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (x, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (z, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 2) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (0.5, 0) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (0.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 0) not in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + +def test_DisjointUnion_iter(): + D = DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z)) + it = iter(D) + L1 = [(x, 1), (y, 1), (z, 1)] + L2 = [(3, 0), (5, 0), (7, 0), (9, 0)] + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + raises(StopIteration, lambda: next(it)) + + raises(ValueError, lambda: iter(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_DisjointUnion_len(): + assert len(DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z))) == 7 + assert len(DisjointUnion(S.EmptySet, S.EmptySet, FiniteSet(x, y, z), S.EmptySet)) == 3 + raises(ValueError, lambda: len(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_SetKind_ProductSet(): + p = ProductSet(FiniteSet(Matrix([1, 2])), FiniteSet(Matrix([1, 2]))) + mk = MatrixKind(NumberKind) + k = SetKind(TupleKind(mk, mk)) + assert p.kind is k + assert ProductSet(Interval(1, 2), FiniteSet(Matrix([1, 2]))).kind is SetKind(TupleKind(NumberKind, mk)) + +def test_SetKind_Interval(): + assert Interval(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_EmptySet_UniversalSet(): + assert S.UniversalSet.kind is SetKind(UndefinedKind) + assert EmptySet.kind is SetKind() + +def test_SetKind_FiniteSet(): + assert FiniteSet(1, Matrix([1, 2])).kind is SetKind(UndefinedKind) + assert FiniteSet(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_Unions(): + assert Union(FiniteSet(Matrix([1, 2])), Interval(1, 2)).kind is SetKind(UndefinedKind) + assert Union(Interval(1, 2), Interval(1, 7)).kind is SetKind(NumberKind) + +def test_SetKind_DisjointUnion(): + A = FiniteSet(1, 2, 3) + B = Interval(0, 5) + assert DisjointUnion(A, B).kind is SetKind(NumberKind) + +def test_SetKind_evaluate_False(): + U = lambda *args: Union(*args, evaluate=False) + assert U({1}, EmptySet).kind is SetKind(NumberKind) + assert U(Interval(1, 2), EmptySet).kind is SetKind(NumberKind) + assert U({1}, S.UniversalSet).kind is SetKind(UndefinedKind) + assert U(Interval(1, 2), Interval(4, 5), + FiniteSet(1)).kind is SetKind(NumberKind) + I = lambda *args: Intersection(*args, evaluate=False) + assert I({1}, S.UniversalSet).kind is SetKind(NumberKind) + assert I({1}, EmptySet).kind is SetKind() + C = lambda *args: Complement(*args, evaluate=False) + assert C(S.UniversalSet, {1, 2, 4, 5}).kind is SetKind(UndefinedKind) + assert C({1, 2, 3, 4, 5}, EmptySet).kind is SetKind(NumberKind) + assert C(EmptySet, {1, 2, 3, 4, 5}).kind is SetKind() + +def test_SetKind_ImageSet_Special(): + f = ImageSet(Lambda(n, n ** 2), Interval(1, 4)) + assert (f - FiniteSet(3)).kind is SetKind(NumberKind) + assert (f + Interval(16, 17)).kind is SetKind(NumberKind) + assert (f + FiniteSet(17)).kind is SetKind(NumberKind) + +def test_issue_20089(): + B = FiniteSet(FiniteSet(1, 2), FiniteSet(1)) + assert 1 not in B + assert 1.0 not in B + assert not Eq(1, FiniteSet(1, 2)) + assert FiniteSet(1) in B + A = FiniteSet(1, 2) + assert A in B + assert B.issubset(B) + assert not A.issubset(B) + assert 1 in A + C = FiniteSet(FiniteSet(1, 2), FiniteSet(1), 1, 2) + assert A.issubset(C) + assert B.issubset(C) + +def test_issue_19378(): + a = FiniteSet(1, 2) + b = ProductSet(a, a) + c = FiniteSet((1, 1), (1, 2), (2, 1), (2, 2)) + assert b.is_subset(c) is True + d = FiniteSet(1) + assert b.is_subset(d) is False + assert Eq(c, b).simplify() is S.true + assert Eq(a, c).simplify() is S.false + assert Eq({1}, {x}).simplify() == Eq({1}, {x}) + +def test_intersection_symbolic(): + n = Symbol('n') + # These should not throw an error + assert isinstance(Intersection(Range(n), Range(100)), Intersection) + assert isinstance(Intersection(Range(n), Interval(1, 100)), Intersection) + assert isinstance(Intersection(Range(100), Interval(1, n)), Intersection) + + +@XFAIL +def test_intersection_symbolic_failing(): + n = Symbol('n', integer=True, positive=True) + assert Intersection(Range(10, n), Range(4, 500, 5)) == Intersection( + Range(14, n), Range(14, 500, 5)) + assert Intersection(Interval(10, n), Range(4, 500, 5)) == Intersection( + Interval(14, n), Range(14, 500, 5)) + + +def test_issue_20379(): + #https://github.com/sympy/sympy/issues/20379 + x = pi - 3.14159265358979 + assert FiniteSet(x).evalf(2) == FiniteSet(Float('3.23108914886517e-15', 2)) + +def test_finiteset_simplify(): + S = FiniteSet(1, cos(1)**2 + sin(1)**2) + assert S.simplify() == {1} + +def test_issue_14336(): + #https://github.com/sympy/sympy/issues/14336 + U = S.Complexes + x = Symbol("x") + U -= U.intersect(Ne(x, 1).as_set()) + U -= U.intersect(S.true.as_set()) + +def test_issue_9855(): + #https://github.com/sympy/sympy/issues/9855 + x, y, z = symbols('x, y, z', real=True) + s1 = Interval(1, x) & Interval(y, 2) + s2 = Interval(1, 2) + assert s1.is_subset(s2) == None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0619d1c3ebbd6c6a7d663093c7ed2202114148af --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__init__.py @@ -0,0 +1,60 @@ +"""The module helps converting SymPy expressions into shorter forms of them. + +for example: +the expression E**(pi*I) will be converted into -1 +the expression (x+x)**2 will be converted into 4*x**2 +""" +from .simplify import (simplify, hypersimp, hypersimilar, + logcombine, separatevars, posify, besselsimp, kroneckersimp, + signsimp, nsimplify) + +from .fu import FU, fu + +from .sqrtdenest import sqrtdenest + +from .cse_main import cse + +from .epathtools import epath, EPath + +from .hyperexpand import hyperexpand + +from .radsimp import collect, rcollect, radsimp, collect_const, fraction, numer, denom + +from .trigsimp import trigsimp, exptrigsimp + +from .powsimp import powsimp, powdenest + +from .combsimp import combsimp + +from .gammasimp import gammasimp + +from .ratsimp import ratsimp, ratsimpmodprime + +__all__ = [ + 'simplify', 'hypersimp', 'hypersimilar', 'logcombine', 'separatevars', + 'posify', 'besselsimp', 'kroneckersimp', 'signsimp', + 'nsimplify', + + 'FU', 'fu', + + 'sqrtdenest', + + 'cse', + + 'epath', 'EPath', + + 'hyperexpand', + + 'collect', 'rcollect', 'radsimp', 'collect_const', 'fraction', 'numer', + 'denom', + + 'trigsimp', 'exptrigsimp', + + 'powsimp', 'powdenest', + + 'combsimp', + + 'gammasimp', + + 'ratsimp', 'ratsimpmodprime', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6afe425eadb735b56863604d380362dc1398daad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87e0377db54b1c61b46fe7138fa1794bad572aeb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/combsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/combsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee5374a992901de6540e22c53a1f46fbdbeb8f27 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/combsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_main.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d01a96412ffde702db56016b6099f5d7c93253 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_main.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bed8065231bbec2c6f43ac927819a0df33c8316 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/epathtools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/epathtools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..791b8337251085eb27742c9da43e805f10a6ecf4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/epathtools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/fu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/fu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..940bbdaf8c164ae180cf6f81f5b6544eab6b14f6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/fu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ff0db28d26df2961cff00a3a436af71e6c29c73 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121d41a35f8b3665c6b0ffa57b0a32ebbc0fae8d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/powsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/powsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996b0111dc6e0ad1a8ed6661f1ad954a928c8d16 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/powsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/radsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/radsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53138ce187e7e8a7e821158d6d04b06cb86714be Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/radsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b024ef046a838423e943b8df16233b3c5c290ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/simplify.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/simplify.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9256b1935af21363194d177221dc62d08224a484 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/simplify.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..120ef355654e7b77b767d670d0e1cd9762858bc9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22532b210755ccda4c717c992c82fca0a42540d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f85e0ac12c58a9902a10259ca5ecda33b944dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/_cse_diff.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..3496ad3b31a4f45312cac002429be40aa9aa0868 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/_cse_diff.py @@ -0,0 +1,291 @@ +"""Module for differentiation using CSE.""" + +from sympy import cse, Matrix, Derivative, MatrixBase +from sympy.utilities.iterables import iterable + + +def _remove_cse_from_derivative(replacements, reduced_expressions): + """ + This function is designed to postprocess the output of a common subexpression + elimination (CSE) operation. Specifically, it removes any CSE replacement + symbols from the arguments of ``Derivative`` terms in the expression. This + is necessary to ensure that the forward Jacobian function correctly handles + derivative terms. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expressions : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + Returns + ======= + + processed_replacements : list of (Symbol, expression) pairs + Processed replacement list, in the same format of the + ``replacements`` input list. + + processed_reduced : list of SymPy expressions + Processed reduced list, in the same format of the + ``reduced_expressions`` input list. + """ + + def traverse(node, repl_dict): + if isinstance(node, Derivative): + return replace_all(node, repl_dict) + if not node.args: + return node + new_args = [traverse(arg, repl_dict) for arg in node.args] + return node.func(*new_args) + + def replace_all(node, repl_dict): + result = node + while True: + free_symbols = result.free_symbols + symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict} + if not symbols_dict: + break + result = result.xreplace(symbols_dict) + return result + + repl_dict = dict(replacements) + processed_replacements = [ + (rep_sym, traverse(sub_exp, repl_dict)) + for rep_sym, sub_exp in replacements + ] + processed_reduced = [ + red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp]) + for red_exp in reduced_expressions + ] + + return processed_replacements, processed_reduced + + +def _forward_jacobian_cse(replacements, reduced_expr, wrt): + """ + Core function to compute the Jacobian of an input Matrix of expressions + through forward accumulation. Takes directly the output of a CSE operation + (replacements and reduced_expr), and an iterable of variables (wrt) with + respect to which to differentiate the reduced expression and returns the + reduced Jacobian matrix and the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expr : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + wrt : iterable + Iterable of expressions with respect to which to compute the + Jacobian matrix. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. Compared to the input replacement list, + the output one doesn't contain replacement symbols inside + ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each sub-expression. + Useful in the substitution process. + """ + + if not isinstance(reduced_expr[0], MatrixBase): + raise TypeError("``expr`` must be of matrix type") + + if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1): + raise TypeError("``expr`` must be a row or a column matrix") + + if not iterable(wrt): + raise TypeError("``wrt`` must be an iterable of variables") + + elif not isinstance(wrt, MatrixBase): + wrt = Matrix(wrt) + + if not (wrt.shape[0] == 1 or wrt.shape[1] == 1): + raise TypeError("``wrt`` must be a row or a column matrix") + + replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr) + + if replacements: + rep_sym, sub_expr = map(Matrix, zip(*replacements)) + else: + rep_sym, sub_expr = Matrix([]), Matrix([]) + + l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0]) + + f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt, + { + (i, j): diff_value + for i, r in enumerate(reduced_expr[0]) + for j, w in enumerate(wrt) + if (diff_value := r.diff(w)) != 0 + }, + ) + + if not replacements: + return [], [f1], [] + + f2 = Matrix.from_dok(l_red, l_sub, + { + (i, j): diff_value + for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]]) + for j, s in enumerate(rep_sym) + if s in fs and (diff_value := r.diff(s)) != 0 + }, + ) + + rep_sym_set = set(rep_sym) + precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ] + + c_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[0].diff(w)) != 0}) + + for i in range(1, l_sub): + + bi_matrix = Matrix.from_dok(1, i, + {(0, j): diff_value for j in range(i + 1) + if rep_sym[j] in precomputed_fs[i] + and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0}) + + ai_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[i].diff(w)) != 0}) + + if bi_matrix._rep.nnz(): + ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix) + c_matrix = Matrix.vstack(c_matrix, ci_matrix) + else: + c_matrix = Matrix.vstack(c_matrix, ai_matrix) + + jacobian = f2.multiply(c_matrix).add(f1) + jacobian = [reduced_expr[0].__class__(jacobian)] + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian_norm_in_cse_out(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. The matrix is returned in reduced form (containing + replacement symbols) along with the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to perform the differentiation. + Can be a matrix or an iterable of variables. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. The output replacement list doesn't + contain replacement symbols inside ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each + sub-expression. Useful in the substitution process. + """ + + replacements, reduced_expr = cse(expr) + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. + + Explanation + =========== + + Expressions often contain repeated subexpressions. Using a tree structure, + these subexpressions are duplicated and differentiated multiple times, + leading to inefficiency. + + Instead, if a data structure called a directed acyclic graph (DAG) is used + then each of these repeated subexpressions will only exist a single time. + This function uses a combination of representing the expression as a DAG and + a forward accumulation algorithm (repeated application of the chain rule + symbolically) to more efficiently calculate the Jacobian matrix of a target + expression ``expr`` with respect to an expression or set of expressions + ``wrt``. + + Note that this function is intended to improve performance when + differentiating large expressions that contain many common subexpressions. + For small and simple expressions it is likely less performant than using + SymPy's standard differentiation functions and methods. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to do the differentiation. + Can be a matrix or an iterable of variables. + + See Also + ======== + + Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph + """ + + replacements, reduced_expr = cse(expr) + + if replacements: + rep_sym, _ = map(Matrix, zip(*replacements)) + else: + rep_sym = Matrix([]) + + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + if not replacements: return jacobian[0] + + sub_rep = dict(replacements) + for i, ik in enumerate(precomputed_fs): + sub_dict = {j: sub_rep[j] for j in ik} + sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict) + + return jacobian[0].xreplace(sub_rep) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/combsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0b3cefcba11b4b7759b7d3ec3c2d4415cfd849 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/combsimp.py @@ -0,0 +1,114 @@ +from sympy.core import Mul +from sympy.core.function import count_ops +from sympy.core.traversal import preorder_traversal, bottom_up +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions import gamma +from sympy.simplify.gammasimp import gammasimp, _gammasimp + +from sympy.utilities.timeutils import timethis + + +@timethis('combsimp') +def combsimp(expr): + r""" + Simplify combinatorial expressions. + + Explanation + =========== + + This function takes as input an expression containing factorials, + binomials, Pochhammer symbol and other "combinatorial" functions, + and tries to minimize the number of those functions and reduce + the size of their arguments. + + The algorithm works by rewriting all combinatorial functions as + gamma functions and applying gammasimp() except simplification + steps that may make an integer argument non-integer. See docstring + of gammasimp for more information. + + Then it rewrites expression in terms of factorials and binomials by + rewriting gammas as factorials and converting (a+b)!/a!b! into + binomials. + + If expression has gamma functions or combinatorial functions + with non-integer argument, it is automatically passed to gammasimp. + + Examples + ======== + + >>> from sympy.simplify import combsimp + >>> from sympy import factorial, binomial, symbols + >>> n, k = symbols('n k', integer = True) + + >>> combsimp(factorial(n)/factorial(n - 3)) + n*(n - 2)*(n - 1) + >>> combsimp(binomial(n+1, k+1)/binomial(n, k)) + (n + 1)/(k + 1) + + """ + + expr = expr.rewrite(gamma, piecewise=False) + if any(isinstance(node, gamma) and not node.args[0].is_integer + for node in preorder_traversal(expr)): + return gammasimp(expr) + + expr = _gammasimp(expr, as_comb = True) + expr = _gamma_as_comb(expr) + return expr + + +def _gamma_as_comb(expr): + """ + Helper function for combsimp. + + Rewrites expression in terms of factorials and binomials + """ + + expr = expr.rewrite(factorial) + + def f(rv): + if not rv.is_Mul: + return rv + rvd = rv.as_powers_dict() + nd_fact_args = [[], []] # numerator, denominator + + for k in rvd: + if isinstance(k, factorial) and rvd[k].is_Integer: + if rvd[k].is_positive: + nd_fact_args[0].extend([k.args[0]]*rvd[k]) + else: + nd_fact_args[1].extend([k.args[0]]*-rvd[k]) + rvd[k] = 0 + if not nd_fact_args[0] or not nd_fact_args[1]: + return rv + + hit = False + for m in range(2): + i = 0 + while i < len(nd_fact_args[m]): + ai = nd_fact_args[m][i] + for j in range(i + 1, len(nd_fact_args[m])): + aj = nd_fact_args[m][j] + + sum = ai + aj + if sum in nd_fact_args[1 - m]: + hit = True + + nd_fact_args[1 - m].remove(sum) + del nd_fact_args[m][j] + del nd_fact_args[m][i] + + rvd[binomial(sum, ai if count_ops(ai) < + count_ops(aj) else aj)] += ( + -1 if m == 0 else 1) + break + else: + i += 1 + + if hit: + return Mul(*([k**rvd[k] for k in rvd] + [factorial(k) + for k in nd_fact_args[0]]))/Mul(*[factorial(k) + for k in nd_fact_args[1]]) + return rv + + return bottom_up(expr, f) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_main.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_main.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd1b2e50adae8c3d3400d6c489e63a44ae1e59b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_main.py @@ -0,0 +1,945 @@ +""" Tools for doing common subexpression elimination. +""" +from collections import defaultdict + +from sympy.core import Basic, Mul, Add, Pow, sympify +from sympy.core.containers import Tuple, OrderedSet +from sympy.core.exprtools import factor_terms +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol +from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, + SparseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul, + MatAdd, MatPow, Inverse) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.polys.rootoftools import RootOf +from sympy.utilities.iterables import numbered_symbols, sift, \ + topological_sort, iterable + +from . import cse_opts + +# (preprocessor, postprocessor) pairs which are commonly useful. They should +# each take a SymPy expression and return a possibly transformed expression. +# When used in the function ``cse()``, the target expressions will be transformed +# by each of the preprocessor functions in order. After the common +# subexpressions are eliminated, each resulting expression will have the +# postprocessor functions transform them in *reverse* order in order to undo the +# transformation if necessary. This allows the algorithm to operate on +# a representation of the expressions that allows for more optimization +# opportunities. +# ``None`` can be used to specify no transformation for either the preprocessor or +# postprocessor. + + +basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post), + (factor_terms, None)] + +# sometimes we want the output in a different format; non-trivial +# transformations can be put here for users +# =============================================================== + + +def reps_toposort(r): + """Sort replacements ``r`` so (k1, v1) appears before (k2, v2) + if k2 is in v1's free symbols. This orders items in the + way that cse returns its results (hence, in order to use the + replacements in a substitution option it would make sense + to reverse the order). + + Examples + ======== + + >>> from sympy.simplify.cse_main import reps_toposort + >>> from sympy.abc import x, y + >>> from sympy import Eq + >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]): + ... print(Eq(l, r)) + ... + Eq(y, 2) + Eq(x, y + 1) + + """ + r = sympify(r) + E = [] + for c1, (k1, v1) in enumerate(r): + for c2, (k2, v2) in enumerate(r): + if k1 in v2.free_symbols: + E.append((c1, c2)) + return [r[i] for i in topological_sort((range(len(r)), E))] + + +def cse_separate(r, e): + """Move expressions that are in the form (symbol, expr) out of the + expressions and sort them into the replacements using the reps_toposort. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse_separate + >>> from sympy.abc import x, y, z + >>> from sympy import cos, exp, cse, Eq, symbols + >>> x0, x1 = symbols('x:2') + >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [ + ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)], + ... [x1 + exp(x1/x0) + cos(x0), z - 2]], + ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)], + ... [x0 + exp(x0/x1) + cos(x1), z - 2]]] + ... + True + """ + d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol) + r = r + [w.args for w in d[True]] + e = d[False] + return [reps_toposort(r), e] + + +def cse_release_variables(r, e): + """ + Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is + either an expression or None. The value of None is used when a + symbol is no longer needed for subsequent expressions. + + Use of such output can reduce the memory footprint of lambdified + expressions that contain large, repeated subexpressions. + + Examples + ======== + + >>> from sympy import cse + >>> from sympy.simplify.cse_main import cse_release_variables + >>> from sympy.abc import x, y + >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)] + >>> defs, rvs = cse_release_variables(*cse(eqs)) + >>> for i in defs: + ... print(i) + ... + (x0, x + y) + (x1, (x0 - 1)**2) + (x2, 2*x + 1) + (_3, x0/x2 + x1) + (_4, x2**x0) + (x2, None) + (_0, x1) + (x1, None) + (_2, x0) + (x0, None) + (_1, x) + >>> print(rvs) + (_0, _1, _2, _3, _4) + """ + if not r: + return r, e + + s, p = zip(*r) + esyms = symbols('_:%d' % len(e)) + syms = list(esyms) + s = list(s) + in_use = set(s) + p = list(p) + # sort e so those with most sub-expressions appear first + e = [(e[i], syms[i]) for i in range(len(e))] + e, syms = zip(*sorted(e, + key=lambda x: -sum(p[s.index(i)].count_ops() + for i in x[0].free_symbols & in_use))) + syms = list(syms) + p += e + rv = [] + i = len(p) - 1 + while i >= 0: + _p = p.pop() + c = in_use & _p.free_symbols + if c: # sorting for canonical results + rv.extend([(s, None) for s in sorted(c, key=str)]) + if i >= len(r): + rv.append((syms.pop(), _p)) + else: + rv.append((s[i], _p)) + in_use -= c + i -= 1 + rv.reverse() + return rv, esyms + + +# ====end of cse postprocess idioms=========================== + + +def preprocess_for_cse(expr, optimizations): + """ Preprocess an expression to optimize for common subexpression + elimination. + + Parameters + ========== + + expr : SymPy expression + The target expression to optimize. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in optimizations: + if pre is not None: + expr = pre(expr) + return expr + + +def postprocess_for_cse(expr, optimizations): + """Postprocess an expression after common subexpression elimination to + return the expression to canonical SymPy form. + + Parameters + ========== + + expr : SymPy expression + The target expression to transform. + optimizations : list of (callable, callable) pairs, optional + The (preprocessor, postprocessor) pairs. The postprocessors will be + applied in reversed order to undo the effects of the preprocessors + correctly. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in reversed(optimizations): + if post is not None: + expr = post(expr) + return expr + + +class FuncArgTracker: + """ + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. + """ + + def __init__(self, funcs): + # To minimize the number of symbolic comparisons, all function arguments + # get assigned a value number. + self.value_numbers = {} + self.value_number_to_value = [] + + # Both of these maps use integer indices for arguments / functions. + self.arg_to_funcset = [] + self.func_to_argset = [] + + for func_i, func in enumerate(funcs): + func_argset = OrderedSet() + + for func_arg in func.args: + arg_number = self.get_or_add_value_number(func_arg) + func_argset.add(arg_number) + self.arg_to_funcset[arg_number].add(func_i) + + self.func_to_argset.append(func_argset) + + def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ + return [self.value_number_to_value[argn] for argn in sorted(argset)] + + def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ + nvalues = len(self.value_numbers) + value_number = self.value_numbers.setdefault(value, nvalues) + if value_number == nvalues: + self.value_number_to_value.append(value) + self.arg_to_funcset.append(OrderedSet()) + return value_number + + def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ + for arg in self.func_to_argset[func_i]: + self.arg_to_funcset[arg].remove(func_i) + + + def get_common_arg_candidates(self, argset, min_func_i=0): + """Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with + ``argset``. Entries have at least 2 items in common. All keys have + value at least ``min_func_i``. + """ + count_map = defaultdict(lambda: 0) + if not argset: + return count_map + + funcsets = [self.arg_to_funcset[arg] for arg in argset] + # As an optimization below, we handle the largest funcset separately from + # the others. + largest_funcset = max(funcsets, key=len) + + for funcset in funcsets: + if largest_funcset is funcset: + continue + for func_i in funcset: + if func_i >= min_func_i: + count_map[func_i] += 1 + + # We pick the smaller of the two containers (count_map, largest_funcset) + # to iterate over to reduce the number of iterations needed. + (smaller_funcs_container, + larger_funcs_container) = sorted( + [largest_funcset, count_map], + key=len) + + for func_i in smaller_funcs_container: + # Not already in count_map? It can't possibly be in the output, so + # skip it. + if count_map[func_i] < 1: + continue + + if func_i in larger_funcs_container: + count_map[func_i] += 1 + + return {k: v for k, v in count_map.items() if v >= 2} + + def get_subset_candidates(self, argset, restrict_to_funcset=None): + """ + Return a set of functions each of which whose argument list contains + ``argset``, optionally filtered only to contain functions in + ``restrict_to_funcset``. + """ + iarg = iter(argset) + + indices = OrderedSet( + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset + + for arg in iarg: + indices &= self.arg_to_funcset[arg] + + return indices + + def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ + new_args = OrderedSet(new_argset) + old_args = self.func_to_argset[func_i] + + for deleted_arg in old_args - new_args: + self.arg_to_funcset[deleted_arg].remove(func_i) + for added_arg in new_args - old_args: + self.arg_to_funcset[added_arg].add(func_i) + + self.func_to_argset[func_i].clear() + self.func_to_argset[func_i].update(new_args) + + +class Unevaluated: + + def __init__(self, func, args): + self.func = func + self.args = args + + def __str__(self): + return "Uneval<{}>({})".format( + self.func, ", ".join(str(a) for a in self.args)) + + def as_unevaluated_basic(self): + return self.func(*self.args, evaluate=False) + + @property + def free_symbols(self): + return set().union(*[a.free_symbols for a in self.args]) + + __repr__ = __str__ + + +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + Parameters + ========== + + func_class: class + The function class (e.g. Add, Mul) + funcs: list of functions + A list of function calls. + opt_subs: dict + A dictionary of substitutions which this function may update. + """ + + # Sort to ensure that whole-function subexpressions come before the items + # that use them. + funcs = sorted(funcs, key=lambda f: len(f.args)) + arg_tracker = FuncArgTracker(funcs) + + changed = OrderedSet() + + for i in range(len(funcs)): + common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( + arg_tracker.func_to_argset[i], min_func_i=i + 1) + + # Sort the candidates in order of match size. + # This makes us try combining smaller matches first. + common_arg_candidates = OrderedSet(sorted( + common_arg_candidates_counts.keys(), + key=lambda k: (common_arg_candidates_counts[k], k))) + + while common_arg_candidates: + j = common_arg_candidates.pop(last=False) + + com_args = arg_tracker.func_to_argset[i].intersection( + arg_tracker.func_to_argset[j]) + + if len(com_args) <= 1: + # This may happen if a set of common arguments was already + # combined in a previous iteration. + continue + + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. + + diff_i = arg_tracker.func_to_argset[i].difference(com_args) + if diff_i: + # com_func needs to be unevaluated to allow for recursive matches. + com_func = Unevaluated( + func_class, arg_tracker.get_args_in_value_order(com_args)) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number])) + changed.add(i) + else: + # Treat the whole expression as a CSE. + # + # The reason this needs to be done is somewhat subtle. Within + # tree_cse(), to_eliminate only contains expressions that are + # seen more than once. The problem is unevaluated expressions + # do not compare equal to the evaluated equivalent. So + # tree_cse() won't mark funcs[i] as a CSE if we use an + # unevaluated version. + com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) + + diff_j = arg_tracker.func_to_argset[j].difference(com_args) + arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number])) + changed.add(j) + + for k in arg_tracker.get_subset_candidates( + com_args, common_arg_candidates): + diff_k = arg_tracker.func_to_argset[k].difference(com_args) + arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number])) + changed.add(k) + + if i in changed: + opt_subs[funcs[i]] = Unevaluated(func_class, + arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i])) + + arg_tracker.stop_arg_tracking(i) + + +def opt_cse(exprs, order='canonical'): + """Find optimization opportunities in Adds, Muls, Pows and negative + coefficient Muls. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to optimize. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + + Returns + ======= + + opt_subs : dictionary of expression substitutions + The expression substitutions which can be useful to optimize CSE. + + Examples + ======== + + >>> from sympy.simplify.cse_main import opt_cse + >>> from sympy.abc import x + >>> opt_subs = opt_cse([x**-2]) + >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0] + >>> print((k, v.as_unevaluated_basic())) + (x**(-2), 1/(x**2)) + """ + opt_subs = {} + + adds = OrderedSet() + muls = OrderedSet() + + seen_subexp = set() + collapsible_subexp = set() + + def _find_opts(expr): + + if not isinstance(expr, (Basic, Unevaluated)): + return + + if expr.is_Atom or expr.is_Order: + return + + if iterable(expr): + list(map(_find_opts, expr)) + return + + if expr in seen_subexp: + return expr + seen_subexp.add(expr) + + list(map(_find_opts, expr.args)) + + if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign(): + # XXX -expr does not always work rigorously for some expressions + # containing UnevaluatedExpr. + # https://github.com/sympy/sympy/issues/24818 + if isinstance(expr, Add): + neg_expr = Add(*(-i for i in expr.args)) + else: + neg_expr = -expr + + if not neg_expr.is_Atom: + opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr)) + seen_subexp.add(neg_expr) + expr = neg_expr + + if isinstance(expr, (Mul, MatMul)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + muls.add(expr) + + elif isinstance(expr, (Add, MatAdd)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + adds.add(expr) + + elif isinstance(expr, Inverse): + # Do not want to treat `Inverse` as a `MatPow` + pass + + elif isinstance(expr, (Pow, MatPow)): + base, exp = expr.base, expr.exp + if exp.could_extract_minus_sign(): + opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1)) + + for e in exprs: + if isinstance(e, (Basic, Unevaluated)): + _find_opts(e) + + # Handle collapsing of multinary operations with single arguments + edges = [(s, s.args[0]) for s in collapsible_subexp + if s.args[0] in collapsible_subexp] + for e in reversed(topological_sort((collapsible_subexp, edges))): + opt_subs[e] = opt_subs.get(e.args[0], e.args[0]) + + # split muls into commutative + commutative_muls = OrderedSet() + for m in muls: + c, nc = m.args_cnc(cset=False) + if c: + c_mul = m.func(*c) + if nc: + if c_mul == 1: + new_obj = m.func(*nc) + else: + if isinstance(m, MatMul): + new_obj = m.func(c_mul, *nc, evaluate=False) + else: + new_obj = m.func(c_mul, m.func(*nc), evaluate=False) + opt_subs[m] = new_obj + if len(c) > 1: + commutative_muls.add(c_mul) + + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, commutative_muls, opt_subs) + + return opt_subs + + +def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()): + """Perform raw CSE on expression tree, taking opt_subs into account. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. + opt_subs : dictionary of expression substitutions + The expressions to be substituted before any CSE action is performed. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + """ + if opt_subs is None: + opt_subs = {} + + ## Find repeated sub-expressions + + to_eliminate = set() + + seen_subexp = set() + excluded_symbols = set() + + def _find_repeated(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return + + if isinstance(expr, RootOf): + return + + if isinstance(expr, Basic) and ( + expr.is_Atom or + expr.is_Order or + isinstance(expr, (MatrixSymbol, MatrixElement))): + if expr.is_Symbol: + excluded_symbols.add(expr.name) + return + + if iterable(expr): + args = expr + + else: + if expr in seen_subexp: + for ign in ignore: + if ign in expr.free_symbols: + break + else: + to_eliminate.add(expr) + return + + seen_subexp.add(expr) + + if expr in opt_subs: + expr = opt_subs[expr] + + args = expr.args + + list(map(_find_repeated, args)) + + for e in exprs: + if isinstance(e, Basic): + _find_repeated(e) + + ## Rebuild tree + + # Remove symbols from the generator that conflict with names in the expressions. + symbols = (_ for _ in symbols if _.name not in excluded_symbols) + + replacements = [] + + subs = {} + + def _rebuild(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return expr + + if not expr.args: + return expr + + if iterable(expr): + new_args = [_rebuild(arg) for arg in expr.args] + return expr.func(*new_args) + + if expr in subs: + return subs[expr] + + orig_expr = expr + if expr in opt_subs: + expr = opt_subs[expr] + + # If enabled, parse Muls and Adds arguments by order to ensure + # replacement order independent from hashes + if order != 'none': + if isinstance(expr, (Mul, MatMul)): + c, nc = expr.args_cnc() + if c == [1]: + args = nc + else: + args = list(ordered(c)) + nc + elif isinstance(expr, (Add, MatAdd)): + args = list(ordered(expr.args)) + else: + args = expr.args + else: + args = expr.args + + new_args = list(map(_rebuild, args)) + if isinstance(expr, Unevaluated) or new_args != args: + new_expr = expr.func(*new_args) + else: + new_expr = expr + + if orig_expr in to_eliminate: + try: + sym = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + if isinstance(orig_expr, MatrixExpr): + sym = MatrixSymbol(sym.name, orig_expr.rows, + orig_expr.cols) + + subs[orig_expr] = sym + replacements.append((sym, new_expr)) + return sym + + else: + return new_expr + + reduced_exprs = [] + for e in exprs: + if isinstance(e, Basic): + reduced_e = _rebuild(e) + else: + reduced_e = e + reduced_exprs.append(reduced_e) + return replacements, reduced_exprs + + +def cse(exprs, symbols=None, optimizations=None, postprocess=None, + order='canonical', ignore=(), list=True): + """ Perform common subexpression elimination on an expression. + + Parameters + ========== + + exprs : list of SymPy expressions, or a single SymPy expression + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. The ``numbered_symbols`` generator is useful. The default is a + stream of symbols of the form "x0", "x1", etc. This must be an + infinite iterator. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs of external optimization + functions. Optionally 'basic' can be passed for a set of predefined + basic optimizations. Such 'basic' optimizations were used by default + in old implementation, however they can be really slow on larger + expressions. Now, no pre or post optimizations are made by default. + postprocess : a function which accepts the two return values of cse and + returns the desired form of output from cse, e.g. if you want the + replacements reversed the function might be the following lambda: + lambda r, e: return reversed(r), e + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. If set to + 'canonical', arguments will be canonically ordered. If set to 'none', + ordering will be faster but dependent on expressions hashes, thus + machine dependent and variable. For large expressions where speed is a + concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + list : bool, (default True) + Returns expression in list or else with same type as input (when False). + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy import cse, SparseMatrix + >>> from sympy.abc import x, y, z, w + >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) + ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) + + + List of expressions with recursive substitutions: + + >>> m = SparseMatrix([x + y, x + y + z]) + >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) + ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ + [x0], + [x1]])]) + + Note: the type and mutability of input matrices is retained. + + >>> isinstance(_[1][-1], SparseMatrix) + True + + The user may disallow substitutions containing certain symbols: + + >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) + ([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) + + The default return value for the reduced expression(s) is a list, even if there is only + one expression. The `list` flag preserves the type of the input in the output: + + >>> cse(x) + ([], [x]) + >>> cse(x, list=False) + ([], x) + """ + if not list: + return _cse_homogeneous(exprs, + symbols=symbols, optimizations=optimizations, + postprocess=postprocess, order=order, ignore=ignore) + + if isinstance(exprs, (int, float)): + exprs = sympify(exprs) + + # Handle the case if just one expression was passed. + if isinstance(exprs, (Basic, MatrixBase)): + exprs = [exprs] + + copy = exprs + temp = [] + for e in exprs: + if isinstance(e, (Matrix, ImmutableMatrix)): + temp.append(Tuple(*e.flat())) + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + temp.append(Tuple(*e.todok().items())) + else: + temp.append(e) + exprs = temp + del temp + + if optimizations is None: + optimizations = [] + elif optimizations == 'basic': + optimizations = basic_optimizations + + # Preprocess the expressions to give us better optimization opportunities. + reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] + + if symbols is None: + symbols = numbered_symbols(cls=Symbol) + else: + # In case we get passed an iterable with an __iter__ method instead of + # an actual iterator. + symbols = iter(symbols) + + # Find other optimization opportunities. + opt_subs = opt_cse(reduced_exprs, order) + + # Main CSE algorithm. + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, + order, ignore) + + # Postprocess the expressions to return the expressions to canonical form. + exprs = copy + replacements = [(sym, postprocess_for_cse(subtree, optimizations)) + for sym, subtree in replacements] + reduced_exprs = [postprocess_for_cse(e, optimizations) + for e in reduced_exprs] + + # Get the matrices back + for i, e in enumerate(exprs): + if isinstance(e, (Matrix, ImmutableMatrix)): + reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) + if isinstance(e, ImmutableMatrix): + reduced_exprs[i] = reduced_exprs[i].as_immutable() + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + m = SparseMatrix(e.rows, e.cols, {}) + for k, v in reduced_exprs[i]: + m[k] = v + if isinstance(e, ImmutableSparseMatrix): + m = m.as_immutable() + reduced_exprs[i] = m + + if postprocess is None: + return replacements, reduced_exprs + + return postprocess(replacements, reduced_exprs) + + +def _cse_homogeneous(exprs, **kwargs): + """ + Same as ``cse`` but the ``reduced_exprs`` are returned + with the same type as ``exprs`` or a sympified version of the same. + + Parameters + ========== + + exprs : an Expr, iterable of Expr or dictionary with Expr values + the expressions in which repeated subexpressions will be identified + kwargs : additional arguments for the ``cse`` function + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse + >>> from sympy import cos, Tuple, Matrix + >>> from sympy.abc import x + >>> output = lambda x: type(cse(x, list=False)[1]) + >>> output(1) + + >>> output('cos(x)') + + >>> output(cos(x)) + cos + >>> output(Tuple(1, x)) + + >>> output(Matrix([[1,0], [0,1]])) + + >>> output([1, x]) + + >>> output((1, x)) + + >>> output({1, x}) + + """ + if isinstance(exprs, str): + replacements, reduced_exprs = _cse_homogeneous( + sympify(exprs), **kwargs) + return replacements, repr(reduced_exprs) + if isinstance(exprs, (list, tuple, set)): + replacements, reduced_exprs = cse(exprs, **kwargs) + return replacements, type(exprs)(reduced_exprs) + if isinstance(exprs, dict): + keys = list(exprs.keys()) # In order to guarantee the order of the elements. + replacements, values = cse([exprs[k] for k in keys], **kwargs) + reduced_exprs = dict(zip(keys, values)) + return replacements, reduced_exprs + + try: + replacements, (reduced_exprs,) = cse(exprs, **kwargs) + except TypeError: # For example 'mpf' objects + return [], exprs + else: + return replacements, reduced_exprs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_opts.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_opts.py new file mode 100644 index 0000000000000000000000000000000000000000..36a59857411de740ae47423442af88b118a3395d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/cse_opts.py @@ -0,0 +1,52 @@ +""" Optimizations of the expression tree representation for better CSE +opportunities. +""" +from sympy.core import Add, Basic, Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.traversal import preorder_traversal + + +def sub_pre(e): + """ Replace y - x with -(x - y) if -1 can be extracted from y - x. + """ + # replacing Add, A, from which -1 can be extracted with -1*-A + adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()] + reps = {} + ignore = set() + for a in adds: + na = -a + if na.is_Mul: # e.g. MatExpr + ignore.add(a) + continue + reps[a] = Mul._from_args([S.NegativeOne, na]) + + e = e.xreplace(reps) + + # repeat again for persisting Adds but mark these with a leading 1, -1 + # e.g. y - x -> 1*-1*(x - y) + if isinstance(e, Basic): + negs = {} + for a in sorted(e.atoms(Add), key=default_sort_key): + if a in ignore: + continue + if a in reps: + negs[a] = reps[a] + elif a.could_extract_minus_sign(): + negs[a] = Mul._from_args([S.One, S.NegativeOne, -a]) + e = e.xreplace(negs) + return e + + +def sub_post(e): + """ Replace 1*-1*x with -x. + """ + replacements = [] + for node in preorder_traversal(e): + if isinstance(node, Mul) and \ + node.args[0] is S.One and node.args[1] is S.NegativeOne: + replacements.append((node, -Mul._from_args(node.args[2:]))) + for node, replacement in replacements: + e = e.xreplace({node: replacement}) + + return e diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/epathtools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..7be983ada63fd39d7d467acf9afd62b3a41a2d85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/epathtools.py @@ -0,0 +1,352 @@ +"""Tools for manipulation of expressions using paths. """ + +from sympy.core import Basic + + +class EPath: + r""" + Manipulate expressions using paths. + + EPath grammar in EBNF notation:: + + literal ::= /[A-Za-z_][A-Za-z_0-9]*/ + number ::= /-?\d+/ + type ::= literal + attribute ::= literal "?" + all ::= "*" + slice ::= "[" number? (":" number? (":" number?)?)? "]" + range ::= all | slice + query ::= (type | attribute) ("|" (type | attribute))* + selector ::= range | query range? + path ::= "/" selector ("/" selector)* + + See the docstring of the epath() function. + + """ + + __slots__ = ("_path", "_epath") + + def __new__(cls, path): + """Construct new EPath. """ + if isinstance(path, EPath): + return path + + if not path: + raise ValueError("empty EPath") + + _path = path + + if path[0] == '/': + path = path[1:] + else: + raise NotImplementedError("non-root EPath") + + epath = [] + + for selector in path.split('/'): + selector = selector.strip() + + if not selector: + raise ValueError("empty selector") + + index = 0 + + for c in selector: + if c.isalnum() or c in ('_', '|', '?'): + index += 1 + else: + break + + attrs = [] + types = [] + + if index: + elements = selector[:index] + selector = selector[index:] + + for element in elements.split('|'): + element = element.strip() + + if not element: + raise ValueError("empty element") + + if element.endswith('?'): + attrs.append(element[:-1]) + else: + types.append(element) + + span = None + + if selector == '*': + pass + else: + if selector.startswith('['): + try: + i = selector.index(']') + except ValueError: + raise ValueError("expected ']', got EOL") + + _span, span = selector[1:i], [] + + if ':' not in _span: + span = int(_span) + else: + for elt in _span.split(':', 3): + if not elt: + span.append(None) + else: + span.append(int(elt)) + + span = slice(*span) + + selector = selector[i + 1:] + + if selector: + raise ValueError("trailing characters in selector") + + epath.append((attrs, types, span)) + + obj = object.__new__(cls) + + obj._path = _path + obj._epath = epath + + return obj + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._path) + + def _get_ordered_args(self, expr): + """Sort ``expr.args`` using printing order. """ + if expr.is_Add: + return expr.as_ordered_terms() + elif expr.is_Mul: + return expr.as_ordered_factors() + else: + return expr.args + + def _hasattrs(self, expr, attrs) -> bool: + """Check if ``expr`` has any of ``attrs``. """ + return all(hasattr(expr, attr) for attr in attrs) + + def _hastypes(self, expr, types): + """Check if ``expr`` is any of ``types``. """ + _types = [ cls.__name__ for cls in expr.__class__.mro() ] + return bool(set(_types).intersection(types)) + + def _has(self, expr, attrs, types): + """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """ + if not (attrs or types): + return True + + if attrs and self._hasattrs(expr, attrs): + return True + + if types and self._hastypes(expr, types): + return True + + return False + + def apply(self, expr, func, args=None, kwargs=None): + """ + Modify parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.apply(expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.apply(expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + def _apply(path, expr, func): + if not path: + return func(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + if not expr.is_Atom: + args, basic = self._get_ordered_args(expr), True + else: + return expr + elif hasattr(expr, '__iter__'): + args, basic = expr, False + else: + return expr + + args = list(args) + + if span is not None: + if isinstance(span, slice): + indices = range(*span.indices(len(args))) + else: + indices = [span] + else: + indices = range(len(args)) + + for i in indices: + try: + arg = args[i] + except IndexError: + continue + + if self._has(arg, attrs, types): + args[i] = _apply(path, arg, func) + + if basic: + return expr.func(*args) + else: + return expr.__class__(args) + + _args, _kwargs = args or (), kwargs or {} + _func = lambda expr: func(expr, *_args, **_kwargs) + + return _apply(self._epath, expr, _func) + + def select(self, expr): + """ + Retrieve parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.select(expr) + [x, y] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.select(expr) + [x, x, y] + + """ + result = [] + + def _select(path, expr): + if not path: + result.append(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + args = self._get_ordered_args(expr) + elif hasattr(expr, '__iter__'): + args = expr + else: + return + + if span is not None: + if isinstance(span, slice): + args = args[span] + else: + try: + args = [args[span]] + except IndexError: + return + + for arg in args: + if self._has(arg, attrs, types): + _select(path, arg) + + _select(self._epath, expr) + return result + + +def epath(path, expr=None, func=None, args=None, kwargs=None): + r""" + Manipulate parts of an expression selected by a path. + + Explanation + =========== + + This function allows to manipulate large nested expressions in single + line of code, utilizing techniques to those applied in XML processing + standards (e.g. XPath). + + If ``func`` is ``None``, :func:`epath` retrieves elements selected by + the ``path``. Otherwise it applies ``func`` to each matching element. + + Note that it is more efficient to create an EPath object and use the select + and apply methods of that object, since this will compile the path string + only once. This function should only be used as a convenient shortcut for + interactive use. + + This is the supported syntax: + + * select all: ``/*`` + Equivalent of ``for arg in args:``. + * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]`` + Supports standard Python's slice syntax. + * select by type: ``/list`` or ``/list|tuple`` + Emulates ``isinstance()``. + * select by attribute: ``/__iter__?`` + Emulates ``hasattr()``. + + Parameters + ========== + + path : str | EPath + A path as a string or a compiled EPath. + expr : Basic | iterable + An expression or a container of expressions. + func : callable (optional) + A callable that will be applied to matching parts. + args : tuple (optional) + Additional positional arguments to ``func``. + kwargs : dict (optional) + Additional keyword arguments to ``func``. + + Examples + ======== + + >>> from sympy.simplify.epathtools import epath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = "/*/[0]/Symbol" + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> epath(path, expr) + [x, y] + >>> epath(path, expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = "/*/*/Symbol" + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> epath(path, expr) + [x, x, y] + >>> epath(path, expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + _epath = EPath(path) + + if expr is None: + return _epath + if func is None: + return _epath.select(expr) + else: + return _epath.apply(expr, func, args, kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/fu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/fu.py new file mode 100644 index 0000000000000000000000000000000000000000..a26706edca98385df0009a8ee41476a17d36420c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/fu.py @@ -0,0 +1,2112 @@ +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.exprtools import Factors, gcd_terms, factor_terms +from sympy.core.function import expand_mul +from sympy.core.mul import Mul +from sympy.core.numbers import pi, I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.core.traversal import bottom_up +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.hyperbolic import ( + cosh, sinh, tanh, coth, sech, csch, HyperbolicFunction) +from sympy.functions.elementary.trigonometric import ( + cos, sin, tan, cot, sec, csc, sqrt, TrigonometricFunction) +from sympy.ntheory.factor_ import perfect_power +from sympy.polys.polytools import factor +from sympy.strategies.tree import greedy +from sympy.strategies.core import identity, debug + +from sympy import SYMPY_DEBUG + + +# ================== Fu-like tools =========================== + + +def TR0(rv): + """Simplification of rational polynomials, trying to simplify + the expression, e.g. combine things like 3*x + 2*x, etc.... + """ + # although it would be nice to use cancel, it doesn't work + # with noncommutatives + return rv.normal().factor().expand() + + +def TR1(rv): + """Replace sec, csc with 1/cos, 1/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR1, sec, csc + >>> from sympy.abc import x + >>> TR1(2*csc(x) + sec(x)) + 1/cos(x) + 2/sin(x) + """ + + def f(rv): + if isinstance(rv, sec): + a = rv.args[0] + return S.One/cos(a) + elif isinstance(rv, csc): + a = rv.args[0] + return S.One/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2(rv): + """Replace tan and cot with sin/cos and cos/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR2 + >>> from sympy.abc import x + >>> from sympy import tan, cot, sin, cos + >>> TR2(tan(x)) + sin(x)/cos(x) + >>> TR2(cot(x)) + cos(x)/sin(x) + >>> TR2(tan(tan(x) - sin(x)/cos(x))) + 0 + + """ + + def f(rv): + if isinstance(rv, tan): + a = rv.args[0] + return sin(a)/cos(a) + elif isinstance(rv, cot): + a = rv.args[0] + return cos(a)/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2i(rv, half=False): + """Converts ratios involving sin and cos as follows:: + sin(x)/cos(x) -> tan(x) + sin(x)/(cos(x) + 1) -> tan(x/2) if half=True + + Examples + ======== + + >>> from sympy.simplify.fu import TR2i + >>> from sympy.abc import x, a + >>> from sympy import sin, cos + >>> TR2i(sin(x)/cos(x)) + tan(x) + + Powers of the numerator and denominator are also recognized + + >>> TR2i(sin(x)**2/(cos(x) + 1)**2, half=True) + tan(x/2)**2 + + The transformation does not take place unless assumptions allow + (i.e. the base must be positive or the exponent must be an integer + for both numerator and denominator) + + >>> TR2i(sin(x)**a/(cos(x) + 1)**a) + sin(x)**a/(cos(x) + 1)**a + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + n, d = rv.as_numer_denom() + if n.is_Atom or d.is_Atom: + return rv + + def ok(k, e): + # initial filtering of factors + return ( + (e.is_integer or k.is_positive) and ( + k.func in (sin, cos) or (half and + k.is_Add and + len(k.args) >= 2 and + any(any(isinstance(ai, cos) or ai.is_Pow and ai.base is cos + for ai in Mul.make_args(a)) for a in k.args)))) + + n = n.as_powers_dict() + ndone = [(k, n.pop(k)) for k in list(n.keys()) if not ok(k, n[k])] + if not n: + return rv + + d = d.as_powers_dict() + ddone = [(k, d.pop(k)) for k in list(d.keys()) if not ok(k, d[k])] + if not d: + return rv + + # factoring if necessary + + def factorize(d, ddone): + newk = [] + for k in d: + if k.is_Add and len(k.args) > 1: + knew = factor(k) if half else factor_terms(k) + if knew != k: + newk.append((k, knew)) + if newk: + for i, (k, knew) in enumerate(newk): + del d[k] + newk[i] = knew + newk = Mul(*newk).as_powers_dict() + for k in newk: + v = d[k] + newk[k] + if ok(k, v): + d[k] = v + else: + ddone.append((k, v)) + del newk + factorize(n, ndone) + factorize(d, ddone) + + # joining + t = [] + for k in n: + if isinstance(k, sin): + a = cos(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**n[k]) + n[k] = d[a] = None + elif half: + a1 = 1 + a + if a1 in d and d[a1] == n[k]: + t.append((tan(k.args[0]/2))**n[k]) + n[k] = d[a1] = None + elif isinstance(k, cos): + a = sin(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**-n[k]) + n[k] = d[a] = None + elif half and k.is_Add and k.args[0] is S.One and \ + isinstance(k.args[1], cos): + a = sin(k.args[1].args[0], evaluate=False) + if a in d and d[a] == n[k] and (d[a].is_integer or \ + a.is_positive): + t.append(tan(a.args[0]/2)**-n[k]) + n[k] = d[a] = None + + if t: + rv = Mul(*(t + [b**e for b, e in n.items() if e]))/\ + Mul(*[b**e for b, e in d.items() if e]) + rv *= Mul(*[b**e for b, e in ndone])/Mul(*[b**e for b, e in ddone]) + + return rv + + return bottom_up(rv, f) + + +def TR3(rv): + """Induced formula: example sin(-a) = -sin(a) + + Examples + ======== + + >>> from sympy.simplify.fu import TR3 + >>> from sympy.abc import x, y + >>> from sympy import pi + >>> from sympy import cos + >>> TR3(cos(y - x*(y - x))) + cos(x*(x - y) + y) + >>> cos(pi/2 + x) + -sin(x) + >>> cos(30*pi/2 + x) + -cos(x) + + """ + from sympy.simplify.simplify import signsimp + + # Negative argument (already automatic for funcs like sin(-x) -> -sin(x) + # but more complicated expressions can use it, too). Also, trig angles + # between pi/4 and pi/2 are not reduced to an angle between 0 and pi/4. + # The following are automatically handled: + # Argument of type: pi/2 +/- angle + # Argument of type: pi +/- angle + # Argument of type : 2k*pi +/- angle + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + rv = rv.func(signsimp(rv.args[0])) + if not isinstance(rv, TrigonometricFunction): + return rv + if (rv.args[0] - S.Pi/4).is_positive is (S.Pi/2 - rv.args[0]).is_positive is True: + fmap = {cos: sin, sin: cos, tan: cot, cot: tan, sec: csc, csc: sec} + rv = fmap[type(rv)](S.Pi/2 - rv.args[0]) + return rv + + # touch numbers iside of trig functions to let them automatically update + rv = rv.replace( + lambda x: isinstance(x, TrigonometricFunction), + lambda x: x.replace( + lambda n: n.is_number and n.is_Mul, + lambda n: n.func(*n.args))) + + return bottom_up(rv, f) + + +def TR4(rv): + """Identify values of special angles. + + a= 0 pi/6 pi/4 pi/3 pi/2 + ---------------------------------------------------- + sin(a) 0 1/2 sqrt(2)/2 sqrt(3)/2 1 + cos(a) 1 sqrt(3)/2 sqrt(2)/2 1/2 0 + tan(a) 0 sqt(3)/3 1 sqrt(3) -- + + Examples + ======== + + >>> from sympy import pi + >>> from sympy import cos, sin, tan, cot + >>> for s in (0, pi/6, pi/4, pi/3, pi/2): + ... print('%s %s %s %s' % (cos(s), sin(s), tan(s), cot(s))) + ... + 1 0 0 zoo + sqrt(3)/2 1/2 sqrt(3)/3 sqrt(3) + sqrt(2)/2 sqrt(2)/2 1 1 + 1/2 sqrt(3)/2 sqrt(3) sqrt(3)/3 + 0 1 zoo 0 + """ + # special values at 0, pi/6, pi/4, pi/3, pi/2 already handled + return rv.replace( + lambda x: + isinstance(x, TrigonometricFunction) and + (r:=x.args[0]/pi).is_Rational and r.q in (1, 2, 3, 4, 6), + lambda x: + x.func(x.args[0].func(*x.args[0].args))) + + +def _TR56(rv, f, g, h, max, pow): + """Helper for TR5 and TR6 to replace f**2 with h(g**2) + + Options + ======= + + max : controls size of exponent that can appear on f + e.g. if max=4 then f**4 will be changed to h(g**2)**2. + pow : controls whether the exponent must be a perfect power of 2 + e.g. if pow=True (and max >= 6) then f**6 will not be changed + but f**8 will be changed to h(g**2)**4 + + >>> from sympy.simplify.fu import _TR56 as T + >>> from sympy.abc import x + >>> from sympy import sin, cos + >>> h = lambda x: 1 - x + >>> T(sin(x)**3, sin, cos, h, 4, False) + (1 - cos(x)**2)*sin(x) + >>> T(sin(x)**6, sin, cos, h, 6, False) + (1 - cos(x)**2)**3 + >>> T(sin(x)**6, sin, cos, h, 6, True) + sin(x)**6 + >>> T(sin(x)**8, sin, cos, h, 10, True) + (1 - cos(x)**2)**4 + """ + + def _f(rv): + # I'm not sure if this transformation should target all even powers + # or only those expressible as powers of 2. Also, should it only + # make the changes in powers that appear in sums -- making an isolated + # change is not going to allow a simplification as far as I can tell. + if not (rv.is_Pow and rv.base.func == f): + return rv + if not rv.exp.is_real: + return rv + + if (rv.exp < 0) == True: + return rv + if (rv.exp > max) == True: + return rv + if rv.exp == 1: + return rv + if rv.exp == 2: + return h(g(rv.base.args[0])**2) + else: + if rv.exp % 2 == 1: + e = rv.exp//2 + return f(rv.base.args[0])*h(g(rv.base.args[0])**2)**e + elif rv.exp == 4: + e = 2 + elif not pow: + if rv.exp % 2: + return rv + e = rv.exp//2 + else: + p = perfect_power(rv.exp) + if not p: + return rv + e = rv.exp//2 + return h(g(rv.base.args[0])**2)**e + + return bottom_up(rv, _f) + + +def TR5(rv, max=4, pow=False): + """Replacement of sin**2 with 1 - cos(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR5 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR5(sin(x)**2) + 1 - cos(x)**2 + >>> TR5(sin(x)**-2) # unchanged + sin(x)**(-2) + >>> TR5(sin(x)**4) + (1 - cos(x)**2)**2 + """ + return _TR56(rv, sin, cos, lambda x: 1 - x, max=max, pow=pow) + + +def TR6(rv, max=4, pow=False): + """Replacement of cos**2 with 1 - sin(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR6 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR6(cos(x)**2) + 1 - sin(x)**2 + >>> TR6(cos(x)**-2) #unchanged + cos(x)**(-2) + >>> TR6(cos(x)**4) + (1 - sin(x)**2)**2 + """ + return _TR56(rv, cos, sin, lambda x: 1 - x, max=max, pow=pow) + + +def TR7(rv): + """Lowering the degree of cos(x)**2. + + Examples + ======== + + >>> from sympy.simplify.fu import TR7 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR7(cos(x)**2) + cos(2*x)/2 + 1/2 + >>> TR7(cos(x)**2 + 1) + cos(2*x)/2 + 3/2 + + """ + + def f(rv): + if not (rv.is_Pow and rv.base.func == cos and rv.exp == 2): + return rv + return (1 + cos(2*rv.base.args[0]))/2 + + return bottom_up(rv, f) + + +def TR8(rv, first=True): + """Converting products of ``cos`` and/or ``sin`` to a sum or + difference of ``cos`` and or ``sin`` terms. + + Examples + ======== + + >>> from sympy.simplify.fu import TR8 + >>> from sympy import cos, sin + >>> TR8(cos(2)*cos(3)) + cos(5)/2 + cos(1)/2 + >>> TR8(cos(2)*sin(3)) + sin(5)/2 + sin(1)/2 + >>> TR8(sin(2)*sin(3)) + -cos(5)/2 + cos(1)/2 + """ + + def f(rv): + if not ( + rv.is_Mul or + rv.is_Pow and + rv.base.func in (cos, sin) and + (rv.exp.is_integer or rv.base.is_positive)): + return rv + + if first: + n, d = [expand_mul(i) for i in rv.as_numer_denom()] + newn = TR8(n, first=False) + newd = TR8(d, first=False) + if newn != n or newd != d: + rv = gcd_terms(newn/newd) + if rv.is_Mul and rv.args[0].is_Rational and \ + len(rv.args) == 2 and rv.args[1].is_Add: + rv = Mul(*rv.as_coeff_Mul()) + return rv + + args = {cos: [], sin: [], None: []} + for a in Mul.make_args(rv): + if a.func in (cos, sin): + args[type(a)].append(a.args[0]) + elif (a.is_Pow and a.exp.is_Integer and a.exp > 0 and \ + a.base.func in (cos, sin)): + # XXX this is ok but pathological expression could be handled + # more efficiently as in TRmorrie + args[type(a.base)].extend([a.base.args[0]]*a.exp) + else: + args[None].append(a) + c = args[cos] + s = args[sin] + if not (c and s or len(c) > 1 or len(s) > 1): + return rv + + args = args[None] + n = min(len(c), len(s)) + for i in range(n): + a1 = s.pop() + a2 = c.pop() + args.append((sin(a1 + a2) + sin(a1 - a2))/2) + while len(c) > 1: + a1 = c.pop() + a2 = c.pop() + args.append((cos(a1 + a2) + cos(a1 - a2))/2) + if c: + args.append(cos(c.pop())) + while len(s) > 1: + a1 = s.pop() + a2 = s.pop() + args.append((-cos(a1 + a2) + cos(a1 - a2))/2) + if s: + args.append(sin(s.pop())) + return TR8(expand_mul(Mul(*args))) + + return bottom_up(rv, f) + + +def TR9(rv): + """Sum of ``cos`` or ``sin`` terms as a product of ``cos`` or ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR9 + >>> from sympy import cos, sin + >>> TR9(cos(1) + cos(2)) + 2*cos(1/2)*cos(3/2) + >>> TR9(cos(1) + 2*sin(1) + 2*sin(2)) + cos(1) + 4*sin(3/2)*cos(1/2) + + If no change is made by TR9, no re-arrangement of the + expression will be made. For example, though factoring + of common term is attempted, if the factored expression + was not changed, the original expression will be returned: + + >>> TR9(cos(3) + cos(3)*cos(2)) + cos(3) + cos(2)*cos(3) + + """ + + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # cos(a)+/-cos(b) can be combined into a product of cosines and + # sin(a)+/-sin(b) can be combined into a product of cosine and + # sine. + # + # If there are more than two args, the pairs which "work" will + # have a gcd extractable and the remaining two terms will have + # the above structure -- all pairs must be checked to find the + # ones that work. args that don't have a common set of symbols + # are skipped since this doesn't lead to a simpler formula and + # also has the arbitrariness of combining, for example, the x + # and y term instead of the y and z term in something like + # cos(x) + cos(y) + cos(z). + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args) + if not split: + return rv + gcd, n1, n2, a, b, iscos = split + + # application of rule if possible + if iscos: + if n1 == n2: + return gcd*n1*2*cos((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return -2*gcd*sin((a + b)/2)*sin((a - b)/2) + else: + if n1 == n2: + return gcd*n1*2*sin((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return 2*gcd*cos((a + b)/2)*sin((a - b)/2) + + return process_common_addends(rv, do) # DON'T sift by free symbols + + return bottom_up(rv, f) + + +def TR10(rv, first=True): + """Separate sums in ``cos`` and ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10 + >>> from sympy.abc import a, b, c + >>> from sympy import cos, sin + >>> TR10(cos(a + b)) + -sin(a)*sin(b) + cos(a)*cos(b) + >>> TR10(sin(a + b)) + sin(a)*cos(b) + sin(b)*cos(a) + >>> TR10(sin(a + b + c)) + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + f = rv.func + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + if f == sin: + return sin(a)*TR10(cos(b), first=False) + \ + cos(a)*TR10(sin(b), first=False) + else: + return cos(a)*TR10(cos(b), first=False) - \ + sin(a)*TR10(sin(b), first=False) + else: + if f == sin: + return sin(a)*cos(b) + cos(a)*sin(b) + else: + return cos(a)*cos(b) - sin(a)*sin(b) + return rv + + return bottom_up(rv, f) + + +def TR10i(rv): + """Sum of products to function of sum. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10i + >>> from sympy import cos, sin, sqrt + >>> from sympy.abc import x + + >>> TR10i(cos(1)*cos(3) + sin(1)*sin(3)) + cos(2) + >>> TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) + cos(3) + sin(4) + >>> TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) + 2*sqrt(2)*x*sin(x + pi/6) + + """ + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # args which can be expressed as A*(cos(a)*cos(b)+/-sin(a)*sin(b)) + # or B*(cos(a)*sin(b)+/-cos(b)*sin(a)) can be combined into + # A*f(a+/-b) where f is either sin or cos. + # + # If there are more than two args, the pairs which "work" will have + # a gcd extractable and the remaining two terms will have the above + # structure -- all pairs must be checked to find the ones that + # work. + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args, two=True) + if not split: + return rv + gcd, n1, n2, a, b, same = split + + # identify and get c1 to be cos then apply rule if possible + if same: # coscos, sinsin + gcd = n1*gcd + if n1 == n2: + return gcd*cos(a - b) + return gcd*cos(a + b) + else: #cossin, cossin + gcd = n1*gcd + if n1 == n2: + return gcd*sin(a + b) + return gcd*sin(b - a) + + rv = process_common_addends( + rv, do, lambda x: tuple(ordered(x.free_symbols))) + + # need to check for inducible pairs in ratio of sqrt(3):1 that + # appeared in different lists when sorting by coefficient + while rv.is_Add: + byrad = defaultdict(list) + for a in rv.args: + hit = 0 + if a.is_Mul: + for ai in a.args: + if ai.is_Pow and ai.exp is S.Half and \ + ai.base.is_Integer: + byrad[ai].append(a) + hit = 1 + break + if not hit: + byrad[S.One].append(a) + + # no need to check all pairs -- just check for the onees + # that have the right ratio + args = [] + for a in byrad: + for b in [_ROOT3()*a, _invROOT3()]: + if b in byrad: + for i in range(len(byrad[a])): + if byrad[a][i] is None: + continue + for j in range(len(byrad[b])): + if byrad[b][j] is None: + continue + was = Add(byrad[a][i] + byrad[b][j]) + new = do(was) + if new != was: + args.append(new) + byrad[a][i] = None + byrad[b][j] = None + break + if args: + rv = Add(*(args + [Add(*[_f for _f in v if _f]) + for v in byrad.values()])) + else: + rv = do(rv) # final pass to resolve any new inducible pairs + break + + return rv + + return bottom_up(rv, f) + + +def TR11(rv, base=None): + """Function of double angle to product. The ``base`` argument can be used + to indicate what is the un-doubled argument, e.g. if 3*pi/7 is the base + then cosine and sine functions with argument 6*pi/7 will be replaced. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11 + >>> from sympy import cos, sin, pi + >>> from sympy.abc import x + >>> TR11(sin(2*x)) + 2*sin(x)*cos(x) + >>> TR11(cos(2*x)) + -sin(x)**2 + cos(x)**2 + >>> TR11(sin(4*x)) + 4*(-sin(x)**2 + cos(x)**2)*sin(x)*cos(x) + >>> TR11(sin(4*x/3)) + 4*(-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3) + + If the arguments are simply integers, no change is made + unless a base is provided: + + >>> TR11(cos(2)) + cos(2) + >>> TR11(cos(4), 2) + -sin(2)**2 + cos(2)**2 + + There is a subtle issue here in that autosimplification will convert + some higher angles to lower angles + + >>> cos(6*pi/7) + cos(3*pi/7) + -cos(pi/7) + cos(3*pi/7) + + The 6*pi/7 angle is now pi/7 but can be targeted with TR11 by supplying + the 3*pi/7 base: + + >>> TR11(_, 3*pi/7) + -sin(3*pi/7)**2 + cos(3*pi/7)**2 + cos(3*pi/7) + + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + if base: + f = rv.func + t = f(base*2) + co = S.One + if t.is_Mul: + co, t = t.as_coeff_Mul() + if t.func not in (cos, sin): + return rv + if rv.args[0] == t.args[0]: + c = cos(base) + s = sin(base) + if f is cos: + return (c**2 - s**2)/co + else: + return 2*c*s/co + return rv + + elif not rv.args[0].is_Number: + # make a change if the leading coefficient's numerator is + # divisible by 2 + c, m = rv.args[0].as_coeff_Mul(rational=True) + if c.p % 2 == 0: + arg = c.p//2*m/c.q + c = TR11(cos(arg)) + s = TR11(sin(arg)) + if rv.func == sin: + rv = 2*s*c + else: + rv = c**2 - s**2 + return rv + + return bottom_up(rv, f) + + +def _TR11(rv): + """ + Helper for TR11 to find half-arguments for sin in factors of + num/den that appear in cos or sin factors in the den/num. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11, _TR11 + >>> from sympy import cos, sin + >>> from sympy.abc import x + >>> TR11(sin(x/3)/(cos(x/6))) + sin(x/3)/cos(x/6) + >>> _TR11(sin(x/3)/(cos(x/6))) + 2*sin(x/6) + >>> TR11(sin(x/6)/(sin(x/3))) + sin(x/6)/sin(x/3) + >>> _TR11(sin(x/6)/(sin(x/3))) + 1/(2*cos(x/6)) + + """ + def f(rv): + if not isinstance(rv, Expr): + return rv + + def sincos_args(flat): + # find arguments of sin and cos that + # appears as bases in args of flat + # and have Integer exponents + args = defaultdict(set) + for fi in Mul.make_args(flat): + b, e = fi.as_base_exp() + if e.is_Integer and e > 0: + if b.func in (cos, sin): + args[type(b)].add(b.args[0]) + return args + num_args, den_args = map(sincos_args, rv.as_numer_denom()) + def handle_match(rv, num_args, den_args): + # for arg in sin args of num_args, look for arg/2 + # in den_args and pass this half-angle to TR11 + # for handling in rv + for narg in num_args[sin]: + half = narg/2 + if half in den_args[cos]: + func = cos + elif half in den_args[sin]: + func = sin + else: + continue + rv = TR11(rv, half) + den_args[func].remove(half) + return rv + # sin in num, sin or cos in den + rv = handle_match(rv, num_args, den_args) + # sin in den, sin or cos in num + rv = handle_match(rv, den_args, num_args) + return rv + + return bottom_up(rv, f) + + +def TR12(rv, first=True): + """Separate sums in ``tan``. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import tan + >>> from sympy.simplify.fu import TR12 + >>> TR12(tan(x + y)) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + """ + + def f(rv): + if not rv.func == tan: + return rv + + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + tb = TR12(tan(b), first=False) + else: + tb = tan(b) + return (tan(a) + tb)/(1 - tan(a)*tb) + return rv + + return bottom_up(rv, f) + + +def TR12i(rv): + """Combine tan arguments as + (tan(y) + tan(x))/(tan(x)*tan(y) - 1) -> -tan(x + y). + + Examples + ======== + + >>> from sympy.simplify.fu import TR12i + >>> from sympy import tan + >>> from sympy.abc import a, b, c + >>> ta, tb, tc = [tan(i) for i in (a, b, c)] + >>> TR12i((ta + tb)/(-ta*tb + 1)) + tan(a + b) + >>> TR12i((ta + tb)/(ta*tb - 1)) + -tan(a + b) + >>> TR12i((-ta - tb)/(ta*tb - 1)) + tan(a + b) + >>> eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + >>> TR12i(eq.expand()) + -3*tan(a + b)*tan(a + c)/(2*(tan(a) + tan(b) - 1)) + """ + def f(rv): + if not (rv.is_Add or rv.is_Mul or rv.is_Pow): + return rv + + n, d = rv.as_numer_denom() + if not d.args or not n.args: + return rv + + dok = {} + + def ok(di): + m = as_f_sign_1(di) + if m: + g, f, s = m + if s is S.NegativeOne and f.is_Mul and len(f.args) == 2 and \ + all(isinstance(fi, tan) for fi in f.args): + return g, f + + d_args = list(Mul.make_args(d)) + for i, di in enumerate(d_args): + m = ok(di) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = S.One + d_args[i] = g + continue + if di.is_Add: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + elif di.is_Pow and (di.exp.is_integer or di.base.is_positive): + m = ok(di.base) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = di.exp + d_args[i] = g**di.exp + else: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + if not dok: + return rv + + def ok(ni): + if ni.is_Add and len(ni.args) == 2: + a, b = ni.args + if isinstance(a, tan) and isinstance(b, tan): + return a, b + n_args = list(Mul.make_args(factor_terms(n))) + hit = False + for i, ni in enumerate(n_args): + m = ok(ni) + if not m: + m = ok(-ni) + if m: + n_args[i] = S.NegativeOne + else: + if ni.is_Add: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + elif ni.is_Pow and ( + ni.exp.is_integer or ni.base.is_positive): + m = ok(ni.base) + if m: + n_args[i] = S.One + else: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + else: + continue + else: + n_args[i] = S.One + hit = True + s = Add(*[_.args[0] for _ in m]) + ed = dok[s] + newed = ed.extract_additively(S.One) + if newed is not None: + if newed: + dok[s] = newed + else: + dok.pop(s) + n_args[i] *= -tan(s) + + if hit: + rv = Mul(*n_args)/Mul(*d_args)/Mul(*[(Add(*[ + tan(a) for a in i.args]) - 1)**e for i, e in dok.items()]) + + return rv + + return bottom_up(rv, f) + + +def TR13(rv): + """Change products of ``tan`` or ``cot``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR13 + >>> from sympy import tan, cot + >>> TR13(tan(3)*tan(2)) + -tan(2)/tan(5) - tan(3)/tan(5) + 1 + >>> TR13(cot(3)*cot(2)) + cot(2)*cot(5) + 1 + cot(3)*cot(5) + """ + + def f(rv): + if not rv.is_Mul: + return rv + + # XXX handle products of powers? or let power-reducing handle it? + args = {tan: [], cot: [], None: []} + for a in Mul.make_args(rv): + if a.func in (tan, cot): + args[type(a)].append(a.args[0]) + else: + args[None].append(a) + t = args[tan] + c = args[cot] + if len(t) < 2 and len(c) < 2: + return rv + args = args[None] + while len(t) > 1: + t1 = t.pop() + t2 = t.pop() + args.append(1 - (tan(t1)/tan(t1 + t2) + tan(t2)/tan(t1 + t2))) + if t: + args.append(tan(t.pop())) + while len(c) > 1: + t1 = c.pop() + t2 = c.pop() + args.append(1 + cot(t1)*cot(t1 + t2) + cot(t2)*cot(t1 + t2)) + if c: + args.append(cot(c.pop())) + return Mul(*args) + + return bottom_up(rv, f) + + +def TRmorrie(rv): + """Returns cos(x)*cos(2*x)*...*cos(2**(k-1)*x) -> sin(2**k*x)/(2**k*sin(x)) + + Examples + ======== + + >>> from sympy.simplify.fu import TRmorrie, TR8, TR3 + >>> from sympy.abc import x + >>> from sympy import Mul, cos, pi + >>> TRmorrie(cos(x)*cos(2*x)) + sin(4*x)/(4*sin(x)) + >>> TRmorrie(7*Mul(*[cos(x) for x in range(10)])) + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + + Sometimes autosimplification will cause a power to be + not recognized. e.g. in the following, cos(4*pi/7) automatically + simplifies to -cos(3*pi/7) so only 2 of the 3 terms are + recognized: + + >>> TRmorrie(cos(pi/7)*cos(2*pi/7)*cos(4*pi/7)) + -sin(3*pi/7)*cos(3*pi/7)/(4*sin(pi/7)) + + A touch by TR8 resolves the expression to a Rational + + >>> TR8(_) + -1/8 + + In this case, if eq is unsimplified, the answer is obtained + directly: + + >>> eq = cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9) + >>> TRmorrie(eq) + 1/16 + + But if angles are made canonical with TR3 then the answer + is not simplified without further work: + + >>> TR3(eq) + sin(pi/18)*cos(pi/9)*cos(2*pi/9)/2 + >>> TRmorrie(_) + sin(pi/18)*sin(4*pi/9)/(8*sin(pi/9)) + >>> TR8(_) + cos(7*pi/18)/(16*sin(pi/9)) + >>> TR3(_) + 1/16 + + The original expression would have resolve to 1/16 directly with TR8, + however: + + >>> TR8(eq) + 1/16 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Morrie%27s_law + + """ + + def f(rv, first=True): + if not rv.is_Mul: + return rv + if first: + n, d = rv.as_numer_denom() + return f(n, 0)/f(d, 0) + + args = defaultdict(list) + coss = {} + other = [] + for c in rv.args: + b, e = c.as_base_exp() + if e.is_Integer and isinstance(b, cos): + co, a = b.args[0].as_coeff_Mul() + args[a].append(co) + coss[b] = e + else: + other.append(c) + + new = [] + for a in args: + c = args[a] + c.sort() + while c: + k = 0 + cc = ci = c[0] + while cc in c: + k += 1 + cc *= 2 + if k > 1: + newarg = sin(2**k*ci*a)/2**k/sin(ci*a) + # see how many times this can be taken + take = None + ccs = [] + for i in range(k): + cc /= 2 + key = cos(a*cc, evaluate=False) + ccs.append(cc) + take = min(coss[key], take or coss[key]) + # update exponent counts + for i in range(k): + cc = ccs.pop() + key = cos(a*cc, evaluate=False) + coss[key] -= take + if not coss[key]: + c.remove(cc) + new.append(newarg**take) + else: + b = cos(c.pop(0)*a) + other.append(b**coss[b]) + + if new: + rv = Mul(*(new + other + [ + cos(k*a, evaluate=False) for a in args for k in args[a]])) + + return rv + + return bottom_up(rv, f) + + +def TR14(rv, first=True): + """Convert factored powers of sin and cos identities into simpler + expressions. + + Examples + ======== + + >>> from sympy.simplify.fu import TR14 + >>> from sympy.abc import x, y + >>> from sympy import cos, sin + >>> TR14((cos(x) - 1)*(cos(x) + 1)) + -sin(x)**2 + >>> TR14((sin(x) - 1)*(sin(x) + 1)) + -cos(x)**2 + >>> p1 = (cos(x) + 1)*(cos(x) - 1) + >>> p2 = (cos(y) - 1)*2*(cos(y) + 1) + >>> p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + >>> TR14(p1*p2*p3*(x - 1)) + -18*(x - 1)*sin(x)**2*sin(y)**4 + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + if first: + # sort them by location in numerator and denominator + # so the code below can just deal with positive exponents + n, d = rv.as_numer_denom() + if d is not S.One: + newn = TR14(n, first=False) + newd = TR14(d, first=False) + if newn != n or newd != d: + rv = newn/newd + return rv + + other = [] + process = [] + for a in rv.args: + if a.is_Pow: + b, e = a.as_base_exp() + if not (e.is_integer or b.is_positive): + other.append(a) + continue + a = b + else: + e = S.One + m = as_f_sign_1(a) + if not m or m[1].func not in (cos, sin): + if e is S.One: + other.append(a) + else: + other.append(a**e) + continue + g, f, si = m + process.append((g, e.is_Number, e, f, si, a)) + + # sort them to get like terms next to each other + process = list(ordered(process)) + + # keep track of whether there was any change + nother = len(other) + + # access keys + keys = (g, t, e, f, si, a) = list(range(6)) + + while process: + A = process.pop(0) + if process: + B = process[0] + + if A[e].is_Number and B[e].is_Number: + # both exponents are numbers + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = min(A[e], B[e]) + + # reinsert any remainder + # the B will likely sort after A so check it first + if B[e] != take: + rem = [B[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + elif A[e] != take: + rem = [A[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + elif A[e] == B[e]: + # both exponents are equal symbols + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = A[e] + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + # either we are done or neither condition above applied + other.append(A[a]**A[e]) + + if len(other) != nother: + rv = Mul(*other) + + return rv + + return bottom_up(rv, f) + + +def TR15(rv, max=4, pow=False): + """Convert sin(x)**-2 to 1 + cot(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR15 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR15(1 - 1/sin(x)**2) + -cot(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, sin)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, sin, cot, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR16(rv, max=4, pow=False): + """Convert cos(x)**-2 to 1 + tan(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR16 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR16(1 - 1/cos(x)**2) + -tan(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, cos)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, cos, tan, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR111(rv): + """Convert f(x)**-i to g(x)**i where either ``i`` is an integer + or the base is positive and f, g are: tan, cot; sin, csc; or cos, sec. + + Examples + ======== + + >>> from sympy.simplify.fu import TR111 + >>> from sympy.abc import x + >>> from sympy import tan + >>> TR111(1 - 1/tan(x)**2) + 1 - cot(x)**2 + + """ + + def f(rv): + if not ( + isinstance(rv, Pow) and + (rv.base.is_positive or rv.exp.is_integer and rv.exp.is_negative)): + return rv + + if isinstance(rv.base, tan): + return cot(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, sin): + return csc(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, cos): + return sec(rv.base.args[0])**-rv.exp + return rv + + return bottom_up(rv, f) + + +def TR22(rv, max=4, pow=False): + """Convert tan(x)**2 to sec(x)**2 - 1 and cot(x)**2 to csc(x)**2 - 1. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR22 + >>> from sympy.abc import x + >>> from sympy import tan, cot + >>> TR22(1 + tan(x)**2) + sec(x)**2 + >>> TR22(1 + cot(x)**2) + csc(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and rv.base.func in (cot, tan)): + return rv + + rv = _TR56(rv, tan, sec, lambda x: x - 1, max=max, pow=pow) + rv = _TR56(rv, cot, csc, lambda x: x - 1, max=max, pow=pow) + return rv + + return bottom_up(rv, f) + + +def TRpower(rv): + """Convert sin(x)**n and cos(x)**n with positive n to sums. + + Examples + ======== + + >>> from sympy.simplify.fu import TRpower + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> TRpower(sin(x)**6) + -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 + 5/16 + >>> TRpower(sin(x)**3*cos(2*x)**4) + (3*sin(x)/4 - sin(3*x)/4)*(cos(4*x)/2 + cos(8*x)/8 + 3/8) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/List_of_trigonometric_identities#Power-reduction_formulae + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, (sin, cos))): + return rv + b, n = rv.as_base_exp() + x = b.args[0] + if n.is_Integer and n.is_positive: + if n.is_odd and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range((n + 1)/2)]) + elif n.is_odd and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**((n-1)/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*sin((n - 2*k)*x) for k in range((n + 1)/2)]) + elif n.is_even and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range(n/2)]) + elif n.is_even and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**(n/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*cos((n - 2*k)*x) for k in range(n/2)]) + if n.is_even: + rv += 2**(-n)*binomial(n, n/2) + return rv + + return bottom_up(rv, f) + + +def L(rv): + """Return count of trigonometric functions in expression. + + Examples + ======== + + >>> from sympy.simplify.fu import L + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> L(cos(x)+sin(x)) + 2 + """ + return S(rv.count(TrigonometricFunction)) + + +# ============== end of basic Fu-like tools ===================== + +if SYMPY_DEBUG: + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22 + )= list(map(debug, + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22))) + + +# tuples are chains -- (f, g) -> lambda x: g(f(x)) +# lists are choices -- [f, g] -> lambda x: min(f(x), g(x), key=objective) + +CTR1 = [(TR5, TR0), (TR6, TR0), identity] + +CTR2 = (TR11, [(TR5, TR0), (TR6, TR0), TR0]) + +CTR3 = [(TRmorrie, TR8, TR0), (TRmorrie, TR8, TR10i, TR0), identity] + +CTR4 = [(TR4, TR10i), identity] + +RL1 = (TR4, TR3, TR4, TR12, TR4, TR13, TR4, TR0) + + +# XXX it's a little unclear how this one is to be implemented +# see Fu paper of reference, page 7. What is the Union symbol referring to? +# The diagram shows all these as one chain of transformations, but the +# text refers to them being applied independently. Also, a break +# if L starts to increase has not been implemented. +RL2 = [ + (TR4, TR3, TR10, TR4, TR3, TR11), + (TR5, TR7, TR11, TR4), + (CTR3, CTR1, TR9, CTR2, TR4, TR9, TR9, CTR4), + identity, + ] + + +def fu(rv, measure=lambda x: (L(x), x.count_ops())): + """Attempt to simplify expression by using transformation rules given + in the algorithm by Fu et al. + + :func:`fu` will try to minimize the objective function ``measure``. + By default this first minimizes the number of trig terms and then minimizes + the number of total operations. + + Examples + ======== + + >>> from sympy.simplify.fu import fu + >>> from sympy import cos, sin, tan, pi, S, sqrt + >>> from sympy.abc import x, y, a, b + + >>> fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) + 3/2 + >>> fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) + 2*sqrt(2)*sin(x + pi/3) + + CTR1 example + + >>> eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + >>> fu(eq) + cos(x)**4 - 2*cos(y)**2 + 2 + + CTR2 example + + >>> fu(S.Half - cos(2*x)/2) + sin(x)**2 + + CTR3 example + + >>> fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) + sqrt(2)*sin(a + b + pi/4) + + CTR4 example + + >>> fu(sqrt(3)*cos(x)/2 + sin(x)/2) + sin(x + pi/3) + + Example 1 + + >>> fu(1-sin(2*x)**2/4-sin(y)**2-cos(x)**4) + -cos(x)**2 + cos(y)**2 + + Example 2 + + >>> fu(cos(4*pi/9)) + sin(pi/18) + >>> fu(cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9)) + 1/16 + + Example 3 + + >>> fu(tan(7*pi/18)+tan(5*pi/18)-sqrt(3)*tan(5*pi/18)*tan(7*pi/18)) + -sqrt(3) + + Objective function example + + >>> fu(sin(x)/cos(x)) # default objective function + tan(x) + >>> fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) # maximize op count + sin(x)/cos(x) + + References + ========== + + .. [1] https://www.sciencedirect.com/science/article/pii/S0895717706001609 + """ + fRL1 = greedy(RL1, measure) + fRL2 = greedy(RL2, measure) + + was = rv + rv = sympify(rv) + if not isinstance(rv, Expr): + return rv.func(*[fu(a, measure=measure) for a in rv.args]) + rv = TR1(rv) + if rv.has(tan, cot): + rv1 = fRL1(rv) + if (measure(rv1) < measure(rv)): + rv = rv1 + if rv.has(tan, cot): + rv = TR2(rv) + if rv.has(sin, cos): + rv1 = fRL2(rv) + rv2 = TR8(TRmorrie(rv1)) + rv = min([was, rv, rv1, rv2], key=measure) + return min(TR2i(rv), rv, key=measure) + + +def process_common_addends(rv, do, key2=None, key1=True): + """Apply ``do`` to addends of ``rv`` that (if ``key1=True``) share at least + a common absolute value of their coefficient and the value of ``key2`` when + applied to the argument. If ``key1`` is False ``key2`` must be supplied and + will be the only key applied. + """ + + # collect by absolute value of coefficient and key2 + absc = defaultdict(list) + if key1: + for a in rv.args: + c, a = a.as_coeff_Mul() + if c < 0: + c = -c + a = -a # put the sign on `a` + absc[(c, key2(a) if key2 else 1)].append(a) + elif key2: + for a in rv.args: + absc[(S.One, key2(a))].append(a) + else: + raise ValueError('must have at least one key') + + args = [] + hit = False + for k in absc: + v = absc[k] + c, _ = k + if len(v) > 1: + e = Add(*v, evaluate=False) + new = do(e) + if new != e: + e = new + hit = True + args.append(c*e) + else: + args.append(c*v[0]) + if hit: + rv = Add(*args) + + return rv + + +fufuncs = ''' + TR0 TR1 TR2 TR3 TR4 TR5 TR6 TR7 TR8 TR9 TR10 TR10i TR11 + TR12 TR13 L TR2i TRmorrie TR12i + TR14 TR15 TR16 TR111 TR22'''.split() +FU = dict(list(zip(fufuncs, list(map(locals().get, fufuncs))))) + + +@cacheit +def _ROOT2(): + return sqrt(2) + + +@cacheit +def _ROOT3(): + return sqrt(3) + + +@cacheit +def _invROOT3(): + return 1/sqrt(3) + + +def trig_split(a, b, two=False): + """Return the gcd, s1, s2, a1, a2, bool where + + If two is False (default) then:: + a + b = gcd*(s1*f(a1) + s2*f(a2)) where f = cos if bool else sin + else: + if bool, a + b was +/- cos(a1)*cos(a2) +/- sin(a1)*sin(a2) and equals + n1*gcd*cos(a - b) if n1 == n2 else + n1*gcd*cos(a + b) + else a + b was +/- cos(a1)*sin(a2) +/- sin(a1)*cos(a2) and equals + n1*gcd*sin(a + b) if n1 = n2 else + n1*gcd*sin(b - a) + + Examples + ======== + + >>> from sympy.simplify.fu import trig_split + >>> from sympy.abc import x, y, z + >>> from sympy import cos, sin, sqrt + + >>> trig_split(cos(x), cos(y)) + (1, 1, 1, x, y, True) + >>> trig_split(2*cos(x), -2*cos(y)) + (2, 1, -1, x, y, True) + >>> trig_split(cos(x)*sin(y), cos(y)*sin(y)) + (sin(y), 1, 1, x, y, True) + + >>> trig_split(cos(x), -sqrt(3)*sin(x), two=True) + (2, 1, -1, x, pi/6, False) + >>> trig_split(cos(x), sin(x), two=True) + (sqrt(2), 1, 1, x, pi/4, False) + >>> trig_split(cos(x), -sin(x), two=True) + (sqrt(2), 1, -1, x, pi/4, False) + >>> trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) + (2*sqrt(2), 1, -1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) + (-2*sqrt(2), 1, 1, x, pi/3, False) + >>> trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) + (sqrt(6)/3, 1, 1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x)*sin(y), -sqrt(2)*sin(x)*sin(y), two=True) + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + >>> trig_split(cos(x), sin(x)) + >>> trig_split(cos(x), sin(z)) + >>> trig_split(2*cos(x), -sin(x)) + >>> trig_split(cos(x), -sqrt(3)*sin(x)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(z)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(y)) + >>> trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) + """ + a, b = [Factors(i) for i in (a, b)] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + n1 = n2 = 1 + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -n1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n2 = -n2 + a, b = [i.as_expr() for i in (ua, ub)] + + def pow_cos_sin(a, two): + """Return ``a`` as a tuple (r, c, s) such that + ``a = (r or 1)*(c or 1)*(s or 1)``. + + Three arguments are returned (radical, c-factor, s-factor) as + long as the conditions set by ``two`` are met; otherwise None is + returned. If ``two`` is True there will be one or two non-None + values in the tuple: c and s or c and r or s and r or s or c with c + being a cosine function (if possible) else a sine, and s being a sine + function (if possible) else oosine. If ``two`` is False then there + will only be a c or s term in the tuple. + + ``two`` also require that either two cos and/or sin be present (with + the condition that if the functions are the same the arguments are + different or vice versa) or that a single cosine or a single sine + be present with an optional radical. + + If the above conditions dictated by ``two`` are not met then None + is returned. + """ + c = s = None + co = S.One + if a.is_Mul: + co, a = a.as_coeff_Mul() + if len(a.args) > 2 or not two: + return None + if a.is_Mul: + args = list(a.args) + else: + args = [a] + a = args.pop(0) + if isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + elif a.is_Pow and a.exp is S.Half: # autoeval doesn't allow -1/2 + co *= a + else: + return None + if args: + b = args[0] + if isinstance(b, cos): + if c: + s = b + else: + c = b + elif isinstance(b, sin): + if s: + c = b + else: + s = b + elif b.is_Pow and b.exp is S.Half: + co *= b + else: + return None + return co if co is not S.One else None, c, s + elif isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + if c is None and s is None: + return + co = co if co is not S.One else None + return co, c, s + + # get the parts + m = pow_cos_sin(a, two) + if m is None: + return + coa, ca, sa = m + m = pow_cos_sin(b, two) + if m is None: + return + cob, cb, sb = m + + # check them + if (not ca) and cb or ca and isinstance(ca, sin): + coa, ca, sa, cob, cb, sb = cob, cb, sb, coa, ca, sa + n1, n2 = n2, n1 + if not two: # need cos(x) and cos(y) or sin(x) and sin(y) + c = ca or sa + s = cb or sb + if not isinstance(c, s.func): + return None + return gcd, n1, n2, c.args[0], s.args[0], isinstance(c, cos) + else: + if not coa and not cob: + if (ca and cb and sa and sb): + if isinstance(ca, sa.func) is not isinstance(cb, sb.func): + return + args = {j.args for j in (ca, sa)} + if not all(i.args in args for i in (cb, sb)): + return + return gcd, n1, n2, ca.args[0], sa.args[0], isinstance(ca, sa.func) + if ca and sa or cb and sb or \ + two and (ca is None and sa is None or cb is None and sb is None): + return + c = ca or sa + s = cb or sb + if c.args != s.args: + return + if not coa: + coa = S.One + if not cob: + cob = S.One + if coa is cob: + gcd *= _ROOT2() + return gcd, n1, n2, c.args[0], pi/4, False + elif coa/cob == _ROOT3(): + gcd *= 2*cob + return gcd, n1, n2, c.args[0], pi/3, False + elif coa/cob == _invROOT3(): + gcd *= 2*coa + return gcd, n1, n2, c.args[0], pi/6, False + + +def as_f_sign_1(e): + """If ``e`` is a sum that can be written as ``g*(a + s)`` where + ``s`` is ``+/-1``, return ``g``, ``a``, and ``s`` where ``a`` does + not have a leading negative coefficient. + + Examples + ======== + + >>> from sympy.simplify.fu import as_f_sign_1 + >>> from sympy.abc import x + >>> as_f_sign_1(x + 1) + (1, x, 1) + >>> as_f_sign_1(x - 1) + (1, x, -1) + >>> as_f_sign_1(-x + 1) + (-1, x, -1) + >>> as_f_sign_1(-x - 1) + (-1, x, 1) + >>> as_f_sign_1(2*x + 2) + (2, x, 1) + """ + if not e.is_Add or len(e.args) != 2: + return + # exact match + a, b = e.args + if a in (S.NegativeOne, S.One): + g = S.One + if b.is_Mul and b.args[0].is_Number and b.args[0] < 0: + a, b = -a, -b + g = -g + return g, b, a + # gcd match + a, b = [Factors(i) for i in e.args] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -1 + n2 = 1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n1 = 1 + n2 = -1 + else: + n1 = n2 = 1 + a, b = [i.as_expr() for i in (ua, ub)] + if a is S.One: + a, b = b, a + n1, n2 = n2, n1 + if n1 == -1: + gcd = -gcd + n2 = -n2 + + if b is S.One: + return gcd, a, n2 + + +def _osborne(e, d): + """Replace all hyperbolic functions with trig functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, HyperbolicFunction): + return rv + a = rv.args[0] + a = a*d if not a.is_Add else Add._from_args([i*d for i in a.args]) + if isinstance(rv, sinh): + return I*sin(a) + elif isinstance(rv, cosh): + return cos(a) + elif isinstance(rv, tanh): + return I*tan(a) + elif isinstance(rv, coth): + return cot(a)/I + elif isinstance(rv, sech): + return sec(a) + elif isinstance(rv, csch): + return csc(a)/I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def _osbornei(e, d): + """Replace all trig functions with hyperbolic functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + const, x = rv.args[0].as_independent(d, as_Add=True) + a = x.xreplace({d: S.One}) + const*I + if isinstance(rv, sin): + return sinh(a)/I + elif isinstance(rv, cos): + return cosh(a) + elif isinstance(rv, tan): + return tanh(a)/I + elif isinstance(rv, cot): + return coth(a)*I + elif isinstance(rv, sec): + return sech(a) + elif isinstance(rv, csc): + return csch(a)*I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def hyper_as_trig(rv): + """Return an expression containing hyperbolic functions in terms + of trigonometric functions. Any trigonometric functions initially + present are replaced with Dummy symbols and the function to undo + the masking and the conversion back to hyperbolics is also returned. It + should always be true that:: + + t, f = hyper_as_trig(expr) + expr == f(t) + + Examples + ======== + + >>> from sympy.simplify.fu import hyper_as_trig, fu + >>> from sympy.abc import x + >>> from sympy import cosh, sinh + >>> eq = sinh(x)**2 + cosh(x)**2 + >>> t, f = hyper_as_trig(eq) + >>> f(fu(t)) + cosh(2*x) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + from sympy.simplify.simplify import signsimp + from sympy.simplify.radsimp import collect + + # mask off trig functions + trigs = rv.atoms(TrigonometricFunction) + reps = [(t, Dummy()) for t in trigs] + masked = rv.xreplace(dict(reps)) + + # get inversion substitutions in place + reps = [(v, k) for k, v in reps] + + d = Dummy() + + return _osborne(masked, d), lambda x: collect(signsimp( + _osbornei(x, d).xreplace(dict(reps))), S.ImaginaryUnit) + + +def sincos_to_sum(expr): + """Convert products and powers of sin and cos to sums. + + Explanation + =========== + + Applied power reduction TRpower first, then expands products, and + converts products to sums with TR8. + + Examples + ======== + + >>> from sympy.simplify.fu import sincos_to_sum + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> sincos_to_sum(16*sin(x)**3*cos(2*x)**2) + 7*sin(x) - 5*sin(3*x) + 3*sin(5*x) - sin(7*x) + """ + + if not expr.has(cos, sin): + return expr + else: + return TR8(expand_mul(TRpower(expr))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/gammasimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..aec20c56eb60efb8e1aadfb5bff3d1ba1ab51869 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/gammasimp.py @@ -0,0 +1,493 @@ +from sympy.core import Function, S, Mul, Pow, Add +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.function import expand_func +from sympy.core.symbol import Dummy +from sympy.functions import gamma, sqrt, sin +from sympy.polys import factor, cancel +from sympy.utilities.iterables import sift, uniq + + +def gammasimp(expr): + r""" + Simplify expressions with gamma functions. + + Explanation + =========== + + This function takes as input an expression containing gamma + functions or functions that can be rewritten in terms of gamma + functions and tries to minimize the number of those functions and + reduce the size of their arguments. + + The algorithm works by rewriting all gamma functions as expressions + involving rising factorials (Pochhammer symbols) and applies + recurrence relations and other transformations applicable to rising + factorials, to reduce their arguments, possibly letting the resulting + rising factorial to cancel. Rising factorials with the second argument + being an integer are expanded into polynomial forms and finally all + other rising factorial are rewritten in terms of gamma functions. + + Then the following two steps are performed. + + 1. Reduce the number of gammas by applying the reflection theorem + gamma(x)*gamma(1-x) == pi/sin(pi*x). + 2. Reduce the number of gammas by applying the multiplication theorem + gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x). + + It then reduces the number of prefactors by absorbing them into gammas + where possible and expands gammas with rational argument. + + All transformation rules can be found (or were derived from) here: + + .. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/ + .. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/ + + Examples + ======== + + >>> from sympy.simplify import gammasimp + >>> from sympy import gamma, Symbol + >>> from sympy.abc import x + >>> n = Symbol('n', integer = True) + + >>> gammasimp(gamma(x)/gamma(x - 3)) + (x - 3)*(x - 2)*(x - 1) + >>> gammasimp(gamma(n + 3)) + gamma(n + 3) + + """ + + expr = expr.rewrite(gamma) + + # compute_ST will be looking for Functions and we don't want + # it looking for non-gamma functions: issue 22606 + # so we mask free, non-gamma functions + f = expr.atoms(Function) + # take out gammas + gammas = {i for i in f if isinstance(i, gamma)} + if not gammas: + return expr # avoid side effects like factoring + f -= gammas + # keep only those without bound symbols + f = f & expr.as_dummy().atoms(Function) + if f: + dum, fun, simp = zip(*[ + (Dummy(), fi, fi.func(*[ + _gammasimp(a, as_comb=False) for a in fi.args])) + for fi in ordered(f)]) + d = expr.xreplace(dict(zip(fun, dum))) + return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp))) + + return _gammasimp(expr, as_comb=False) + + +def _gammasimp(expr, as_comb): + """ + Helper function for gammasimp and combsimp. + + Explanation + =========== + + Simplifies expressions written in terms of gamma function. If + as_comb is True, it tries to preserve integer arguments. See + docstring of gammasimp for more information. This was part of + combsimp() in combsimp.py. + """ + expr = expr.replace(gamma, + lambda n: _rf(1, (n - 1).expand())) + + if as_comb: + expr = expr.replace(_rf, + lambda a, b: gamma(b + 1)) + else: + expr = expr.replace(_rf, + lambda a, b: gamma(a + b)/gamma(a)) + + def rule_gamma(expr, level=0): + """ Simplify products of gamma functions further. """ + + if expr.is_Atom: + return expr + + def gamma_rat(x): + # helper to simplify ratios of gammas + was = x.count(gamma) + xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand() + ).replace(_rf, lambda a, b: gamma(a + b)/gamma(a))) + if xx.count(gamma) < was: + x = xx + return x + + def gamma_factor(x): + # return True if there is a gamma factor in shallow args + if isinstance(x, gamma): + return True + if x.is_Add or x.is_Mul: + return any(gamma_factor(xi) for xi in x.args) + if x.is_Pow and (x.exp.is_integer or x.base.is_positive): + return gamma_factor(x.base) + return False + + # recursion step + if level == 0: + expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args]) + level += 1 + + if not expr.is_Mul: + return expr + + # non-commutative step + if level == 1: + args, nc = expr.args_cnc() + if not args: + return expr + if nc: + return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc) + level += 1 + + # pure gamma handling, not factor absorption + if level == 2: + T, F = sift(expr.args, gamma_factor, binary=True) + gamma_ind = Mul(*F) + d = Mul(*T) + + nd, dd = d.as_numer_denom() + for ipass in range(2): + args = list(ordered(Mul.make_args(nd))) + for i, ni in enumerate(args): + if ni.is_Add: + ni, dd = Add(*[ + rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args] + ).as_numer_denom() + args[i] = ni + if not dd.has(gamma): + break + nd = Mul(*args) + if ipass == 0 and not gamma_factor(nd): + break + nd, dd = dd, nd # now process in reversed order + expr = gamma_ind*nd/dd + if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))): + return expr + level += 1 + + # iteration until constant + if level == 3: + while True: + was = expr + expr = rule_gamma(expr, 4) + if expr == was: + return expr + + numer_gammas = [] + denom_gammas = [] + numer_others = [] + denom_others = [] + def explicate(p): + if p is S.One: + return None, [] + b, e = p.as_base_exp() + if e.is_Integer: + if isinstance(b, gamma): + return True, [b.args[0]]*e + else: + return False, [b]*e + else: + return False, [p] + + newargs = list(ordered(expr.args)) + while newargs: + n, d = newargs.pop().as_numer_denom() + isg, l = explicate(n) + if isg: + numer_gammas.extend(l) + elif isg is False: + numer_others.extend(l) + isg, l = explicate(d) + if isg: + denom_gammas.extend(l) + elif isg is False: + denom_others.extend(l) + + # =========== level 2 work: pure gamma manipulation ========= + + if not as_comb: + # Try to reduce the number of gamma factors by applying the + # reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x) + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g1 = gammas.pop() + if g1.is_integer: + new.append(g1) + continue + for i, g2 in enumerate(gammas): + n = g1 + g2 - 1 + if not n.is_Integer: + continue + numer.append(S.Pi) + denom.append(sin(S.Pi*g1)) + gammas.pop(i) + if n > 0: + numer.extend(1 - g1 + k for k in range(n)) + elif n < 0: + denom.extend(-g1 - k for k in range(-n)) + break + else: + new.append(g1) + # /!\ updating IN PLACE + gammas[:] = new + + # Try to reduce the number of gammas by using the duplication + # theorem to cancel an upper and lower: gamma(2*s)/gamma(s) = + # 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could + # be done with higher argument ratios like gamma(3*x)/gamma(x), + # this would not reduce the number of gammas as in this case. + for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others, + denom_others), + (denom_gammas, numer_gammas, denom_others, + numer_others)]: + + while True: + for x in ng: + for y in dg: + n = x - 2*y + if n.is_Integer: + break + else: + continue + break + else: + break + ng.remove(x) + dg.remove(y) + if n > 0: + no.extend(2*y + k for k in range(n)) + elif n < 0: + do.extend(2*y - 1 - k for k in range(-n)) + ng.append(y + S.Half) + no.append(2**(2*y - 1)) + do.append(sqrt(S.Pi)) + + # Try to reduce the number of gamma factors by applying the + # multiplication theorem (used when n gammas with args differing + # by 1/n mod 1 are encountered). + # + # run of 2 with args differing by 1/2 + # + # >>> gammasimp(gamma(x)*gamma(x+S.Half)) + # 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x) + # + # run of 3 args differing by 1/3 (mod 1) + # + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3)) + # 6*3**(-3*x - 1/2)*pi*gamma(3*x) + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3)) + # 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x) + # + def _run(coeffs): + # find runs in coeffs such that the difference in terms (mod 1) + # of t1, t2, ..., tn is 1/n + u = list(uniq(coeffs)) + for i in range(len(u)): + dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))]) + for one, j in dj: + if one.p == 1 and one.q != 1: + n = one.q + got = [i] + get = list(range(1, n)) + for d, j in dj: + m = n*d + if m.is_Integer and m in get: + get.remove(m) + got.append(j) + if not get: + break + else: + continue + for i, j in enumerate(got): + c = u[j] + coeffs.remove(c) + got[i] = c + return one.q, got[0], got[1:] + + def _mult_thm(gammas, numer, denom): + # pull off and analyze the leading coefficient from each gamma arg + # looking for runs in those Rationals + + # expr -> coeff + resid -> rats[resid] = coeff + rats = {} + for g in gammas: + c, resid = g.as_coeff_Add() + rats.setdefault(resid, []).append(c) + + # look for runs in Rationals for each resid + keys = sorted(rats, key=default_sort_key) + for resid in keys: + coeffs = sorted(rats[resid]) + new = [] + while True: + run = _run(coeffs) + if run is None: + break + + # process the sequence that was found: + # 1) convert all the gamma functions to have the right + # argument (could be off by an integer) + # 2) append the factors corresponding to the theorem + # 3) append the new gamma function + + n, ui, other = run + + # (1) + for u in other: + con = resid + u - 1 + for k in range(int(u - ui)): + numer.append(con - k) + + con = n*(resid + ui) # for (2) and (3) + + # (2) + numer.append((2*S.Pi)**(S(n - 1)/2)* + n**(S.Half - con)) + # (3) + new.append(con) + + # restore resid to coeffs + rats[resid] = [resid + c for c in coeffs] + new + + # rebuild the gamma arguments + g = [] + for resid in keys: + g += rats[resid] + # /!\ updating IN PLACE + gammas[:] = g + + for l, numer, denom in [(numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + _mult_thm(l, numer, denom) + + # =========== level >= 2 work: factor absorption ========= + + if level >= 2: + # Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1) + # and gamma(x)/(x - 1) -> gamma(x - 1) + # This code (in particular repeated calls to find_fuzzy) can be very + # slow. + def find_fuzzy(l, x): + if not l: + return + S1, T1 = compute_ST(x) + for y in l: + S2, T2 = inv[y] + if T1 != T2 or (not S1.intersection(S2) and + (S1 != set() or S2 != set())): + continue + # XXX we want some simplification (e.g. cancel or + # simplify) but no matter what it's slow. + a = len(cancel(x/y).free_symbols) + b = len(x.free_symbols) + c = len(y.free_symbols) + # TODO is there a better heuristic? + if a == 0 and (b > 0 or c > 0): + return y + + # We thus try to avoid expensive calls by building the following + # "invariants": For every factor or gamma function argument + # - the set of free symbols S + # - the set of functional components T + # We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset + # or S1 == S2 == emptyset) + inv = {} + + def compute_ST(expr): + if expr in inv: + return inv[expr] + return (expr.free_symbols, expr.atoms(Function).union( + {e.exp for e in expr.atoms(Pow)})) + + def update_ST(expr): + inv[expr] = compute_ST(expr) + for expr in numer_gammas + denom_gammas + numer_others + denom_others: + update_ST(expr) + + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g = gammas.pop() + cont = True + while cont: + cont = False + y = find_fuzzy(numer, g) + if y is not None: + numer.remove(y) + if y != g: + numer.append(y/g) + update_ST(y/g) + g += 1 + cont = True + y = find_fuzzy(denom, g - 1) + if y is not None: + denom.remove(y) + if y != g - 1: + numer.append((g - 1)/y) + update_ST((g - 1)/y) + g -= 1 + cont = True + new.append(g) + # /!\ updating IN PLACE + gammas[:] = new + + # =========== rebuild expr ================================== + + return Mul(*[gamma(g) for g in numer_gammas]) \ + / Mul(*[gamma(g) for g in denom_gammas]) \ + * Mul(*numer_others) / Mul(*denom_others) + + was = factor(expr) + # (for some reason we cannot use Basic.replace in this case) + expr = rule_gamma(was) + if expr != was: + expr = factor(expr) + + expr = expr.replace(gamma, + lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n)) + + return expr + + +class _rf(Function): + @classmethod + def eval(cls, a, b): + if b.is_Integer: + if not b: + return S.One + + n = int(b) + + if n > 0: + return Mul(*[a + i for i in range(n)]) + elif n < 0: + return 1/Mul(*[a - i for i in range(1, -n + 1)]) + else: + if b.is_Add: + c, _b = b.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(a, _b)*_rf(a + _b, c) + elif c < 0: + return _rf(a, _b)/_rf(a + _b + c, -c) + + if a.is_Add: + c, _a = a.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c) + elif c < 0: + return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c070aa2e44b92794107b3e33df897813a54307b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand.py @@ -0,0 +1,2494 @@ +""" +Expand Hypergeometric (and Meijer G) functions into named +special functions. + +The algorithm for doing this uses a collection of lookup tables of +hypergeometric functions, and various of their properties, to expand +many hypergeometric functions in terms of special functions. + +It is based on the following paper: + Kelly B. Roach. Meijer G Function Representations. + In: Proceedings of the 1997 International Symposium on Symbolic and + Algebraic Computation, pages 205-211, New York, 1997. ACM. + +It is described in great(er) detail in the Sphinx documentation. +""" +# SUMMARY OF EXTENSIONS FOR MEIJER G FUNCTIONS +# +# o z**rho G(ap, bq; z) = G(ap + rho, bq + rho; z) +# +# o denote z*d/dz by D +# +# o It is helpful to keep in mind that ap and bq play essentially symmetric +# roles: G(1/z) has slightly altered parameters, with ap and bq interchanged. +# +# o There are four shift operators: +# A_J = b_J - D, J = 1, ..., n +# B_J = 1 - a_j + D, J = 1, ..., m +# C_J = -b_J + D, J = m+1, ..., q +# D_J = a_J - 1 - D, J = n+1, ..., p +# +# A_J, C_J increment b_J +# B_J, D_J decrement a_J +# +# o The corresponding four inverse-shift operators are defined if there +# is no cancellation. Thus e.g. an index a_J (upper or lower) can be +# incremented if a_J != b_i for i = 1, ..., q. +# +# o Order reduction: if b_j - a_i is a non-negative integer, where +# j <= m and i > n, the corresponding quotient of gamma functions reduces +# to a polynomial. Hence the G function can be expressed using a G-function +# of lower order. +# Similarly if j > m and i <= n. +# +# Secondly, there are paired index theorems [Adamchik, The evaluation of +# integrals of Bessel functions via G-function identities]. Suppose there +# are three parameters a, b, c, where a is an a_i, i <= n, b is a b_j, +# j <= m and c is a denominator parameter (i.e. a_i, i > n or b_j, j > m). +# Suppose further all three differ by integers. +# Then the order can be reduced. +# TODO work this out in detail. +# +# o An index quadruple is called suitable if its order cannot be reduced. +# If there exists a sequence of shift operators transforming one index +# quadruple into another, we say one is reachable from the other. +# +# o Deciding if one index quadruple is reachable from another is tricky. For +# this reason, we use hand-built routines to match and instantiate formulas. +# +from collections import defaultdict +from itertools import product +from functools import reduce +from math import prod + +from sympy import SYMPY_DEBUG +from sympy.core import (S, Dummy, symbols, sympify, Tuple, expand, I, pi, Mul, + EulerGamma, oo, zoo, expand_func, Add, nan, Expr, Rational) +from sympy.core.mod import Mod +from sympy.core.sorting import default_sort_key +from sympy.functions import (exp, sqrt, root, log, lowergamma, cos, + besseli, gamma, uppergamma, expint, erf, sin, besselj, Ei, Ci, Si, Shi, + sinh, cosh, Chi, fresnels, fresnelc, polar_lift, exp_polar, floor, ceiling, + rf, factorial, lerchphi, Piecewise, re, elliptic_k, elliptic_e) +from sympy.functions.elementary.complexes import polarify, unpolarify +from sympy.functions.special.hyper import (hyper, HyperRep_atanh, + HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1, + HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2, + HyperRep_cosasin, HyperRep_sinasin, meijerg) +from sympy.matrices import Matrix, eye, zeros +from sympy.polys import apart, poly, Poly +from sympy.series import residue +from sympy.simplify.powsimp import powdenest +from sympy.utilities.iterables import sift + +# function to define "buckets" +def _mod1(x): + # TODO see if this can work as Mod(x, 1); this will require + # different handling of the "buckets" since these need to + # be sorted and that fails when there is a mixture of + # integers and expressions with parameters. With the current + # Mod behavior, Mod(k, 1) == Mod(1, 1) == 0 if k is an integer. + # Although the sorting can be done with Basic.compare, this may + # still require different handling of the sorted buckets. + if x.is_Number: + return Mod(x, 1) + c, x = x.as_coeff_Add() + return Mod(c, 1) + x + + +# leave add formulae at the top for easy reference +def add_formulae(formulae): + """ Create our knowledge base. """ + a, b, c, z = symbols('a b c, z', cls=Dummy) + + def add(ap, bq, res): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, res, (a, b, c))) + + def addb(ap, bq, B, C, M): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, None, (a, b, c), B, C, M)) + + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + # 0F0 + add((), (), exp(z)) + + # 1F0 + add((a, ), (), HyperRep_power1(-a, z)) + + # 2F1 + addb((a, a - S.Half), (2*a, ), + Matrix([HyperRep_power2(a, z), + HyperRep_power2(a + S.Half, z)/2]), + Matrix([[1, 0]]), + Matrix([[(a - S.Half)*z/(1 - z), (S.Half - a)*z/(1 - z)], + [a/(1 - z), a*(z - 2)/(1 - z)]])) + addb((1, 1), (2, ), + Matrix([HyperRep_log1(z), 1]), Matrix([[-1/z, 0]]), + Matrix([[0, z/(z - 1)], [0, 0]])) + addb((S.Half, 1), (S('3/2'), ), + Matrix([HyperRep_atanh(z), 1]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), 1/(1 - z)/2], [0, 0]])) + addb((S.Half, S.Half), (S('3/2'), ), + Matrix([HyperRep_asin1(z), HyperRep_power1(Rational(-1, 2), z)]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), S.Half], [0, z/(1 - z)/2]])) + addb((a, S.Half + a), (S.Half, ), + Matrix([HyperRep_sqrts1(-a, z), -HyperRep_sqrts2(-a - S.Half, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], + [z*(-2*a - 1)/2/(1 - z), S.Half - z*(-2*a - 1)/(1 - z)]])) + + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + addb([a, -a], [S.Half], + Matrix([HyperRep_cosasin(a, z), HyperRep_sinasin(a, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], [a*z/(1 - z), 1/(1 - z)/2]])) + addb([1, 1], [3*S.Half], + Matrix([HyperRep_asin2(z), 1]), Matrix([[1, 0]]), + Matrix([[(z - S.Half)/(1 - z), 1/(1 - z)/2], [0, 0]])) + + # Complete elliptic integrals K(z) and E(z), both a 2F1 function + addb([S.Half, S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[2/pi, 0]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + addb([Rational(-1, 2), S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[0, 2/pi]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + + # 3F2 + addb([Rational(-1, 2), 1, 1], [S.Half, 2], + Matrix([z*HyperRep_atanh(z), HyperRep_log1(z), 1]), + Matrix([[Rational(-2, 3), -S.One/(3*z), Rational(2, 3)]]), + Matrix([[S.Half, 0, z/(1 - z)/2], + [0, 0, z/(z - 1)], + [0, 0, 0]])) + # actually the formula for 3/2 is much nicer ... + addb([Rational(-1, 2), 1, 1], [2, 2], + Matrix([HyperRep_power1(S.Half, z), HyperRep_log2(z), 1]), + Matrix([[Rational(4, 9) - 16/(9*z), 4/(3*z), 16/(9*z)]]), + Matrix([[z/2/(z - 1), 0, 0], [1/(2*(z - 1)), 0, S.Half], [0, 0, 0]])) + + # 1F1 + addb([1], [b], Matrix([z**(1 - b) * exp(z) * lowergamma(b - 1, z), 1]), + Matrix([[b - 1, 0]]), Matrix([[1 - b + z, 1], [0, 0]])) + addb([a], [2*a], + Matrix([z**(S.Half - a)*exp(z/2)*besseli(a - S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a), + z**(S.Half - a)*exp(z/2)*besseli(a + S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a)]), + Matrix([[1, 0]]), + Matrix([[z/2, z/2], [z/2, (z/2 - 2*a)]])) + mz = polar_lift(-1)*z + addb([a], [a + 1], + Matrix([mz**(-a)*a*lowergamma(a, mz), a*exp(z)]), + Matrix([[1, 0]]), + Matrix([[-a, 1], [0, z]])) + # This one is redundant. + add([Rational(-1, 2)], [S.Half], exp(z) - sqrt(pi*z)*(-I)*erf(I*sqrt(z))) + + # Added to get nice results for Laplace transform of Fresnel functions + # https://functions.wolfram.com/07.22.03.6437.01 + # Basic rule + #add([1], [Rational(3, 4), Rational(5, 4)], + # sqrt(pi) * (cos(2*sqrt(polar_lift(-1)*z))*fresnelc(2*root(polar_lift(-1)*z,4)/sqrt(pi)) + + # sin(2*sqrt(polar_lift(-1)*z))*fresnels(2*root(polar_lift(-1)*z,4)/sqrt(pi))) + # / (2*root(polar_lift(-1)*z,4))) + # Manually tuned rule + addb([1], [Rational(3, 4), Rational(5, 4)], + Matrix([ sqrt(pi)*(I*sinh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + cosh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + * exp(-I*pi/4)/(2*root(z, 4)), + sqrt(pi)*root(z, 4)*(sinh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + I*cosh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + *exp(-I*pi/4)/2, + 1 ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), 1, Rational(1, 4)], + [ z, Rational(1, 4), 0], + [ 0, 0, 0]])) + + # 2F2 + addb([S.Half, a], [Rational(3, 2), a + 1], + Matrix([a/(2*a - 1)*(-I)*sqrt(pi/z)*erf(I*sqrt(z)), + a/(2*a - 1)*(polar_lift(-1)*z)**(-a)* + lowergamma(a, polar_lift(-1)*z), + a/(2*a - 1)*exp(z)]), + Matrix([[1, -1, 0]]), + Matrix([[Rational(-1, 2), 0, 1], [0, -a, 1], [0, 0, z]])) + # We make a "basis" of four functions instead of three, and give EulerGamma + # an extra slot (it could just be a coefficient to 1). The advantage is + # that this way Polys will not see multivariate polynomials (it treats + # EulerGamma as an indeterminate), which is *way* faster. + addb([1, 1], [2, 2], + Matrix([Ei(z) - log(z), exp(z), 1, EulerGamma]), + Matrix([[1/z, 0, 0, -1/z]]), + Matrix([[0, 1, -1, 0], [0, z, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])) + + # 0F1 + add((), (S.Half, ), cosh(2*sqrt(z))) + addb([], [b], + Matrix([gamma(b)*z**((1 - b)/2)*besseli(b - 1, 2*sqrt(z)), + gamma(b)*z**(1 - b/2)*besseli(b, 2*sqrt(z))]), + Matrix([[1, 0]]), Matrix([[0, 1], [z, (1 - b)]])) + + # 0F3 + x = 4*z**Rational(1, 4) + + def fp(a, z): + return besseli(a, x) + besselj(a, x) + + def fm(a, z): + return besseli(a, x) - besselj(a, x) + + # TODO branching + addb([], [S.Half, a, a + S.Half], + Matrix([fp(2*a - 1, z), fm(2*a, z)*z**Rational(1, 4), + fm(2*a - 1, z)*sqrt(z), fp(2*a, z)*z**Rational(3, 4)]) + * 2**(-2*a)*gamma(2*a)*z**((1 - 2*a)/4), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, 1, 0, 0], + [0, S.Half - a, 1, 0], + [0, 0, S.Half, 1], + [z, 0, 0, 1 - a]])) + x = 2*(4*z)**Rational(1, 4)*exp_polar(I*pi/4) + addb([], [a, a + S.Half, 2*a], + (2*sqrt(polar_lift(-1)*z))**(1 - 2*a)*gamma(2*a)**2 * + Matrix([besselj(2*a - 1, x)*besseli(2*a - 1, x), + x*(besseli(2*a, x)*besselj(2*a - 1, x) + - besseli(2*a - 1, x)*besselj(2*a, x)), + x**2*besseli(2*a, x)*besselj(2*a, x), + x**3*(besseli(2*a, x)*besselj(2*a - 1, x) + + besseli(2*a - 1, x)*besselj(2*a, x))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, Rational(1, 4), 0, 0], + [0, (1 - 2*a)/2, Rational(-1, 2), 0], + [0, 0, 1 - 2*a, Rational(1, 4)], + [-32*z, 0, 0, 1 - a]])) + + # 1F2 + addb([a], [a - S.Half, 2*a], + Matrix([z**(S.Half - a)*besseli(a - S.Half, sqrt(z))**2, + z**(1 - a)*besseli(a - S.Half, sqrt(z)) + *besseli(a - Rational(3, 2), sqrt(z)), + z**(Rational(3, 2) - a)*besseli(a - Rational(3, 2), sqrt(z))**2]), + Matrix([[-gamma(a + S.Half)**2/4**(S.Half - a), + 2*gamma(a - S.Half)*gamma(a + S.Half)/4**(1 - a), + 0]]), + Matrix([[1 - 2*a, 1, 0], [z/2, S.Half - a, S.Half], [0, z, 0]])) + addb([S.Half], [b, 2 - b], + pi*(1 - b)/sin(pi*b)* + Matrix([besseli(1 - b, sqrt(z))*besseli(b - 1, sqrt(z)), + sqrt(z)*(besseli(-b, sqrt(z))*besseli(b - 1, sqrt(z)) + + besseli(1 - b, sqrt(z))*besseli(b, sqrt(z))), + besseli(-b, sqrt(z))*besseli(b, sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[b - 1, S.Half, 0], + [z, 0, z], + [0, S.Half, -b]])) + addb([S.Half], [Rational(3, 2), Rational(3, 2)], + Matrix([Shi(2*sqrt(z))/2/sqrt(z), sinh(2*sqrt(z))/2/sqrt(z), + cosh(2*sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 2), S.Half, 0], [0, Rational(-1, 2), S.Half], [0, 2*z, 0]])) + + # FresnelS + # Basic rule + #add([Rational(3, 4)], [Rational(3, 2),Rational(7, 4)], 6*fresnels( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( pi * (exp(pi*I/4)*root(z,4)*2/sqrt(pi))**3 ) ) + # Manually tuned rule + addb([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)], + Matrix( + [ fresnels( + exp( + pi*I/4)*root( + z, 4)*2/sqrt( + pi) ) / ( + pi * (exp(pi*I/4)*root(z, 4)*2/sqrt(pi))**3 ), + sinh(2*sqrt(z))/sqrt(z), + cosh(2*sqrt(z)) ]), + Matrix([[6, 0, 0]]), + Matrix([[Rational(-3, 4), Rational(1, 16), 0], + [ 0, Rational(-1, 2), 1], + [ 0, z, 0]])) + + # FresnelC + # Basic rule + #add([Rational(1, 4)], [S.Half,Rational(5, 4)], fresnelc( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) ) + # Manually tuned rule + addb([Rational(1, 4)], [S.Half, Rational(5, 4)], + Matrix( + [ sqrt( + pi)*exp( + -I*pi/4)*fresnelc( + 2*root(z, 4)*exp(I*pi/4)/sqrt(pi))/(2*root(z, 4)), + cosh(2*sqrt(z)), + sinh(2*sqrt(z))*sqrt(z) ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), Rational(1, 4), 0 ], + [ 0, 0, 1 ], + [ 0, z, S.Half]])) + + # 2F3 + # XXX with this five-parameter formula is pretty slow with the current + # Formula.find_instantiations (creates 2!*3!*3**(2+3) ~ 3000 + # instantiations ... But it's not too bad. + addb([a, a + S.Half], [2*a, b, 2*a - b + 1], + gamma(b)*gamma(2*a - b + 1) * (sqrt(z)/2)**(1 - 2*a) * + Matrix([besseli(b - 1, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b - 1, sqrt(z))*besseli(2*a - b + 1, sqrt(z)), + besseli(b, sqrt(z))*besseli(2*a - b + 1, sqrt(z))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, S.Half, S.Half, 0], + [z/2, 1 - b, 0, z/2], + [z/2, 0, b - 2*a, z/2], + [0, S.Half, S.Half, -2*a]])) + # (C/f above comment about eulergamma in the basis). + addb([1, 1], [2, 2, Rational(3, 2)], + Matrix([Chi(2*sqrt(z)) - log(2*sqrt(z)), + cosh(2*sqrt(z)), sqrt(z)*sinh(2*sqrt(z)), 1, EulerGamma]), + Matrix([[1/z, 0, 0, 0, -1/z]]), + Matrix([[0, S.Half, 0, Rational(-1, 2), 0], + [0, 0, 1, 0, 0], + [0, z, S.Half, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]])) + + # 3F3 + # This is rule: https://functions.wolfram.com/07.31.03.0134.01 + # Initial reason to add it was a nice solution for + # integrate(erf(a*z)/z**2, z) and same for erfc and erfi. + # Basic rule + # add([1, 1, a], [2, 2, a+1], (a/(z*(a-1)**2)) * + # (1 - (-z)**(1-a) * (gamma(a) - uppergamma(a,-z)) + # - (a-1) * (EulerGamma + uppergamma(0,-z) + log(-z)) + # - exp(z))) + # Manually tuned rule + addb([1, 1, a], [2, 2, a+1], + Matrix([a*(log(-z) + expint(1, -z) + EulerGamma)/(z*(a**2 - 2*a + 1)), + a*(-z)**(-a)*(gamma(a) - uppergamma(a, -z))/(a - 1)**2, + a*exp(z)/(a**2 - 2*a + 1), + a/(z*(a**2 - 2*a + 1))]), + Matrix([[1-a, 1, -1/z, 1]]), + Matrix([[-1,0,-1/z,1], + [0,-a,1,0], + [0,0,z,0], + [0,0,0,-1]])) + + +def add_meijerg_formulae(formulae): + a, b, c, z = list(map(Dummy, 'abcz')) + rho = Dummy('rho') + + def add(an, ap, bm, bq, B, C, M, matcher): + formulae.append(MeijerFormula(an, ap, bm, bq, z, [a, b, c, rho], + B, C, M, matcher)) + + def detect_uppergamma(func): + x = func.an[0] + y, z = func.bm + swapped = False + if not _mod1((x - y).simplify()): + swapped = True + (y, z) = (z, y) + if _mod1((x - z).simplify()) or x - z > 0: + return None + l = [y, x] + if swapped: + l = [x, y] + return {rho: y, a: x - y}, G_Function([x], [], l, []) + + add([a + rho], [], [rho, a + rho], [], + Matrix([gamma(1 - a)*z**rho*exp(z)*uppergamma(a, z), + gamma(1 - a)*z**(a + rho)]), + Matrix([[1, 0]]), + Matrix([[rho + z, -1], [0, a + rho]]), + detect_uppergamma) + + def detect_3113(func): + """https://functions.wolfram.com/07.34.03.0984.01""" + x = func.an[0] + u, v, w = func.bm + if _mod1((u - v).simplify()) == 0: + if _mod1((v - w).simplify()) == 0: + return + sig = (S.Half, S.Half, S.Zero) + x1, x2, y = u, v, w + else: + if _mod1((x - u).simplify()) == 0: + sig = (S.Half, S.Zero, S.Half) + x1, y, x2 = u, v, w + else: + sig = (S.Zero, S.Half, S.Half) + y, x1, x2 = u, v, w + + if (_mod1((x - x1).simplify()) != 0 or + _mod1((x - x2).simplify()) != 0 or + _mod1((x - y).simplify()) != S.Half or + x - x1 > 0 or x - x2 > 0): + return + + return {a: x}, G_Function([x], [], [x - S.Half + t for t in sig], []) + + s = sin(2*sqrt(z)) + c_ = cos(2*sqrt(z)) + S_ = Si(2*sqrt(z)) - pi/2 + C = Ci(2*sqrt(z)) + add([a], [], [a, a, a - S.Half], [], + Matrix([sqrt(pi)*z**(a - S.Half)*(c_*S_ - s*C), + sqrt(pi)*z**a*(s*S_ + c_*C), + sqrt(pi)*z**a]), + Matrix([[-2, 0, 0]]), + Matrix([[a - S.Half, -1, 0], [z, a, S.Half], [0, 0, a]]), + detect_3113) + + +def make_simp(z): + """ Create a function that simplifies rational functions in ``z``. """ + + def simp(expr): + """ Efficiently simplify the rational function ``expr``. """ + numer, denom = expr.as_numer_denom() + numer = numer.expand() + # denom = denom.expand() # is this needed? + c, numer, denom = poly(numer, z).cancel(poly(denom, z)) + return c * numer.as_expr() / denom.as_expr() + + return simp + + +def debug(*args): + if SYMPY_DEBUG: + for a in args: + print(a, end="") + print() + + +class Hyper_Function(Expr): + """ A generalized hypergeometric function. """ + + def __new__(cls, ap, bq): + obj = super().__new__(cls) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.ap, self.bq) + + @property + def sizes(self): + return (len(self.ap), len(self.bq)) + + @property + def gamma(self): + """ + Number of upper parameters that are negative integers + + This is a transformation invariant. + """ + return sum(bool(x.is_integer and x.is_negative) for x in self.ap) + + def _hashable_content(self): + return super()._hashable_content() + (self.ap, + self.bq) + + def __call__(self, arg): + return hyper(self.ap, self.bq, arg) + + def build_invariants(self): + """ + Compute the invariant vector. + + Explanation + =========== + + The invariant vector is: + (gamma, ((s1, n1), ..., (sk, nk)), ((t1, m1), ..., (tr, mr))) + where gamma is the number of integer a < 0, + s1 < ... < sk + nl is the number of parameters a_i congruent to sl mod 1 + t1 < ... < tr + ml is the number of parameters b_i congruent to tl mod 1 + + If the index pair contains parameters, then this is not truly an + invariant, since the parameters cannot be sorted uniquely mod1. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Hyper_Function + >>> from sympy import S + >>> ap = (S.Half, S.One/3, S(-1)/2, -2) + >>> bq = (1, 2) + + Here gamma = 1, + k = 3, s1 = 0, s2 = 1/3, s3 = 1/2 + n1 = 1, n2 = 1, n2 = 2 + r = 1, t1 = 0 + m1 = 2: + + >>> Hyper_Function(ap, bq).build_invariants() + (1, ((0, 1), (1/3, 1), (1/2, 2)), ((0, 2),)) + """ + abuckets, bbuckets = sift(self.ap, _mod1), sift(self.bq, _mod1) + + def tr(bucket): + bucket = list(bucket.items()) + if not any(isinstance(x[0], Mod) for x in bucket): + bucket.sort(key=lambda x: default_sort_key(x[0])) + bucket = tuple([(mod, len(values)) for mod, values in bucket if + values]) + return bucket + + return (self.gamma, tr(abuckets), tr(bbuckets)) + + def difficulty(self, func): + """ Estimate how many steps it takes to reach ``func`` from self. + Return -1 if impossible. """ + if self.gamma != func.gamma: + return -1 + oabuckets, obbuckets, abuckets, bbuckets = [sift(params, _mod1) for + params in (self.ap, self.bq, func.ap, func.bq)] + + diff = 0 + for bucket, obucket in [(abuckets, oabuckets), (bbuckets, obbuckets)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + return -1 + l1 = list(bucket[mod]) + l2 = list(obucket[mod]) + l1.sort() + l2.sort() + for i, j in zip(l1, l2): + diff += abs(i - j) + + return diff + + def _is_suitable_origin(self): + """ + Decide if ``self`` is a suitable origin. + + Explanation + =========== + + A function is a suitable origin iff: + * none of the ai equals bj + n, with n a non-negative integer + * none of the ai is zero + * none of the bj is a non-positive integer + + Note that this gives meaningful results only when none of the indices + are symbolic. + + """ + for a in self.ap: + for b in self.bq: + if (a - b).is_integer and (a - b).is_negative is False: + return False + for a in self.ap: + if a == 0: + return False + for b in self.bq: + if b.is_integer and b.is_nonpositive: + return False + return True + + +class G_Function(Expr): + """ A Meijer G-function. """ + + def __new__(cls, an, ap, bm, bq): + obj = super().__new__(cls) + obj.an = Tuple(*list(map(expand, an))) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bm = Tuple(*list(map(expand, bm))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.an, self.ap, self.bm, self.bq) + + def _hashable_content(self): + return super()._hashable_content() + self.args + + def __call__(self, z): + return meijerg(self.an, self.ap, self.bm, self.bq, z) + + def compute_buckets(self): + """ + Compute buckets for the fours sets of parameters. + + Explanation + =========== + + We guarantee that any two equal Mod objects returned are actually the + same, and that the buckets are sorted by real part (an and bq + descendending, bm and ap ascending). + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import G_Function + >>> from sympy.abc import y + >>> from sympy import S + + >>> a, b = [1, 3, 2, S(3)/2], [1 + y, y, 2, y + 3] + >>> G_Function(a, b, [2], [y]).compute_buckets() + ({0: [3, 2, 1], 1/2: [3/2]}, + {0: [2], y: [y, y + 1, y + 3]}, {0: [2]}, {y: [y]}) + + """ + dicts = pan, pap, pbm, pbq = [defaultdict(list) for i in range(4)] + for dic, lis in zip(dicts, (self.an, self.ap, self.bm, self.bq)): + for x in lis: + dic[_mod1(x)].append(x) + + for dic, flip in zip(dicts, (True, False, False, True)): + for m, items in dic.items(): + x0 = items[0] + items.sort(key=lambda x: x - x0, reverse=flip) + dic[m] = items + + return tuple([dict(w) for w in dicts]) + + @property + def signature(self): + return (len(self.an), len(self.ap), len(self.bm), len(self.bq)) + + +# Dummy variable. +_x = Dummy('x') + +class Formula: + """ + This class represents hypergeometric formulae. + + Explanation + =========== + + Its data members are: + - z, the argument + - closed_form, the closed form expression + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (see _compute_basis) + + Examples + ======== + + >>> from sympy.abc import a, b, z + >>> from sympy.simplify.hyperexpand import Formula, Hyper_Function + >>> func = Hyper_Function((a/2, a/3 + b, (1+a)/2), (a, b, (a+b)/7)) + >>> f = Formula(func, z, None, [a, b]) + + """ + + def _compute_basis(self, closed_form): + """ + Compute a set of functions B=(f1, ..., fn), a nxn matrix M + and a 1xn matrix C such that: + closed_form = C B + z d/dz B = M B. + """ + afactors = [_x + a for a in self.func.ap] + bfactors = [_x + b - 1 for b in self.func.bq] + expr = _x*Mul(*bfactors) - self.z*Mul(*afactors) + poly = Poly(expr, _x) + + n = poly.degree() - 1 + b = [closed_form] + for _ in range(n): + b.append(self.z*b[-1].diff(self.z)) + + self.B = Matrix(b) + self.C = Matrix([[1] + [0]*n]) + + m = eye(n) + m = m.col_insert(0, zeros(n, 1)) + l = poly.all_coeffs()[1:] + l.reverse() + self.M = m.row_insert(n, -Matrix([l])/poly.all_coeffs()[0]) + + def __init__(self, func, z, res, symbols, B=None, C=None, M=None): + z = sympify(z) + res = sympify(res) + symbols = [x for x in sympify(symbols) if func.has(x)] + + self.z = z + self.symbols = symbols + self.B = B + self.C = C + self.M = M + self.func = func + + # TODO with symbolic parameters, it could be advantageous + # (for prettier answers) to compute a basis only *after* + # instantiation + if res is not None: + self._compute_basis(res) + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def find_instantiations(self, func): + """ + Find substitutions of the free symbols that match ``func``. + + Return the substitution dictionaries as a list. Note that the returned + instantiations need not actually match, or be valid! + + """ + from sympy.solvers import solve + ap = func.ap + bq = func.bq + if len(ap) != len(self.func.ap) or len(bq) != len(self.func.bq): + raise TypeError('Cannot instantiate other number of parameters') + symbol_values = [] + for a in self.symbols: + if a in self.func.ap.args: + symbol_values.append(ap) + elif a in self.func.bq.args: + symbol_values.append(bq) + else: + raise ValueError("At least one of the parameters of the " + "formula must be equal to %s" % (a,)) + base_repl = [dict(list(zip(self.symbols, values))) + for values in product(*symbol_values)] + abuckets, bbuckets = [sift(params, _mod1) for params in [ap, bq]] + a_inv, b_inv = [{a: len(vals) for a, vals in bucket.items()} + for bucket in [abuckets, bbuckets]] + critical_values = [[0] for _ in self.symbols] + result = [] + _n = Dummy() + for repl in base_repl: + symb_a, symb_b = [sift(params, lambda x: _mod1(x.xreplace(repl))) + for params in [self.func.ap, self.func.bq]] + for bucket, obucket in [(abuckets, symb_a), (bbuckets, symb_b)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + break + for a, vals in zip(self.symbols, critical_values): + if repl[a].free_symbols: + continue + exprs = [expr for expr in obucket[mod] if expr.has(a)] + repl0 = repl.copy() + repl0[a] += _n + for expr in exprs: + for target in bucket[mod]: + n0, = solve(expr.xreplace(repl0) - target, _n) + if n0.free_symbols: + raise ValueError("Value should not be true") + vals.append(n0) + else: + values = [] + for a, vals in zip(self.symbols, critical_values): + a0 = repl[a] + min_ = floor(min(vals)) + max_ = ceiling(max(vals)) + values.append([a0 + n for n in range(min_, max_ + 1)]) + result.extend(dict(list(zip(self.symbols, l))) for l in product(*values)) + return result + + + + +class FormulaCollection: + """ A collection of formulae to use as origins. """ + + def __init__(self): + """ Doing this globally at module init time is a pain ... """ + self.symbolic_formulae = {} + self.concrete_formulae = {} + self.formulae = [] + + add_formulae(self.formulae) + + # Now process the formulae into a helpful form. + # These dicts are indexed by (p, q). + + for f in self.formulae: + sizes = f.func.sizes + if len(f.symbols) > 0: + self.symbolic_formulae.setdefault(sizes, []).append(f) + else: + inv = f.func.build_invariants() + self.concrete_formulae.setdefault(sizes, {})[inv] = f + + def lookup_origin(self, func): + """ + Given the suitable target ``func``, try to find an origin in our + knowledge base. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (FormulaCollection, + ... Hyper_Function) + >>> f = FormulaCollection() + >>> f.lookup_origin(Hyper_Function((), ())).closed_form + exp(_z) + >>> f.lookup_origin(Hyper_Function([1], ())).closed_form + HyperRep_power1(-1, _z) + + >>> from sympy import S + >>> i = Hyper_Function([S('1/4'), S('3/4 + 4')], [S.Half]) + >>> f.lookup_origin(i).closed_form + HyperRep_sqrts1(-1/4, _z) + """ + inv = func.build_invariants() + sizes = func.sizes + if sizes in self.concrete_formulae and \ + inv in self.concrete_formulae[sizes]: + return self.concrete_formulae[sizes][inv] + + # We don't have a concrete formula. Try to instantiate. + if sizes not in self.symbolic_formulae: + return None # Too bad... + + possible = [] + for f in self.symbolic_formulae[sizes]: + repls = f.find_instantiations(func) + for repl in repls: + func2 = f.func.xreplace(repl) + if not func2._is_suitable_origin(): + continue + diff = func2.difficulty(func) + if diff == -1: + continue + possible.append((diff, repl, f, func2)) + + # find the nearest origin + possible.sort(key=lambda x: x[0]) + for _, repl, f, func2 in possible: + f2 = Formula(func2, f.z, None, [], f.B.subs(repl), + f.C.subs(repl), f.M.subs(repl)) + if not any(e.has(S.NaN, oo, -oo, zoo) for e in [f2.B, f2.M, f2.C]): + return f2 + + return None + + +class MeijerFormula: + """ + This class represents a Meijer G-function formula. + + Its data members are: + - z, the argument + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (c/f ordinary Formula) + """ + + def __init__(self, an, ap, bm, bq, z, symbols, B, C, M, matcher): + an, ap, bm, bq = [Tuple(*list(map(expand, w))) for w in [an, ap, bm, bq]] + self.func = G_Function(an, ap, bm, bq) + self.z = z + self.symbols = symbols + self._matcher = matcher + self.B = B + self.C = C + self.M = M + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def try_instantiate(self, func): + """ + Try to instantiate the current formula to (almost) match func. + This uses the _matcher passed on init. + """ + if func.signature != self.func.signature: + return None + res = self._matcher(func) + if res is not None: + subs, newfunc = res + return MeijerFormula(newfunc.an, newfunc.ap, newfunc.bm, newfunc.bq, + self.z, [], + self.B.subs(subs), self.C.subs(subs), + self.M.subs(subs), None) + + +class MeijerFormulaCollection: + """ + This class holds a collection of meijer g formulae. + """ + + def __init__(self): + formulae = [] + add_meijerg_formulae(formulae) + self.formulae = defaultdict(list) + for formula in formulae: + self.formulae[formula.func.signature].append(formula) + self.formulae = dict(self.formulae) + + def lookup_origin(self, func): + """ Try to find a formula that matches func. """ + if func.signature not in self.formulae: + return None + for formula in self.formulae[func.signature]: + res = formula.try_instantiate(func) + if res is not None: + return res + + +class Operator: + """ + Base class for operators to be applied to our functions. + + Explanation + =========== + + These operators are differential operators. They are by convention + expressed in the variable D = z*d/dz (although this base class does + not actually care). + Note that when the operator is applied to an object, we typically do + *not* blindly differentiate but instead use a different representation + of the z*d/dz operator (see make_derivative_operator). + + To subclass from this, define a __init__ method that initializes a + self._poly variable. This variable stores a polynomial. By convention + the generator is z*d/dz, and acts to the right of all coefficients. + + Thus this poly + x**2 + 2*z*x + 1 + represents the differential operator + (z*d/dz)**2 + 2*z**2*d/dz. + + This class is used only in the implementation of the hypergeometric + function expansion algorithm. + """ + + def apply(self, obj, op): + """ + Apply ``self`` to the object ``obj``, where the generator is ``op``. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Operator + >>> from sympy.polys.polytools import Poly + >>> from sympy.abc import x, y, z + >>> op = Operator() + >>> op._poly = Poly(x**2 + z*x + y, x) + >>> op.apply(z**7, lambda f: f.diff(z)) + y*z**7 + 7*z**7 + 42*z**5 + """ + coeffs = self._poly.all_coeffs() + coeffs.reverse() + diffs = [obj] + for c in coeffs[1:]: + diffs.append(op(diffs[-1])) + r = coeffs[0]*diffs[0] + for c, d in zip(coeffs[1:], diffs[1:]): + r += c*d + return r + + +class MultOperator(Operator): + """ Simply multiply by a "constant" """ + + def __init__(self, p): + self._poly = Poly(p, _x) + + +class ShiftA(Operator): + """ Increment an upper index. """ + + def __init__(self, ai): + ai = sympify(ai) + if ai == 0: + raise ValueError('Cannot increment zero upper index.') + self._poly = Poly(_x/ai + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0]) + + +class ShiftB(Operator): + """ Decrement a lower index. """ + + def __init__(self, bi): + bi = sympify(bi) + if bi == 1: + raise ValueError('Cannot decrement unit lower index.') + self._poly = Poly(_x/(bi - 1) + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0] + 1) + + +class UnShiftA(Operator): + """ Decrement an upper index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + ai = ap.pop(i) - 1 + + if ai == 0: + raise ValueError('Cannot decrement unit upper index.') + + m = Poly(z*ai, _x) + for a in ap: + m *= Poly(_x + a, _x) + + A = Dummy('A') + n = D = Poly(ai*A - ai, A) + for b in bq: + n *= D + (b - 1).as_poly(A) + + b0 = -n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper index: ' + 'cancels with lower') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, _x/ai + 1), _x) + + self._poly = Poly((n - m)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class UnShiftB(Operator): + """ Increment a lower index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + bi = bq.pop(i) + 1 + + if bi == 0: + raise ValueError('Cannot increment -1 lower index.') + + m = Poly(_x*(bi - 1), _x) + for b in bq: + m *= Poly(_x + b - 1, _x) + + B = Dummy('B') + D = Poly((bi - 1)*B - bi + 1, B) + n = Poly(z, B) + for a in ap: + n *= (D + a.as_poly(B)) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment index: cancels with upper') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, _x/(bi - 1) + 1), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class MeijerShiftA(Operator): + """ Increment an upper b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1]) + + +class MeijerShiftB(Operator): + """ Decrement an upper a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(1 - bi + _x, _x) + + def __str__(self): + return '' % (1 - self._poly.all_coeffs()[1]) + + +class MeijerShiftC(Operator): + """ Increment a lower b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(-bi + _x, _x) + + def __str__(self): + return '' % (-self._poly.all_coeffs()[1]) + + +class MeijerShiftD(Operator): + """ Decrement a lower a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - 1 - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1] + 1) + + +class MeijerUnShiftA(Operator): + """ Decrement an upper b index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bm.pop(i) - 1 + + m = Poly(1, _x) * prod(Poly(b - _x, _x) for b in bm) * prod(Poly(_x - b, _x) for b in bq) + + A = Dummy('A') + D = Poly(bi - A, A) + n = Poly(z, A) * prod((D + 1 - a) for a in an) * prod((-D + a - 1) for a in ap) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, bi - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftB(Operator): + """ Increment an upper a index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = an.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') + D = Poly(B + ai - 1, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment upper a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, 1 - ai + _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftC(Operator): + """ Decrement a lower b index. """ + # XXX this is "essentially" the same as MeijerUnShiftA. This "essentially" + # can be made rigorous using the functional equation G(1/z) = G'(z), + # where G' denotes a G function of slightly altered parameters. + # However, sorting out the details seems harder than just coding it + # again. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bq.pop(i) - 1 + + m = Poly(1, _x) + for b in bm: + m *= Poly(b - _x, _x) + for b in bq: + m *= Poly(_x - b, _x) + + C = Dummy('C') + D = Poly(bi + C, C) + n = Poly(z, C) + for a in an: + n *= (D + 1 - a) + for a in ap: + n *= (-D + a - 1) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement lower b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], C).as_expr().subs(C, _x - bi), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftD(Operator): + """ Increment a lower a index. """ + # XXX This is essentially the same as MeijerUnShiftA. + # See comment at MeijerUnShiftC. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = ap.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') # - this is the shift operator `D_I` + D = Poly(ai - 1 - B, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment lower a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, ai - 1 - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class ReduceOrder(Operator): + """ Reduce Order by cancelling an upper and a lower index. """ + + def __new__(cls, ai, bj): + """ For convenience if reduction is not possible, return None. """ + ai = sympify(ai) + bj = sympify(bj) + n = ai - bj + if not n.is_Integer or n < 0: + return None + if bj.is_integer and bj.is_nonpositive: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (_x + bj + k)/(bj + k) + + expr._poly = Poly(p, _x) + expr._a = ai + expr._b = bj + + return expr + + @classmethod + def _meijer(cls, b, a, sign): + """ Cancel b + sign*s and a + sign*s + This is for meijer G functions. """ + b = sympify(b) + a = sympify(a) + n = b - a + if n.is_negative or not n.is_Integer: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (sign*_x + a + k) + + expr._poly = Poly(p, _x) + if sign == -1: + expr._a = b + expr._b = a + else: + expr._b = Add(1, a - 1, evaluate=False) + expr._a = Add(1, b - 1, evaluate=False) + + return expr + + @classmethod + def meijer_minus(cls, b, a): + return cls._meijer(b, a, -1) + + @classmethod + def meijer_plus(cls, a, b): + return cls._meijer(1 - a, 1 - b, 1) + + def __str__(self): + return '' % \ + (self._a, self._b) + + +def _reduce_order(ap, bq, gen, key): + """ Order reduction algorithm used in Hypergeometric and Meijer G """ + ap = list(ap) + bq = list(bq) + + ap.sort(key=key) + bq.sort(key=key) + + nap = [] + # we will edit bq in place + operators = [] + for a in ap: + op = None + for i in range(len(bq)): + op = gen(a, bq[i]) + if op is not None: + bq.pop(i) + break + if op is None: + nap.append(a) + else: + operators.append(op) + + return nap, bq, operators + + +def reduce_order(func): + """ + Given the hypergeometric function ``func``, find a sequence of operators to + reduces order as much as possible. + + Explanation + =========== + + Return (newfunc, [operators]), where applying the operators to the + hypergeometric function newfunc yields func. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import reduce_order, Hyper_Function + >>> reduce_order(Hyper_Function((1, 2), (3, 4))) + (Hyper_Function((1, 2), (3, 4)), []) + >>> reduce_order(Hyper_Function((1,), (1,))) + (Hyper_Function((), ()), []) + >>> reduce_order(Hyper_Function((2, 4), (3, 3))) + (Hyper_Function((2,), (3,)), []) + """ + nap, nbq, operators = _reduce_order(func.ap, func.bq, ReduceOrder, default_sort_key) + + return Hyper_Function(Tuple(*nap), Tuple(*nbq)), operators + + +def reduce_order_meijer(func): + """ + Given the Meijer G function parameters, ``func``, find a sequence of + operators that reduces order as much as possible. + + Return newfunc, [operators]. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (reduce_order_meijer, + ... G_Function) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 2]))[0] + G_Function((4, 3), (5, 6), (3, 4), (2, 1)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 8]))[0] + G_Function((3,), (5, 6), (3, 4), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [1, 5]))[0] + G_Function((3,), (), (), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [5, 3]))[0] + G_Function((), (), (), ()) + """ + + nan, nbq, ops1 = _reduce_order(func.an, func.bq, ReduceOrder.meijer_plus, + lambda x: default_sort_key(-x)) + nbm, nap, ops2 = _reduce_order(func.bm, func.ap, ReduceOrder.meijer_minus, + default_sort_key) + + return G_Function(nan, nap, nbm, nbq), ops1 + ops2 + + +def make_derivative_operator(M, z): + """ Create a derivative operator, to be passed to Operator.apply. """ + def doit(C): + r = z*C.diff(z) + C*M + r = r.applyfunc(make_simp(z)) + return r + return doit + + +def apply_operators(obj, ops, op): + """ + Apply the list of operators ``ops`` to object ``obj``, substituting + ``op`` for the generator. + """ + res = obj + for o in reversed(ops): + res = o.apply(res, op) + return res + + +def devise_plan(target, origin, z): + """ + Devise a plan (consisting of shift and un-shift operators) to be applied + to the hypergeometric function ``target`` to yield ``origin``. + Returns a list of operators. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import devise_plan, Hyper_Function + >>> from sympy.abc import z + + Nothing to do: + + >>> devise_plan(Hyper_Function((1, 2), ()), Hyper_Function((1, 2), ()), z) + [] + >>> devise_plan(Hyper_Function((), (1, 2)), Hyper_Function((), (1, 2)), z) + [] + + Very simple plans: + + >>> devise_plan(Hyper_Function((2,), ()), Hyper_Function((1,), ()), z) + [] + >>> devise_plan(Hyper_Function((), (2,)), Hyper_Function((), (1,)), z) + [] + + Several buckets: + + >>> from sympy import S + >>> devise_plan(Hyper_Function((1, S.Half), ()), + ... Hyper_Function((2, S('3/2')), ()), z) #doctest: +NORMALIZE_WHITESPACE + [, + ] + + A slightly more complicated plan: + + >>> devise_plan(Hyper_Function((1, 3), ()), Hyper_Function((2, 2), ()), z) + [, ] + + Another more complicated plan: (note that the ap have to be shifted first!) + + >>> devise_plan(Hyper_Function((1, -1), (2,)), Hyper_Function((3, -2), (4,)), z) + [, , + , + , ] + """ + abuckets, bbuckets, nabuckets, nbbuckets = [sift(params, _mod1) for + params in (target.ap, target.bq, origin.ap, origin.bq)] + + if len(list(abuckets.keys())) != len(list(nabuckets.keys())) or \ + len(list(bbuckets.keys())) != len(list(nbbuckets.keys())): + raise ValueError('%s not reachable from %s' % (target, origin)) + + ops = [] + + def do_shifts(fro, to, inc, dec): + ops = [] + for i in range(len(fro)): + if to[i] - fro[i] > 0: + sh = inc + ch = 1 + else: + sh = dec + ch = -1 + + while to[i] != fro[i]: + ops += [sh(fro, i)] + fro[i] += ch + + return ops + + def do_shifts_a(nal, nbk, al, aother, bother): + """ Shift us from (nal, nbk) to (al, nbk). """ + return do_shifts(nal, al, lambda p, i: ShiftA(p[i]), + lambda p, i: UnShiftA(p + aother, nbk + bother, i, z)) + + def do_shifts_b(nal, nbk, bk, aother, bother): + """ Shift us from (nal, nbk) to (nal, bk). """ + return do_shifts(nbk, bk, + lambda p, i: UnShiftB(nal + aother, p + bother, i, z), + lambda p, i: ShiftB(p[i])) + + for r in sorted(list(abuckets.keys()) + list(bbuckets.keys()), key=default_sort_key): + al = () + nal = () + bk = () + nbk = () + if r in abuckets: + al = abuckets[r] + nal = nabuckets[r] + if r in bbuckets: + bk = bbuckets[r] + nbk = nbbuckets[r] + if len(al) != len(nal) or len(bk) != len(nbk): + raise ValueError('%s not reachable from %s' % (target, origin)) + + al, nal, bk, nbk = [sorted(w, key=default_sort_key) + for w in [al, nal, bk, nbk]] + + def others(dic, key): + l = [] + for k in dic: + if k != key: + l.extend(dic[k]) + return l + aother = others(nabuckets, r) + bother = others(nbbuckets, r) + + if len(al) == 0: + # there can be no complications, just shift the bs as we please + ops += do_shifts_b([], nbk, bk, aother, bother) + elif len(bk) == 0: + # there can be no complications, just shift the as as we please + ops += do_shifts_a(nal, [], al, aother, bother) + else: + namax = nal[-1] + amax = al[-1] + + if nbk[0] - namax <= 0 or bk[0] - amax <= 0: + raise ValueError('Non-suitable parameters.') + + if namax - amax > 0: + # we are going to shift down - first do the as, then the bs + ops += do_shifts_a(nal, nbk, al, aother, bother) + ops += do_shifts_b(al, nbk, bk, aother, bother) + else: + # we are going to shift up - first do the bs, then the as + ops += do_shifts_b(nal, nbk, bk, aother, bother) + ops += do_shifts_a(nal, bk, al, aother, bother) + + nabuckets[r] = al + nbbuckets[r] = bk + + ops.reverse() + return ops + + +def try_shifted_sum(func, z): + """ Try to recognise a hypergeometric sum that starts from k > 0. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + if len(abuckets[S.Zero]) != 1: + return None + r = abuckets[S.Zero][0] + if r <= 0: + return None + if S.Zero not in bbuckets: + return None + l = list(bbuckets[S.Zero]) + l.sort() + k = l[0] + if k <= 0: + return None + + nap = list(func.ap) + nap.remove(r) + nbq = list(func.bq) + nbq.remove(k) + k -= 1 + nap = [x - k for x in nap] + nbq = [x - k for x in nbq] + + ops = [] + for n in range(r - 1): + ops.append(ShiftA(n + 1)) + ops.reverse() + + fac = factorial(k)/z**k + fac *= Mul(*[rf(b, k) for b in nbq]) + fac /= Mul(*[rf(a, k) for a in nap]) + + ops += [MultOperator(fac)] + + p = 0 + for n in range(k): + m = z**n/factorial(n) + m *= Mul(*[rf(a, n) for a in nap]) + m /= Mul(*[rf(b, n) for b in nbq]) + p += m + + return Hyper_Function(nap, nbq), ops, -p + + +def try_polynomial(func, z): + """ Recognise polynomial cases. Returns None if not such a case. + Requires order to be fully reduced. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + a0 = abuckets[S.Zero] + b0 = bbuckets[S.Zero] + a0.sort() + b0.sort() + al0 = [x for x in a0 if x <= 0] + bl0 = [x for x in b0 if x <= 0] + + if bl0 and all(a < bl0[-1] for a in al0): + return oo + if not al0: + return None + + a = al0[-1] + fac = 1 + res = S.One + for n in Tuple(*list(range(-a))): + fac *= z + fac /= n + 1 + fac *= Mul(*[a + n for a in func.ap]) + fac /= Mul(*[b + n for b in func.bq]) + res += fac + return res + + +def try_lerchphi(func): + """ + Try to find an expression for Hyper_Function ``func`` in terms of Lerch + Transcendents. + + Return None if no such expression can be found. + """ + # This is actually quite simple, and is described in Roach's paper, + # section 18. + # We don't need to implement the reduction to polylog here, this + # is handled by expand_func. + + # First we need to figure out if the summation coefficient is a rational + # function of the summation index, and construct that rational function. + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + + paired = {} + for key, value in abuckets.items(): + if key != 0 and key not in bbuckets: + return None + bvalue = bbuckets[key] + paired[key] = (list(value), list(bvalue)) + bbuckets.pop(key, None) + if bbuckets != {}: + return None + if S.Zero not in abuckets: + return None + aints, bints = paired[S.Zero] + # Account for the additional n! in denominator + paired[S.Zero] = (aints, bints + [1]) + + t = Dummy('t') + numer = S.One + denom = S.One + for key, (avalue, bvalue) in paired.items(): + if len(avalue) != len(bvalue): + return None + # Note that since order has been reduced fully, all the b are + # bigger than all the a they differ from by an integer. In particular + # if there are any negative b left, this function is not well-defined. + for a, b in zip(avalue, bvalue): + if (a - b).is_positive: + k = a - b + numer *= rf(b + t, k) + denom *= rf(b, k) + else: + k = b - a + numer *= rf(a, k) + denom *= rf(a + t, k) + + # Now do a partial fraction decomposition. + # We assemble two structures: a list monomials of pairs (a, b) representing + # a*t**b (b a non-negative integer), and a dict terms, where + # terms[a] = [(b, c)] means that there is a term b/(t-a)**c. + part = apart(numer/denom, t) + args = Add.make_args(part) + monomials = [] + terms = {} + for arg in args: + numer, denom = arg.as_numer_denom() + if not denom.has(t): + p = Poly(numer, t) + if not p.is_monomial: + raise TypeError("p should be monomial") + ((b, ), a) = p.LT() + monomials += [(a/denom, b)] + continue + if numer.has(t): + raise NotImplementedError('Need partial fraction decomposition' + ' with linear denominators') + indep, [dep] = denom.as_coeff_mul(t) + n = 1 + if dep.is_Pow: + n = dep.exp + dep = dep.base + if dep == t: + a = 0 + elif dep.is_Add: + a, tmp = dep.as_independent(t) + b = 1 + if tmp != t: + b, _ = tmp.as_independent(t) + if dep != b*t + a: + raise NotImplementedError('unrecognised form %s' % dep) + a /= b + indep *= b**n + else: + raise NotImplementedError('unrecognised form of partial fraction') + terms.setdefault(a, []).append((numer/indep, n)) + + # Now that we have this information, assemble our formula. All the + # monomials yield rational functions and go into one basis element. + # The terms[a] are related by differentiation. If the largest exponent is + # n, we need lerchphi(z, k, a) for k = 1, 2, ..., n. + # deriv maps a basis to its derivative, expressed as a C(z)-linear + # combination of other basis elements. + deriv = {} + coeffs = {} + z = Dummy('z') + monomials.sort(key=lambda x: x[1]) + mon = {0: 1/(1 - z)} + if monomials: + for k in range(monomials[-1][1]): + mon[k + 1] = z*mon[k].diff(z) + for a, n in monomials: + coeffs.setdefault(S.One, []).append(a*mon[n]) + for a, l in terms.items(): + for c, k in l: + coeffs.setdefault(lerchphi(z, k, a), []).append(c) + l.sort(key=lambda x: x[1]) + for k in range(2, l[-1][1] + 1): + deriv[lerchphi(z, k, a)] = [(-a, lerchphi(z, k, a)), + (1, lerchphi(z, k - 1, a))] + deriv[lerchphi(z, 1, a)] = [(-a, lerchphi(z, 1, a)), + (1/(1 - z), S.One)] + trans = {} + for n, b in enumerate([S.One] + list(deriv.keys())): + trans[b] = n + basis = [expand_func(b) for (b, _) in sorted(trans.items(), + key=lambda x:x[1])] + B = Matrix(basis) + C = Matrix([[0]*len(B)]) + for b, c in coeffs.items(): + C[trans[b]] = Add(*c) + M = zeros(len(B)) + for b, l in deriv.items(): + for c, b2 in l: + M[trans[b], trans[b2]] = c + return Formula(func, z, None, [], B, C, M) + + +def build_hypergeometric_formula(func): + """ + Create a formula object representing the hypergeometric function ``func``. + + """ + # We know that no `ap` are negative integers, otherwise "detect poly" + # would have kicked in. However, `ap` could be empty. In this case we can + # use a different basis. + # I'm not aware of a basis that works in all cases. + z = Dummy('z') + if func.ap: + afactors = [_x + a for a in func.ap] + bfactors = [_x + b - 1 for b in func.bq] + expr = _x*Mul(*bfactors) - z*Mul(*afactors) + poly = Poly(expr, _x) + n = poly.degree() + basis = [] + M = zeros(n) + for k in range(n): + a = func.ap[0] + k + basis += [hyper([a] + list(func.ap[1:]), func.bq, z)] + if k < n - 1: + M[k, k] = -a + M[k, k + 1] = a + B = Matrix(basis) + C = Matrix([[1] + [0]*(n - 1)]) + derivs = [eye(n)] + for k in range(n): + derivs.append(M*derivs[k]) + l = poly.all_coeffs() + l.reverse() + res = [0]*n + for k, c in enumerate(l): + for r, d in enumerate(C*derivs[k]): + res[r] += c*d + for k, c in enumerate(res): + M[n - 1, k] = -c/derivs[n - 1][0, n - 1]/poly.all_coeffs()[0] + return Formula(func, z, None, [], B, C, M) + else: + # Since there are no `ap`, none of the `bq` can be non-positive + # integers. + basis = [] + bq = list(func.bq[:]) + for i in range(len(bq)): + basis += [hyper([], bq, z)] + bq[i] += 1 + basis += [hyper([], bq, z)] + B = Matrix(basis) + n = len(B) + C = Matrix([[1] + [0]*(n - 1)]) + M = zeros(n) + M[0, n - 1] = z/Mul(*func.bq) + for k in range(1, n): + M[k, k - 1] = func.bq[k - 1] + M[k, k] = -func.bq[k - 1] + return Formula(func, z, None, [], B, C, M) + + +def hyperexpand_special(ap, bq, z): + """ + Try to find a closed-form expression for hyper(ap, bq, z), where ``z`` + is supposed to be a "special" value, e.g. 1. + + This function tries various of the classical summation formulae + (Gauss, Saalschuetz, etc). + """ + # This code is very ad-hoc. There are many clever algorithms + # (notably Zeilberger's) related to this problem. + # For now we just want a few simple cases to work. + p, q = len(ap), len(bq) + z_ = z + z = unpolarify(z) + if z == 0: + return S.One + from sympy.simplify.simplify import simplify + if p == 2 and q == 1: + # 2F1 + a, b, c = ap + bq + if z == 1: + # Gauss + return gamma(c - a - b)*gamma(c)/gamma(c - a)/gamma(c - b) + if z == -1 and simplify(b - a + c) == 1: + b, a = a, b + if z == -1 and simplify(a - b + c) == 1: + # Kummer + if b.is_integer and b.is_negative: + return 2*cos(pi*b/2)*gamma(-b)*gamma(b - a + 1) \ + /gamma(-b/2)/gamma(b/2 - a + 1) + else: + return gamma(b/2 + 1)*gamma(b - a + 1) \ + /gamma(b + 1)/gamma(b/2 - a + 1) + # TODO tons of more formulae + # investigate what algorithms exist + return hyper(ap, bq, z_) + +_collection = None + + +def _hyperexpand(func, z, ops0=[], z0=Dummy('z0'), premult=1, prem=0, + rewrite='default'): + """ + Try to find an expression for the hypergeometric function ``func``. + + Explanation + =========== + + The result is expressed in terms of a dummy variable ``z0``. Then it + is multiplied by ``premult``. Then ``ops0`` is applied. + ``premult`` must be a*z**prem for some a independent of ``z``. + """ + + if z.is_zero: + return S.One + + from sympy.simplify.simplify import simplify + + z = polarify(z, subs=False) + if rewrite == 'default': + rewrite = 'nonrepsmall' + + def carryout_plan(f, ops): + C = apply_operators(f.C.subs(f.z, z0), ops, + make_derivative_operator(f.M.subs(f.z, z0), z0)) + C = apply_operators(C, ops0, + make_derivative_operator(f.M.subs(f.z, z0) + + prem*eye(f.M.shape[0]), z0)) + + if premult == 1: + C = C.applyfunc(make_simp(z0)) + r = reduce(lambda s,m: s+m[0]*m[1], zip(C, f.B.subs(f.z, z0)), S.Zero)*premult + res = r.subs(z0, z) + if rewrite: + res = res.rewrite(rewrite) + return res + + # TODO + # The following would be possible: + # *) PFD Duplication (see Kelly Roach's paper) + # *) In a similar spirit, try_lerchphi() can be generalised considerably. + + global _collection + if _collection is None: + _collection = FormulaCollection() + + debug('Trying to expand hypergeometric function ', func) + + # First reduce order as much as possible. + func, ops = reduce_order(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Now try polynomial cases + res = try_polynomial(func, z0) + if res is not None: + debug(' Recognised polynomial.') + p = apply_operators(res, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + return unpolarify(simplify(p).subs(z0, z)) + + # Try to recognise a shifted sum. + p = S.Zero + res = try_shifted_sum(func, z0) + if res is not None: + func, nops, p = res + debug(' Recognised shifted sum, reduced order to ', func) + ops += nops + + # apply the plan for poly + p = apply_operators(p, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + p = simplify(p).subs(z0, z) + + # Try special expansions early. + if unpolarify(z) in [1, -1] and (len(func.ap), len(func.bq)) == (2, 1): + f = build_hypergeometric_formula(func) + r = carryout_plan(f, ops).replace(hyper, hyperexpand_special) + if not r.has(hyper): + return r + p + + # Try to find a formula in our collection + formula = _collection.lookup_origin(func) + + # Now try a lerch phi formula + if formula is None: + formula = try_lerchphi(func) + + if formula is None: + debug(' Could not find an origin. ', + 'Will return answer in terms of ' + 'simpler hypergeometric functions.') + formula = build_hypergeometric_formula(func) + + debug(' Found an origin: ', formula.closed_form, ' ', formula.func) + + # We need to find the operators that convert formula into func. + ops += devise_plan(func, formula.func, z0) + + # Now carry out the plan. + r = carryout_plan(formula, ops) + p + + return powdenest(r, polar=True).replace(hyper, hyperexpand_special) + + +def devise_plan_meijer(fro, to, z): + """ + Find operators to convert G-function ``fro`` into G-function ``to``. + + Explanation + =========== + + It is assumed that ``fro`` and ``to`` have the same signatures, and that in fact + any corresponding pair of parameters differs by integers, and a direct path + is possible. I.e. if there are parameters a1 b1 c1 and a2 b2 c2 it is + assumed that a1 can be shifted to a2, etc. The only thing this routine + determines is the order of shifts to apply, nothing clever will be tried. + It is also assumed that ``fro`` is suitable. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (devise_plan_meijer, + ... G_Function) + >>> from sympy.abc import z + + Empty plan: + + >>> devise_plan_meijer(G_Function([1], [2], [3], [4]), + ... G_Function([1], [2], [3], [4]), z) + [] + + Very simple plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([-1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([], [1], [], []), + ... G_Function([], [2], [], []), z) + [] + + Slightly more complicated plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([2], [], [], []), z) + [, + ] + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([-1], [], [1], []), z) + [, ] + + Order matters: + + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([1], [], [1], []), z) + [, ] + """ + # TODO for now, we use the following simple heuristic: inverse-shift + # when possible, shift otherwise. Give up if we cannot make progress. + + def try_shift(f, t, shifter, diff, counter): + """ Try to apply ``shifter`` in order to bring some element in ``f`` + nearer to its counterpart in ``to``. ``diff`` is +/- 1 and + determines the effect of ``shifter``. Counter is a list of elements + blocking the shift. + + Return an operator if change was possible, else None. + """ + for idx, (a, b) in enumerate(zip(f, t)): + if ( + (a - b).is_integer and (b - a)/diff > 0 and + all(a != x for x in counter)): + sh = shifter(idx) + f[idx] += diff + return sh + fan = list(fro.an) + fap = list(fro.ap) + fbm = list(fro.bm) + fbq = list(fro.bq) + ops = [] + change = True + while change: + change = False + op = try_shift(fan, to.an, + lambda i: MeijerUnShiftB(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, + lambda i: MeijerUnShiftD(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, + lambda i: MeijerUnShiftA(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, + lambda i: MeijerUnShiftC(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fan, to.an, lambda i: MeijerShiftB(fan[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, lambda i: MeijerShiftD(fap[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, lambda i: MeijerShiftA(fbm[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, lambda i: MeijerShiftC(fbq[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + if fan != list(to.an) or fap != list(to.ap) or fbm != list(to.bm) or \ + fbq != list(to.bq): + raise NotImplementedError('Could not devise plan.') + ops.reverse() + return ops + +_meijercollection = None + + +def _meijergexpand(func, z0, allow_hyper=False, rewrite='default', + place=None): + """ + Try to find an expression for the Meijer G function specified + by the G_Function ``func``. If ``allow_hyper`` is True, then returning + an expression in terms of hypergeometric functions is allowed. + + Currently this just does Slater's theorem. + If expansions exist both at zero and at infinity, ``place`` + can be set to ``0`` or ``zoo`` for the preferred choice. + """ + global _meijercollection + if _meijercollection is None: + _meijercollection = MeijerFormulaCollection() + if rewrite == 'default': + rewrite = None + + func0 = func + debug('Try to expand Meijer G function corresponding to ', func) + + # We will play games with analytic continuation - rather use a fresh symbol + z = Dummy('z') + + func, ops = reduce_order_meijer(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Try to find a direct formula + f = _meijercollection.lookup_origin(func) + if f is not None: + debug(' Found a Meijer G formula: ', f.func) + ops += devise_plan_meijer(f.func, func, z) + + # Now carry out the plan. + C = apply_operators(f.C.subs(f.z, z), ops, + make_derivative_operator(f.M.subs(f.z, z), z)) + + C = C.applyfunc(make_simp(z)) + r = C*f.B.subs(f.z, z) + r = r[0].subs(z, z0) + return powdenest(r, polar=True) + + debug(" Could not find a direct formula. Trying Slater's theorem.") + + # TODO the following would be possible: + # *) Paired Index Theorems + # *) PFD Duplication + # (See Kelly Roach's paper for details on either.) + # + # TODO Also, we tend to create combinations of gamma functions that can be + # simplified. + + def can_do(pbm, pap): + """ Test if slater applies. """ + for i in pbm: + if len(pbm[i]) > 1: + l = 0 + if i in pap: + l = len(pap[i]) + if l + 1 < len(pbm[i]): + return False + return True + + def do_slater(an, bm, ap, bq, z, zfinal): + # zfinal is the value that will eventually be substituted for z. + # We pass it to _hyperexpand to improve performance. + func = G_Function(an, bm, ap, bq) + _, pbm, pap, _ = func.compute_buckets() + if not can_do(pbm, pap): + return S.Zero, False + + cond = len(an) + len(ap) < len(bm) + len(bq) + if len(an) + len(ap) == len(bm) + len(bq): + cond = abs(z) < 1 + if cond is False: + return S.Zero, False + + res = S.Zero + for m in pbm: + if len(pbm[m]) == 1: + bh = pbm[m][0] + fac = 1 + bo = list(bm) + bo.remove(bh) + for bj in bo: + fac *= gamma(bj - bh) + for aj in an: + fac *= gamma(1 + bh - aj) + for bj in bq: + fac /= gamma(1 + bh - bj) + for aj in ap: + fac /= gamma(aj - bh) + nap = [1 + bh - a for a in list(an) + list(ap)] + nbq = [1 + bh - b for b in list(bo) + list(bq)] + + k = polar_lift(S.NegativeOne**(len(ap) - len(bm))) + harg = k*zfinal + # NOTE even though k "is" +-1, this has to be t/k instead of + # t*k ... we are using polar numbers for consistency! + premult = (t/k)**bh + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, bh, rewrite=None) + res += fac * hyp + else: + b_ = pbm[m][0] + ki = [bi - b_ for bi in pbm[m][1:]] + u = len(ki) + li = [ai - b_ for ai in pap[m][:u + 1]] + bo = list(bm) + for b in pbm[m]: + bo.remove(b) + ao = list(ap) + for a in pap[m][:u]: + ao.remove(a) + lu = li[-1] + di = [l - k for (l, k) in zip(li, ki)] + + # We first work out the integrand: + s = Dummy('s') + integrand = z**s + for b in bm: + if not Mod(b, 1) and b.is_Number: + b = int(round(b)) + integrand *= gamma(b - s) + for a in an: + integrand *= gamma(1 - a + s) + for b in bq: + integrand /= gamma(1 - b + s) + for a in ap: + integrand /= gamma(a - s) + + # Now sum the finitely many residues: + # XXX This speeds up some cases - is it a good idea? + integrand = expand_func(integrand) + for r in range(int(round(lu))): + resid = residue(integrand, s, b_ + r) + resid = apply_operators(resid, ops, lambda f: z*f.diff(z)) + res -= resid + + # Now the hypergeometric term. + au = b_ + lu + k = polar_lift(S.NegativeOne**(len(ao) + len(bo) + 1)) + harg = k*zfinal + premult = (t/k)**au + nap = [1 + au - a for a in list(an) + list(ap)] + [1] + nbq = [1 + au - b for b in list(bm) + list(bq)] + + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, au, rewrite=None) + + C = S.NegativeOne**(lu)/factorial(lu) + for i in range(u): + C *= S.NegativeOne**di[i]/rf(lu - li[i] + 1, di[i]) + for a in an: + C *= gamma(1 - a + au) + for b in bo: + C *= gamma(b - au) + for a in ao: + C /= gamma(a - au) + for b in bq: + C /= gamma(1 - b + au) + + res += C*hyp + + return res, cond + + t = Dummy('t') + slater1, cond1 = do_slater(func.an, func.bm, func.ap, func.bq, z, z0) + + def tr(l): + return [1 - x for x in l] + + for op in ops: + op._poly = Poly(op._poly.subs({z: 1/t, _x: -_x}), _x) + slater2, cond2 = do_slater(tr(func.bm), tr(func.an), tr(func.bq), tr(func.ap), + t, 1/z0) + + slater1 = powdenest(slater1.subs(z, z0), polar=True) + slater2 = powdenest(slater2.subs(t, 1/z0), polar=True) + if not isinstance(cond2, bool): + cond2 = cond2.subs(t, 1/z) + + m = func(z) + if m.delta > 0 or \ + (m.delta == 0 and len(m.ap) == len(m.bq) and + (re(m.nu) < -1) is not False and polar_lift(z0) == polar_lift(1)): + # The condition delta > 0 means that the convergence region is + # connected. Any expression we find can be continued analytically + # to the entire convergence region. + # The conditions delta==0, p==q, re(nu) < -1 imply that G is continuous + # on the positive reals, so the values at z=1 agree. + if cond1 is not False: + cond1 = True + if cond2 is not False: + cond2 = True + + if cond1 is True: + slater1 = slater1.rewrite(rewrite or 'nonrep') + else: + slater1 = slater1.rewrite(rewrite or 'nonrepsmall') + if cond2 is True: + slater2 = slater2.rewrite(rewrite or 'nonrep') + else: + slater2 = slater2.rewrite(rewrite or 'nonrepsmall') + + if cond1 is not False and cond2 is not False: + # If one condition is False, there is no choice. + if place == 0: + cond2 = False + if place == zoo: + cond1 = False + + if not isinstance(cond1, bool): + cond1 = cond1.subs(z, z0) + if not isinstance(cond2, bool): + cond2 = cond2.subs(z, z0) + + def weight(expr, cond): + if cond is True: + c0 = 0 + elif cond is False: + c0 = 1 + else: + c0 = 2 + if expr.has(oo, zoo, -oo, nan): + # XXX this actually should not happen, but consider + # S('meijerg(((0, -1/2, 0, -1/2, 1/2), ()), ((0,), + # (-1/2, -1/2, -1/2, -1)), exp_polar(I*pi))/4') + c0 = 3 + return (c0, expr.count(hyper), expr.count_ops()) + + w1 = weight(slater1, cond1) + w2 = weight(slater2, cond2) + if min(w1, w2) <= (0, 1, oo): + if w1 < w2: + return slater1 + else: + return slater2 + if max(w1[0], w2[0]) <= 1 and max(w1[1], w2[1]) <= 1: + return Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + + # We couldn't find an expression without hypergeometric functions. + # TODO it would be helpful to give conditions under which the integral + # is known to diverge. + r = Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + if r.has(hyper) and not allow_hyper: + debug(' Could express using hypergeometric functions, ' + 'but not allowed.') + if not r.has(hyper) or allow_hyper: + return r + + return func0(z0) + + +def hyperexpand(f, allow_hyper=False, rewrite='default', place=None): + """ + Expand hypergeometric functions. If allow_hyper is True, allow partial + simplification (that is a result different from input, + but still containing hypergeometric functions). + + If a G-function has expansions both at zero and at infinity, + ``place`` can be set to ``0`` or ``zoo`` to indicate the + preferred choice. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import hyperexpand + >>> from sympy.functions import hyper + >>> from sympy.abc import z + >>> hyperexpand(hyper([], [], z)) + exp(z) + + Non-hyperegeometric parts of the expression and hypergeometric expressions + that are not recognised are left unchanged: + + >>> hyperexpand(1 + hyper([1, 1, 1], [], z)) + hyper((1, 1, 1), (), z) + 1 + """ + f = sympify(f) + + def do_replace(ap, bq, z): + r = _hyperexpand(Hyper_Function(ap, bq), z, rewrite=rewrite) + if r is None: + return hyper(ap, bq, z) + else: + return r + + def do_meijer(ap, bq, z): + r = _meijergexpand(G_Function(ap[0], ap[1], bq[0], bq[1]), z, + allow_hyper, rewrite=rewrite, place=place) + if not r.has(nan, zoo, oo, -oo): + return r + return f.replace(hyper, do_replace).replace(meijerg, do_meijer) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand_doc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand_doc.py new file mode 100644 index 0000000000000000000000000000000000000000..a18377f3aede5214036fbf628825536611001584 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/hyperexpand_doc.py @@ -0,0 +1,18 @@ +""" This module cooks up a docstring when imported. Its only purpose is to + be displayed in the sphinx documentation. """ + +from sympy.core.relational import Eq +from sympy.functions.special.hyper import hyper +from sympy.printing.latex import latex +from sympy.simplify.hyperexpand import FormulaCollection + +c = FormulaCollection() + +doc = "" + +for f in c.formulae: + obj = Eq(hyper(f.func.ap, f.func.bq, f.z), + f.closed_form.rewrite('nonrepsmall')) + doc += ".. math::\n %s\n" % latex(obj) + +__doc__ = doc diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/powsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f72dfeb072e0d0d4737ace310eda5c2a3a082c16 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/powsimp.py @@ -0,0 +1,718 @@ +from collections import defaultdict +from functools import reduce +from math import prod + +from sympy.core.function import expand_log, count_ops, _coeff_isneg +from sympy.core import sympify, Basic, Dummy, S, Add, Mul, Pow, expand_mul, factor_terms +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.numbers import Integer, Rational, equal_valued +from sympy.core.mul import _keep_coeff +from sympy.core.rules import Transform +from sympy.functions import exp_polar, exp, log, root, polarify, unpolarify +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys import lcm, gcd +from sympy.ntheory.factor_ import multiplicity + + + +def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops): + """ + Reduce expression by combining powers with similar bases and exponents. + + Explanation + =========== + + If ``deep`` is ``True`` then powsimp() will also simplify arguments of + functions. By default ``deep`` is set to ``False``. + + If ``force`` is ``True`` then bases will be combined without checking for + assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true + if x and y are both negative. + + You can make powsimp() only combine bases or only combine exponents by + changing combine='base' or combine='exp'. By default, combine='all', + which does both. combine='base' will only combine:: + + a a a 2x x + x * y => (x*y) as well as things like 2 => 4 + + and combine='exp' will only combine + :: + + a b (a + b) + x * x => x + + combine='exp' will strictly only combine exponents in the way that used + to be automatic. Also use deep=True if you need the old behavior. + + When combine='all', 'exp' is evaluated first. Consider the first + example below for when there could be an ambiguity relating to this. + This is done so things like the second example can be completely + combined. If you want 'base' combined first, do something like + powsimp(powsimp(expr, combine='base'), combine='exp'). + + Examples + ======== + + >>> from sympy import powsimp, exp, log, symbols + >>> from sympy.abc import x, y, z, n + >>> powsimp(x**y*x**z*y**z, combine='all') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='exp') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='base', force=True) + x**y*(x*y)**z + + >>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True) + (n*x)**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='exp') + n**(y + z)*x**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True) + (n*x)**y*(n*x)**z + + >>> x, y = symbols('x y', positive=True) + >>> powsimp(log(exp(x)*exp(y))) + log(exp(x)*exp(y)) + >>> powsimp(log(exp(x)*exp(y)), deep=True) + x + y + + Radicals with Mul bases will be combined if combine='exp' + + >>> from sympy import sqrt + >>> x, y = symbols('x y') + + Two radicals are automatically joined through Mul: + + >>> a=sqrt(x*sqrt(y)) + >>> a*a**3 == a**4 + True + + But if an integer power of that radical has been + autoexpanded then Mul does not join the resulting factors: + + >>> a**4 # auto expands to a Mul, no longer a Pow + x**2*y + >>> _*a # so Mul doesn't combine them + x**2*y*sqrt(x*sqrt(y)) + >>> powsimp(_) # but powsimp will + (x*sqrt(y))**(5/2) + >>> powsimp(x*y*a) # but won't when doing so would violate assumptions + x*y*sqrt(x*sqrt(y)) + + """ + def recurse(arg, **kwargs): + _deep = kwargs.get('deep', deep) + _combine = kwargs.get('combine', combine) + _force = kwargs.get('force', force) + _measure = kwargs.get('measure', measure) + return powsimp(arg, _deep, _combine, _force, _measure) + + expr = sympify(expr) + + if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or ( + expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))): + return expr + + if deep or expr.is_Add or expr.is_Mul and _y not in expr.args: + expr = expr.func(*[recurse(w) for w in expr.args]) + + if expr.is_Pow: + return recurse(expr*_y, deep=False)/_y + + if not expr.is_Mul: + return expr + + # handle the Mul + if combine in ('exp', 'all'): + # Collect base/exp data, while maintaining order in the + # non-commutative parts of the product + c_powers = defaultdict(list) + nc_part = [] + newexpr = [] + coeff = S.One + for term in expr.args: + if term.is_Rational: + coeff *= term + continue + if term.is_Pow: + term = _denest_pow(term) + if term.is_commutative: + b, e = term.as_base_exp() + if deep: + b, e = [recurse(i) for i in [b, e]] + if b.is_Pow or isinstance(b, exp): + # don't let smthg like sqrt(x**a) split into x**a, 1/2 + # or else it will be joined as x**(a/2) later + b, e = b**e, S.One + c_powers[b].append(e) + else: + # This is the logic that combines exponents for equal, + # but non-commutative bases: A**x*A**y == A**(x+y). + if nc_part: + b1, e1 = nc_part[-1].as_base_exp() + b2, e2 = term.as_base_exp() + if (b1 == b2 and + e1.is_commutative and e2.is_commutative): + nc_part[-1] = Pow(b1, Add(e1, e2)) + continue + nc_part.append(term) + + # add up exponents of common bases + for b, e in ordered(iter(c_powers.items())): + # allow 2**x/4 -> 2**(x - 2); don't do this when b and e are + # Numbers since autoevaluation will undo it, e.g. + # 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4 + if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \ + coeff is not S.One and + b not in (S.One, S.NegativeOne)): + m = multiplicity(abs(b), abs(coeff)) + if m: + e.append(m) + coeff /= b**m + c_powers[b] = Add(*e) + if coeff is not S.One: + if coeff in c_powers: + c_powers[coeff] += S.One + else: + c_powers[coeff] = S.One + + # convert to plain dictionary + c_powers = dict(c_powers) + + # check for base and inverted base pairs + be = list(c_powers.items()) + skip = set() # skip if we already saw them + for b, e in be: + if b in skip: + continue + bpos = b.is_positive or b.is_polar + if bpos: + binv = 1/b + #Special case for float 1 + if b.is_Float and equal_valued(b, 1): + c_powers[b] = S.One + continue + if b != binv and binv in c_powers: + if b.as_numer_denom()[0] is S.One: + c_powers.pop(b) + c_powers[binv] -= e + else: + skip.add(binv) + e = c_powers.pop(binv) + c_powers[b] -= e + + # check for base and negated base pairs + be = list(c_powers.items()) + _n = S.NegativeOne + for b, e in be: + if (b.is_Symbol or b.is_Add) and -b in c_powers and b in c_powers: + if (b.is_positive is not None or e.is_integer): + if e.is_integer or b.is_negative: + c_powers[-b] += c_powers.pop(b) + else: # (-b).is_positive so use its e + e = c_powers.pop(-b) + c_powers[b] += e + if _n in c_powers: + c_powers[_n] += e + else: + c_powers[_n] = e + + # filter c_powers and convert to a list + c_powers = [(b, e) for b, e in c_powers.items() if e] + + # ============================================================== + # check for Mul bases of Rational powers that can be combined with + # separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) -> + # (x*sqrt(x*y))**(3/2) + # ---------------- helper functions + + def ratq(x): + '''Return Rational part of x's exponent as it appears in the bkey. + ''' + return bkey(x)[0][1] + + def bkey(b, e=None): + '''Return (b**s, c.q), c.p where e -> c*s. If e is not given then + it will be taken by using as_base_exp() on the input b. + e.g. + x**3/2 -> (x, 2), 3 + x**y -> (x**y, 1), 1 + x**(2*y/3) -> (x**y, 3), 2 + exp(x/2) -> (exp(a), 2), 1 + + ''' + if e is not None: # coming from c_powers or from below + if e.is_Integer: + return (b, S.One), e + elif e.is_Rational: + return (b, Integer(e.q)), Integer(e.p) + else: + c, m = e.as_coeff_Mul(rational=True) + if c is not S.One: + if m.is_integer: + return (b, Integer(c.q)), m*Integer(c.p) + return (b**m, Integer(c.q)), Integer(c.p) + else: + return (b**e, S.One), S.One + else: + return bkey(*b.as_base_exp()) + + def update(b): + '''Decide what to do with base, b. If its exponent is now an + integer multiple of the Rational denominator, then remove it + and put the factors of its base in the common_b dictionary or + update the existing bases if necessary. If it has been zeroed + out, simply remove the base. + ''' + newe, r = divmod(common_b[b], b[1]) + if not r: + common_b.pop(b) + if newe: + for m in Mul.make_args(b[0]**newe): + b, e = bkey(m) + if b not in common_b: + common_b[b] = 0 + common_b[b] += e + if b[1] != 1: + bases.append(b) + # ---------------- end of helper functions + + # assemble a dictionary of the factors having a Rational power + common_b = {} + done = [] + bases = [] + for b, e in c_powers: + b, e = bkey(b, e) + if b in common_b: + common_b[b] = common_b[b] + e + else: + common_b[b] = e + if b[1] != 1 and b[0].is_Mul: + bases.append(b) + bases.sort(key=default_sort_key) # this makes tie-breaking canonical + bases.sort(key=measure, reverse=True) # handle longest first + for base in bases: + if base not in common_b: # it may have been removed already + continue + b, exponent = base + last = False # True when no factor of base is a radical + qlcm = 1 # the lcm of the radical denominators + while True: + bstart = b + qstart = qlcm + + bb = [] # list of factors + ee = [] # (factor's expo. and it's current value in common_b) + for bi in Mul.make_args(b): + bib, bie = bkey(bi) + if bib not in common_b or common_b[bib] < bie: + ee = bb = [] # failed + break + ee.append([bie, common_b[bib]]) + bb.append(bib) + if ee: + # find the number of integral extractions possible + # e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1 + min1 = ee[0][1]//ee[0][0] + for i in range(1, len(ee)): + rat = ee[i][1]//ee[i][0] + if rat < 1: + break + min1 = min(min1, rat) + else: + # update base factor counts + # e.g. if ee = [(2, 5), (3, 6)] then min1 = 2 + # and the new base counts will be 5-2*2 and 6-2*3 + for i in range(len(bb)): + common_b[bb[i]] -= min1*ee[i][0] + update(bb[i]) + # update the count of the base + # e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y) + # will increase by 4 to give bkey (x*sqrt(y), 2, 5) + common_b[base] += min1*qstart*exponent + if (last # no more radicals in base + or len(common_b) == 1 # nothing left to join with + or all(k[1] == 1 for k in common_b) # no rad's in common_b + ): + break + # see what we can exponentiate base by to remove any radicals + # so we know what to search for + # e.g. if base were x**(1/2)*y**(1/3) then we should + # exponentiate by 6 and look for powers of x and y in the ratio + # of 2 to 3 + qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)]) + if qlcm == 1: + break # we are done + b = bstart**qlcm + qlcm *= qstart + if all(ratq(bi) == 1 for bi in Mul.make_args(b)): + last = True # we are going to be done after this next pass + # this base no longer can find anything to join with and + # since it was longer than any other we are done with it + b, q = base + done.append((b, common_b.pop(base)*Rational(1, q))) + + # update c_powers and get ready to continue with powsimp + c_powers = done + # there may be terms still in common_b that were bases that were + # identified as needing processing, so remove those, too + for (b, q), e in common_b.items(): + if (b.is_Pow or isinstance(b, exp)) and \ + q is not S.One and not b.exp.is_Rational: + b, be = b.as_base_exp() + b = b**(be/q) + else: + b = root(b, q) + c_powers.append((b, e)) + check = len(c_powers) + c_powers = dict(c_powers) + assert len(c_powers) == check # there should have been no duplicates + # ============================================================== + + # rebuild the expression + newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()])) + if combine == 'exp': + return expr.func(newexpr, expr.func(*nc_part)) + else: + return recurse(expr.func(*nc_part), combine='base') * \ + recurse(newexpr, combine='base') + + elif combine == 'base': + + # Build c_powers and nc_part. These must both be lists not + # dicts because exp's are not combined. + c_powers = [] + nc_part = [] + for term in expr.args: + if term.is_commutative: + c_powers.append(list(term.as_base_exp())) + else: + nc_part.append(term) + + # Pull out numerical coefficients from exponent if assumptions allow + # e.g., 2**(2*x) => 4**x + for i in range(len(c_powers)): + b, e = c_powers[i] + if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar): + continue + exp_c, exp_t = e.as_coeff_Mul(rational=True) + if exp_c is not S.One and exp_t is not S.One: + c_powers[i] = [Pow(b, exp_c), exp_t] + + # Combine bases whenever they have the same exponent and + # assumptions allow + # first gather the potential bases under the common exponent + c_exp = defaultdict(list) + for b, e in c_powers: + if deep: + e = recurse(e) + if e.is_Add and (b.is_positive or e.is_integer): + e = factor_terms(e) + if _coeff_isneg(e): + e = -e + b = 1/b + c_exp[e].append(b) + del c_powers + + # Merge back in the results of the above to form a new product + c_powers = defaultdict(list) + for e in c_exp: + bases = c_exp[e] + + # calculate the new base for e + + if len(bases) == 1: + new_base = bases[0] + elif e.is_integer or force: + new_base = expr.func(*bases) + else: + # see which ones can be joined + unk = [] + nonneg = [] + neg = [] + for bi in bases: + if bi.is_negative: + neg.append(bi) + elif bi.is_nonnegative: + nonneg.append(bi) + elif bi.is_polar: + nonneg.append( + bi) # polar can be treated like non-negative + else: + unk.append(bi) + if len(unk) == 1 and not neg or len(neg) == 1 and not unk: + # a single neg or a single unk can join the rest + nonneg.extend(unk + neg) + unk = neg = [] + elif neg: + # their negative signs cancel in groups of 2*q if we know + # that e = p/q else we have to treat them as unknown + israt = False + if e.is_Rational: + israt = True + else: + p, d = e.as_numer_denom() + if p.is_integer and d.is_integer: + israt = True + if israt: + neg = [-w for w in neg] + unk.extend([S.NegativeOne]*len(neg)) + else: + unk.extend(neg) + neg = [] + del israt + + # these shouldn't be joined + for b in unk: + c_powers[b].append(e) + # here is a new joined base + new_base = expr.func(*(nonneg + neg)) + # if there are positive parts they will just get separated + # again unless some change is made + + def _terms(e): + # return the number of terms of this expression + # when multiplied out -- assuming no joining of terms + if e.is_Add: + return sum(_terms(ai) for ai in e.args) + if e.is_Mul: + return prod([_terms(mi) for mi in e.args]) + return 1 + xnew_base = expand_mul(new_base, deep=False) + if len(Add.make_args(xnew_base)) < _terms(new_base): + new_base = factor_terms(xnew_base) + + c_powers[new_base].append(e) + + # break out the powers from c_powers now + c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e] + + # we're done + return expr.func(*(c_part + nc_part)) + + else: + raise ValueError("combine must be one of ('all', 'exp', 'base').") + + +def powdenest(eq, force=False, polar=False): + r""" + Collect exponents on powers as assumptions allow. + + Explanation + =========== + + Given ``(bb**be)**e``, this can be simplified as follows: + * if ``bb`` is positive, or + * ``e`` is an integer, or + * ``|be| < 1`` then this simplifies to ``bb**(be*e)`` + + Given a product of powers raised to a power, ``(bb1**be1 * + bb2**be2...)**e``, simplification can be done as follows: + + - if e is positive, the gcd of all bei can be joined with e; + - all non-negative bb can be separated from those that are negative + and their gcd can be joined with e; autosimplification already + handles this separation. + - integer factors from powers that have integers in the denominator + of the exponent can be removed from any term and the gcd of such + integers can be joined with e + + Setting ``force`` to ``True`` will make symbols that are not explicitly + negative behave as though they are positive, resulting in more + denesting. + + Setting ``polar`` to ``True`` will do simplifications on the Riemann surface of + the logarithm, also resulting in more denestings. + + When there are sums of logs in exp() then a product of powers may be + obtained e.g. ``exp(3*(log(a) + 2*log(b)))`` - > ``a**3*b**6``. + + Examples + ======== + + >>> from sympy.abc import a, b, x, y, z + >>> from sympy import Symbol, exp, log, sqrt, symbols, powdenest + + >>> powdenest((x**(2*a/3))**(3*x)) + (x**(2*a/3))**(3*x) + >>> powdenest(exp(3*x*log(2))) + 2**(3*x) + + Assumptions may prevent expansion: + + >>> powdenest(sqrt(x**2)) + sqrt(x**2) + + >>> p = symbols('p', positive=True) + >>> powdenest(sqrt(p**2)) + p + + No other expansion is done. + + >>> i, j = symbols('i,j', integer=True) + >>> powdenest((x**x)**(i + j)) # -X-> (x**x)**i*(x**x)**j + x**(x*(i + j)) + + But exp() will be denested by moving all non-log terms outside of + the function; this may result in the collapsing of the exp to a power + with a different base: + + >>> powdenest(exp(3*y*log(x))) + x**(3*y) + >>> powdenest(exp(y*(log(a) + log(b)))) + (a*b)**y + >>> powdenest(exp(3*(log(a) + log(b)))) + a**3*b**3 + + If assumptions allow, symbols can also be moved to the outermost exponent: + + >>> i = Symbol('i', integer=True) + >>> powdenest(((x**(2*i))**(3*y))**x) + ((x**(2*i))**(3*y))**x + >>> powdenest(((x**(2*i))**(3*y))**x, force=True) + x**(6*i*x*y) + + >>> powdenest(((x**(2*a/3))**(3*y/i))**x) + ((x**(2*a/3))**(3*y/i))**x + >>> powdenest((x**(2*i)*y**(4*i))**z, force=True) + (x*y**2)**(2*i*z) + + >>> n = Symbol('n', negative=True) + + >>> powdenest((x**i)**y, force=True) + x**(i*y) + >>> powdenest((n**i)**x, force=True) + (n**i)**x + + """ + from sympy.simplify.simplify import posify + + if force: + def _denest(b, e): + if not isinstance(b, (Pow, exp)): + return b.is_positive, Pow(b, e, evaluate=False) + return _denest(b.base, b.exp*e) + reps = [] + for p in eq.atoms(Pow, exp): + if isinstance(p.base, (Pow, exp)): + ok, dp = _denest(*p.args) + if ok is not False: + reps.append((p, dp)) + if reps: + eq = eq.subs(reps) + eq, reps = posify(eq) + return powdenest(eq, force=False, polar=polar).xreplace(reps) + + if polar: + eq, rep = polarify(eq) + return unpolarify(powdenest(unpolarify(eq, exponents_only=True)), rep) + + new = powsimp(eq) + return new.xreplace(Transform( + _denest_pow, filter=lambda m: m.is_Pow or isinstance(m, exp))) + +_y = Dummy('y') + + +def _denest_pow(eq): + """ + Denest powers. + + This is a helper function for powdenest that performs the actual + transformation. + """ + from sympy.simplify.simplify import logcombine + + b, e = eq.as_base_exp() + if b.is_Pow or isinstance(b, exp) and e != 1: + new = b._eval_power(e) + if new is not None: + eq = new + b, e = new.as_base_exp() + + # denest exp with log terms in exponent + if b is S.Exp1 and e.is_Mul: + logs = [] + other = [] + for ei in e.args: + if any(isinstance(ai, log) for ai in Add.make_args(ei)): + logs.append(ei) + else: + other.append(ei) + logs = logcombine(Mul(*logs)) + return Pow(exp(logs), Mul(*other)) + + _, be = b.as_base_exp() + if be is S.One and not (b.is_Mul or + b.is_Rational and b.q != 1 or + b.is_positive): + return eq + + # denest eq which is either pos**e or Pow**e or Mul**e or + # Mul(b1**e1, b2**e2) + + # handle polar numbers specially + polars, nonpolars = [], [] + for bb in Mul.make_args(b): + if bb.is_polar: + polars.append(bb.as_base_exp()) + else: + nonpolars.append(bb) + if len(polars) == 1 and not polars[0][0].is_Mul: + return Pow(polars[0][0], polars[0][1]*e)*powdenest(Mul(*nonpolars)**e) + elif polars: + return Mul(*[powdenest(bb**(ee*e)) for (bb, ee) in polars]) \ + *powdenest(Mul(*nonpolars)**e) + + if b.is_Integer: + # use log to see if there is a power here + logb = expand_log(log(b)) + if logb.is_Mul: + c, logb = logb.args + e *= c + base = logb.args[0] + return Pow(base, e) + + # if b is not a Mul or any factor is an atom then there is nothing to do + if not b.is_Mul or any(s.is_Atom for s in Mul.make_args(b)): + return eq + + # let log handle the case of the base of the argument being a Mul, e.g. + # sqrt(x**(2*i)*y**(6*i)) -> x**i*y**(3**i) if x and y are positive; we + # will take the log, expand it, and then factor out the common powers that + # now appear as coefficient. We do this manually since terms_gcd pulls out + # fractions, terms_gcd(x+x*y/2) -> x*(y + 2)/2 and we don't want the 1/2; + # gcd won't pull out numerators from a fraction: gcd(3*x, 9*x/2) -> x but + # we want 3*x. Neither work with noncommutatives. + + def nc_gcd(aa, bb): + a, b = [i.as_coeff_Mul() for i in [aa, bb]] + c = gcd(a[0], b[0]).as_numer_denom()[0] + g = Mul(*(a[1].args_cnc(cset=True)[0] & b[1].args_cnc(cset=True)[0])) + return _keep_coeff(c, g) + + glogb = expand_log(log(b)) + if glogb.is_Add: + args = glogb.args + g = reduce(nc_gcd, args) + if g != 1: + cg, rg = g.as_coeff_Mul() + glogb = _keep_coeff(cg, rg*Add(*[a/g for a in args])) + + # now put the log back together again + if isinstance(glogb, log) or not glogb.is_Mul: + if glogb.args[0].is_Pow or isinstance(glogb.args[0], exp): + glogb = _denest_pow(glogb.args[0]) + if (abs(glogb.exp) < 1) == True: + return Pow(glogb.base, glogb.exp*e) + return eq + + # the log(b) was a Mul so join any adds with logcombine + add = [] + other = [] + for a in glogb.args: + if a.is_Add: + add.append(a) + else: + other.append(a) + return Pow(exp(logcombine(Mul(*add))), e*Mul(*other)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/radsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..c878168ebfbc29fc632577d6325befc120c26f56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/radsimp.py @@ -0,0 +1,1234 @@ +from collections import defaultdict + +from sympy.core import sympify, S, Mul, Derivative, Pow +from sympy.core.add import _unevaluated_Add, Add +from sympy.core.assumptions import assumptions +from sympy.core.exprtools import Factors, gcd_terms +from sympy.core.function import _mexpand, expand_mul, expand_power_base +from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort +from sympy.core.numbers import Rational, zoo, nan +from sympy.core.parameters import global_parameters +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.symbol import Dummy, Wild, symbols +from sympy.functions import exp, sqrt, log +from sympy.functions.elementary.complexes import Abs +from sympy.polys import gcd +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.utilities.iterables import iterable, sift + + + + +def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True): + """ + Collect additive terms of an expression. + + Explanation + =========== + + This function collects additive terms of an expression with respect + to a list of expression up to powers with rational exponents. By the + term symbol here are meant arbitrary expressions, which can contain + powers, products, sums etc. In other words symbol is a pattern which + will be searched for in the expression's terms. + + The input expression is not expanded by :func:`collect`, so user is + expected to provide an expression in an appropriate form. This makes + :func:`collect` more predictable as there is no magic happening behind the + scenes. However, it is important to note, that powers of products are + converted to products of powers using the :func:`~.expand_power_base` + function. + + There are two possible types of output. First, if ``evaluate`` flag is + set, this function will return an expression with collected terms or + else it will return a dictionary with expressions up to rational powers + as keys and collected coefficients as values. + + Examples + ======== + + >>> from sympy import S, collect, expand, factor, Wild + >>> from sympy.abc import a, b, c, x, y + + This function can collect symbolic coefficients in polynomials or + rational expressions. It will manage to find all integer or rational + powers of collection variable:: + + >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x) + c + x**2*(a + b) + x*(a - b) + + The same result can be achieved in dictionary form:: + + >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False) + >>> d[x**2] + a + b + >>> d[x] + a - b + >>> d[S.One] + c + + You can also work with multivariate polynomials. However, remember that + this function is greedy so it will care only about a single symbol at time, + in specification order:: + + >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y]) + x**2*(y + 1) + x*y + y*(a + 1) + + Also more complicated expressions can be used as patterns:: + + >>> from sympy import sin, log + >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x)) + (a + b)*sin(2*x) + + >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x)) + x*(a + b)*log(x) + + You can use wildcards in the pattern:: + + >>> w = Wild('w1') + >>> collect(a*x**y - b*x**y, w**y) + x**y*(a - b) + + It is also possible to work with symbolic powers, although it has more + complicated behavior, because in this case power's base and symbolic part + of the exponent are treated as a single symbol:: + + >>> collect(a*x**c + b*x**c, x) + a*x**c + b*x**c + >>> collect(a*x**c + b*x**c, x**c) + x**c*(a + b) + + However if you incorporate rationals to the exponents, then you will get + well known behavior:: + + >>> collect(a*x**(2*c) + b*x**(2*c), x**c) + x**(2*c)*(a + b) + + Note also that all previously stated facts about :func:`collect` function + apply to the exponential function, so you can get:: + + >>> from sympy import exp + >>> collect(a*exp(2*x) + b*exp(2*x), exp(x)) + (a + b)*exp(2*x) + + If you are interested only in collecting specific powers of some symbols + then set ``exact`` flag to True:: + + >>> collect(a*x**7 + b*x**7, x, exact=True) + a*x**7 + b*x**7 + >>> collect(a*x**7 + b*x**7, x**7, exact=True) + x**7*(a + b) + + If you want to collect on any object containing symbols, set + ``exact`` to None: + + >>> collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None) + x*exp(x) + 3*x + (y + 2)*sin(x) + >>> collect(a*x*y + x*y + b*x + x, [x, y], exact=None) + x*y*(a + 1) + x*(b + 1) + + You can also apply this function to differential equations, where + derivatives of arbitrary order can be collected. Note that if you + collect with respect to a function or a derivative of a function, all + derivatives of that function will also be collected. Use + ``exact=True`` to prevent this from happening:: + + >>> from sympy import Derivative as D, collect, Function + >>> f = Function('f') (x) + + >>> collect(a*D(f,x) + b*D(f,x), D(f,x)) + (a + b)*Derivative(f(x), x) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f) + (a + b)*Derivative(f(x), (x, 2)) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True) + a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2)) + + >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f) + (a + b)*f(x) + (a + b)*Derivative(f(x), x) + + Or you can even match both derivative order and exponent at the same time:: + + >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x)) + (a + b)*Derivative(f(x), (x, 2))**2 + + Finally, you can apply a function to each of the collected coefficients. + For example you can factorize symbolic coefficients of polynomial:: + + >>> f = expand((x + a + 1)**3) + + >>> collect(f, x, factor) + x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3 + + .. note:: Arguments are expected to be in expanded form, so you might have + to call :func:`~.expand` prior to calling this function. + + See Also + ======== + + collect_const, collect_sqrt, rcollect + """ + expr = sympify(expr) + syms = [sympify(i) for i in (syms if iterable(syms) else [syms])] + + # replace syms[i] if it is not x, -x or has Wild symbols + cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool( + x.atoms(Wild)) + _, nonsyms = sift(syms, cond, binary=True) + if nonsyms: + reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms])) + syms = [reps.get(s, s) for s in syms] + rv = collect(expr.subs(reps), syms, + func=func, evaluate=evaluate, exact=exact, + distribute_order_term=distribute_order_term) + urep = {v: k for k, v in reps.items()} + if not isinstance(rv, dict): + return rv.xreplace(urep) + else: + return {urep.get(k, k).xreplace(urep): v.xreplace(urep) + for k, v in rv.items()} + + # see if other expressions should be considered + if exact is None: + _syms = set() + for i in Add.make_args(expr): + if not i.has_free(*syms) or i in syms: + continue + if not i.is_Mul and i not in syms: + _syms.add(i) + else: + # identify compound generators + g = i._new_rawargs(*i.as_coeff_mul(*syms)[1]) + if g not in syms: + _syms.add(g) + simple = all(i.is_Pow and i.base in syms for i in _syms) + syms = syms + list(ordered(_syms)) + if not simple: + return collect(expr, syms, + func=func, evaluate=evaluate, exact=False, + distribute_order_term=distribute_order_term) + + if evaluate is None: + evaluate = global_parameters.evaluate + + def make_expression(terms): + product = [] + + for term, rat, sym, deriv in terms: + if deriv is not None: + var, order = deriv + for _ in range(order): + term = Derivative(term, var) + + if sym is None: + if rat is S.One: + product.append(term) + else: + product.append(Pow(term, rat)) + else: + product.append(Pow(term, rat*sym)) + + return Mul(*product) + + def parse_derivative(deriv): + # scan derivatives tower in the input expression and return + # underlying function and maximal differentiation order + expr, sym, order = deriv.expr, deriv.variables[0], 1 + + for s in deriv.variables[1:]: + if s == sym: + order += 1 + else: + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + while isinstance(expr, Derivative): + s0 = expr.variables[0] + + if any(s != s0 for s in expr.variables): + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + if s0 == sym: + expr, order = expr.expr, order + len(expr.variables) + else: + break + + return expr, (sym, Rational(order)) + + def parse_term(expr): + """Parses expression expr and outputs tuple (sexpr, rat_expo, + sym_expo, deriv) + where: + - sexpr is the base expression + - rat_expo is the rational exponent that sexpr is raised to + - sym_expo is the symbolic exponent that sexpr is raised to + - deriv contains the derivatives of the expression + + For example, the output of x would be (x, 1, None, None) + the output of 2**x would be (2, 1, x, None). + """ + rat_expo, sym_expo = S.One, None + sexpr, deriv = expr, None + + if expr.is_Pow: + if isinstance(expr.base, Derivative): + sexpr, deriv = parse_derivative(expr.base) + else: + sexpr = expr.base + + if expr.base == S.Exp1: + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + + elif expr.exp.is_Number: + rat_expo = expr.exp + else: + coeff, tail = expr.exp.as_coeff_Mul() + + if coeff.is_Number: + rat_expo, sym_expo = coeff, tail + else: + sym_expo = expr.exp + elif isinstance(expr, exp): + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + elif isinstance(expr, Derivative): + sexpr, deriv = parse_derivative(expr) + + return sexpr, rat_expo, sym_expo, deriv + + def parse_expression(terms, pattern): + """Parse terms searching for a pattern. + Terms is a list of tuples as returned by parse_terms; + Pattern is an expression treated as a product of factors. + """ + pattern = Mul.make_args(pattern) + + if len(terms) < len(pattern): + # pattern is longer than matched product + # so no chance for positive parsing result + return None + else: + pattern = [parse_term(elem) for elem in pattern] + + terms = terms[:] # need a copy + elems, common_expo, has_deriv = [], None, False + + for elem, e_rat, e_sym, e_ord in pattern: + + if elem.is_Number and e_rat == 1 and e_sym is None: + # a constant is a match for everything + continue + + for j in range(len(terms)): + if terms[j] is None: + continue + + term, t_rat, t_sym, t_ord = terms[j] + + # keeping track of whether one of the terms had + # a derivative or not as this will require rebuilding + # the expression later + if t_ord is not None: + has_deriv = True + + if (term.match(elem) is not None and + (t_sym == e_sym or t_sym is not None and + e_sym is not None and + t_sym.match(e_sym) is not None)): + if exact is False: + # we don't have to be exact so find common exponent + # for both expression's term and pattern's element + expo = t_rat / e_rat + + if common_expo is None: + # first time + common_expo = expo + else: + # common exponent was negotiated before so + # there is no chance for a pattern match unless + # common and current exponents are equal + if common_expo != expo: + common_expo = 1 + else: + # we ought to be exact so all fields of + # interest must match in every details + if e_rat != t_rat or e_ord != t_ord: + continue + + # found common term so remove it from the expression + # and try to match next element in the pattern + elems.append(terms[j]) + terms[j] = None + + break + + else: + # pattern element not found + return None + + return [_f for _f in terms if _f], elems, common_expo, has_deriv + + if evaluate: + if expr.is_Add: + o = expr.getO() or 0 + expr = expr.func(*[ + collect(a, syms, func, True, exact, distribute_order_term) + for a in expr.args if a != o]) + o + elif expr.is_Mul: + return expr.func(*[ + collect(term, syms, func, True, exact, distribute_order_term) + for term in expr.args]) + elif expr.is_Pow: + b = collect( + expr.base, syms, func, True, exact, distribute_order_term) + return Pow(b, expr.exp) + + syms = [expand_power_base(i, deep=False) for i in syms] + + order_term = None + + if distribute_order_term: + order_term = expr.getO() + + if order_term is not None: + if order_term.has(*syms): + order_term = None + else: + expr = expr.removeO() + + summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)] + + collected, disliked = defaultdict(list), S.Zero + for product in summa: + c, nc = product.args_cnc(split_1=False) + args = list(ordered(c)) + nc + terms = [parse_term(i) for i in args] + small_first = True + + for symbol in syms: + if isinstance(symbol, Derivative) and small_first: + terms = list(reversed(terms)) + small_first = not small_first + result = parse_expression(terms, symbol) + + if result is not None: + if not symbol.is_commutative: + raise AttributeError("Can not collect noncommutative symbol") + + terms, elems, common_expo, has_deriv = result + + # when there was derivative in current pattern we + # will need to rebuild its expression from scratch + if not has_deriv: + margs = [] + for elem in elems: + if elem[2] is None: + e = elem[1] + else: + e = elem[1]*elem[2] + margs.append(Pow(elem[0], e)) + index = Mul(*margs) + else: + index = make_expression(elems) + terms = expand_power_base(make_expression(terms), deep=False) + index = expand_power_base(index, deep=False) + collected[index].append(terms) + break + else: + # none of the patterns matched + disliked += product + # add terms now for each key + collected = {k: Add(*v) for k, v in collected.items()} + + if disliked is not S.Zero: + collected[S.One] = disliked + + if order_term is not None: + for key, val in collected.items(): + collected[key] = val + order_term + + if func is not None: + collected = { + key: func(val) for key, val in collected.items()} + + if evaluate: + return Add(*[key*val for key, val in collected.items()]) + else: + return collected + + +def rcollect(expr, *vars): + """ + Recursively collect sums in an expression. + + Examples + ======== + + >>> from sympy.simplify import rcollect + >>> from sympy.abc import x, y + + >>> expr = (x**2*y + x*y + x + y)/(x + y) + + >>> rcollect(expr, y) + (x + y*(x**2 + x + 1))/(x + y) + + See Also + ======== + + collect, collect_const, collect_sqrt + """ + if expr.is_Atom or not expr.has(*vars): + return expr + else: + expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args]) + + if expr.is_Add: + return collect(expr, vars) + else: + return expr + + +def collect_sqrt(expr, evaluate=None): + """Return expr with terms having common square roots collected together. + If ``evaluate`` is False a count indicating the number of sqrt-containing + terms will be returned and, if non-zero, the terms of the Add will be + returned, else the expression itself will be returned as a single term. + If ``evaluate`` is True, the expression with any collected terms will be + returned. + + Note: since I = sqrt(-1), it is collected, too. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b + + >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]] + >>> collect_sqrt(a*r2 + b*r2) + sqrt(2)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3) + sqrt(2)*(a + b) + sqrt(3)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5) + sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b) + + If evaluate is False then the arguments will be sorted and + returned as a list and a count of the number of sqrt-containing + terms will be returned: + + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False) + ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3) + >>> collect_sqrt(a*sqrt(2) + b, evaluate=False) + ((b, sqrt(2)*a), 1) + >>> collect_sqrt(a + b, evaluate=False) + ((a + b,), 0) + + See Also + ======== + + collect, collect_const, rcollect + """ + if evaluate is None: + evaluate = global_parameters.evaluate + # this step will help to standardize any complex arguments + # of sqrts + coeff, expr = expr.as_content_primitive() + vars = set() + for a in Add.make_args(expr): + for m in a.args_cnc()[0]: + if m.is_number and ( + m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or + m is S.ImaginaryUnit): + vars.add(m) + + # we only want radicals, so exclude Number handling; in this case + # d will be evaluated + d = collect_const(expr, *vars, Numbers=False) + hit = expr != d + + if not evaluate: + nrad = 0 + # make the evaluated args canonical + args = list(ordered(Add.make_args(d))) + for i, m in enumerate(args): + c, nc = m.args_cnc() + for ci in c: + # XXX should this be restricted to ci.is_number as above? + if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \ + ci is S.ImaginaryUnit: + nrad += 1 + break + args[i] *= coeff + if not (hit or nrad): + args = [Add(*args)] + return tuple(args), nrad + + return coeff*d + + +def collect_abs(expr): + """Return ``expr`` with arguments of multiple Abs in a term collected + under a single instance. + + Examples + ======== + + >>> from sympy.simplify.radsimp import collect_abs + >>> from sympy.abc import x + >>> collect_abs(abs(x + 1)/abs(x**2 - 1)) + Abs((x + 1)/(x**2 - 1)) + >>> collect_abs(abs(1/x)) + Abs(1/x) + """ + def _abs(mul): + c, nc = mul.args_cnc() + a = [] + o = [] + for i in c: + if isinstance(i, Abs): + a.append(i.args[0]) + elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real: + a.append(i.base.args[0]**i.exp) + else: + o.append(i) + if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)): + return mul + absarg = Mul(*a) + A = Abs(absarg) + args = [A] + args.extend(o) + if not A.has(Abs): + args.extend(nc) + return Mul(*args) + if not isinstance(A, Abs): + # reevaluate and make it unevaluated + A = Abs(absarg, evaluate=False) + args[0] = A + _mulsort(args) + args.extend(nc) # nc always go last + return Mul._from_args(args, is_commutative=not nc) + + return expr.replace( + lambda x: isinstance(x, Mul), + lambda x: _abs(x)).replace( + lambda x: isinstance(x, Pow), + lambda x: _abs(x)) + + +def collect_const(expr, *vars, Numbers=True): + """A non-greedy collection of terms with similar number coefficients in + an Add expr. If ``vars`` is given then only those constants will be + targeted. Although any Number can also be targeted, if this is not + desired set ``Numbers=False`` and no Float or Rational will be collected. + + Parameters + ========== + + expr : SymPy expression + This parameter defines the expression the expression from which + terms with similar coefficients are to be collected. A non-Add + expression is returned as it is. + + vars : variable length collection of Numbers, optional + Specifies the constants to target for collection. Can be multiple in + number. + + Numbers : bool + Specifies to target all instance of + :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then + no Float or Rational will be collected. + + Returns + ======= + + expr : Expr + Returns an expression with similar coefficient terms collected. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.abc import s, x, y, z + >>> from sympy.simplify.radsimp import collect_const + >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) + sqrt(3)*(sqrt(2) + 2) + >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) + (sqrt(3) + sqrt(7))*(s + 1) + >>> s = sqrt(2) + 2 + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) + (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) + sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) + + The collection is sign-sensitive, giving higher precedence to the + unsigned values: + + >>> collect_const(x - y - z) + x - (y + z) + >>> collect_const(-y - z) + -(y + z) + >>> collect_const(2*x - 2*y - 2*z, 2) + 2*(x - y - z) + >>> collect_const(2*x - 2*y - 2*z, -2) + 2*x - 2*(y + z) + + See Also + ======== + + collect, collect_sqrt, rcollect + """ + if not expr.is_Add: + return expr + + recurse = False + + if not vars: + recurse = True + vars = set() + for a in expr.args: + for m in Mul.make_args(a): + if m.is_number: + vars.add(m) + else: + vars = sympify(vars) + if not Numbers: + vars = [v for v in vars if not v.is_Number] + + vars = list(ordered(vars)) + for v in vars: + terms = defaultdict(list) + Fv = Factors(v) + for m in Add.make_args(expr): + f = Factors(m) + q, r = f.div(Fv) + if r.is_one: + # only accept this as a true factor if + # it didn't change an exponent from an Integer + # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) + # -- we aren't looking for this sort of change + fwas = f.factors.copy() + fnow = q.factors + if not any(k in fwas and fwas[k].is_Integer and not + fnow[k].is_Integer for k in fnow): + terms[v].append(q.as_expr()) + continue + terms[S.One].append(m) + + args = [] + hit = False + uneval = False + for k in ordered(terms): + v = terms[k] + if k is S.One: + args.extend(v) + continue + + if len(v) > 1: + v = Add(*v) + hit = True + if recurse and v != expr: + vars.append(v) + else: + v = v[0] + + # be careful not to let uneval become True unless + # it must be because it's going to be more expensive + # to rebuild the expression as an unevaluated one + if Numbers and k.is_Number and v.is_Add: + args.append(_keep_coeff(k, v, sign=True)) + uneval = True + else: + args.append(k*v) + + if hit: + if uneval: + expr = _unevaluated_Add(*args) + else: + expr = Add(*args) + if not expr.is_Add: + break + + return expr + + +def radsimp(expr, symbolic=True, max_terms=4): + r""" + Rationalize the denominator by removing square roots. + + Explanation + =========== + + The expression returned from radsimp must be used with caution + since if the denominator contains symbols, it will be possible to make + substitutions that violate the assumptions of the simplification process: + that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If + there are no symbols, this assumptions is made valid by collecting terms + of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If + you do not want the simplification to occur for symbolic denominators, set + ``symbolic`` to False. + + If there are more than ``max_terms`` radical terms then the expression is + returned unchanged. + + Examples + ======== + + >>> from sympy import radsimp, sqrt, Symbol, pprint + >>> from sympy import factor_terms, fraction, signsimp + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b, c + + >>> radsimp(1/(2 + sqrt(2))) + (2 - sqrt(2))/2 + >>> x,y = map(Symbol, 'xy') + >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) + >>> radsimp(e) + sqrt(2)*(x + y) + + No simplification beyond removal of the gcd is done. One might + want to polish the result a little, however, by collecting + square root terms: + + >>> r2 = sqrt(2) + >>> r5 = sqrt(5) + >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans) + ___ ___ ___ ___ + \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + >>> n, d = fraction(ans) + >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True)) + ___ ___ + \/ 5 *(a + b) - \/ 2 *(x + y) + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + If radicals in the denominator cannot be removed or there is no denominator, + the original expression will be returned. + + >>> radsimp(sqrt(2)*x + sqrt(2)) + sqrt(2)*x + sqrt(2) + + Results with symbols will not always be valid for all substitutions: + + >>> eq = 1/(a + b*sqrt(c)) + >>> eq.subs(a, b*sqrt(c)) + 1/(2*b*sqrt(c)) + >>> radsimp(eq).subs(a, b*sqrt(c)) + nan + + If ``symbolic=False``, symbolic denominators will not be transformed (but + numeric denominators will still be processed): + + >>> radsimp(eq, symbolic=False) + 1/(a + b*sqrt(c)) + + """ + from sympy.core.expr import Expr + from sympy.simplify.simplify import signsimp + + syms = symbols("a:d A:D") + def _num(rterms): + # return the multiplier that will simplify the expression described + # by rterms [(sqrt arg, coeff), ... ] + a, b, c, d, A, B, C, D = syms + if len(rterms) == 2: + reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i]))) + return ( + sqrt(A)*a - sqrt(B)*b).xreplace(reps) + if len(rterms) == 3: + reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i]))) + return ( + (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 - + B*b**2 + C*c**2)).xreplace(reps) + elif len(rterms) == 4: + reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))) + return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b + - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 + + D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 - + 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 - + 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 + + D**2*d**4)).xreplace(reps) + elif len(rterms) == 1: + return sqrt(rterms[0][0]) + else: + raise NotImplementedError + + def ispow2(d, log2=False): + if not d.is_Pow: + return False + e = d.exp + if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2: + return True + if log2: + q = 1 + if e.is_Rational: + q = e.q + elif symbolic: + d = denom(e) + if d.is_Integer: + q = d + if q != 1 and log(q, 2).is_Integer: + return True + return False + + def handle(expr): + # Handle first reduces to the case + # expr = 1/d, where d is an add, or d is base**p/2. + # We do this by recursively calling handle on each piece. + from sympy.simplify.simplify import nsimplify + + if expr.is_Atom: + return expr + elif not isinstance(expr, Expr): + return expr.func(*[handle(a) for a in expr.args]) + + n, d = fraction(expr) + + if d.is_Atom and n.is_Atom: + return expr + elif not n.is_Atom: + n = n.func(*[handle(a) for a in n.args]) + return _unevaluated_Mul(n, handle(1/d)) + elif n is not S.One: + return _unevaluated_Mul(n, handle(1/d)) + elif d.is_Mul: + return _unevaluated_Mul(*[handle(1/d) for d in d.args]) + + # By this step, expr is 1/d, and d is not a mul. + if not symbolic and d.free_symbols: + return expr + + if ispow2(d): + d2 = sqrtdenest(sqrt(d.base))**numer(d.exp) + if d2 != d: + return handle(1/d2) + elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): + # (1/d**i) = (1/d)**i + return handle(1/d.base)**d.exp + + if not (d.is_Add or ispow2(d)): + return 1/d.func(*[handle(a) for a in d.args]) + + # handle 1/d treating d as an Add (though it may not be) + + keep = True # keep changes that are made + + # flatten it and collect radicals after checking for special + # conditions + d = _mexpand(d) + + # did it change? + if d.is_Atom: + return 1/d + + # is it a number that might be handled easily? + if d.is_number: + _d = nsimplify(d) + if _d.is_Number and _d.equals(d): + return 1/_d + + while True: + # collect similar terms + collected = defaultdict(list) + for m in Add.make_args(d): # d might have become non-Add + p2 = [] + other = [] + for i in Mul.make_args(m): + if ispow2(i, log2=True): + p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp)) + elif i is S.ImaginaryUnit: + p2.append(S.NegativeOne) + else: + other.append(i) + collected[tuple(ordered(p2))].append(Mul(*other)) + rterms = list(ordered(list(collected.items()))) + rterms = [(Mul(*i), Add(*j)) for i, j in rterms] + nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) + if nrad < 1: + break + elif nrad > max_terms: + # there may have been invalid operations leading to this point + # so don't keep changes, e.g. this expression is troublesome + # in collecting terms so as not to raise the issue of 2834: + # r = sqrt(sqrt(5) + 5) + # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) + keep = False + break + if len(rterms) > 4: + # in general, only 4 terms can be removed with repeated squaring + # but other considerations can guide selection of radical terms + # so that radicals are removed + if all(x.is_Integer and (y**2).is_Rational for x, y in rterms): + nd, d = rad_rationalize(S.One, Add._from_args( + [sqrt(x)*y for x, y in rterms])) + n *= nd + else: + # is there anything else that might be attempted? + keep = False + break + from sympy.simplify.powsimp import powsimp, powdenest + + num = powsimp(_num(rterms)) + n *= num + d *= num + d = powdenest(_mexpand(d), force=symbolic) + if d.has(S.Zero, nan, zoo): + return expr + if d.is_Atom: + break + + if not keep: + return expr + return _unevaluated_Mul(n, 1/d) + + if not isinstance(expr, Expr): + return expr.func(*[radsimp(a, symbolic=symbolic, max_terms=max_terms) for a in expr.args]) + + coeff, expr = expr.as_coeff_Add() + expr = expr.normal() + old = fraction(expr) + n, d = fraction(handle(expr)) + if old != (n, d): + if not d.is_Atom: + was = (n, d) + n = signsimp(n, evaluate=False) + d = signsimp(d, evaluate=False) + u = Factors(_unevaluated_Mul(n, 1/d)) + u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) + n, d = fraction(u) + if old == (n, d): + n, d = was + n = expand_mul(n) + if d.is_Number or d.is_Add: + n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d))) + if d2.is_Number or (d2.count_ops() <= d.count_ops()): + n, d = [signsimp(i) for i in (n2, d2)] + if n.is_Mul and n.args[0].is_Number: + n = n.func(*n.args) + + return coeff + _unevaluated_Mul(n, 1/d) + + +def rad_rationalize(num, den): + """ + Rationalize ``num/den`` by removing square roots in the denominator; + num and den are sum of terms whose squares are positive rationals. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import rad_rationalize + >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3) + (-sqrt(3) + sqrt(6)/3, -7/9) + """ + if not den.is_Add: + return num, den + g, a, b = split_surds(den) + a = a*sqrt(g) + num = _mexpand((a - b)*num) + den = _mexpand(a**2 - b**2) + return rad_rationalize(num, den) + + +def fraction(expr, exact=False): + """Returns a pair with expression's numerator and denominator. + If the given expression is not a fraction then this function + will return the tuple (expr, 1). + + This function will not make any attempt to simplify nested + fractions or to do any term rewriting at all. + + If only one of the numerator/denominator pair is needed then + use numer(expr) or denom(expr) functions respectively. + + >>> from sympy import fraction, Rational, Symbol + >>> from sympy.abc import x, y + + >>> fraction(x/y) + (x, y) + >>> fraction(x) + (x, 1) + + >>> fraction(1/y**2) + (1, y**2) + + >>> fraction(x*y/2) + (x*y, 2) + >>> fraction(Rational(1, 2)) + (1, 2) + + This function will also work fine with assumptions: + + >>> k = Symbol('k', negative=True) + >>> fraction(x * y**k) + (x, y**(-k)) + + If we know nothing about sign of some exponent and ``exact`` + flag is unset, then the exponent's structure will + be analyzed and pretty fraction will be returned: + + >>> from sympy import exp, Mul + >>> fraction(2*x**(-y)) + (2, x**y) + + >>> fraction(exp(-x)) + (1, exp(x)) + + >>> fraction(exp(-x), exact=True) + (exp(-x), 1) + + The ``exact`` flag will also keep any unevaluated Muls from + being evaluated: + + >>> u = Mul(2, x + 1, evaluate=False) + >>> fraction(u) + (2*x + 2, 1) + >>> fraction(u, exact=True) + (2*(x + 1), 1) + """ + expr = sympify(expr) + + numer, denom = [], [] + + for term in Mul.make_args(expr): + if term.is_commutative and (term.is_Pow or isinstance(term, exp)): + b, ex = term.as_base_exp() + if ex.is_negative: + if ex is S.NegativeOne: + denom.append(b) + elif exact: + if ex.is_constant(): + denom.append(Pow(b, -ex)) + else: + numer.append(term) + else: + denom.append(Pow(b, -ex)) + elif ex.is_positive: + numer.append(term) + elif not exact and ex.is_Mul: + n, d = term.as_numer_denom() # this will cause evaluation + if n != 1: + numer.append(n) + denom.append(d) + else: + numer.append(term) + elif term.is_Rational and not term.is_Integer: + if term.p != 1: + numer.append(term.p) + denom.append(term.q) + else: + numer.append(term) + return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact) + + +def numer(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[0] + + +def denom(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[1] + + +def fraction_expand(expr, **hints): + return expr.expand(frac=True, **hints) + + +def numer_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a.expand(numer=True, **hints) / b + + +def denom_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a / b.expand(denom=True, **hints) + + +expand_numer = numer_expand +expand_denom = denom_expand +expand_fraction = fraction_expand + + +def split_surds(expr): + """ + Split an expression with terms whose squares are positive rationals + into a sum of terms whose surds squared have gcd equal to g + and a sum of terms with surds squared prime with g. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import split_surds + >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15)) + (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10)) + """ + args = sorted(expr.args, key=default_sort_key) + coeff_muls = [x.as_coeff_Mul() for x in args] + surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow] + surds.sort(key=default_sort_key) + g, b1, b2 = _split_gcd(*surds) + g2 = g + if not b2 and len(b1) >= 2: + b1n = [x/g for x in b1] + b1n = [x for x in b1n if x != 1] + # only a common factor has been factored; split again + g1, b1n, b2 = _split_gcd(*b1n) + g2 = g*g1 + a1v, a2v = [], [] + for c, s in coeff_muls: + if s.is_Pow and s.exp == S.Half: + s1 = s.base + if s1 in b1: + a1v.append(c*sqrt(s1/g2)) + else: + a2v.append(c*s) + else: + a2v.append(c*s) + a = Add(*a1v) + b = Add(*a2v) + return g2, a, b + + +def _split_gcd(*a): + """ + Split the list of integers ``a`` into a list of integers, ``a1`` having + ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by + ``g``. Returns ``g, a1, a2``. + + Examples + ======== + + >>> from sympy.simplify.radsimp import _split_gcd + >>> _split_gcd(55, 35, 22, 14, 77, 10) + (5, [55, 35, 10], [22, 14, 77]) + """ + g = a[0] + b1 = [g] + b2 = [] + for x in a[1:]: + g1 = gcd(g, x) + if g1 == 1: + b2.append(x) + else: + g = g1 + b1.append(x) + return g, b1, b2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/ratsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..95751fab47f585d3ae2e1289f014fba0f2708224 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/ratsimp.py @@ -0,0 +1,222 @@ +from itertools import combinations_with_replacement +from sympy.core import symbols, Add, Dummy +from sympy.core.numbers import Rational +from sympy.polys import cancel, ComputationFailed, parallel_poly_from_expr, reduced, Poly +from sympy.polys.monomials import Monomial, monomial_div +from sympy.polys.polyerrors import DomainError, PolificationFailed +from sympy.utilities.misc import debug, debugf + +def ratsimp(expr): + """ + Put an expression over a common denominator, cancel and reduce. + + Examples + ======== + + >>> from sympy import ratsimp + >>> from sympy.abc import x, y + >>> ratsimp(1/x + 1/y) + (x + y)/(x*y) + """ + + f, g = cancel(expr).as_numer_denom() + try: + Q, r = reduced(f, [g], field=True, expand=False) + except ComputationFailed: + return f/g + + return Add(*Q) + cancel(r/g) + + +def ratsimpmodprime(expr, G, *gens, quick=True, polynomial=False, **args): + """ + Simplifies a rational expression ``expr`` modulo the prime ideal + generated by ``G``. ``G`` should be a Groebner basis of the + ideal. + + Examples + ======== + + >>> from sympy.simplify.ratsimp import ratsimpmodprime + >>> from sympy.abc import x, y + >>> eq = (x + y**5 + y)/(x - y) + >>> ratsimpmodprime(eq, [x*y**5 - x - y], x, y, order='lex') + (-x**2 - x*y - x - y)/(-x**2 + x*y) + + If ``polynomial`` is ``False``, the algorithm computes a rational + simplification which minimizes the sum of the total degrees of + the numerator and the denominator. + + If ``polynomial`` is ``True``, this function just brings numerator and + denominator into a canonical form. This is much faster, but has + potentially worse results. + + References + ========== + + .. [1] M. Monagan, R. Pearce, Rational Simplification Modulo a Polynomial + Ideal, https://dl.acm.org/doi/pdf/10.1145/1145768.1145809 + (specifically, the second algorithm) + """ + from sympy.solvers.solvers import solve + + debug('ratsimpmodprime', expr) + + # usual preparation of polynomials: + + num, denom = cancel(expr).as_numer_denom() + + try: + polys, opt = parallel_poly_from_expr([num, denom] + G, *gens, **args) + except PolificationFailed: + return expr + + domain = opt.domain + + if domain.has_assoc_Field: + opt.domain = domain.get_field() + else: + raise DomainError( + "Cannot compute rational simplification over %s" % domain) + + # compute only once + leading_monomials = [g.LM(opt.order) for g in polys[2:]] + tested = set() + + def staircase(n): + """ + Compute all monomials with degree less than ``n`` that are + not divisible by any element of ``leading_monomials``. + """ + if n == 0: + return [1] + S = [] + for mi in combinations_with_replacement(range(len(opt.gens)), n): + m = [0]*len(opt.gens) + for i in mi: + m[i] += 1 + if all(monomial_div(m, lmg) is None for lmg in + leading_monomials): + S.append(m) + + return [Monomial(s).as_expr(*opt.gens) for s in S] + staircase(n - 1) + + def _ratsimpmodprime(a, b, allsol, N=0, D=0): + r""" + Computes a rational simplification of ``a/b`` which minimizes + the sum of the total degrees of the numerator and the denominator. + + Explanation + =========== + + The algorithm proceeds by looking at ``a * d - b * c`` modulo + the ideal generated by ``G`` for some ``c`` and ``d`` with degree + less than ``a`` and ``b`` respectively. + The coefficients of ``c`` and ``d`` are indeterminates and thus + the coefficients of the normalform of ``a * d - b * c`` are + linear polynomials in these indeterminates. + If these linear polynomials, considered as system of + equations, have a nontrivial solution, then `\frac{a}{b} + \equiv \frac{c}{d}` modulo the ideal generated by ``G``. So, + by construction, the degree of ``c`` and ``d`` is less than + the degree of ``a`` and ``b``, so a simpler representation + has been found. + After a simpler representation has been found, the algorithm + tries to reduce the degree of the numerator and denominator + and returns the result afterwards. + + As an extension, if quick=False, we look at all possible degrees such + that the total degree is less than *or equal to* the best current + solution. We retain a list of all solutions of minimal degree, and try + to find the best one at the end. + """ + c, d = a, b + steps = 0 + + maxdeg = a.total_degree() + b.total_degree() + if quick: + bound = maxdeg - 1 + else: + bound = maxdeg + while N + D <= bound: + if (N, D) in tested: + break + tested.add((N, D)) + + M1 = staircase(N) + M2 = staircase(D) + debugf('%s / %s: %s, %s', (N, D, M1, M2)) + + Cs = symbols("c:%d" % len(M1), cls=Dummy) + Ds = symbols("d:%d" % len(M2), cls=Dummy) + ng = Cs + Ds + + c_hat = Poly( + sum(Cs[i] * M1[i] for i in range(len(M1))), opt.gens + ng) + d_hat = Poly( + sum(Ds[i] * M2[i] for i in range(len(M2))), opt.gens + ng) + + r = reduced(a * d_hat - b * c_hat, G, opt.gens + ng, + order=opt.order, polys=True)[1] + + S = Poly(r, gens=opt.gens).coeffs() + sol = solve(S, Cs + Ds, particular=True, quick=True) + + if sol and not all(s == 0 for s in sol.values()): + c = c_hat.subs(sol) + d = d_hat.subs(sol) + + # The "free" variables occurring before as parameters + # might still be in the substituted c, d, so set them + # to the value chosen before: + c = c.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + d = d.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + + c = Poly(c, opt.gens) + d = Poly(d, opt.gens) + if d == 0: + raise ValueError('Ideal not prime?') + + allsol.append((c_hat, d_hat, S, Cs + Ds)) + if N + D != maxdeg: + allsol = [allsol[-1]] + + break + + steps += 1 + N += 1 + D += 1 + + if steps > 0: + c, d, allsol = _ratsimpmodprime(c, d, allsol, N, D - steps) + c, d, allsol = _ratsimpmodprime(c, d, allsol, N - steps, D) + + return c, d, allsol + + # preprocessing. this improves performance a bit when deg(num) + # and deg(denom) are large: + num = reduced(num, G, opt.gens, order=opt.order)[1] + denom = reduced(denom, G, opt.gens, order=opt.order)[1] + + if polynomial: + return (num/denom).cancel() + + c, d, allsol = _ratsimpmodprime( + Poly(num, opt.gens, domain=opt.domain), Poly(denom, opt.gens, domain=opt.domain), []) + if not quick and allsol: + debugf('Looking for best minimal solution. Got: %s', len(allsol)) + newsol = [] + for c_hat, d_hat, S, ng in allsol: + sol = solve(S, ng, particular=True, quick=False) + # all values of sol should be numbers; if not, solve is broken + newsol.append((c_hat.subs(sol), d_hat.subs(sol))) + c, d = min(newsol, key=lambda x: len(x[0].terms()) + len(x[1].terms())) + + if not domain.is_Field: + cn, c = c.clear_denoms(convert=True) + dn, d = d.clear_denoms(convert=True) + r = Rational(cn, dn) + else: + r = Rational(1) + + return (c*r.q)/(d*r.p) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/simplify.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..8b315cc20c19fc10c37b903d16129a7f5579ecd3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/simplify.py @@ -0,0 +1,2164 @@ +from __future__ import annotations + +from typing import overload + +from collections import defaultdict + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core import (Basic, S, Add, Mul, Pow, Symbol, sympify, + expand_func, Function, Dummy, Expr, factor_terms, + expand_power_exp, Eq) +from sympy.core.exprtools import factor_nc +from sympy.core.parameters import global_parameters +from sympy.core.function import (expand_log, count_ops, _mexpand, + nfloat, expand_mul, expand) +from sympy.core.numbers import Float, I, pi, Rational, equal_valued +from sympy.core.relational import Relational +from sympy.core.rules import Transform +from sympy.core.sorting import ordered +from sympy.core.sympify import _sympify +from sympy.core.traversal import bottom_up as _bottom_up, walk as _walk +from sympy.functions import gamma, exp, sqrt, log, exp_polar, re +from sympy.functions.combinatorial.factorials import CombinatorialFunction +from sympy.functions.elementary.complexes import unpolarify, Abs, sign +from sympy.functions.elementary.exponential import ExpBase +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.piecewise import (Piecewise, piecewise_fold, + piecewise_simplify) +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.functions.special.bessel import (BesselBase, besselj, besseli, + besselk, bessely, jn) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import Boolean +from sympy.matrices.expressions import (MatrixExpr, MatAdd, MatMul, + MatPow, MatrixSymbol) +from sympy.polys import together, cancel, factor +from sympy.polys.numberfields.minpoly import _is_sum_surds, _minimal_polynomial_sq +from sympy.sets.sets import Set +from sympy.simplify.combsimp import combsimp +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.simplify.hyperexpand import hyperexpand +from sympy.simplify.powsimp import powsimp +from sympy.simplify.radsimp import radsimp, fraction, collect_abs +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.trigsimp import trigsimp, exptrigsimp +from sympy.utilities.decorator import deprecated +from sympy.utilities.iterables import has_variety, sift, subsets, iterable +from sympy.utilities.misc import as_int + +import mpmath + + +def separatevars(expr, symbols=[], dict=False, force=False): + """ + Separates variables in an expression, if possible. By + default, it separates with respect to all symbols in an + expression and collects constant coefficients that are + independent of symbols. + + Explanation + =========== + + If ``dict=True`` then the separated terms will be returned + in a dictionary keyed to their corresponding symbols. + By default, all symbols in the expression will appear as + keys; if symbols are provided, then all those symbols will + be used as keys, and any terms in the expression containing + other symbols or non-symbols will be returned keyed to the + string 'coeff'. (Passing None for symbols will return the + expression in a dictionary keyed to 'coeff'.) + + If ``force=True``, then bases of powers will be separated regardless + of assumptions on the symbols involved. + + Notes + ===== + + The order of the factors is determined by Mul, so that the + separated expressions may not necessarily be grouped together. + + Although factoring is necessary to separate variables in some + expressions, it is not necessary in all cases, so one should not + count on the returned factors being factored. + + Examples + ======== + + >>> from sympy.abc import x, y, z, alpha + >>> from sympy import separatevars, sin + >>> separatevars((x*y)**y) + (x*y)**y + >>> separatevars((x*y)**y, force=True) + x**y*y**y + + >>> e = 2*x**2*z*sin(y)+2*z*x**2 + >>> separatevars(e) + 2*x**2*z*(sin(y) + 1) + >>> separatevars(e, symbols=(x, y), dict=True) + {'coeff': 2*z, x: x**2, y: sin(y) + 1} + >>> separatevars(e, [x, y, alpha], dict=True) + {'coeff': 2*z, alpha: 1, x: x**2, y: sin(y) + 1} + + If the expression is not really separable, or is only partially + separable, separatevars will do the best it can to separate it + by using factoring. + + >>> separatevars(x + x*y - 3*x**2) + -x*(3*x - y - 1) + + If the expression is not separable then expr is returned unchanged + or (if dict=True) then None is returned. + + >>> eq = 2*x + y*sin(x) + >>> separatevars(eq) == eq + True + >>> separatevars(2*x + y*sin(x), symbols=(x, y), dict=True) is None + True + + """ + expr = sympify(expr) + if dict: + return _separatevars_dict(_separatevars(expr, force), symbols) + else: + return _separatevars(expr, force) + + +def _separatevars(expr, force): + if isinstance(expr, Abs): + arg = expr.args[0] + if arg.is_Mul and not arg.is_number: + s = separatevars(arg, dict=True, force=force) + if s is not None: + return Mul(*map(expr.func, s.values())) + else: + return expr + + if len(expr.free_symbols) < 2: + return expr + + # don't destroy a Mul since much of the work may already be done + if expr.is_Mul: + args = list(expr.args) + changed = False + for i, a in enumerate(args): + args[i] = separatevars(a, force) + changed = changed or args[i] != a + if changed: + expr = expr.func(*args) + return expr + + # get a Pow ready for expansion + if expr.is_Pow and expr.base != S.Exp1: + expr = Pow(separatevars(expr.base, force=force), expr.exp) + + # First try other expansion methods + expr = expr.expand(mul=False, multinomial=False, force=force) + + _expr, reps = posify(expr) if force else (expr, {}) + expr = factor(_expr).subs(reps) + + if not expr.is_Add: + return expr + + # Find any common coefficients to pull out + args = list(expr.args) + commonc = args[0].args_cnc(cset=True, warn=False)[0] + for i in args[1:]: + commonc &= i.args_cnc(cset=True, warn=False)[0] + commonc = Mul(*commonc) + commonc = commonc.as_coeff_Mul()[1] # ignore constants + commonc_set = commonc.args_cnc(cset=True, warn=False)[0] + + # remove them + for i, a in enumerate(args): + c, nc = a.args_cnc(cset=True, warn=False) + c = c - commonc_set + args[i] = Mul(*c)*Mul(*nc) + nonsepar = Add(*args) + + if len(nonsepar.free_symbols) > 1: + _expr = nonsepar + _expr, reps = posify(_expr) if force else (_expr, {}) + _expr = (factor(_expr)).subs(reps) + + if not _expr.is_Add: + nonsepar = _expr + + return commonc*nonsepar + + +def _separatevars_dict(expr, symbols): + if symbols: + if not all(t.is_Atom for t in symbols): + raise ValueError("symbols must be Atoms.") + symbols = list(symbols) + elif symbols is None: + return {'coeff': expr} + else: + symbols = list(expr.free_symbols) + if not symbols: + return None + + ret = {i: [] for i in symbols + ['coeff']} + + for i in Mul.make_args(expr): + expsym = i.free_symbols + intersection = set(symbols).intersection(expsym) + if len(intersection) > 1: + return None + if len(intersection) == 0: + # There are no symbols, so it is part of the coefficient + ret['coeff'].append(i) + else: + ret[intersection.pop()].append(i) + + # rebuild + for k, v in ret.items(): + ret[k] = Mul(*v) + + return ret + + +def posify(eq): + """Return ``eq`` (with generic symbols made positive) and a + dictionary containing the mapping between the old and new + symbols. + + Explanation + =========== + + Any symbol that has positive=None will be replaced with a positive dummy + symbol having the same name. This replacement will allow more symbolic + processing of expressions, especially those involving powers and + logarithms. + + A dictionary that can be sent to subs to restore ``eq`` to its original + symbols is also returned. + + >>> from sympy import posify, Symbol, log, solve + >>> from sympy.abc import x + >>> posify(x + Symbol('p', positive=True) + Symbol('n', negative=True)) + (_x + n + p, {_x: x}) + + >>> eq = 1/x + >>> log(eq).expand() + log(1/x) + >>> log(posify(eq)[0]).expand() + -log(_x) + >>> p, rep = posify(eq) + >>> log(p).expand().subs(rep) + -log(x) + + It is possible to apply the same transformations to an iterable + of expressions: + + >>> eq = x**2 - 4 + >>> solve(eq, x) + [-2, 2] + >>> eq_x, reps = posify([eq, x]); eq_x + [_x**2 - 4, _x] + >>> solve(*eq_x) + [2] + """ + eq = sympify(eq) + if not isinstance(eq, Basic) and iterable(eq): + f = type(eq) + eq = list(eq) + syms = set() + for e in eq: + syms = syms.union(e.atoms(Symbol)) + reps = {} + for s in syms: + reps.update({v: k for k, v in posify(s)[1].items()}) + for i, e in enumerate(eq): + eq[i] = e.subs(reps) + return f(eq), {r: s for s, r in reps.items()} + + reps = {s: Dummy(s.name, positive=True, **s.assumptions0) + for s in eq.free_symbols if s.is_positive is None} + eq = eq.subs(reps) + return eq, {r: s for s, r in reps.items()} + + +def hypersimp(f, k): + """Given combinatorial term f(k) simplify its consecutive term ratio + i.e. f(k+1)/f(k). The input term can be composed of functions and + integer sequences which have equivalent representation in terms + of gamma special function. + + Explanation + =========== + + The algorithm performs three basic steps: + + 1. Rewrite all functions in terms of gamma, if possible. + + 2. Rewrite all occurrences of gamma in terms of products + of gamma and rising factorial with integer, absolute + constant exponent. + + 3. Perform simplification of nested fractions, powers + and if the resulting expression is a quotient of + polynomials, reduce their total degree. + + If f(k) is hypergeometric then as result we arrive with a + quotient of polynomials of minimal degree. Otherwise None + is returned. + + For more information on the implemented algorithm refer to: + + 1. W. Koepf, Algorithms for m-fold Hypergeometric Summation, + Journal of Symbolic Computation (1995) 20, 399-417 + """ + f = sympify(f) + + g = f.subs(k, k + 1) / f + + g = g.rewrite(gamma) + if g.has(Piecewise): + g = piecewise_fold(g) + g = g.args[-1][0] + g = expand_func(g) + g = powsimp(g, deep=True, combine='exp') + + if g.is_rational_function(k): + return simplify(g, ratio=S.Infinity) + else: + return None + + +def hypersimilar(f, g, k): + """ + Returns True if ``f`` and ``g`` are hyper-similar. + + Explanation + =========== + + Similarity in hypergeometric sense means that a quotient of + f(k) and g(k) is a rational function in ``k``. This procedure + is useful in solving recurrence relations. + + For more information see hypersimp(). + + """ + f, g = list(map(sympify, (f, g))) + + h = (f/g).rewrite(gamma) + h = h.expand(func=True, basic=False) + + return h.is_rational_function(k) + + +def signsimp(expr, evaluate=None): + """Make all Add sub-expressions canonical wrt sign. + + Explanation + =========== + + If an Add subexpression, ``a``, can have a sign extracted, + as determined by could_extract_minus_sign, it is replaced + with Mul(-1, a, evaluate=False). This allows signs to be + extracted from powers and products. + + Examples + ======== + + >>> from sympy import signsimp, exp, symbols + >>> from sympy.abc import x, y + >>> i = symbols('i', odd=True) + >>> n = -1 + 1/x + >>> n/x/(-n)**2 - 1/n/x + (-1 + 1/x)/(x*(1 - 1/x)**2) - 1/(x*(-1 + 1/x)) + >>> signsimp(_) + 0 + >>> x*n + x*-n + x*(-1 + 1/x) + x*(1 - 1/x) + >>> signsimp(_) + 0 + + Since powers automatically handle leading signs + + >>> (-2)**i + -2**i + + signsimp can be used to put the base of a power with an integer + exponent into canonical form: + + >>> n**i + (-1 + 1/x)**i + + By default, signsimp does not leave behind any hollow simplification: + if making an Add canonical wrt sign didn't change the expression, the + original Add is restored. If this is not desired then the keyword + ``evaluate`` can be set to False: + + >>> e = exp(y - x) + >>> signsimp(e) == e + True + >>> signsimp(e, evaluate=False) + exp(-(x - y)) + + """ + if evaluate is None: + evaluate = global_parameters.evaluate + expr = sympify(expr) + if not isinstance(expr, (Expr, Relational)) or expr.is_Atom: + return expr + # get rid of an pre-existing unevaluation regarding sign + e = expr.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + e = sub_post(sub_pre(e)) + if not isinstance(e, (Expr, Relational)) or e.is_Atom: + return e + if e.is_Add: + rv = e.func(*[signsimp(a) for a in e.args]) + if not evaluate and isinstance(rv, Add + ) and rv.could_extract_minus_sign(): + return Mul(S.NegativeOne, -rv, evaluate=False) + return rv + if evaluate: + e = e.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + return e + + +@overload +def simplify(expr: Expr, **kwargs) -> Expr: ... +@overload +def simplify(expr: Boolean, **kwargs) -> Boolean: ... +@overload +def simplify(expr: Set, **kwargs) -> Set: ... +@overload +def simplify(expr: Basic, **kwargs) -> Basic: ... + +def simplify(expr, ratio=1.7, measure=count_ops, rational=False, inverse=False, doit=True, **kwargs): + """Simplifies the given expression. + + Explanation + =========== + + Simplification is not a well defined term and the exact strategies + this function tries can change in the future versions of SymPy. If + your algorithm relies on "simplification" (whatever it is), try to + determine what you need exactly - is it powsimp()?, radsimp()?, + together()?, logcombine()?, or something else? And use this particular + function directly, because those are well defined and thus your algorithm + will be robust. + + Nonetheless, especially for interactive use, or when you do not know + anything about the structure of the expression, simplify() tries to apply + intelligent heuristics to make the input expression "simpler". For + example: + + >>> from sympy import simplify, cos, sin + >>> from sympy.abc import x, y + >>> a = (x + x**2)/(x*sin(y)**2 + x*cos(y)**2) + >>> a + (x**2 + x)/(x*sin(y)**2 + x*cos(y)**2) + >>> simplify(a) + x + 1 + + Note that we could have obtained the same result by using specific + simplification functions: + + >>> from sympy import trigsimp, cancel + >>> trigsimp(a) + (x**2 + x)/x + >>> cancel(_) + x + 1 + + In some cases, applying :func:`simplify` may actually result in some more + complicated expression. The default ``ratio=1.7`` prevents more extreme + cases: if (result length)/(input length) > ratio, then input is returned + unmodified. The ``measure`` parameter lets you specify the function used + to determine how complex an expression is. The function should take a + single argument as an expression and return a number such that if + expression ``a`` is more complex than expression ``b``, then + ``measure(a) > measure(b)``. The default measure function is + :func:`~.count_ops`, which returns the total number of operations in the + expression. + + For example, if ``ratio=1``, ``simplify`` output cannot be longer + than input. + + :: + + >>> from sympy import sqrt, simplify, count_ops, oo + >>> root = 1/(sqrt(2)+3) + + Since ``simplify(root)`` would result in a slightly longer expression, + root is returned unchanged instead:: + + >>> simplify(root, ratio=1) == root + True + + If ``ratio=oo``, simplify will be applied anyway:: + + >>> count_ops(simplify(root, ratio=oo)) > count_ops(root) + True + + Note that the shortest expression is not necessary the simplest, so + setting ``ratio`` to 1 may not be a good idea. + Heuristically, the default value ``ratio=1.7`` seems like a reasonable + choice. + + You can easily define your own measure function based on what you feel + should represent the "size" or "complexity" of the input expression. Note + that some choices, such as ``lambda expr: len(str(expr))`` may appear to be + good metrics, but have other problems (in this case, the measure function + may slow down simplify too much for very large expressions). If you do not + know what a good metric would be, the default, ``count_ops``, is a good + one. + + For example: + + >>> from sympy import symbols, log + >>> a, b = symbols('a b', positive=True) + >>> g = log(a) + log(b) + log(a)*log(1/b) + >>> h = simplify(g) + >>> h + log(a*b**(1 - log(a))) + >>> count_ops(g) + 8 + >>> count_ops(h) + 5 + + So you can see that ``h`` is simpler than ``g`` using the count_ops metric. + However, we may not like how ``simplify`` (in this case, using + ``logcombine``) has created the ``b**(log(1/a) + 1)`` term. A simple way + to reduce this would be to give more weight to powers as operations in + ``count_ops``. We can do this by using the ``visual=True`` option: + + >>> print(count_ops(g, visual=True)) + 2*ADD + DIV + 4*LOG + MUL + >>> print(count_ops(h, visual=True)) + 2*LOG + MUL + POW + SUB + + >>> from sympy import Symbol, S + >>> def my_measure(expr): + ... POW = Symbol('POW') + ... # Discourage powers by giving POW a weight of 10 + ... count = count_ops(expr, visual=True).subs(POW, 10) + ... # Every other operation gets a weight of 1 (the default) + ... count = count.replace(Symbol, type(S.One)) + ... return count + >>> my_measure(g) + 8 + >>> my_measure(h) + 14 + >>> 15./8 > 1.7 # 1.7 is the default ratio + True + >>> simplify(g, measure=my_measure) + -log(a)*log(b) + log(a) + log(b) + + Note that because ``simplify()`` internally tries many different + simplification strategies and then compares them using the measure + function, we get a completely different result that is still different + from the input expression by doing this. + + If ``rational=True``, Floats will be recast as Rationals before simplification. + If ``rational=None``, Floats will be recast as Rationals but the result will + be recast as Floats. If rational=False(default) then nothing will be done + to the Floats. + + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is + False. + + Note that ``simplify()`` automatically calls ``doit()`` on the final + expression. You can avoid this behavior by passing ``doit=False`` as + an argument. + + Also, it should be noted that simplifying a boolean expression is not + well defined. If the expression prefers automatic evaluation (such as + :obj:`~.Eq()` or :obj:`~.Or()`), simplification will return ``True`` or + ``False`` if truth value can be determined. If the expression is not + evaluated by default (such as :obj:`~.Predicate()`), simplification will + not reduce it and you should use :func:`~.refine` or :func:`~.ask` + function. This inconsistency will be resolved in future version. + + See Also + ======== + + sympy.assumptions.refine.refine : Simplification using assumptions. + sympy.assumptions.ask.ask : Query for boolean expressions using assumptions. + """ + + def shorter(*choices): + """ + Return the choice that has the fewest ops. In case of a tie, + the expression listed first is selected. + """ + if not has_variety(choices): + return choices[0] + return min(choices, key=measure) + + def done(e): + rv = e.doit() if doit else e + return shorter(rv, collect_abs(rv)) + + expr = sympify(expr, rational=rational) + kwargs = { + "ratio": kwargs.get('ratio', ratio), + "measure": kwargs.get('measure', measure), + "rational": kwargs.get('rational', rational), + "inverse": kwargs.get('inverse', inverse), + "doit": kwargs.get('doit', doit)} + # no routine for Expr needs to check for is_zero + if isinstance(expr, Expr) and expr.is_zero: + return S.Zero if not expr.is_Number else expr + + _eval_simplify = getattr(expr, '_eval_simplify', None) + if _eval_simplify is not None: + return _eval_simplify(**kwargs) + + original_expr = expr = collect_abs(signsimp(expr)) + + if not isinstance(expr, Basic) or not expr.args: # XXX: temporary hack + return expr + + if inverse and expr.has(Function): + expr = inversecombine(expr) + if not expr.args: # simplified to atomic + return expr + + # do deep simplification + handled = Add, Mul, Pow, ExpBase + expr = expr.replace( + # here, checking for x.args is not enough because Basic has + # args but Basic does not always play well with replace, e.g. + # when simultaneous is True found expressions will be masked + # off with a Dummy but not all Basic objects in an expression + # can be replaced with a Dummy + lambda x: isinstance(x, Expr) and x.args and not isinstance( + x, handled), + lambda x: x.func(*[simplify(i, **kwargs) for i in x.args]), + simultaneous=False) + if not isinstance(expr, handled): + return done(expr) + + if not expr.is_commutative: + expr = nc_simplify(expr) + + # TODO: Apply different strategies, considering expression pattern: + # is it a purely rational function? Is there any trigonometric function?... + # See also https://github.com/sympy/sympy/pull/185. + + # rationalize Floats + floats = False + if rational is not False and expr.has(Float): + floats = True + expr = nsimplify(expr, rational=True) + + expr = _bottom_up(expr, lambda w: getattr(w, 'normal', lambda: w)()) + expr = Mul(*powsimp(expr).as_content_primitive()) + _e = cancel(expr) + expr1 = shorter(_e, _mexpand(_e).cancel()) # issue 6829 + expr2 = shorter(together(expr, deep=True), together(expr1, deep=True)) + + if ratio is S.Infinity: + expr = expr2 + else: + expr = shorter(expr2, expr1, expr) + if not isinstance(expr, Basic): # XXX: temporary hack + return expr + + expr = factor_terms(expr, sign=False) + + # must come before `Piecewise` since this introduces more `Piecewise` terms + if expr.has(sign): + expr = expr.rewrite(Abs) + + # Deal with Piecewise separately to avoid recursive growth of expressions + if expr.has(Piecewise): + # Fold into a single Piecewise + expr = piecewise_fold(expr) + # Apply doit, if doit=True + expr = done(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Fold into a single Piecewise, in case doit lead to some + # expressions being Piecewise + expr = piecewise_fold(expr) + # kroneckersimp also affects Piecewise + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Do not apply doit on the segments as it has already + # been done above, but simplify + expr = piecewise_simplify(expr, deep=True, doit=False) + # Still a Piecewise? + if expr.has(Piecewise): + # Try factor common terms + expr = shorter(expr, factor_terms(expr)) + # As all expressions have been simplified above with the + # complete simplify, nothing more needs to be done here + return expr + + # hyperexpand automatically only works on hypergeometric terms + # Do this after the Piecewise part to avoid recursive expansion + expr = hyperexpand(expr) + + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + + if expr.has(BesselBase): + expr = besselsimp(expr) + + if expr.has(TrigonometricFunction, HyperbolicFunction): + expr = trigsimp(expr, deep=True) + + if expr.has(log): + expr = shorter(expand_log(expr, deep=True), logcombine(expr)) + + if expr.has(CombinatorialFunction, gamma): + # expression with gamma functions or non-integer arguments is + # automatically passed to gammasimp + expr = combsimp(expr) + + if expr.has(Sum): + expr = sum_simplify(expr, **kwargs) + + if expr.has(Integral): + expr = expr.xreplace({ + i: factor_terms(i) for i in expr.atoms(Integral)}) + + if expr.has(Product): + expr = product_simplify(expr, **kwargs) + + from sympy.physics.units import Quantity + + if expr.has(Quantity): + from sympy.physics.units.util import quantity_simplify + expr = quantity_simplify(expr) + + short = shorter(powsimp(expr, combine='exp', deep=True), powsimp(expr), expr) + short = shorter(short, cancel(short)) + short = shorter(short, factor_terms(short), expand_power_exp(expand_mul(short))) + if short.has(TrigonometricFunction, HyperbolicFunction, ExpBase, exp): + short = exptrigsimp(short) + + # get rid of hollow 2-arg Mul factorization + hollow_mul = Transform( + lambda x: Mul(*x.args), + lambda x: + x.is_Mul and + len(x.args) == 2 and + x.args[0].is_Number and + x.args[1].is_Add and + x.is_commutative) + expr = short.xreplace(hollow_mul) + + numer, denom = expr.as_numer_denom() + if denom.is_Add: + n, d = fraction(radsimp(1/denom, symbolic=False, max_terms=1)) + if n is not S.One: + expr = (numer*n).expand()/d + + if expr.could_extract_minus_sign(): + n, d = fraction(expr) + if d != 0: + expr = signsimp(-n/(-d)) + + if measure(expr) > ratio*measure(original_expr): + expr = original_expr + + # restore floats + if floats and rational is None: + expr = nfloat(expr, exponent=False) + + return done(expr) + + +def sum_simplify(s, **kwargs): + """Main function for Sum simplification""" + if not isinstance(s, Add): + s = s.xreplace({a: sum_simplify(a, **kwargs) + for a in s.atoms(Add) if a.has(Sum)}) + s = expand(s) + if not isinstance(s, Add): + return s + + terms = s.args + s_t = [] # Sum Terms + o_t = [] # Other Terms + + for term in terms: + sum_terms, other = sift(Mul.make_args(term), + lambda i: isinstance(i, Sum), binary=True) + if not sum_terms: + o_t.append(term) + continue + other = [Mul(*other)] + s_t.append(Mul(*(other + [s._eval_simplify(**kwargs) for s in sum_terms]))) + + result = Add(sum_combine(s_t), *o_t) + + return result + + +def sum_combine(s_t): + """Helper function for Sum simplification + + Attempts to simplify a list of sums, by combining limits / sum function's + returns the simplified sum + """ + used = [False] * len(s_t) + + for method in range(2): + for i, s_term1 in enumerate(s_t): + if not used[i]: + for j, s_term2 in enumerate(s_t): + if not used[j] and i != j: + temp = sum_add(s_term1, s_term2, method) + if isinstance(temp, (Sum, Mul)): + s_t[i] = temp + s_term1 = s_t[i] + used[j] = True + + result = S.Zero + for i, s_term in enumerate(s_t): + if not used[i]: + result = Add(result, s_term) + + return result + +def factor_sum(self, limits=None, radical=False, clear=False, fraction=False, sign=True): + """Return Sum with constant factors extracted. + + If ``limits`` is specified then ``self`` is the summand; the other + keywords are passed to ``factor_terms``. + + Examples + ======== + + >>> from sympy import Sum + >>> from sympy.abc import x, y + >>> from sympy.simplify.simplify import factor_sum + >>> s = Sum(x*y, (x, 1, 3)) + >>> factor_sum(s) + y*Sum(x, (x, 1, 3)) + >>> factor_sum(s.function, s.limits) + y*Sum(x, (x, 1, 3)) + """ + + # XXX deprecate in favor of direct call to factor_terms + kwargs = {"radical": radical, "clear": clear, + "fraction": fraction, "sign": sign} + expr = Sum(self, *limits) if limits else self + return factor_terms(expr, **kwargs) + + +def sum_add(self, other, method=0): + """Helper function for Sum simplification""" + #we know this is something in terms of a constant * a sum + #so we temporarily put the constants inside for simplification + #then simplify the result + def __refactor(val): + args = Mul.make_args(val) + sumv = next(x for x in args if isinstance(x, Sum)) + constant = Mul(*[x for x in args if x != sumv]) + return Sum(constant * sumv.function, *sumv.limits) + + if isinstance(self, Mul): + rself = __refactor(self) + else: + rself = self + + if isinstance(other, Mul): + rother = __refactor(other) + else: + rother = other + + if type(rself) is type(rother): + if method == 0: + if rself.limits == rother.limits: + return factor_sum(Sum(rself.function + rother.function, *rself.limits)) + elif method == 1: + if simplify(rself.function - rother.function) == 0: + if len(rself.limits) == len(rother.limits) == 1: + i = rself.limits[0][0] + x1 = rself.limits[0][1] + y1 = rself.limits[0][2] + j = rother.limits[0][0] + x2 = rother.limits[0][1] + y2 = rother.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return factor_sum(Sum(rself.function, (i, x1, y2))) + elif x1 == y2 + 1: + return factor_sum(Sum(rself.function, (i, x2, y1))) + + return Add(self, other) + + +def product_simplify(s, **kwargs): + """Main function for Product simplification""" + terms = Mul.make_args(s) + p_t = [] # Product Terms + o_t = [] # Other Terms + + deep = kwargs.get('deep', True) + for term in terms: + if isinstance(term, Product): + if deep: + p_t.append(Product(term.function.simplify(**kwargs), + *term.limits)) + else: + p_t.append(term) + else: + o_t.append(term) + + used = [False] * len(p_t) + + for method in range(2): + for i, p_term1 in enumerate(p_t): + if not used[i]: + for j, p_term2 in enumerate(p_t): + if not used[j] and i != j: + tmp_prod = product_mul(p_term1, p_term2, method) + if isinstance(tmp_prod, Product): + p_t[i] = tmp_prod + used[j] = True + + result = Mul(*o_t) + + for i, p_term in enumerate(p_t): + if not used[i]: + result = Mul(result, p_term) + + return result + + +def product_mul(self, other, method=0): + """Helper function for Product simplification""" + if type(self) is type(other): + if method == 0: + if self.limits == other.limits: + return Product(self.function * other.function, *self.limits) + elif method == 1: + if simplify(self.function - other.function) == 0: + if len(self.limits) == len(other.limits) == 1: + i = self.limits[0][0] + x1 = self.limits[0][1] + y1 = self.limits[0][2] + j = other.limits[0][0] + x2 = other.limits[0][1] + y2 = other.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return Product(self.function, (i, x1, y2)) + elif x1 == y2 + 1: + return Product(self.function, (i, x2, y1)) + + return Mul(self, other) + + +def _nthroot_solve(p, n, prec): + """ + helper function for ``nthroot`` + It denests ``p**Rational(1, n)`` using its minimal polynomial + """ + from sympy.solvers import solve + while n % 2 == 0: + p = sqrtdenest(sqrt(p)) + n = n // 2 + if n == 1: + return p + pn = p**Rational(1, n) + x = Symbol('x') + f = _minimal_polynomial_sq(p, n, x) + if f is None: + return None + sols = solve(f, x) + for sol in sols: + if abs(sol - pn).n() < 1./10**prec: + sol = sqrtdenest(sol) + if _mexpand(sol**n) == p: + return sol + + +def logcombine(expr, force=False): + """ + Takes logarithms and combines them using the following rules: + + - log(x) + log(y) == log(x*y) if both are positive + - a*log(x) == log(x**a) if x is positive and a is real + + If ``force`` is ``True`` then the assumptions above will be assumed to hold if + there is no assumption already in place on a quantity. For example, if + ``a`` is imaginary or the argument negative, force will not perform a + combination but if ``a`` is a symbol with no assumptions the change will + take place. + + Examples + ======== + + >>> from sympy import Symbol, symbols, log, logcombine, I + >>> from sympy.abc import a, x, y, z + >>> logcombine(a*log(x) + log(y) - log(z)) + a*log(x) + log(y) - log(z) + >>> logcombine(a*log(x) + log(y) - log(z), force=True) + log(x**a*y/z) + >>> x,y,z = symbols('x,y,z', positive=True) + >>> a = Symbol('a', real=True) + >>> logcombine(a*log(x) + log(y) - log(z)) + log(x**a*y/z) + + The transformation is limited to factors and/or terms that + contain logs, so the result depends on the initial state of + expansion: + + >>> eq = (2 + 3*I)*log(x) + >>> logcombine(eq, force=True) == eq + True + >>> logcombine(eq.expand(), force=True) + log(x**2) + I*log(x**3) + + See Also + ======== + + posify: replace all symbols with symbols having positive assumptions + sympy.core.function.expand_log: expand the logarithms of products + and powers; the opposite of logcombine + + """ + + def f(rv): + if not (rv.is_Add or rv.is_Mul): + return rv + + def gooda(a): + # bool to tell whether the leading ``a`` in ``a*log(x)`` + # could appear as log(x**a) + return (a is not S.NegativeOne and # -1 *could* go, but we disallow + (a.is_extended_real or force and a.is_extended_real is not False)) + + def goodlog(l): + # bool to tell whether log ``l``'s argument can combine with others + a = l.args[0] + return a.is_positive or force and a.is_nonpositive is not False + + other = [] + logs = [] + log1 = defaultdict(list) + for a in Add.make_args(rv): + if isinstance(a, log) and goodlog(a): + log1[()].append(([], a)) + elif not a.is_Mul: + other.append(a) + else: + ot = [] + co = [] + lo = [] + for ai in a.args: + if ai.is_Rational and ai < 0: + ot.append(S.NegativeOne) + co.append(-ai) + elif isinstance(ai, log) and goodlog(ai): + lo.append(ai) + elif gooda(ai): + co.append(ai) + else: + ot.append(ai) + if len(lo) > 1: + logs.append((ot, co, lo)) + elif lo: + log1[tuple(ot)].append((co, lo[0])) + else: + other.append(a) + + # if there is only one log in other, put it with the + # good logs + if len(other) == 1 and isinstance(other[0], log): + log1[()].append(([], other.pop())) + # if there is only one log at each coefficient and none have + # an exponent to place inside the log then there is nothing to do + if not logs and all(len(log1[k]) == 1 and log1[k][0] == [] for k in log1): + return rv + + # collapse multi-logs as far as possible in a canonical way + # TODO: see if x*log(a)+x*log(a)*log(b) -> x*log(a)*(1+log(b))? + # -- in this case, it's unambiguous, but if it were were a log(c) in + # each term then it's arbitrary whether they are grouped by log(a) or + # by log(c). So for now, just leave this alone; it's probably better to + # let the user decide + for o, e, l in logs: + l = list(ordered(l)) + e = log(l.pop(0).args[0]**Mul(*e)) + while l: + li = l.pop(0) + e = log(li.args[0]**e) + c, l = Mul(*o), e + if isinstance(l, log): # it should be, but check to be sure + log1[(c,)].append(([], l)) + else: + other.append(c*l) + + # logs that have the same coefficient can multiply + for k in list(log1.keys()): + log1[Mul(*k)] = log(logcombine(Mul(*[ + l.args[0]**Mul(*c) for c, l in log1.pop(k)]), + force=force), evaluate=False) + + # logs that have oppositely signed coefficients can divide + for k in ordered(list(log1.keys())): + if k not in log1: # already popped as -k + continue + if -k in log1: + # figure out which has the minus sign; the one with + # more op counts should be the one + num, den = k, -k + if num.count_ops() > den.count_ops(): + num, den = den, num + other.append( + num*log(log1.pop(num).args[0]/log1.pop(den).args[0], + evaluate=False)) + else: + other.append(k*log1.pop(k)) + + return Add(*other) + + return _bottom_up(expr, f) + + +def inversecombine(expr): + """Simplify the composition of a function and its inverse. + + Explanation + =========== + + No attention is paid to whether the inverse is a left inverse or a + right inverse; thus, the result will in general not be equivalent + to the original expression. + + Examples + ======== + + >>> from sympy.simplify.simplify import inversecombine + >>> from sympy import asin, sin, log, exp + >>> from sympy.abc import x + >>> inversecombine(asin(sin(x))) + x + >>> inversecombine(2*log(exp(3*x))) + 6*x + """ + + def f(rv): + if isinstance(rv, log): + if isinstance(rv.args[0], exp) or (rv.args[0].is_Pow and rv.args[0].base == S.Exp1): + rv = rv.args[0].exp + elif rv.is_Function and hasattr(rv, "inverse"): + if (len(rv.args) == 1 and len(rv.args[0].args) == 1 and + isinstance(rv.args[0], rv.inverse(argindex=1))): + rv = rv.args[0].args[0] + if rv.is_Pow and rv.base == S.Exp1: + if isinstance(rv.exp, log): + rv = rv.exp.args[0] + return rv + + return _bottom_up(expr, f) + + +def kroneckersimp(expr): + """ + Simplify expressions with KroneckerDelta. + + The only simplification currently attempted is to identify multiplicative cancellation: + + Examples + ======== + + >>> from sympy import KroneckerDelta, kroneckersimp + >>> from sympy.abc import i + >>> kroneckersimp(1 + KroneckerDelta(0, i) * KroneckerDelta(1, i)) + 1 + """ + def args_cancel(args1, args2): + for i1 in range(2): + for i2 in range(2): + a1 = args1[i1] + a2 = args2[i2] + a3 = args1[(i1 + 1) % 2] + a4 = args2[(i2 + 1) % 2] + if Eq(a1, a2) is S.true and Eq(a3, a4) is S.false: + return True + return False + + def cancel_kronecker_mul(m): + args = m.args + deltas = [a for a in args if isinstance(a, KroneckerDelta)] + for delta1, delta2 in subsets(deltas, 2): + args1 = delta1.args + args2 = delta2.args + if args_cancel(args1, args2): + return S.Zero * m # In case of oo etc + return m + + if not expr.has(KroneckerDelta): + return expr + + if expr.has(Piecewise): + expr = expr.rewrite(KroneckerDelta) + + newexpr = expr + expr = None + + while newexpr != expr: + expr = newexpr + newexpr = expr.replace(lambda e: isinstance(e, Mul), cancel_kronecker_mul) + + return expr + + +def besselsimp(expr): + """ + Simplify bessel-type functions. + + Explanation + =========== + + This routine tries to simplify bessel-type functions. Currently it only + works on the Bessel J and I functions, however. It works by looking at all + such functions in turn, and eliminating factors of "I" and "-1" (actually + their polar equivalents) in front of the argument. Then, functions of + half-integer order are rewritten using trigonometric functions and + functions of integer order (> 1) are rewritten using functions + of low order. Finally, if the expression was changed, compute + factorization of the result with factor(). + + >>> from sympy import besselj, besseli, besselsimp, polar_lift, I, S + >>> from sympy.abc import z, nu + >>> besselsimp(besselj(nu, z*polar_lift(-1))) + exp(I*pi*nu)*besselj(nu, z) + >>> besselsimp(besseli(nu, z*polar_lift(-I))) + exp(-I*pi*nu/2)*besselj(nu, z) + >>> besselsimp(besseli(S(-1)/2, z)) + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + >>> besselsimp(z*besseli(0, z) + z*(besseli(2, z))/2 + besseli(1, z)) + 3*z*besseli(0, z)/2 + """ + # TODO + # - better algorithm? + # - simplify (cos(pi*b)*besselj(b,z) - besselj(-b,z))/sin(pi*b) ... + # - use contiguity relations? + + def replacer(fro, to, factors): + factors = set(factors) + + def repl(nu, z): + if factors.intersection(Mul.make_args(z)): + return to(nu, z) + return fro(nu, z) + return repl + + def torewrite(fro, to): + def tofunc(nu, z): + return fro(nu, z).rewrite(to) + return tofunc + + def tominus(fro): + def tofunc(nu, z): + return exp(I*pi*nu)*fro(nu, exp_polar(-I*pi)*z) + return tofunc + + orig_expr = expr + + ifactors = [I, exp_polar(I*pi/2), exp_polar(-I*pi/2)] + expr = expr.replace( + besselj, replacer(besselj, + torewrite(besselj, besseli), ifactors)) + expr = expr.replace( + besseli, replacer(besseli, + torewrite(besseli, besselj), ifactors)) + + minusfactors = [-1, exp_polar(I*pi)] + expr = expr.replace( + besselj, replacer(besselj, tominus(besselj), minusfactors)) + expr = expr.replace( + besseli, replacer(besseli, tominus(besseli), minusfactors)) + + z0 = Dummy('z') + + def expander(fro): + def repl(nu, z): + if (nu % 1) == S.Half: + return simplify(trigsimp(unpolarify( + fro(nu, z0).rewrite(besselj).rewrite(jn).expand( + func=True)).subs(z0, z))) + elif nu.is_Integer and nu > 1: + return fro(nu, z).expand(func=True) + return fro(nu, z) + return repl + + expr = expr.replace(besselj, expander(besselj)) + expr = expr.replace(bessely, expander(bessely)) + expr = expr.replace(besseli, expander(besseli)) + expr = expr.replace(besselk, expander(besselk)) + + def _bessel_simp_recursion(expr): + + def _use_recursion(bessel, expr): + while True: + bessels = expr.find(lambda x: isinstance(x, bessel)) + try: + for ba in sorted(bessels, key=lambda x: re(x.args[0])): + a, x = ba.args + bap1 = bessel(a+1, x) + bap2 = bessel(a+2, x) + if expr.has(bap1) and expr.has(bap2): + expr = expr.subs(ba, 2*(a+1)/x*bap1 - bap2) + break + else: + return expr + except (ValueError, TypeError): + return expr + if expr.has(besselj): + expr = _use_recursion(besselj, expr) + if expr.has(bessely): + expr = _use_recursion(bessely, expr) + return expr + + expr = _bessel_simp_recursion(expr) + if expr != orig_expr: + expr = expr.factor() + + return expr + + +def nthroot(expr, n, max_len=4, prec=15): + """ + Compute a real nth-root of a sum of surds. + + Parameters + ========== + + expr : sum of surds + n : integer + max_len : maximum number of surds passed as constants to ``nsimplify`` + + Algorithm + ========= + + First ``nsimplify`` is used to get a candidate root; if it is not a + root the minimal polynomial is computed; the answer is one of its + roots. + + Examples + ======== + + >>> from sympy.simplify.simplify import nthroot + >>> from sympy import sqrt + >>> nthroot(90 + 34*sqrt(7), 3) + sqrt(7) + 3 + + """ + expr = sympify(expr) + n = sympify(n) + p = expr**Rational(1, n) + if not n.is_integer: + return p + if not _is_sum_surds(expr): + return p + surds = [] + coeff_muls = [x.as_coeff_Mul() for x in expr.args] + for x, y in coeff_muls: + if not x.is_rational: + return p + if y is S.One: + continue + if not (y.is_Pow and y.exp == S.Half and y.base.is_integer): + return p + surds.append(y) + surds.sort() + surds = surds[:max_len] + if expr < 0 and n % 2 == 1: + p = (-expr)**Rational(1, n) + a = nsimplify(p, constants=surds) + res = a if _mexpand(a**n) == _mexpand(-expr) else p + return -res + a = nsimplify(p, constants=surds) + if _mexpand(a) is not _mexpand(p) and _mexpand(a**n) == _mexpand(expr): + return _mexpand(a) + expr = _nthroot_solve(expr, n, prec) + if expr is None: + return p + return expr + + +def nsimplify(expr, constants=(), tolerance=None, full=False, rational=None, + rational_conversion='base10'): + """ + Find a simple representation for a number or, if there are free symbols or + if ``rational=True``, then replace Floats with their Rational equivalents. If + no change is made and rational is not False then Floats will at least be + converted to Rationals. + + Explanation + =========== + + For numerical expressions, a simple formula that numerically matches the + given numerical expression is sought (and the input should be possible + to evalf to a precision of at least 30 digits). + + Optionally, a list of (rationally independent) constants to + include in the formula may be given. + + A lower tolerance may be set to find less exact matches. If no tolerance + is given then the least precise value will set the tolerance (e.g. Floats + default to 15 digits of precision, so would be tolerance=10**-15). + + With ``full=True``, a more extensive search is performed + (this is useful to find simpler numbers when the tolerance + is set low). + + When converting to rational, if rational_conversion='base10' (the default), then + convert floats to rationals using their base-10 (string) representation. + When rational_conversion='exact' it uses the exact, base-2 representation. + + Examples + ======== + + >>> from sympy import nsimplify, sqrt, GoldenRatio, exp, I, pi + >>> nsimplify(4/(1+sqrt(5)), [GoldenRatio]) + -2 + 2*GoldenRatio + >>> nsimplify((1/(exp(3*pi*I/5)+1))) + 1/2 - I*sqrt(sqrt(5)/10 + 1/4) + >>> nsimplify(I**I, [pi]) + exp(-pi/2) + >>> nsimplify(pi, tolerance=0.01) + 22/7 + + >>> nsimplify(0.333333333333333, rational=True, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> nsimplify(0.333333333333333, rational=True) + 1/3 + + See Also + ======== + + sympy.core.function.nfloat + + """ + try: + return sympify(as_int(expr)) + except (TypeError, ValueError): + pass + expr = sympify(expr).xreplace({ + Float('inf'): S.Infinity, + Float('-inf'): S.NegativeInfinity, + }) + if expr is S.Infinity or expr is S.NegativeInfinity: + return expr + if rational or expr.free_symbols: + return _real_to_rational(expr, tolerance, rational_conversion) + + # SymPy's default tolerance for Rationals is 15; other numbers may have + # lower tolerances set, so use them to pick the largest tolerance if None + # was given + if tolerance is None: + tolerance = 10**-min([15] + + [mpmath.libmp.libmpf.prec_to_dps(n._prec) + for n in expr.atoms(Float)]) + # XXX should prec be set independent of tolerance or should it be computed + # from tolerance? + prec = 30 + bprec = int(prec*3.33) + + constants_dict = {} + for constant in constants: + constant = sympify(constant) + v = constant.evalf(prec) + if not v.is_Float: + raise ValueError("constants must be real-valued") + constants_dict[str(constant)] = v._to_mpmath(bprec) + + exprval = expr.evalf(prec, chop=True) + re, im = exprval.as_real_imag() + + # safety check to make sure that this evaluated to a number + if not (re.is_Number and im.is_Number): + return expr + + def nsimplify_real(x): + orig = mpmath.mp.dps + xv = x._to_mpmath(bprec) + try: + # We'll be happy with low precision if a simple fraction + if not (tolerance or full): + mpmath.mp.dps = 15 + rat = mpmath.pslq([xv, 1]) + if rat is not None: + return Rational(-int(rat[1]), int(rat[0])) + mpmath.mp.dps = prec + newexpr = mpmath.identify(xv, constants=constants_dict, + tol=tolerance, full=full) + if not newexpr: + raise ValueError + if full: + newexpr = newexpr[0] + expr = sympify(newexpr) + if x and not expr: # don't let x become 0 + raise ValueError + if expr.is_finite is False and xv not in [mpmath.inf, mpmath.ninf]: + raise ValueError + return expr + finally: + # even though there are returns above, this is executed + # before leaving + mpmath.mp.dps = orig + try: + if re: + re = nsimplify_real(re) + if im: + im = nsimplify_real(im) + except ValueError: + if rational is None: + return _real_to_rational(expr, rational_conversion=rational_conversion) + return expr + + rv = re + im*S.ImaginaryUnit + # if there was a change or rational is explicitly not wanted + # return the value, else return the Rational representation + if rv != expr or rational is False: + return rv + return _real_to_rational(expr, rational_conversion=rational_conversion) + + +def _real_to_rational(expr, tolerance=None, rational_conversion='base10'): + """ + Replace all reals in expr with rationals. + + Examples + ======== + + >>> from sympy.simplify.simplify import _real_to_rational + >>> from sympy.abc import x + + >>> _real_to_rational(.76 + .1*x**.5) + sqrt(x)/10 + 19/25 + + If rational_conversion='base10', this uses the base-10 string. If + rational_conversion='exact', the exact, base-2 representation is used. + + >>> _real_to_rational(0.333333333333333, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> _real_to_rational(0.333333333333333) + 1/3 + + """ + expr = _sympify(expr) + inf = Float('inf') + p = expr + reps = {} + reduce_num = None + if tolerance is not None and tolerance < 1: + reduce_num = ceiling(1/tolerance) + for fl in p.atoms(Float): + key = fl + if reduce_num is not None: + r = Rational(fl).limit_denominator(reduce_num) + elif (tolerance is not None and tolerance >= 1 and + fl.is_Integer is False): + r = Rational(tolerance*round(fl/tolerance) + ).limit_denominator(int(tolerance)) + else: + if rational_conversion == 'exact': + r = Rational(fl) + reps[key] = r + continue + elif rational_conversion != 'base10': + raise ValueError("rational_conversion must be 'base10' or 'exact'") + + r = nsimplify(fl, rational=False) + # e.g. log(3).n() -> log(3) instead of a Rational + if fl and not r: + r = Rational(fl) + elif not r.is_Rational: + if fl in (inf, -inf): + r = S.ComplexInfinity + elif fl < 0: + fl = -fl + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = -Rational(str(fl/d))*d + elif fl > 0: + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = Rational(str(fl/d))*d + else: + r = S.Zero + reps[key] = r + return p.subs(reps, simultaneous=True) + + +def clear_coefficients(expr, rhs=S.Zero): + """Return `p, r` where `p` is the expression obtained when Rational + additive and multiplicative coefficients of `expr` have been stripped + away in a naive fashion (i.e. without simplification). The operations + needed to remove the coefficients will be applied to `rhs` and returned + as `r`. + + Examples + ======== + + >>> from sympy.simplify.simplify import clear_coefficients + >>> from sympy.abc import x, y + >>> from sympy import Dummy + >>> expr = 4*y*(6*x + 3) + >>> clear_coefficients(expr - 2) + (y*(2*x + 1), 1/6) + + When solving 2 or more expressions like `expr = a`, + `expr = b`, etc..., it is advantageous to provide a Dummy symbol + for `rhs` and simply replace it with `a`, `b`, etc... in `r`. + + >>> rhs = Dummy('rhs') + >>> clear_coefficients(expr, rhs) + (y*(2*x + 1), _rhs/12) + >>> _[1].subs(rhs, 2) + 1/6 + """ + was = None + free = expr.free_symbols + if expr.is_Rational: + return (S.Zero, rhs - expr) + while expr and was != expr: + was = expr + m, expr = ( + expr.as_content_primitive() + if free else + factor_terms(expr).as_coeff_Mul(rational=True)) + rhs /= m + c, expr = expr.as_coeff_Add(rational=True) + rhs -= c + expr = signsimp(expr, evaluate = False) + if expr.could_extract_minus_sign(): + expr = -expr + rhs = -rhs + return expr, rhs + +def nc_simplify(expr, deep=True): + ''' + Simplify a non-commutative expression composed of multiplication + and raising to a power by grouping repeated subterms into one power. + Priority is given to simplifications that give the fewest number + of arguments in the end (for example, in a*b*a*b*c*a*b*c simplifying + to (a*b)**2*c*a*b*c gives 5 arguments while a*b*(a*b*c)**2 has 3). + If ``expr`` is a sum of such terms, the sum of the simplified terms + is returned. + + Keyword argument ``deep`` controls whether or not subexpressions + nested deeper inside the main expression are simplified. See examples + below. Setting `deep` to `False` can save time on nested expressions + that do not need simplifying on all levels. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.simplify.simplify import nc_simplify + >>> a, b, c = symbols("a b c", commutative=False) + >>> nc_simplify(a*b*a*b*c*a*b*c) + a*b*(a*b*c)**2 + >>> expr = a**2*b*a**4*b*a**4 + >>> nc_simplify(expr) + a**2*(b*a**4)**2 + >>> nc_simplify(a*b*a*b*c**2*(a*b)**2*c**2) + ((a*b)**2*c**2)**2 + >>> nc_simplify(a*b*a*b + 2*a*c*a**2*c*a**2*c*a) + (a*b)**2 + 2*(a*c*a)**3 + >>> nc_simplify(b**-1*a**-1*(a*b)**2) + a*b + >>> nc_simplify(a**-1*b**-1*c*a) + (b*a)**(-1)*c*a + >>> expr = (a*b*a*b)**2*a*c*a*c + >>> nc_simplify(expr) + (a*b)**4*(a*c)**2 + >>> nc_simplify(expr, deep=False) + (a*b*a*b)**2*(a*c)**2 + + ''' + if isinstance(expr, MatrixExpr): + expr = expr.doit(inv_expand=False) + _Add, _Mul, _Pow, _Symbol = MatAdd, MatMul, MatPow, MatrixSymbol + else: + _Add, _Mul, _Pow, _Symbol = Add, Mul, Pow, Symbol + + # =========== Auxiliary functions ======================== + def _overlaps(args): + # Calculate a list of lists m such that m[i][j] contains the lengths + # of all possible overlaps between args[:i+1] and args[i+1+j:]. + # An overlap is a suffix of the prefix that matches a prefix + # of the suffix. + # For example, let expr=c*a*b*a*b*a*b*a*b. Then m[3][0] contains + # the lengths of overlaps of c*a*b*a*b with a*b*a*b. The overlaps + # are a*b*a*b, a*b and the empty word so that m[3][0]=[4,2,0]. + # All overlaps rather than only the longest one are recorded + # because this information helps calculate other overlap lengths. + m = [[([1, 0] if a == args[0] else [0]) for a in args[1:]]] + for i in range(1, len(args)): + overlaps = [] + j = 0 + for j in range(len(args) - i - 1): + overlap = [] + for v in m[i-1][j+1]: + if j + i + 1 + v < len(args) and args[i] == args[j+i+1+v]: + overlap.append(v + 1) + overlap += [0] + overlaps.append(overlap) + m.append(overlaps) + return m + + def _reduce_inverses(_args): + # replace consecutive negative powers by an inverse + # of a product of positive powers, e.g. a**-1*b**-1*c + # will simplify to (a*b)**-1*c; + # return that new args list and the number of negative + # powers in it (inv_tot) + inv_tot = 0 # total number of inverses + inverses = [] + args = [] + for arg in _args: + if isinstance(arg, _Pow) and arg.args[1].is_extended_negative: + inverses = [arg**-1] + inverses + inv_tot += 1 + else: + if len(inverses) == 1: + args.append(inverses[0]**-1) + elif len(inverses) > 1: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + inverses = [] + args.append(arg) + if inverses: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + return inv_tot, tuple(args) + + def get_score(s): + # compute the number of arguments of s + # (including in nested expressions) overall + # but ignore exponents + if isinstance(s, _Pow): + return get_score(s.args[0]) + elif isinstance(s, (_Add, _Mul)): + return sum(get_score(a) for a in s.args) + return 1 + + def compare(s, alt_s): + # compare two possible simplifications and return a + # "better" one + if s != alt_s and get_score(alt_s) < get_score(s): + return alt_s + return s + # ======================================================== + + if not isinstance(expr, (_Add, _Mul, _Pow)) or expr.is_commutative: + return expr + args = expr.args[:] + if isinstance(expr, _Pow): + if deep: + return _Pow(nc_simplify(args[0]), args[1]).doit() + else: + return expr + elif isinstance(expr, _Add): + return _Add(*[nc_simplify(a, deep=deep) for a in args]).doit() + else: + # get the non-commutative part + c_args, args = expr.args_cnc() + com_coeff = Mul(*c_args) + if not equal_valued(com_coeff, 1): + return com_coeff*nc_simplify(expr/com_coeff, deep=deep) + + inv_tot, args = _reduce_inverses(args) + # if most arguments are negative, work with the inverse + # of the expression, e.g. a**-1*b*a**-1*c**-1 will become + # (c*a*b**-1*a)**-1 at the end so can work with c*a*b**-1*a + invert = False + if inv_tot > len(args)/2: + invert = True + args = [a**-1 for a in args[::-1]] + + if deep: + args = tuple(nc_simplify(a) for a in args) + + m = _overlaps(args) + + # simps will be {subterm: end} where `end` is the ending + # index of a sequence of repetitions of subterm; + # this is for not wasting time with subterms that are part + # of longer, already considered sequences + simps = {} + + post = 1 + pre = 1 + + # the simplification coefficient is the number of + # arguments by which contracting a given sequence + # would reduce the word; e.g. in a*b*a*b*c*a*b*c, + # contracting a*b*a*b to (a*b)**2 removes 3 arguments + # while a*b*c*a*b*c to (a*b*c)**2 removes 6. It's + # better to contract the latter so simplification + # with a maximum simplification coefficient will be chosen + max_simp_coeff = 0 + simp = None # information about future simplification + + for i in range(1, len(args)): + simp_coeff = 0 + l = 0 # length of a subterm + p = 0 # the power of a subterm + if i < len(args) - 1: + rep = m[i][0] + start = i # starting index of the repeated sequence + end = i+1 # ending index of the repeated sequence + if i == len(args)-1 or rep == [0]: + # no subterm is repeated at this stage, at least as + # far as the arguments are concerned - there may be + # a repetition if powers are taken into account + if (isinstance(args[i], _Pow) and + not isinstance(args[i].args[0], _Symbol)): + subterm = args[i].args[0].args + l = len(subterm) + if args[i-l:i] == subterm: + # e.g. a*b in a*b*(a*b)**2 is not repeated + # in args (= [a, b, (a*b)**2]) but it + # can be matched here + p += 1 + start -= l + if args[i+1:i+1+l] == subterm: + # e.g. a*b in (a*b)**2*a*b + p += 1 + end += l + if p: + p += args[i].args[1] + else: + continue + else: + l = rep[0] # length of the longest repeated subterm at this point + start -= l - 1 + subterm = args[start:end] + p = 2 + end += l + + if subterm in simps and simps[subterm] >= start: + # the subterm is part of a sequence that + # has already been considered + continue + + # count how many times it's repeated + while end < len(args): + if l in m[end-1][0]: + p += 1 + end += l + elif isinstance(args[end], _Pow) and args[end].args[0].args == subterm: + # for cases like a*b*a*b*(a*b)**2*a*b + p += args[end].args[1] + end += 1 + else: + break + + # see if another match can be made, e.g. + # for b*a**2 in b*a**2*b*a**3 or a*b in + # a**2*b*a*b + + pre_exp = 0 + pre_arg = 1 + if start - l >= 0 and args[start-l+1:start] == subterm[1:]: + if isinstance(subterm[0], _Pow): + pre_arg = subterm[0].args[0] + exp = subterm[0].args[1] + else: + pre_arg = subterm[0] + exp = 1 + if isinstance(args[start-l], _Pow) and args[start-l].args[0] == pre_arg: + pre_exp = args[start-l].args[1] - exp + start -= l + p += 1 + elif args[start-l] == pre_arg: + pre_exp = 1 - exp + start -= l + p += 1 + + post_exp = 0 + post_arg = 1 + if end + l - 1 < len(args) and args[end:end+l-1] == subterm[:-1]: + if isinstance(subterm[-1], _Pow): + post_arg = subterm[-1].args[0] + exp = subterm[-1].args[1] + else: + post_arg = subterm[-1] + exp = 1 + if isinstance(args[end+l-1], _Pow) and args[end+l-1].args[0] == post_arg: + post_exp = args[end+l-1].args[1] - exp + end += l + p += 1 + elif args[end+l-1] == post_arg: + post_exp = 1 - exp + end += l + p += 1 + + # Consider a*b*a**2*b*a**2*b*a: + # b*a**2 is explicitly repeated, but note + # that in this case a*b*a is also repeated + # so there are two possible simplifications: + # a*(b*a**2)**3*a**-1 or (a*b*a)**3 + # The latter is obviously simpler. + # But in a*b*a**2*b**2*a**2 the simplifications are + # a*(b*a**2)**2 and (a*b*a)**3*a in which case + # it's better to stick with the shorter subterm + if post_exp and exp % 2 == 0 and start > 0: + exp = exp/2 + _pre_exp = 1 + _post_exp = 1 + if isinstance(args[start-1], _Pow) and args[start-1].args[0] == post_arg: + _post_exp = post_exp + exp + _pre_exp = args[start-1].args[1] - exp + elif args[start-1] == post_arg: + _post_exp = post_exp + exp + _pre_exp = 1 - exp + if _pre_exp == 0 or _post_exp == 0: + if not pre_exp: + start -= 1 + post_exp = _post_exp + pre_exp = _pre_exp + pre_arg = post_arg + subterm = (post_arg**exp,) + subterm[:-1] + (post_arg**exp,) + + simp_coeff += end-start + + if post_exp: + simp_coeff -= 1 + if pre_exp: + simp_coeff -= 1 + + simps[subterm] = end + + if simp_coeff > max_simp_coeff: + max_simp_coeff = simp_coeff + simp = (start, _Mul(*subterm), p, end, l) + pre = pre_arg**pre_exp + post = post_arg**post_exp + + if simp: + subterm = _Pow(nc_simplify(simp[1], deep=deep), simp[2]) + pre = nc_simplify(_Mul(*args[:simp[0]])*pre, deep=deep) + post = post*nc_simplify(_Mul(*args[simp[3]:]), deep=deep) + simp = pre*subterm*post + if pre != 1 or post != 1: + # new simplifications may be possible but no need + # to recurse over arguments + simp = nc_simplify(simp, deep=False) + else: + simp = _Mul(*args) + + if invert: + simp = _Pow(simp, -1) + + # see if factor_nc(expr) is simplified better + if not isinstance(expr, MatrixExpr): + f_expr = factor_nc(expr) + if f_expr != expr: + alt_simp = nc_simplify(f_expr, deep=deep) + simp = compare(simp, alt_simp) + else: + simp = simp.doit(inv_expand=False) + return simp + + +def dotprodsimp(expr, withsimp=False): + """Simplification for a sum of products targeted at the kind of blowup that + occurs during summation of products. Intended to reduce expression blowup + during matrix multiplication or other similar operations. Only works with + algebraic expressions and does not recurse into non. + + Parameters + ========== + + withsimp : bool, optional + Specifies whether a flag should be returned along with the expression + to indicate roughly whether simplification was successful. It is used + in ``MatrixArithmetic._eval_pow_by_recursion`` to avoid attempting to + simplify an expression repetitively which does not simplify. + """ + + def count_ops_alg(expr): + """Optimized count algebraic operations with no recursion into + non-algebraic args that ``core.function.count_ops`` does. Also returns + whether rational functions may be present according to negative + exponents of powers or non-number fractions. + + Returns + ======= + + ops, ratfunc : int, bool + ``ops`` is the number of algebraic operations starting at the top + level expression (not recursing into non-alg children). ``ratfunc`` + specifies whether the expression MAY contain rational functions + which ``cancel`` MIGHT optimize. + """ + + ops = 0 + args = [expr] + ratfunc = False + + while args: + a = args.pop() + + if not isinstance(a, Basic): + continue + + if a.is_Rational: + if a is not S.One: # -1/3 = NEG + DIV + ops += bool (a.p < 0) + bool (a.q != 1) + + elif a.is_Mul: + if a.could_extract_minus_sign(): + ops += 1 + if a.args[0] is S.NegativeOne: + a = a.as_two_terms()[1] + else: + a = -a + + n, d = fraction(a) + + if n.is_Integer: + ops += 1 + bool (n < 0) + args.append(d) # won't be -Mul but could be Add + + elif d is not S.One: + if not d.is_Integer: + args.append(d) + ratfunc=True + + ops += 1 + args.append(n) # could be -Mul + + else: + ops += len(a.args) - 1 + args.extend(a.args) + + elif a.is_Add: + laargs = len(a.args) + negs = 0 + + for ai in a.args: + if ai.could_extract_minus_sign(): + negs += 1 + ai = -ai + args.append(ai) + + ops += laargs - (negs != laargs) # -x - y = NEG + SUB + + elif a.is_Pow: + ops += 1 + args.append(a.base) + + if not ratfunc: + ratfunc = a.exp.is_negative is not False + + return ops, ratfunc + + def nonalg_subs_dummies(expr, dummies): + """Substitute dummy variables for non-algebraic expressions to avoid + evaluation of non-algebraic terms that ``polys.polytools.cancel`` does. + """ + + if not expr.args: + return expr + + if expr.is_Add or expr.is_Mul or expr.is_Pow: + args = None + + for i, a in enumerate(expr.args): + c = nonalg_subs_dummies(a, dummies) + + if c is a: + continue + + if args is None: + args = list(expr.args) + + args[i] = c + + if args is None: + return expr + + return expr.func(*args) + + return dummies.setdefault(expr, Dummy()) + + simplified = False # doesn't really mean simplified, rather "can simplify again" + + if isinstance(expr, Basic) and (expr.is_Add or expr.is_Mul or expr.is_Pow): + expr2 = expr.expand(deep=True, modulus=None, power_base=False, + power_exp=False, mul=True, log=False, multinomial=True, basic=False) + + if expr2 != expr: + expr = expr2 + simplified = True + + exprops, ratfunc = count_ops_alg(expr) + + if exprops >= 6: # empirically tested cutoff for expensive simplification + if ratfunc: + dummies = {} + expr2 = nonalg_subs_dummies(expr, dummies) + + if expr2 is expr or count_ops_alg(expr2)[0] >= 6: # check again after substitution + expr3 = cancel(expr2) + + if expr3 != expr2: + expr = expr3.subs([(d, e) for e, d in dummies.items()]) + simplified = True + + # very special case: x/(x-1) - 1/(x-1) -> 1 + elif (exprops == 5 and expr.is_Add and expr.args [0].is_Mul and + expr.args [1].is_Mul and expr.args [0].args [-1].is_Pow and + expr.args [1].args [-1].is_Pow and + expr.args [0].args [-1].exp is S.NegativeOne and + expr.args [1].args [-1].exp is S.NegativeOne): + + expr2 = together (expr) + expr2ops = count_ops_alg(expr2)[0] + + if expr2ops < exprops: + expr = expr2 + simplified = True + + else: + simplified = True + + return (expr, simplified) if withsimp else expr + + +bottom_up = deprecated( + """ + Using bottom_up from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use bottom_up from the top-level sympy namespace, like + + sympy.bottom_up + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_bottom_up) + + +# XXX: This function really should either be private API or exported in the +# top-level sympy/__init__.py +walk = deprecated( + """ + Using walk from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use walk from sympy.core.traversal.walk + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_walk) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/sqrtdenest.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..d266de7e62a4b7d37a2109f7091ff91e4df7c79d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/sqrtdenest.py @@ -0,0 +1,678 @@ +from sympy.core import Add, Expr, Mul, S, sympify +from sympy.core.function import _mexpand, count_ops, expand_mul +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy +from sympy.functions import root, sign, sqrt +from sympy.polys import Poly, PolynomialError + + +def is_sqrt(expr): + """Return True if expr is a sqrt, otherwise False.""" + + return expr.is_Pow and expr.exp.is_Rational and abs(expr.exp) is S.Half + + +def sqrt_depth(p) -> int: + """Return the maximum depth of any square root argument of p. + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import sqrt_depth + + Neither of these square roots contains any other square roots + so the depth is 1: + + >>> sqrt_depth(1 + sqrt(2)*(1 + sqrt(3))) + 1 + + The sqrt(3) is contained within a square root so the depth is + 2: + + >>> sqrt_depth(1 + sqrt(2)*sqrt(1 + sqrt(3))) + 2 + """ + if p is S.ImaginaryUnit: + return 1 + if p.is_Atom: + return 0 + if p.is_Add or p.is_Mul: + return max(sqrt_depth(x) for x in p.args) + if is_sqrt(p): + return sqrt_depth(p.base) + 1 + return 0 + + +def is_algebraic(p): + """Return True if p is comprised of only Rationals or square roots + of Rationals and algebraic operations. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import is_algebraic + >>> from sympy import cos + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*sqrt(2)))) + True + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*cos(2)))) + False + """ + + if p.is_Rational: + return True + elif p.is_Atom: + return False + elif is_sqrt(p) or p.is_Pow and p.exp.is_Integer: + return is_algebraic(p.base) + elif p.is_Add or p.is_Mul: + return all(is_algebraic(x) for x in p.args) + else: + return False + + +def _subsets(n): + """ + Returns all possible subsets of the set (0, 1, ..., n-1) except the + empty set, listed in reversed lexicographical order according to binary + representation, so that the case of the fourth root is treated last. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _subsets + >>> _subsets(2) + [[1, 0], [0, 1], [1, 1]] + + """ + if n == 1: + a = [[1]] + elif n == 2: + a = [[1, 0], [0, 1], [1, 1]] + elif n == 3: + a = [[1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]] + else: + b = _subsets(n - 1) + a0 = [x + [0] for x in b] + a1 = [x + [1] for x in b] + a = a0 + [[0]*(n - 1) + [1]] + a1 + return a + + +def sqrtdenest(expr, max_iter=3): + """Denests sqrts in an expression that contain other square roots + if possible, otherwise returns the expr unchanged. This is based on the + algorithms of [1]. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> from sympy import sqrt + >>> sqrtdenest(sqrt(5 + 2 * sqrt(6))) + sqrt(2) + sqrt(3) + + See Also + ======== + + sympy.solvers.solvers.unrad + + References + ========== + + .. [1] https://web.archive.org/web/20210806201615/https://researcher.watson.ibm.com/researcher/files/us-fagin/symb85.pdf + + .. [2] D. J. Jeffrey and A. D. Rich, 'Symplifying Square Roots of Square Roots + by Denesting' (available at https://www.cybertester.com/data/denest.pdf) + + """ + expr = expand_mul(expr) + for i in range(max_iter): + z = _sqrtdenest0(expr) + if expr == z: + return expr + expr = z + return expr + + +def _sqrt_match(p): + """Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to + matching, sqrt(r) also has then maximal sqrt_depth among addends of p. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match + >>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5))) + [1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)] + """ + from sympy.simplify.radsimp import split_surds + + p = _mexpand(p) + if p.is_Number: + res = (p, S.Zero, S.Zero) + elif p.is_Add: + pargs = sorted(p.args, key=default_sort_key) + sqargs = [x**2 for x in pargs] + if all(sq.is_Rational and sq.is_positive for sq in sqargs): + r, b, a = split_surds(p) + res = a, b, r + return list(res) + # to make the process canonical, the argument is included in the tuple + # so when the max is selected, it will be the largest arg having a + # given depth + v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)] + nmax = max(v, key=default_sort_key) + if nmax[0] == 0: + res = [] + else: + # select r + depth, _, i = nmax + r = pargs.pop(i) + v.pop(i) + b = S.One + if r.is_Mul: + bv = [] + rv = [] + for x in r.args: + if sqrt_depth(x) < depth: + bv.append(x) + else: + rv.append(x) + b = Mul._from_args(bv) + r = Mul._from_args(rv) + # collect terms containing r + a1 = [] + b1 = [b] + for x in v: + if x[0] < depth: + a1.append(x[1]) + else: + x1 = x[1] + if x1 == r: + b1.append(1) + else: + if x1.is_Mul: + x1args = list(x1.args) + if r in x1args: + x1args.remove(r) + b1.append(Mul(*x1args)) + else: + a1.append(x[1]) + else: + a1.append(x[1]) + a = Add(*a1) + b = Add(*b1) + res = (a, b, r**2) + else: + b, r = p.as_coeff_Mul() + if is_sqrt(r): + res = (S.Zero, b, r**2) + else: + res = [] + return list(res) + + +class SqrtdenestStopIteration(StopIteration): + pass + + +def _sqrtdenest0(expr): + """Returns expr after denesting its arguments.""" + + if is_sqrt(expr): + n, d = expr.as_numer_denom() + if d is S.One: # n is a square root + if n.base.is_Add: + args = sorted(n.base.args, key=default_sort_key) + if len(args) > 2 and all((x**2).is_Integer for x in args): + try: + return _sqrtdenest_rec(n) + except SqrtdenestStopIteration: + pass + expr = sqrt(_mexpand(Add(*[_sqrtdenest0(x) for x in args]))) + return _sqrtdenest1(expr) + else: + n, d = [_sqrtdenest0(i) for i in (n, d)] + return n/d + + if isinstance(expr, Add): + cs = [] + args = [] + for arg in expr.args: + c, a = arg.as_coeff_Mul() + cs.append(c) + args.append(a) + + if all(c.is_Rational for c in cs) and all(is_sqrt(arg) for arg in args): + return _sqrt_ratcomb(cs, args) + + if isinstance(expr, Expr): + args = expr.args + if args: + return expr.func(*[_sqrtdenest0(a) for a in args]) + return expr + + +def _sqrtdenest_rec(expr): + """Helper that denests the square root of three or more surds. + + Explanation + =========== + + It returns the denested expression; if it cannot be denested it + throws SqrtdenestStopIteration + + Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k)); + split expr.base = a + b*sqrt(r_k), where `a` and `b` are on + Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is + on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on. + See [1], section 6. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec + >>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498)) + -sqrt(10) + sqrt(2) + 9 + 9*sqrt(5) + >>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65 + >>> _sqrtdenest_rec(sqrt(w)) + -sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5) + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize, split_surds + if not expr.is_Pow: + return sqrtdenest(expr) + if expr.base < 0: + return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base)) + g, a, b = split_surds(expr.base) + a = a*sqrt(g) + if a < b: + a, b = b, a + c2 = _mexpand(a**2 - b**2) + if len(c2.args) > 2: + g, a1, b1 = split_surds(c2) + a1 = a1*sqrt(g) + if a1 < b1: + a1, b1 = b1, a1 + c2_1 = _mexpand(a1**2 - b1**2) + c_1 = _sqrtdenest_rec(sqrt(c2_1)) + d_1 = _sqrtdenest_rec(sqrt(a1 + c_1)) + num, den = rad_rationalize(b1, d_1) + c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2))) + else: + c = _sqrtdenest1(sqrt(c2)) + + if sqrt_depth(c) > 1: + raise SqrtdenestStopIteration + ac = a + c + if len(ac.args) >= len(expr.args): + if count_ops(ac) >= count_ops(expr.base): + raise SqrtdenestStopIteration + d = sqrtdenest(sqrt(ac)) + if sqrt_depth(d) > 1: + raise SqrtdenestStopIteration + num, den = rad_rationalize(b, d) + r = d/sqrt(2) + num/(den*sqrt(2)) + r = radsimp(r) + return _mexpand(r) + + +def _sqrtdenest1(expr, denester=True): + """Return denested expr after denesting with simpler methods or, that + failing, using the denester.""" + + from sympy.simplify.simplify import radsimp + + if not is_sqrt(expr): + return expr + + a = expr.base + if a.is_Atom: + return expr + val = _sqrt_match(a) + if not val: + return expr + + a, b, r = val + # try a quick numeric denesting + d2 = _mexpand(a**2 - b**2*r) + if d2.is_Rational: + if d2.is_positive: + z = _sqrt_numeric_denest(a, b, r, d2) + if z is not None: + return z + else: + # fourth root case + # sqrtdenest(sqrt(3 + 2*sqrt(3))) = + # sqrt(2)*3**(1/4)/2 + sqrt(2)*3**(3/4)/2 + dr2 = _mexpand(-d2*r) + dr = sqrt(dr2) + if dr.is_Rational: + z = _sqrt_numeric_denest(_mexpand(b*r), a, r, dr2) + if z is not None: + return z/root(r, 4) + + else: + z = _sqrt_symbolic_denest(a, b, r) + if z is not None: + return z + + if not denester or not is_algebraic(expr): + return expr + + res = sqrt_biquadratic_denest(expr, a, b, r, d2) + if res: + return res + + # now call to the denester + av0 = [a, b, r, d2] + z = _denester([radsimp(expr**2)], av0, 0, sqrt_depth(expr))[0] + if av0[1] is None: + return expr + if z is not None: + if sqrt_depth(z) == sqrt_depth(expr) and count_ops(z) > count_ops(expr): + return expr + return z + return expr + + +def _sqrt_symbolic_denest(a, b, r): + """Given an expression, sqrt(a + b*sqrt(b)), return the denested + expression or None. + + Explanation + =========== + + If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with + (y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and + (cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as + sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2). + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest + >>> from sympy import sqrt, Symbol + >>> from sympy.abc import x + + >>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55 + >>> _sqrt_symbolic_denest(a, b, r) + sqrt(11 - 2*sqrt(29)) + sqrt(5) + + If the expression is numeric, it will be simplified: + + >>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2) + >>> sqrtdenest(sqrt((w**2).expand())) + 1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3))) + + Otherwise, it will only be simplified if assumptions allow: + + >>> w = w.subs(sqrt(3), sqrt(x + 3)) + >>> sqrtdenest(sqrt((w**2).expand())) + sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2) + + Notice that the argument of the sqrt is a square. If x is made positive + then the sqrt of the square is resolved: + + >>> _.subs(x, Symbol('x', positive=True)) + sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2) + """ + + a, b, r = map(sympify, (a, b, r)) + rval = _sqrt_match(r) + if not rval: + return None + ra, rb, rr = rval + if rb: + y = Dummy('y', positive=True) + try: + newa = Poly(a.subs(sqrt(rr), (y**2 - ra)/rb), y) + except PolynomialError: + return None + if newa.degree() == 2: + ca, cb, cc = newa.all_coeffs() + cb += b + if _mexpand(cb**2 - 4*ca*cc).equals(0): + z = sqrt(ca*(sqrt(r) + cb/(2*ca))**2) + if z.is_number: + z = _mexpand(Mul._from_args(z.as_content_primitive())) + return z + + +def _sqrt_numeric_denest(a, b, r, d2): + r"""Helper that denest + $\sqrt{a + b \sqrt{r}}, d^2 = a^2 - b^2 r > 0$ + + If it cannot be denested, it returns ``None``. + """ + d = sqrt(d2) + s = a + d + # sqrt_depth(res) <= sqrt_depth(s) + 1 + # sqrt_depth(expr) = sqrt_depth(r) + 2 + # there is denesting if sqrt_depth(s) + 1 < sqrt_depth(r) + 2 + # if s**2 is Number there is a fourth root + if sqrt_depth(s) < sqrt_depth(r) + 1 or (s**2).is_Rational: + s1, s2 = sign(s), sign(b) + if s1 == s2 == -1: + s1 = s2 = 1 + res = (s1 * sqrt(a + d) + s2 * sqrt(a - d)) * sqrt(2) / 2 + return res.expand() + + +def sqrt_biquadratic_denest(expr, a, b, r, d2): + """denest expr = sqrt(a + b*sqrt(r)) + where a, b, r are linear combinations of square roots of + positive rationals on the rationals (SQRR) and r > 0, b != 0, + d2 = a**2 - b**2*r > 0 + + If it cannot denest it returns None. + + Explanation + =========== + + Search for a solution A of type SQRR of the biquadratic equation + 4*A**4 - 4*a*A**2 + b**2*r = 0 (1) + sqd = sqrt(a**2 - b**2*r) + Choosing the sqrt to be positive, the possible solutions are + A = sqrt(a/2 +/- sqd/2) + Since a, b, r are SQRR, then a**2 - b**2*r is a SQRR, + so if sqd can be denested, it is done by + _sqrtdenest_rec, and the result is a SQRR. + Similarly for A. + Examples of solutions (in both cases a and sqd are positive): + + Example of expr with solution sqrt(a/2 + sqd/2) but not + solution sqrt(a/2 - sqd/2): + expr = sqrt(-sqrt(15) - sqrt(2)*sqrt(-sqrt(5) + 5) - sqrt(3) + 8) + a = -sqrt(15) - sqrt(3) + 8; sqd = -2*sqrt(5) - 2 + 4*sqrt(3) + + Example of expr with solution sqrt(a/2 - sqd/2) but not + solution sqrt(a/2 + sqd/2): + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + expr = sqrt((w**2).expand()) + a = 4*sqrt(6) + 8*sqrt(2) + 47 + 28*sqrt(3) + sqd = 29 + 20*sqrt(3) + + Define B = b/2*A; eq.(1) implies a = A**2 + B**2*r; then + expr**2 = a + b*sqrt(r) = (A + B*sqrt(r))**2 + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match, sqrt_biquadratic_denest + >>> z = sqrt((2*sqrt(2) + 4)*sqrt(2 + sqrt(2)) + 5*sqrt(2) + 8) + >>> a, b, r = _sqrt_match(z**2) + >>> d2 = a**2 - b**2*r + >>> sqrt_biquadratic_denest(z, a, b, r, d2) + sqrt(2) + sqrt(sqrt(2) + 2) + 2 + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize + if r <= 0 or d2 < 0 or not b or sqrt_depth(expr.base) < 2: + return None + for x in (a, b, r): + for y in x.args: + y2 = y**2 + if not y2.is_Integer or not y2.is_positive: + return None + sqd = _mexpand(sqrtdenest(sqrt(radsimp(d2)))) + if sqrt_depth(sqd) > 1: + return None + x1, x2 = [a/2 + sqd/2, a/2 - sqd/2] + # look for a solution A with depth 1 + for x in (x1, x2): + A = sqrtdenest(sqrt(x)) + if sqrt_depth(A) > 1: + continue + Bn, Bd = rad_rationalize(b, _mexpand(2*A)) + B = Bn/Bd + z = A + B*sqrt(r) + if z < 0: + z = -z + return _mexpand(z) + return None + + +def _denester(nested, av0, h, max_depth_level): + """Denests a list of expressions that contain nested square roots. + + Explanation + =========== + + Algorithm based on . + + It is assumed that all of the elements of 'nested' share the same + bottom-level radicand. (This is stated in the paper, on page 177, in + the paragraph immediately preceding the algorithm.) + + When evaluating all of the arguments in parallel, the bottom-level + radicand only needs to be denested once. This means that calling + _denester with x arguments results in a recursive invocation with x+1 + arguments; hence _denester has polynomial complexity. + + However, if the arguments were evaluated separately, each call would + result in two recursive invocations, and the algorithm would have + exponential complexity. + + This is discussed in the paper in the middle paragraph of page 179. + """ + from sympy.simplify.simplify import radsimp + if h > max_depth_level: + return None, None + if av0[1] is None: + return None, None + if (av0[0] is None and + all(n.is_Number for n in nested)): # no arguments are nested + for f in _subsets(len(nested)): # test subset 'f' of nested + p = _mexpand(Mul(*[nested[i] for i in range(len(f)) if f[i]])) + if f.count(1) > 1 and f[-1]: + p = -p + sqp = sqrt(p) + if sqp.is_Rational: + return sqp, f # got a perfect square so return its square root. + # Otherwise, return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + else: + R = None + if av0[0] is not None: + values = [av0[:2]] + R = av0[2] + nested2 = [av0[3], R] + av0[0] = None + else: + values = list(filter(None, [_sqrt_match(expr) for expr in nested])) + for v in values: + if v[2]: # Since if b=0, r is not defined + if R is not None: + if R != v[2]: + av0[1] = None + return None, None + else: + R = v[2] + if R is None: + # return the radicand from the previous invocation + return sqrt(nested[-1]), [0]*len(nested) + nested2 = [_mexpand(v[0]**2) - + _mexpand(R*v[1]**2) for v in values] + [R] + d, f = _denester(nested2, av0, h + 1, max_depth_level) + if not f: + return None, None + if not any(f[i] for i in range(len(nested))): + v = values[-1] + return sqrt(v[0] + _mexpand(v[1]*d)), f + else: + p = Mul(*[nested[i] for i in range(len(nested)) if f[i]]) + v = _sqrt_match(p) + if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]: + v[0] = -v[0] + v[1] = -v[1] + if not f[len(nested)]: # Solution denests with square roots + vad = _mexpand(v[0] + d) + if vad <= 0: + # return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + if not(sqrt_depth(vad) <= sqrt_depth(R) + 1 or + (vad**2).is_Number): + av0[1] = None + return None, None + + sqvad = _sqrtdenest1(sqrt(vad), denester=False) + if not (sqrt_depth(sqvad) <= sqrt_depth(R) + 1): + av0[1] = None + return None, None + sqvad1 = radsimp(1/sqvad) + res = _mexpand(sqvad/sqrt(2) + (v[1]*sqrt(R)*sqvad1/sqrt(2))) + return res, f + + # sign(v[1])*sqrt(_mexpand(v[1]**2*R*vad1/2))), f + else: # Solution requires a fourth root + s2 = _mexpand(v[1]*R) + d + if s2 <= 0: + return sqrt(nested[-1]), [0]*len(nested) + FR, s = root(_mexpand(R), 4), sqrt(s2) + return _mexpand(s/(sqrt(2)*FR) + v[0]*FR/(sqrt(2)*s)), f + + +def _sqrt_ratcomb(cs, args): + """Denest rational combinations of radicals. + + Based on section 5 of [1]. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> z = sqrt(1+sqrt(3)) + sqrt(3+3*sqrt(3)) - sqrt(10+6*sqrt(3)) + >>> sqrtdenest(z) + 0 + """ + from sympy.simplify.radsimp import radsimp + + # check if there exists a pair of sqrt that can be denested + def find(a): + n = len(a) + for i in range(n - 1): + for j in range(i + 1, n): + s1 = a[i].base + s2 = a[j].base + p = _mexpand(s1 * s2) + s = sqrtdenest(sqrt(p)) + if s != sqrt(p): + return s, i, j + + indices = find(args) + if indices is None: + return Add(*[c * arg for c, arg in zip(cs, args)]) + + s, i1, i2 = indices + + c2 = cs.pop(i2) + args.pop(i2) + a1 = args[i1] + + # replace a2 by s/a1 + cs[i1] += radsimp(c2 * s / a1.base) + + return _sqrt_ratcomb(cs, args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156f96acdc966268c64d6c0fb219181b01a19b88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f144fc79152edb0138b6fee5deb78eb489b9052d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c493b83327b90b08cd340eaf0fdf8dc81f978b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf8cbbbfd4495ac2437b55bdd9061c9f62323fb9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f3cf35e19ba5381424602c09f497b3f674c4eeb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4774917f0a36acfea6bb18530a4393bd8a9859d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..062b45f5e9820458fc31725e64e391f467af5984 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39702a675b87c8394e68693f3c0bb338e20544ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5206f93e2d89ac9f87261be0d82d6fde0d84f1c5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f8351b46a2d1ba15afe294d647bb1c30983bcf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..272675a2758399ba41d79806a8b6fbac86ee8d9f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f561737cd37e12d48f69d1203b0d87dec643239c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b899cfd931c1c10834f4b5513a756a843b9c52bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71622b812976973f8062eeb3827f8ad319de216f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc68789c208cc0dda268689b1b7a5384f941fcc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97419eedab4d54f787bf9c300e8a796823bcda44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_combsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e56758a005fbb013c2b6ea4121b16c3434a54b03 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_combsimp.py @@ -0,0 +1,75 @@ +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.combsimp import combsimp +from sympy.abc import x + + +def test_combsimp(): + k, m, n = symbols('k m n', integer = True) + + assert combsimp(factorial(n)) == factorial(n) + assert combsimp(binomial(n, k)) == binomial(n, k) + + assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n) + assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k) + + assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \ + Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3))) + + assert combsimp(factorial(n)**2/factorial(n - 3)) == \ + factorial(n)*n*(-1 + n)*(-2 + n) + assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \ + factorial(n + 1)/(1 + k) + + assert combsimp(gamma(n + 3)) == factorial(n + 2) + + assert combsimp(factorial(x)) == gamma(x + 1) + + # issue 9699 + assert combsimp((n + 1)*factorial(n)) == factorial(n + 1) + assert combsimp(factorial(n)/n) == factorial(n-1) + + # issue 6658 + assert combsimp(binomial(n, n - k)) == binomial(n, k) + + # issue 6341, 7135 + assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \ + binomial(n, k) + assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \ + 1/binomial(n, k) + assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n) + assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/ + factorial(n)**3) == binomial(2*n, n)/binomial(n, k) + + assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1 + + assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \ + (-1)**n*(n + 1)*(n + 2)*(n + 3) + assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \ + (-1)**(n - 1)*n*(n + 1)*(n + 2) + assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \ + (-1)**(n - 3)*n*(n - 1)*(n - 2) + assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \ + -(-1)**(-n - 1)*n*(n - 1)*(n - 2) + + assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \ + (n + 1)*(n + 2)*(n + 3) + assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \ + n*(n + 1)*(n + 2) + assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \ + n*(n - 1)*(n - 2) + assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \ + -n*(n - 1)*(n - 2) + + +def test_issue_6878(): + n = symbols('n', integer=True) + assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n) + + +def test_issue_14528(): + p = symbols("p", integer=True, positive=True) + assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p)) + assert combsimp(factorial(2-p)) == factorial(2-p) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a34dfb0e227547bd41bed2491284fd7150d0b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse.py @@ -0,0 +1,761 @@ +from functools import reduce +import itertools +from operator import add + +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose +from sympy.polys.rootoftools import CRootOf +from sympy.series.order import O +from sympy.simplify.cse_main import cse +from sympy.simplify.simplify import signsimp +from sympy.tensor.indexed import (Idx, IndexedBase) + +from sympy.core.function import count_ops +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_main, cse_opts +from sympy.utilities.iterables import subsets +from sympy.testing.pytest import XFAIL, raises +from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import MatrixSymbol + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +def test_numbered_symbols(): + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] + ns = cse_main.numbered_symbols() + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y + assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x + assert cse_main.preprocess_for_cse(x, [(None, None)]) == x + assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert cse_main.preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x + assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y + assert cse_main.postprocess_for_cse(x, [(None, None)]) == x + assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert cse_main.postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + subst42, (red42,) = cse([42]) # issue_15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse([0.5]) + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_single2(): + # Simple substitution, test for being able to pass the expression directly + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse(e) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + substs, reduced = cse(Matrix([[1]])) + assert isinstance(reduced[0], Matrix) + + subst42, (red42,) = cse(42) # issue 15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse(0.5) # issue 15082 + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \ + ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +@XFAIL # CSE of non-commutative Mul terms is disabled +def test_non_commutative_cse(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([(x0, A*B)], [x0*C, x0]) + + +# Test if CSE of non-commutative Mul terms is disabled +def test_bypass_non_commutatives(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([], l) + l = [B*C, A*B*C] + assert cse(l) == ([], l) + + +@XFAIL # CSE fails when replacing non-commutative sub-expressions +def test_non_commutative_order(): + A, B, C = symbols('A B C', commutative=False) + x0 = symbols('x0', commutative=False) + l = [B+C, A*(B+C)] + assert cse(l) == ([(x0, B+C)], [x0, A*x0]) + + +@XFAIL # Worked in gh-11232, but was reverted due to performance considerations +def test_issue_10228(): + assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0]) + assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0]) + assert cse((w + 2*x + y + z, w + x + 1)) == ( + [(x0, w + x)], [x0 + x + y + z, x0 + 1]) + assert cse(((w + x + y + z)*(w - x))/(w + x)) == ( + [(x0, w + x)], [(x0 + y + z)*(w - x)/x0]) + a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m') + exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2) + assert cse(exprs) == ( + [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1] +) + +@XFAIL +def test_powers(): + assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0]) + + +def test_issue_4498(): + assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ + ([], [(w - z)/(x - y)]) + + +def test_issue_4020(): + assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ + == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_issue_6263(): + e = Eq(x*(-x + 1) + x*(x - 1), 0) + assert cse(e, optimizations='basic') == ([], [True]) + + +def test_issue_25043(): + c = symbols("c") + x = symbols("x0", real=True) + cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1] + free = cse_expr.free_symbols + assert len(free) == len({i.name for i in free}) + + +def test_dont_cse_tuples(): + from sympy.core.function import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +def test_postprocess(): + eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], + postprocess=cse_main.cse_separate) == \ + [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], + [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') + G = Function('G') + t = Tuple(* + (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b, + -2*a)) + c = cse(t) + ans = ( + [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)), + (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)), + (x8, B(b, x4)), (x9, x6*B(x2, x4))], + [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9, + 1, 0, S.Half, z/2, -x3, -x1, -x0)]) + assert ans == c + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_cse_MatrixSymbol(): + # MatrixSymbols have non-Basic args, so make sure that works + A = MatrixSymbol("A", 3, 3) + assert cse(A) == ([], [A]) + + n = symbols('n', integer=True) + B = MatrixSymbol("B", n, n) + assert cse(B) == ([], [B]) + + assert cse(A[0] * A[0]) == ([], [A[0]*A[0]]) + + assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0]) + +def test_cse_MatrixExpr(): + A = MatrixSymbol('A', 3, 3) + y = MatrixSymbol('y', 3, 1) + + expr1 = (A.T*A).I * A * y + expr2 = (A.T*A) * A * y + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + replacements, reduced_exprs = cse([expr1 + expr2, expr1]) + assert replacements + + replacements, reduced_exprs = cse([A**2, A + A**2]) + assert replacements + + +def test_Piecewise(): + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, x*y)], + [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))]) + assert ans == actual_ans + + +def test_ignore_order_terms(): + eq = exp(x).series(x,0,3) + sin(y+x**3) - 1 + assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)]) + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with raises(ValueError): + cse(l, symbols=sym) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( \ + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( \ + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions,new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), true)), true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +def test_issue_8891(): + for cls in (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix): + m = cls(2, 2, [x + y, 0, 0, 0]) + res = cse([x + y, m]) + ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])]) + assert res == ans + assert isinstance(res[1][-1], cls) + + +def test_issue_11230(): + # a specific test that always failed + a, b, f, k, l, i = symbols('a b f k l i') + p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l] + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + + # random tests for the issue + from sympy.core.random import choice + from sympy.core.function import expand_mul + s = symbols('a:m') + # 35 Mul tests, none of which should ever fail + ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + assert p == C + # 35 Add tests, none of which should ever fail + ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Add for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + # use expand_mul to handle cases like this: + # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g] + # x0 = 2*(b + e) is identified giving a rebuilt p that + # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]` + assert p == [expand_mul(i) for i in C] + + +@XFAIL +def test_issue_11577(): + def check(eq): + r, c = cse(eq) + assert eq.count_ops() >= \ + len(r) + sum(i[1].count_ops() for i in r) + \ + count_ops(c) + + eq = x**5*y**2 + x**5*y + x**5 + assert cse(eq) == ( + [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1]) + # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or + # ([(x0, x**5)], [x0*y**2 + x0*y + x0]) + check(eq) + + eq = x**2/(y + 1)**2 + x/(y + 1) + assert cse(eq) == ( + [(x0, y + 1)], [x**2/x0**2 + x/x0]) + # ([(x0, x/(y + 1))], [x0**2 + x0]) + check(eq) + + +def test_hollow_rejection(): + eq = [x + 3, x + 4] + assert cse(eq) == ([], eq) + + +def test_cse_ignore(): + exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))] + subst1, red1 = cse(exprs) + assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y" + + subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions + assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored" + assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression" + + +def test_cse_ignore_issue_15002(): + l = [ + w*exp(x)*exp(-z), + exp(y)*exp(x)*exp(-z) + ] + substs, reduced = cse(l, ignore=(x,)) + rl = [e.subs(reversed(substs)) for e in reduced] + assert rl == l + + +def test_cse_unevaluated(): + xp1 = UnevaluatedExpr(x + 1) + # This used to cause RecursionError + [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)]) + if ue == xp1: + assert red == (-1 - x0) / (1 - x0) + elif ue == -xp1: + assert red == (-1 + x0) / (1 + x0) + else: + msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}' + assert False, msg + + +def test_cse__performance(): + nexprs, nterms = 3, 20 + x = symbols('x:%d' % nterms) + exprs = [ + reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)]) + for i in range(nexprs) + ] + assert (exprs[0] + exprs[1]).simplify() == 0 + subst, red = cse(exprs) + assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE" + for i, e in enumerate(red): + assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0 + + +def test_issue_12070(): + exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z] + subst, red = cse(exprs) + assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) + + count_ops(red)) + + +def test_issue_13000(): + eq = x/(-4*x**2 + y**2) + cse_eq = cse(eq)[1][0] + assert cse_eq == eq + + +def test_issue_18203(): + eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1) + assert cse(eq) == ([], [eq]) + + +def test_unevaluated_mul(): + eq = Mul(x + y, x + y, evaluate=False) + assert cse(eq) == ([(x0, x + y)], [x0**2]) + + +def test_cse_release_variables(): + from sympy.simplify.cse_main import cse_release_variables + _0, _1, _2, _3, _4 = symbols('_:5') + eqs = [(x + y - 1)**2, x, + x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, + (2*x + 1)**(x + y)] + r, e = cse(eqs, postprocess=cse_release_variables) + # this can change in keeping with the intention of the function + assert r, e == ([ + (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1), + (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1), + (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4)) + r.reverse() + r = [(s, v) for s, v in r if v is not None] + assert eqs == [i.subs(r) for i in e] + + +def test_cse_list(): + _cse = lambda x: cse(x, list=False) + assert _cse(x) == ([], x) + assert _cse('x') == ([], 'x') + it = [x] + for c in (list, tuple, set): + assert _cse(c(it)) == ([], c(it)) + #Tuple works different from tuple: + assert _cse(Tuple(*it)) == ([], Tuple(*it)) + d = {x: 1} + assert _cse(d) == ([], d) + +def test_issue_18991(): + A = MatrixSymbol('A', 2, 2) + assert signsimp(-A * A - A) == -A * A - A + + +def test_unevaluated_Mul(): + m = [Mul(1, 2, evaluate=False)] + assert cse(m) == ([], m) + + +def test_cse_matrix_expression_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = Inverse(A) + cse_expr = cse(x) + assert cse_expr == ([], [Inverse(A)]) + + +def test_cse_matrix_expression_matmul_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatMul(Inverse(A), b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matrix(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matmul_not_extracted(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A, B) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +@XFAIL # No simplification rule for nested associative operations +def test_cse_matrix_nested_matmul_collapsed(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, MatMul(A, B)) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)]) + + +def test_cse_matrix_optimize_out_single_argument_mul(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatMul(MatMul(A))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_mul_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_optimize_out_single_argument_add(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatAdd(MatAdd(MatAdd(A)))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_add_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_expression_matrix_solve(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatrixSolve(A, b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_matrix_expression(): + X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2) + y = ImmutableDenseMatrix(symbols('y:2')) + b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y) + cse_expr = cse(b) + x0 = MatrixSymbol('x0', 2, 2) + reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y) + assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected]) + + +def test_cse_matrix_kalman_filter(): + """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk. + + Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation" + + Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html + + Notes + ===== + + Equations are: + + new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data) + = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma + = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma)) + + """ + N = 2 + mu = ImmutableDenseMatrix(symbols(f'mu:{N}')) + Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N) + H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N) + R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N) + data = ImmutableDenseMatrix(symbols(f'data:{N}')) + new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma)) + cse_expr = cse([new_mu, new_Sigma]) + x0 = MatrixSymbol('x0', N, N) + x1 = MatrixSymbol('x1', N, N) + replacements_expected = [ + (x0, Transpose(H)), + (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))), + ] + reduced_exprs_expected = [ + MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))), + MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)), + ] + assert cse_expr == (replacements_expected, reduced_exprs_expected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse_diff.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..92b2d3d6bbaafb838a5e75f32a214511a1d39567 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_cse_diff.py @@ -0,0 +1,206 @@ +"""Tests for the ``sympy.simplify._cse_diff.py`` module.""" + +import pytest + +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.numbers import Integer +from sympy.core.function import Function +from sympy.core import Derivative +from sympy.functions.elementary.exponential import exp +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.physics.mechanics import dynamicsymbols +from sympy.simplify._cse_diff import (_forward_jacobian, + _remove_cse_from_derivative, + _forward_jacobian_cse, + _forward_jacobian_norm_in_cse_out) +from sympy.simplify.simplify import simplify +from sympy.matrices import Matrix, eye + +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.simplify.trigsimp import trigsimp + +from sympy import cse + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + +q1, q2, q3 = dynamicsymbols('q1 q2 q3') + +# Define the custom functions +k = Function('k')(x, y) +f = Function('f')(k, z) + +zero = Integer(0) +one = Integer(1) +two = Integer(2) +neg_one = Integer(-1) + + +@pytest.mark.parametrize( + 'expr, wrt', + [ + ([zero], [x]), + ([one], [x]), + ([two], [x]), + ([neg_one], [x]), + ([x], [x]), + ([y], [x]), + ([x + y], [x]), + ([x*y], [x]), + ([x**2], [x]), + ([x**y], [x]), + ([exp(x)], [x]), + ([sin(x)], [x]), + ([tan(x)], [x]), + ([zero, one, x, y, x*y, x + y], [x, y]), + ([((x/y) + sin(x/y) - exp(y))*((x/y) - exp(y))], [x, y]), + ([w*tan(y*z)/(x - tan(y*z)), w*x*tan(y*z)/(x - tan(y*z))], [w, x, y, z]), + ([q1**2 + q2, q2**2 + q3, q3**2 + q1], [q1, q2, q3]), + ([f + Derivative(f, x) + k + 2*x], [x]) + ] +) + + +def test_forward_jacobian(expr, wrt): + expr = ImmutableDenseMatrix([expr]).T + wrt = ImmutableDenseMatrix([wrt]).T + jacobian = _forward_jacobian(expr, wrt) + zeros = ImmutableDenseMatrix.zeros(*jacobian.shape) + assert simplify(jacobian - expr.jacobian(wrt)) == zeros + + +def test_process_cse(): + x, y, z = symbols('x y z') + f = Function('f') + k = Function('k') + expr = Matrix([f(k(x,y), z) + Derivative(f(k(x,y), z), x) + k(x,y) + 2*x]) + repl, reduced = cse(expr) + p_repl, p_reduced = _remove_cse_from_derivative(repl, reduced) + + x0 = symbols('x0') + x1 = symbols('x1') + + expected_output = ( + [(x0, k(x, y)), (x1, f(x0, z))], + [Matrix([2 * x + x0 + x1 + Derivative(f(k(x, y), z), x)])] + ) + + assert p_repl == expected_output[0], f"Expected {expected_output[0]}, but got {p_repl}" + assert p_reduced == expected_output[1], f"Expected {expected_output[1]}, but got {p_reduced}" + + +def test_io_matrix_type(): + x, y, z = symbols('x y z') + expr = ImmutableDenseMatrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = ImmutableDenseMatrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(jacobian_core[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out( + expr, wrt) + assert isinstance(jacobian_norm[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian + jacobian = _forward_jacobian(expr, wrt) + assert isinstance(jacobian, type(expr)), "Jacobian should be a Matrix of the same type as the input" + + +def test_forward_jacobian_input_output(): + x, y, z = symbols('x y z') + expr = Matrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = Matrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(replacements_core, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_core, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_core, list), "Precomputed free symbols should be a list" + assert len(replacements_core) == len(replacements), "Length of replacements does not match" + assert len(jacobian_core) == 1, "Jacobian should have one element" + assert len(precomputed_fs_core) == len(replacements), "Length of precomputed free symbols does not match" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(expr, wrt) + assert isinstance(replacements_norm, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_norm, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_norm, list), "Precomputed free symbols should be a list" + assert len(replacements_norm) == len(replacements), "Length of replacements does not match" + assert len(jacobian_norm) == 1, "Jacobian should have one element" + assert len(precomputed_fs_norm) == len(replacements), "Length of precomputed free symbols does not match" + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert _forward_jacobian(L, syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert _forward_jacobian(L, syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi)]) + Y = Matrix([rho, phi]) + J = _forward_jacobian(X, Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T * eye(J.shape[0]) * J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho ** 2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi), rho ** 2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho * sin(phi)], + [sin(phi), rho * cos(phi)], + [2 * rho, 0], + ]) + assert _forward_jacobian(X, Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = _forward_jacobian(X_slice, Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: _forward_jacobian(X, Y)) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([[x, y], [x, z]]))) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_epathtools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb47b2f2ff624077ab9905677b181c587ab5a7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_epathtools.py @@ -0,0 +1,90 @@ +"""Tests for tools for manipulation of expressions using paths. """ + +from sympy.simplify.epathtools import epath, EPath +from sympy.testing.pytest import raises + +from sympy.core.numbers import E +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x, y, z, t + + +def test_epath_select(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + + assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4] + assert epath("/*/*/*/*", expr) == [] + + assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4] + assert epath("/[:]/[:]/[:]/[:]", expr) == [] + + assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/[1]", expr) == [2, z] + assert epath("/*/[2]", expr) == [] + + assert epath("/*/int", expr) == [2] + assert epath("/*/Symbol", expr) == [z] + assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)] + + assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z] + assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z] + assert epath( + "/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]/int", expr) == [1, 3, 4] + assert epath("/*/[0]/Symbol", expr) == [x, t, y] + + assert epath("/*/[0]/int[1:]", expr) == [1, 4] + assert epath("/*/[0]/Symbol[1:]", expr) == [t, y] + + assert epath("/Symbol", x + y + z + 1) == [x, y, z] + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y] + + +def test_epath_apply(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + func = lambda expr: expr**2 + + assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]] + + assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)] + assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)] + assert epath("/*/[2]", expr, list) == expr + + assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)] + assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2), + ((3, y**2, 4), z)] + assert epath( + "/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)] + assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2), + 2), ((3, y**2, 4), z)] + + assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1 + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \ + t + sin(x**2 + 1) + cos(x**2 + y**2 + E) + + +def test_EPath(): + assert EPath("/*/[0]")._path == "/*/[0]" + assert EPath(EPath("/*/[0]"))._path == "/*/[0]" + assert isinstance(epath("/*/[0]"), EPath) is True + + assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')" + + raises(ValueError, lambda: EPath("")) + raises(ValueError, lambda: EPath("/")) + raises(ValueError, lambda: EPath("/|x")) + raises(ValueError, lambda: EPath("/[")) + raises(ValueError, lambda: EPath("/[0]%")) + + raises(NotImplementedError, lambda: EPath("Symbol")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_fu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_fu.py new file mode 100644 index 0000000000000000000000000000000000000000..2de2126b7333195fceeffe72dc9cb642e7eba9a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_fu.py @@ -0,0 +1,492 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan) +from sympy.simplify.powsimp import powsimp +from sympy.simplify.fu import ( + L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16, + TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, + TRpower, hyper_as_trig, fu, process_common_addends, trig_split, + as_f_sign_1) +from sympy.core.random import verify_numerically +from sympy.abc import a, b, c, x, y, z + + +def test_TR1(): + assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x) + + +def test_TR2(): + assert TR2(tan(x)) == sin(x)/cos(x) + assert TR2(cot(x)) == cos(x)/sin(x) + assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0 + + +def test_TR2i(): + # just a reminder that ratios of powers only simplify if both + # numerator and denominator satisfy the condition that each + # has a positive base or an integer exponent; e.g. the following, + # at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I + assert powsimp(2**x/y**x) != (2/y)**x + + assert TR2i(sin(x)/cos(x)) == tan(x) + assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y) + assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x) + assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y) + assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2 + + assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2 + assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half) + assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1) + assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2) + assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half) + assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half) + assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1) + assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2) + assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half) + assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a + assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a + assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a + assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a + assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a) + assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a) + assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a) + assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a) + + i = symbols('i', integer=True) + assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i) + assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i + + +def test_TR3(): + assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y) + assert cos(pi/2 + x) == -sin(x) + assert cos(30*pi/2 + x) == -cos(x) + + for f in (cos, sin, tan, cot, csc, sec): + i = f(pi*Rational(3, 7)) + j = TR3(i) + assert verify_numerically(i, j) and i.func != j.func + + with evaluate(False): + eq = cos(9*pi/22) + assert eq.has(9*pi) and TR3(eq) == sin(pi/11) + + +def test_TR4(): + for i in [0, pi/6, pi/4, pi/3, pi/2]: + with evaluate(False): + eq = cos(i) + assert isinstance(eq, cos) and TR4(eq) == cos(i) + + +def test__TR56(): + h = lambda x: 1 - x + assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1) + assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10 + assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3 + assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6 + assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4 + + # issue 17137 + assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I + assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1) + + +def test_TR5(): + assert TR5(sin(x)**2) == -cos(x)**2 + 1 + assert TR5(sin(x)**-2) == sin(x)**(-2) + assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2 + + +def test_TR6(): + assert TR6(cos(x)**2) == -sin(x)**2 + 1 + assert TR6(cos(x)**-2) == cos(x)**(-2) + assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2 + + +def test_TR7(): + assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half + assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2) + + +def test_TR8(): + assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2 + assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2 + assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2 + assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4 + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \ + cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \ + cos(6)/8 + Rational(1, 8) + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \ + cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \ + cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8 + assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64) + +def test_TR9(): + a = S.Half + b = 3*a + assert TR9(a) == a + assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b) + assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b) + assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b) + assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a) + assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a) + assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3) + assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2) + assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \ + 4*cos(S.Half)*cos(1)*cos(Rational(9, 2)) + assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3) + assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2) + assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2) + c = cos(x) + s = sin(x) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))): + args = zip(si, a) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR9(ex) + assert not (a[0].func == a[1].func and ( + not verify_numerically(ex, t.expand(trig=True)) or t.is_Add) + or a[1].func != a[0].func and ex != t) + + +def test_TR10(): + assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b) + assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a) + assert TR10(sin(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + assert TR10(cos(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \ + (sin(a)*cos(b) + sin(b)*cos(a))*sin(c) + + +def test_TR10i(): + assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2) + assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4) + assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7 + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4) + assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \ + 2*sin(4) + cos(3) + assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \ + cos(1) + eq = (cos(2)*cos(3) + sin(2)*( + cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5) + assert TR10i(eq) == TR10i(eq.expand()) == cos(4) + assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \ + 2*sqrt(2)*x*sin(x + pi/6) + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9 + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \ + sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9 + assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x) + assert TR10i(cos(x) + sqrt(3)*sin(x) + + 2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4) + assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \ + sin(2)*cos(4) + sin(3)*cos(2) + + A = Symbol('A', commutative=False) + assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \ + 2*sqrt(2)*sin(x + pi/6)*A + + + c = cos(x) + s = sin(x) + h = sin(y) + r = cos(y) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + c = cos(x) + s = sin(x) + h = sin(pi/6) + r = cos(pi/6) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # induced + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + +def test_TR11(): + + assert TR11(sin(2*x)) == 2*sin(x)*cos(x) + assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x)) + assert TR11(sin(x*Rational(4, 3))) == \ + 4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3)) + + assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2 + assert TR11(cos(4*x)) == \ + (-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2 + + assert TR11(cos(2)) == cos(2) + + assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2 + assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2 + assert TR11(cos(6), 2) == cos(6) + assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2) + +def test__TR11(): + + assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \ + 4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) + assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6) + + assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6)) + assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4)) + assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2) + + +def test_TR12(): + assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + assert TR12(tan(x + y + z)) ==\ + (tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/( + 1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1)) + assert TR12(tan(x*y)) == tan(x*y) + + +def test_TR13(): + assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1 + assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5) + assert TR13(tan(1)*tan(2)*tan(3)) == \ + (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1) + assert TR13(tan(1)*tan(2)*cot(3)) == \ + (-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3) + + +def test_L(): + assert L(cos(x) + sin(x)) == 2 + + +def test_fu(): + + assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2) + assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3) + + + eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2 + + assert fu(S.Half - cos(2*x)/2) == sin(x)**2 + + assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \ + sqrt(2)*sin(a + b + pi/4) + + assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3) + + assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \ + -cos(x)**2 + cos(y)**2 + + assert fu(cos(pi*Rational(4, 9))) == sin(pi/18) + assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16) + + assert fu( + tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \ + -sqrt(3) + + assert fu(tan(1)*tan(2)) == tan(1)*tan(2) + + expr = Mul(*[cos(2**i) for i in range(10)]) + assert fu(expr) == sin(1024)/(1024*sin(1)) + + # issue #18059: + assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2) + + assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \ + 7*sin(x) + 3*sqrt(3)*cos(x) + + +def test_objective(): + assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \ + tan(x) + assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \ + sin(x)/cos(x) + + +def test_process_common_addends(): + # this tests that the args are not evaluated as they are given to do + # and that key2 works when key1 is False + do = lambda x: Add(*[i**(i%2) for i in x.args]) + assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do, + key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0 + + +def test_trig_split(): + assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True) + assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True) + assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \ + (sin(y), 1, 1, x, y, True) + + assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \ + (2, 1, -1, x, pi/6, False) + assert trig_split(cos(x), sin(x), two=True) == \ + (sqrt(2), 1, 1, x, pi/4, False) + assert trig_split(cos(x), -sin(x), two=True) == \ + (sqrt(2), 1, -1, x, pi/4, False) + assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \ + (2*sqrt(2), 1, -1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \ + (-2*sqrt(2), 1, 1, x, pi/3, False) + assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \ + (sqrt(6)/3, 1, 1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x)*sin(y), + -sqrt(2)*sin(x)*sin(y), two=True) == \ + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + assert trig_split(cos(x), sin(x)) is None + assert trig_split(cos(x), sin(z)) is None + assert trig_split(2*cos(x), -sin(x)) is None + assert trig_split(cos(x), -sqrt(3)*sin(x)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None + assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \ + None + + assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None + assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None + assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None + + +def test_TRmorrie(): + assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \ + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + assert TRmorrie(x) == x + assert TRmorrie(2*x) == 2*x + e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7)) + assert TR8(TRmorrie(e)) == Rational(-1, 8) + e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)]) + assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536) + # issue 17063 + eq = cos(x)/cos(x/2) + assert TRmorrie(eq) == eq + # issue #20430 + eq = cos(x/2)*sin(x/2)*cos(x)**3 + assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4 + + +def test_TRpower(): + assert TRpower(1/sin(x)**2) == 1/sin(x)**2 + assert TRpower(cos(x)**3*sin(x/2)**4) == \ + (3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8)) + for k in range(2, 8): + assert verify_numerically(sin(x)**k, TRpower(sin(x)**k)) + assert verify_numerically(cos(x)**k, TRpower(cos(x)**k)) + + +def test_hyper_as_trig(): + from sympy.simplify.fu import _osborne, _osbornei + + eq = sinh(x)**2 + cosh(x)**2 + t, f = hyper_as_trig(eq) + assert f(fu(t)) == cosh(2*x) + e, f = hyper_as_trig(tanh(x + y)) + assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1) + + d = Dummy() + assert _osborne(sinh(x), d) == I*sin(x*d) + assert _osborne(tanh(x), d) == I*tan(x*d) + assert _osborne(coth(x), d) == cot(x*d)/I + assert _osborne(cosh(x), d) == cos(x*d) + assert _osborne(sech(x), d) == sec(x*d) + assert _osborne(csch(x), d) == csc(x*d)/I + for func in (sinh, cosh, tanh, coth, sech, csch): + h = func(pi) + assert _osbornei(_osborne(h, d), d) == h + # /!\ the _osborne functions are not meant to work + # in the o(i(trig, d), d) direction so we just check + # that they work as they are supposed to work + assert _osbornei(cos(x*y + z), y) == cosh(x + z*I) + assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I + assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I + assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I + assert _osbornei(sec(x*y + z), y) == sech(x + z*I) + assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I + + +def test_TR12i(): + ta, tb, tc = [tan(i) for i in (a, b, c)] + assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b) + assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b) + assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b) + eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + assert TR12i(eq.expand()) == \ + -3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2 + assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x) + eq = (ta + cos(2))/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (ta + tb + 2)**2/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = ta/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1) + assert TR12i(eq) == -(a + 1)**2*tan(a + b) + + +def test_TR14(): + eq = (cos(x) - 1)*(cos(x) + 1) + ans = -sin(x)**2 + assert TR14(eq) == ans + assert TR14(1/eq) == 1/ans + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2 + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1) + assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1) + eq = (cos(x) - 1)**y*(cos(x) + 1)**y + assert TR14(eq) == eq + eq = (cos(x) - 2)**y*(cos(x) + 1) + assert TR14(eq) == eq + eq = (tan(x) - 2)**2*(cos(x) + 1) + assert TR14(eq) == eq + i = symbols('i', integer=True) + assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i + assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i + # could use extraction in this case + eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i + assert TR14(eq) in [(cos(x) - 1)*ans**i, eq] + + assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2 + p1 = (cos(x) + 1)*(cos(x) - 1) + p2 = (cos(y) - 1)*2*(cos(y) + 1) + p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4) + + +def test_TR15_16_17(): + assert TR15(1 - 1/sin(x)**2) == -cot(x)**2 + assert TR16(1 - 1/cos(x)**2) == -tan(x)**2 + assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2 + + +def test_as_f_sign_1(): + assert as_f_sign_1(x + 1) == (1, x, 1) + assert as_f_sign_1(x - 1) == (1, x, -1) + assert as_f_sign_1(-x + 1) == (-1, x, -1) + assert as_f_sign_1(-x - 1) == (-1, x, 1) + assert as_f_sign_1(2*x + 2) == (2, x, 1) + assert as_f_sign_1(x*y - y) == (y, x, -1) + assert as_f_sign_1(-x*y + y) == (-y, x, -1) + + +def test_issue_25590(): + A = Symbol('A', commutative=False) + B = Symbol('B', commutative=False) + + assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A + assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A + + # XXX The result may not be optimal than + # sin(2*x)*B*A + cos(x)**2 and may change in the future + assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_function.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_function.py new file mode 100644 index 0000000000000000000000000000000000000000..441b9faf1bb3c5e7f2279b2a61066d050e45f773 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_function.py @@ -0,0 +1,54 @@ +""" Unit tests for Hyper_Function""" +from sympy.core import symbols, Dummy, Tuple, S, Rational +from sympy.functions import hyper + +from sympy.simplify.hyperexpand import Hyper_Function + +def test_attrs(): + a, b = symbols('a, b', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f.ap == Tuple(2, a) + assert f.bq == Tuple(b) + assert f.args == (Tuple(2, a), Tuple(b)) + assert f.sizes == (2, 1) + +def test_call(): + a, b, x = symbols('a, b, x', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f(x) == hyper([2, a], [b], x) + +def test_has(): + a, b, c = symbols('a, b, c', cls=Dummy) + f = Hyper_Function([2, -a], [b]) + assert f.has(a) + assert f.has(Tuple(b)) + assert not f.has(c) + +def test_eq(): + assert Hyper_Function([1], []) == Hyper_Function([1], []) + assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False + assert Hyper_Function([1], []) != Hyper_Function([2], []) + assert Hyper_Function([1], []) != Hyper_Function([1, 2], []) + assert Hyper_Function([1], []) != Hyper_Function([1], [2]) + +def test_gamma(): + assert Hyper_Function([2, 3], [-1]).gamma == 0 + assert Hyper_Function([-2, -3], [-1]).gamma == 2 + n = Dummy(integer=True) + assert Hyper_Function([-1, n, 1], []).gamma == 1 + assert Hyper_Function([-1, -n, 1], []).gamma == 1 + p = Dummy(integer=True, positive=True) + assert Hyper_Function([-1, p, 1], []).gamma == 1 + assert Hyper_Function([-1, -p, 1], []).gamma == 2 + +def test_suitable_origin(): + assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True + assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3)))._is_suitable_origin() is True + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_gammasimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c73093250b279510e3c2274db22818a9adffd8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_gammasimp.py @@ -0,0 +1,127 @@ +from sympy.core.function import Function +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.gammasimp import gammasimp +from sympy.simplify.powsimp import powsimp +from sympy.simplify.simplify import simplify + +from sympy.abc import x, y, n, k + + +def test_gammasimp(): + R = Rational + + # was part of test_combsimp_gamma() in test_combsimp.py + assert gammasimp(gamma(x)) == gamma(x) + assert gammasimp(gamma(x + 1)/x) == gamma(x) + assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1) + assert gammasimp(x*gamma(x)) == gamma(x + 1) + assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2) + assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1) + assert gammasimp(x/gamma(x + 1)) == 1/gamma(x) + assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1) + assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \ + (x + 2)*gamma(x + 1) + + assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2 + assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1) + + assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x) + assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x)) + assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \ + sin(pi*x)/(pi*x*(x + 1)*(x + 2)) + + assert gammasimp(factorial(n + 2)) == gamma(n + 3) + assert gammasimp(binomial(n, k)) == \ + gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1)) + + assert powsimp(gammasimp( + gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \ + 2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y) + assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \ + 3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1)) + assert simplify( + gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1 + assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3 + + assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \ + 2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi) + + # issue 6792 + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e) == -k + assert gammasimp(1/e) == -1/k + e = (gamma(x) + gamma(x + 1))/gamma(x) + assert gammasimp(e) == x + 1 + assert gammasimp(1/e) == 1/(x + 1) + e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x) + assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1) + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e**2) == k**2 + assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k) + a = R(1, 2) + R(1, 3) + b = a + R(1, 3) + assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b) + ) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2 + + # issue 9699 + assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y) + assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise( + (gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x), + ((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True)) + + A, B = symbols('A B', commutative=False) + assert gammasimp(e*B*A) == gammasimp(e)*B*A + + # check iteration + assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == ( + -2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k)))) + assert gammasimp( + gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == ( + 3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2) + + # issue 6153 + assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4 + + # was part of test_combsimp() in test_combsimp.py + assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \ + (gamma(k + R(3, 2))*gamma(-k + n + R(5, 2))) + assert gammasimp(binomial(n + 2, k + 2.0)) == \ + gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1)) + + # issue 11548 + assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x) + + e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3)) + assert gammasimp(e) == e + assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \ + 2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi) + + i, m = symbols('i m', integer = True) + e = gamma(exp(i)) + assert gammasimp(e) == e + e = gamma(m + 3) + assert gammasimp(e) == e + e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1)) + assert gammasimp(e) == e + + p = symbols("p", integer=True, positive=True) + assert gammasimp(gamma(-p + 4)) == gamma(-p + 4) + + +def test_issue_22606(): + fx = Function('f')(x) + eq = x + gamma(y) + # seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y` + ans = gammasimp(eq) + assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans + assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans + assert 1/gammasimp(1/eq) == ans + assert gammasimp(fx.subs(x, eq)).args[0] == ans diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_hyperexpand.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c703c228a13201de13cfd4c3413fc75a2cf5bdb6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_hyperexpand.py @@ -0,0 +1,1063 @@ +from sympy.core.random import randrange + +from sympy.simplify.hyperexpand import (ShiftA, ShiftB, UnShiftA, UnShiftB, + MeijerShiftA, MeijerShiftB, MeijerShiftC, MeijerShiftD, + MeijerUnShiftA, MeijerUnShiftB, MeijerUnShiftC, + MeijerUnShiftD, + ReduceOrder, reduce_order, apply_operators, + devise_plan, make_derivative_operator, Formula, + hyperexpand, Hyper_Function, G_Function, + reduce_order_meijer, + build_hypergeometric_formula) +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.abc import z, a, b, c +from sympy.testing.pytest import XFAIL, raises, slow, tooslow +from sympy.core.random import verify_numerically as tn + +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.functions.special.bessel import besseli +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) + + +def test_branch_bug(): + assert hyperexpand(hyper((Rational(-1, 3), S.Half), (Rational(2, 3), Rational(3, 2)), -z)) == \ + -z**S('1/3')*lowergamma(exp_polar(I*pi)/3, z)/5 \ + + sqrt(pi)*erf(sqrt(z))/(5*sqrt(z)) + assert hyperexpand(meijerg([Rational(7, 6), 1], [], [Rational(2, 3)], [Rational(1, 6), 0], z)) == \ + 2*z**S('2/3')*(2*sqrt(pi)*erf(sqrt(z))/sqrt(z) - 2*lowergamma( + Rational(2, 3), z)/z**S('2/3'))*gamma(Rational(2, 3))/gamma(Rational(5, 3)) + + +def test_hyperexpand(): + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + assert hyperexpand(hyper([], [], z)) == exp(z) + assert hyperexpand(hyper([1, 1], [2], -z)*z) == log(1 + z) + assert hyperexpand(hyper([], [S.Half], -z**2/4)) == cos(z) + assert hyperexpand(z*hyper([], [S('3/2')], -z**2/4)) == sin(z) + assert hyperexpand(hyper([S('1/2'), S('1/2')], [S('3/2')], z**2)*z) \ + == asin(z) + assert isinstance(Sum(binomial(2, z)*z**2, (z, 0, a)).doit(), Expr) + + +def can_do(ap, bq, numerical=True, div=1, lowerplane=False): + r = hyperexpand(hyper(ap, bq, z)) + if r.has(hyper): + return False + if not numerical: + return True + repl = {} + randsyms = r.free_symbols - {z} + while randsyms: + # Only randomly generated parameters are checked. + for n, ai in enumerate(randsyms): + repl[ai] = randcplx(n)/div + if not any(b.is_Integer and b <= 0 for b in Tuple(*bq).subs(repl)): + break + [a, b, c, d] = [2, -1, 3, 1] + if lowerplane: + [a, b, c, d] = [2, -2, 3, -1] + return tn( + hyper(ap, bq, z).subs(repl), + r.replace(exp_polar, exp).subs(repl), + z, a=a, b=b, c=c, d=d) + + +def test_roach(): + # Kelly B. Roach. Meijer G Function Representations. + # Section "Gallery" + assert can_do([S.Half], [Rational(9, 2)]) + assert can_do([], [1, Rational(5, 2), 4]) + assert can_do([Rational(-1, 2), 1, 2], [3, 4]) + assert can_do([Rational(1, 3)], [Rational(-2, 3), Rational(-1, 2), S.Half, 1]) + assert can_do([Rational(-3, 2), Rational(-1, 2)], [Rational(-5, 2), 1]) + assert can_do([Rational(-3, 2), ], [Rational(-1, 2), S.Half]) # shine-integral + assert can_do([Rational(-3, 2), Rational(-1, 2)], [2]) # elliptic integrals + + +@XFAIL +def test_roach_fail(): + assert can_do([Rational(-1, 2), 1], [Rational(1, 4), S.Half, Rational(3, 4)]) # PFDD + assert can_do([Rational(3, 2)], [Rational(5, 2), 5]) # struve function + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(5, 2)]) # polylog, pfdd + assert can_do([1, 2, 3], [S.Half, 4]) # XXX ? + assert can_do([S.Half], [Rational(-1, 3), Rational(-1, 2), Rational(-2, 3)]) # PFDD ? + +# For the long table tests, see end of file + + +def test_polynomial(): + from sympy.core.numbers import oo + assert hyperexpand(hyper([], [-1], z)) is oo + assert hyperexpand(hyper([-2], [-1], z)) is oo + assert hyperexpand(hyper([0, 0], [-1], z)) == 1 + assert can_do([-5, -2, randcplx(), randcplx()], [-10, randcplx()]) + assert hyperexpand(hyper((-1, 1), (-2,), z)) == 1 + z/2 + + +def test_hyperexpand_bases(): + assert hyperexpand(hyper([2], [a], z)) == \ + a + z**(-a + 1)*(-a**2 + 3*a + z*(a - 1) - 2)*exp(z)* \ + lowergamma(a - 1, z) - 1 + # TODO [a+1, aRational(-1, 2)], [2*a] + assert hyperexpand(hyper([1, 2], [3], z)) == -2/z - 2*log(-z + 1)/z**2 + assert hyperexpand(hyper([S.Half, 2], [Rational(3, 2)], z)) == \ + -1/(2*z - 2) + atanh(sqrt(z))/sqrt(z)/2 + assert hyperexpand(hyper([S.Half, S.Half], [Rational(5, 2)], z)) == \ + (-3*z + 3)/4/(z*sqrt(-z + 1)) \ + + (6*z - 3)*asin(sqrt(z))/(4*z**Rational(3, 2)) + assert hyperexpand(hyper([1, 2], [Rational(3, 2)], z)) == -1/(2*z - 2) \ + - asin(sqrt(z))/(sqrt(z)*(2*z - 2)*sqrt(-z + 1)) + assert hyperexpand(hyper([Rational(-1, 2) - 1, 1, 2], [S.Half, 3], z)) == \ + sqrt(z)*(z*Rational(6, 7) - Rational(6, 5))*atanh(sqrt(z)) \ + + (-30*z**2 + 32*z - 6)/35/z - 6*log(-z + 1)/(35*z**2) + assert hyperexpand(hyper([1 + S.Half, 1, 1], [2, 2], z)) == \ + -4*log(sqrt(-z + 1)/2 + S.Half)/z + # TODO hyperexpand(hyper([a], [2*a + 1], z)) + # TODO [S.Half, a], [Rational(3, 2), a+1] + assert hyperexpand(hyper([2], [b, 1], z)) == \ + z**(-b/2 + S.Half)*besseli(b - 1, 2*sqrt(z))*gamma(b) \ + + z**(-b/2 + 1)*besseli(b, 2*sqrt(z))*gamma(b) + # TODO [a], [a - S.Half, 2*a] + + +def test_hyperexpand_parametric(): + assert hyperexpand(hyper([a, S.Half + a], [S.Half], z)) \ + == (1 + sqrt(z))**(-2*a)/2 + (1 - sqrt(z))**(-2*a)/2 + assert hyperexpand(hyper([a, Rational(-1, 2) + a], [2*a], z)) \ + == 2**(2*a - 1)*((-z + 1)**S.Half + 1)**(-2*a + 1) + + +def test_shifted_sum(): + from sympy.simplify.simplify import simplify + assert simplify(hyperexpand(z**4*hyper([2], [3, S('3/2')], -z**2))) \ + == z*sin(2*z) + (-z**2 + S.Half)*cos(2*z) - S.Half + + +def _randrat(): + """ Steer clear of integers. """ + return S(randrange(25) + 10)/50 + + +def randcplx(offset=-1): + """ Polys is not good with real coefficients. """ + return _randrat() + I*_randrat() + I*(1 + offset) + + +@slow +def test_formulae(): + from sympy.simplify.hyperexpand import FormulaCollection + formulae = FormulaCollection().formulae + for formula in formulae: + h = formula.func(formula.z) + rep = {} + for n, sym in enumerate(formula.symbols): + rep[sym] = randcplx(n) + + # NOTE hyperexpand returns truly branched functions. We know we are + # on the main sheet, but numerical evaluation can still go wrong + # (e.g. if exp_polar cannot be evalf'd). + # Just replace all exp_polar by exp, this usually works. + + # first test if the closed-form is actually correct + h = h.subs(rep) + closed_form = formula.closed_form.subs(rep).rewrite('nonrepsmall') + z = formula.z + assert tn(h, closed_form.replace(exp_polar, exp), z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep).rewrite('nonrepsmall') + assert tn(closed_form.replace( + exp_polar, exp), cl.replace(exp_polar, exp), z) + deriv1 = z*formula.B.applyfunc(lambda t: t.rewrite( + 'nonrepsmall')).diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep).replace(exp_polar, exp), + d2.subs(rep).rewrite('nonrepsmall').replace(exp_polar, exp), z) + + +def test_meijerg_formulae(): + from sympy.simplify.hyperexpand import MeijerFormulaCollection + formulae = MeijerFormulaCollection().formulae + for sig in formulae: + for formula in formulae[sig]: + g = meijerg(formula.func.an, formula.func.ap, + formula.func.bm, formula.func.bq, + formula.z) + rep = {} + for sym in formula.symbols: + rep[sym] = randcplx() + + # first test if the closed-form is actually correct + g = g.subs(rep) + closed_form = formula.closed_form.subs(rep) + z = formula.z + assert tn(g, closed_form, z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep) + assert tn(closed_form, cl, z) + deriv1 = z*formula.B.diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep), d2.subs(rep), z) + + +def op(f): + return z*f.diff(z) + + +def test_plan(): + assert devise_plan(Hyper_Function([0], ()), + Hyper_Function([0], ()), z) == [] + with raises(ValueError): + devise_plan(Hyper_Function([1], ()), Hyper_Function((), ()), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], [1]), Hyper_Function([2], [2]), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], []), Hyper_Function([S("1/2")], []), z) + + # We cannot use pi/(10000 + n) because polys is insanely slow. + a1, a2, b1 = (randcplx(n) for n in range(3)) + b1 += 2*I + h = hyper([a1, a2], [b1], z) + + h2 = hyper((a1 + 1, a2), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + h2 = hyper((a1 + 1, a2 - 1), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2 - 1), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + +def test_plan_derivatives(): + a1, a2, a3 = 1, 2, S('1/2') + b1, b2 = 3, S('5/2') + h = Hyper_Function((a1, a2, a3), (b1, b2)) + h2 = Hyper_Function((a1 + 1, a2 + 1, a3 + 2), (b1 + 1, b2 + 1)) + ops = devise_plan(h2, h, z) + f = Formula(h, z, h(z), []) + deriv = make_derivative_operator(f.M, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + h2 = Hyper_Function((a1, a2 - 1, a3 - 2), (b1 - 1, b2 - 1)) + ops = devise_plan(h2, h, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + +def test_reduction_operators(): + a1, a2, b1 = (randcplx(n) for n in range(3)) + h = hyper([a1], [b1], z) + + assert ReduceOrder(2, 0) is None + assert ReduceOrder(2, -1) is None + assert ReduceOrder(1, S('1/2')) is None + + h2 = hyper((a1, a2), (b1, a2), z) + assert tn(ReduceOrder(a2, a2).apply(h, op), h2, z) + + h2 = hyper((a1, a2 + 1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 1, a2).apply(h, op), h2, z) + + h2 = hyper((a2 + 4, a1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 4, a2).apply(h, op), h2, z) + + # test several step order reduction + ap = (a2 + 4, a1, b1 + 1) + bq = (a2, b1, b1) + func, ops = reduce_order(Hyper_Function(ap, bq)) + assert func.ap == (a1,) + assert func.bq == (b1,) + assert tn(apply_operators(h, ops, op), hyper(ap, bq, z), z) + + +def test_shift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: ShiftA(0)) + raises(ValueError, lambda: ShiftB(1)) + + assert tn(ShiftA(a1).apply(h, op), hyper((a1 + 1, a2), (b1, b2, b3), z), z) + assert tn(ShiftA(a2).apply(h, op), hyper((a1, a2 + 1), (b1, b2, b3), z), z) + assert tn(ShiftB(b1).apply(h, op), hyper((a1, a2), (b1 - 1, b2, b3), z), z) + assert tn(ShiftB(b2).apply(h, op), hyper((a1, a2), (b1, b2 - 1, b3), z), z) + assert tn(ShiftB(b3).apply(h, op), hyper((a1, a2), (b1, b2, b3 - 1), z), z) + + +def test_ushift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: UnShiftA((1,), (), 0, z)) + raises(ValueError, lambda: UnShiftB((), (-1,), 0, z)) + raises(ValueError, lambda: UnShiftA((1,), (0, -1, 1), 0, z)) + raises(ValueError, lambda: UnShiftB((0, 1), (1,), 0, z)) + + s = UnShiftA((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1 - 1, a2), (b1, b2, b3), z), z) + s = UnShiftA((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2 - 1), (b1, b2, b3), z), z) + + s = UnShiftB((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1 + 1, b2, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2 + 1, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 2, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2, b3 + 1), z), z) + + +def can_do_meijer(a1, a2, b1, b2, numeric=True): + """ + This helper function tries to hyperexpand() the meijer g-function + corresponding to the parameters a1, a2, b1, b2. + It returns False if this expansion still contains g-functions. + If numeric is True, it also tests the so-obtained formula numerically + (at random values) and returns False if the test fails. + Else it returns True. + """ + from sympy.core.function import expand + from sympy.functions.elementary.complexes import unpolarify + r = hyperexpand(meijerg(a1, a2, b1, b2, z)) + if r.has(meijerg): + return False + # NOTE hyperexpand() returns a truly branched function, whereas numerical + # evaluation only works on the main branch. Since we are evaluating on + # the main branch, this should not be a problem, but expressions like + # exp_polar(I*pi/2*x)**a are evaluated incorrectly. We thus have to get + # rid of them. The expand heuristically does this... + r = unpolarify(expand(r, force=True, power_base=True, power_exp=False, + mul=False, log=False, multinomial=False, basic=False)) + + if not numeric: + return True + + repl = {} + for n, ai in enumerate(meijerg(a1, a2, b1, b2, z).free_symbols - {z}): + repl[ai] = randcplx(n) + return tn(meijerg(a1, a2, b1, b2, z).subs(repl), r.subs(repl), z) + + +@slow +def test_meijerg_expand(): + from sympy.simplify.gammasimp import gammasimp + from sympy.simplify.simplify import simplify + # from mpmath docs + assert hyperexpand(meijerg([[], []], [[0], []], -z)) == exp(z) + + assert hyperexpand(meijerg([[1, 1], []], [[1], [0]], z)) == \ + log(z + 1) + assert hyperexpand(meijerg([[1, 1], []], [[1], [1]], z)) == \ + z/(z + 1) + assert hyperexpand(meijerg([[], []], [[S.Half], [0]], (z/2)**2)) \ + == sin(z)/sqrt(pi) + assert hyperexpand(meijerg([[], []], [[0], [S.Half]], (z/2)**2)) \ + == cos(z)/sqrt(pi) + assert can_do_meijer([], [a], [a - 1, a - S.Half], []) + assert can_do_meijer([], [], [a/2], [-a/2], False) # branches... + assert can_do_meijer([a], [b], [a], [b, a - 1]) + + # wikipedia + assert hyperexpand(meijerg([1], [], [], [0], z)) == \ + Piecewise((0, abs(z) < 1), (1, abs(1/z) < 1), + (meijerg([1], [], [], [0], z), True)) + assert hyperexpand(meijerg([], [1], [0], [], z)) == \ + Piecewise((1, abs(z) < 1), (0, abs(1/z) < 1), + (meijerg([], [1], [0], [], z), True)) + + # The Special Functions and their Approximations + assert can_do_meijer([], [], [a + b/2], [a, a - b/2, a + S.Half]) + assert can_do_meijer( + [], [], [a], [b], False) # branches only agree for small z + assert can_do_meijer([], [S.Half], [a], [-a]) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, a + S.Half], [b, b + S.Half]) + assert can_do_meijer([], [], [a, -a], [0, S.Half], False) # dito + assert can_do_meijer([], [], [a, a + S.Half, b, b + S.Half], []) + assert can_do_meijer([S.Half], [], [0], [a, -a]) + assert can_do_meijer([S.Half], [], [a], [0, -a], False) # dito + assert can_do_meijer([], [a - S.Half], [a, b], [a - S.Half], False) + assert can_do_meijer([], [a + S.Half], [a + b, a - b, a], [], False) + assert can_do_meijer([a + S.Half], [], [b, 2*a - b, a], [], False) + + # This for example is actually zero. + assert can_do_meijer([], [], [], [a, b]) + + # Testing a bug: + assert hyperexpand(meijerg([0, 2], [], [], [-1, 1], z)) == \ + Piecewise((0, abs(z) < 1), + (z*(1 - 1/z**2)/2, abs(1/z) < 1), + (meijerg([0, 2], [], [], [-1, 1], z), True)) + + # Test that the simplest possible answer is returned: + assert gammasimp(simplify(hyperexpand( + meijerg([1], [1 - a], [-a/2, -a/2 + S.Half], [], 1/z)))) == \ + -2*sqrt(pi)*(sqrt(z + 1) + 1)**a/a + + # Test that hyper is returned + assert hyperexpand(meijerg([1], [], [a], [0, 0], z)) == hyper( + (a,), (a + 1, a + 1), z*exp_polar(I*pi))*z**a*gamma(a)/gamma(a + 1)**2 + + # Test place option + f = meijerg(((0, 1), ()), ((S.Half,), (0,)), z**2) + assert hyperexpand(f) == sqrt(pi)/sqrt(1 + z**(-2)) + assert hyperexpand(f, place=0) == sqrt(pi)*z/sqrt(z**2 + 1) + + +def test_meijerg_lookup(): + from sympy.functions.special.error_functions import (Ci, Si) + from sympy.functions.special.gamma_functions import uppergamma + assert hyperexpand(meijerg([a], [], [b, a], [], z)) == \ + z**b*exp(z)*gamma(-a + b + 1)*uppergamma(a - b, z) + assert hyperexpand(meijerg([0], [], [0, 0], [], z)) == \ + exp(z)*uppergamma(0, z) + assert can_do_meijer([a], [], [b, a + 1], []) + assert can_do_meijer([a], [], [b + 2, a], []) + assert can_do_meijer([a], [], [b - 2, a], []) + + assert hyperexpand(meijerg([a], [], [a, a, a - S.Half], [], z)) == \ + -sqrt(pi)*z**(a - S.Half)*(2*cos(2*sqrt(z))*(Si(2*sqrt(z)) - pi/2) + - 2*sin(2*sqrt(z))*Ci(2*sqrt(z))) == \ + hyperexpand(meijerg([a], [], [a, a - S.Half, a], [], z)) == \ + hyperexpand(meijerg([a], [], [a - S.Half, a, a], [], z)) + assert can_do_meijer([a - 1], [], [a + 2, a - Rational(3, 2), a + 1], []) + + +@XFAIL +def test_meijerg_expand_fail(): + # These basically test hyper([], [1/2 - a, 1/2 + 1, 1/2], z), + # which is *very* messy. But since the meijer g actually yields a + # sum of bessel functions, things can sometimes be simplified a lot and + # are then put into tables... + assert can_do_meijer([], [], [a + S.Half], [a, a - b/2, a + b/2]) + assert can_do_meijer([], [], [0, S.Half], [a, -a]) + assert can_do_meijer([], [], [3*a - S.Half, a, -a - S.Half], [a - S.Half]) + assert can_do_meijer([], [], [0, a - S.Half, -a - S.Half], [S.Half]) + assert can_do_meijer([], [], [a, b + S.Half, b], [2*b - a]) + assert can_do_meijer([], [], [a, b + S.Half, b, 2*b - a]) + assert can_do_meijer([S.Half], [], [-a, a], [0]) + + +@slow +def test_meijerg(): + # carefully set up the parameters. + # NOTE: this used to fail sometimes. I believe it is fixed, but if you + # hit an inexplicable test failure here, please let me know the seed. + a1, a2 = (randcplx(n) - 5*I - n*I for n in range(2)) + b1, b2 = (randcplx(n) + 5*I + n*I for n in range(2)) + b3, b4, b5, a3, a4, a5 = (randcplx() for n in range(6)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert ReduceOrder.meijer_minus(3, 4) is None + assert ReduceOrder.meijer_plus(4, 3) is None + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2], z) + assert tn(ReduceOrder.meijer_plus(a2, a2).apply(g, op), g2, z) + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2 + 1], z) + assert tn(ReduceOrder.meijer_plus(a2, a2 + 1).apply(g, op), g2, z) + + g2 = meijerg([a1, a2 - 1], [a3, a4], [b1], [b3, b4, a2 + 2], z) + assert tn(ReduceOrder.meijer_plus(a2 - 1, a2 + 2).apply(g, op), g2, z) + + g2 = meijerg([a1], [a3, a4, b2 - 1], [b1, b2 + 2], [b3, b4], z) + assert tn(ReduceOrder.meijer_minus( + b2 + 2, b2 - 1).apply(g, op), g2, z, tol=1e-6) + + # test several-step reduction + an = [a1, a2] + bq = [b3, b4, a2 + 1] + ap = [a3, a4, b2 - 1] + bm = [b1, b2 + 1] + niq, ops = reduce_order_meijer(G_Function(an, ap, bm, bq)) + assert niq.an == (a1,) + assert set(niq.ap) == {a3, a4} + assert niq.bm == (b1,) + assert set(niq.bq) == {b3, b4} + assert tn(apply_operators(g, ops, op), meijerg(an, ap, bm, bq, z), z) + + +def test_meijerg_shift_operators(): + # carefully set up the parameters. XXX this still fails sometimes + a1, a2, a3, a4, a5, b1, b2, b3, b4, b5 = (randcplx(n) for n in range(10)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert tn(MeijerShiftA(b1).apply(g, op), + meijerg([a1], [a3, a4], [b1 + 1], [b3, b4], z), z) + assert tn(MeijerShiftB(a1).apply(g, op), + meijerg([a1 - 1], [a3, a4], [b1], [b3, b4], z), z) + assert tn(MeijerShiftC(b3).apply(g, op), + meijerg([a1], [a3, a4], [b1], [b3 + 1, b4], z), z) + assert tn(MeijerShiftD(a3).apply(g, op), + meijerg([a1], [a3 - 1, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftA([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1 - 1], [b3, b4], z), z) + + s = MeijerUnShiftC([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1], [b3 - 1, b4], z), z) + + s = MeijerUnShiftB([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1 + 1], [a3, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftD([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3 + 1, a4], [b1], [b3, b4], z), z) + + +@slow +def test_meijerg_confluence(): + def t(m, a, b): + from sympy.core.sympify import sympify + a, b = sympify([a, b]) + m_ = m + m = hyperexpand(m) + if not m == Piecewise((a, abs(z) < 1), (b, abs(1/z) < 1), (m_, True)): + return False + if not (m.args[0].args[0] == a and m.args[1].args[0] == b): + return False + z0 = randcplx()/10 + if abs(m.subs(z, z0).n() - a.subs(z, z0).n()).n() > 1e-10: + return False + if abs(m.subs(z, 1/z0).n() - b.subs(z, 1/z0).n()).n() > 1e-10: + return False + return True + + assert t(meijerg([], [1, 1], [0, 0], [], z), -log(z), 0) + assert t(meijerg( + [], [3, 1], [0, 0], [], z), -z**2/4 + z - log(z)/2 - Rational(3, 4), 0) + assert t(meijerg([], [3, 1], [-1, 0], [], z), + z**2/12 - z/2 + log(z)/2 + Rational(1, 4) + 1/(6*z), 0) + assert t(meijerg([], [1, 1, 1, 1], [0, 0, 0, 0], [], z), -log(z)**3/6, 0) + assert t(meijerg([1, 1], [], [], [0, 0], z), 0, -log(1/z)) + assert t(meijerg([1, 1], [2, 2], [1, 1], [0, 0], z), + -z*log(z) + 2*z, -log(1/z) + 2) + assert t(meijerg([S.Half], [1, 1], [0, 0], [Rational(3, 2)], z), log(z)/2 - 1, 0) + + def u(an, ap, bm, bq): + m = meijerg(an, ap, bm, bq, z) + m2 = hyperexpand(m, allow_hyper=True) + if m2.has(meijerg) and not (m2.is_Piecewise and len(m2.args) == 3): + return False + return tn(m, m2, z) + assert u([], [1], [0, 0], []) + assert u([1, 1], [], [], [0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0, 0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0]) + + +def test_meijerg_with_Floats(): + # see issue #10681 + from sympy.polys.domains.realfield import RR + f = meijerg(((3.0, 1), ()), ((Rational(3, 2),), (0,)), z) + a = -2.3632718012073 + g = a*z**Rational(3, 2)*hyper((-0.5, Rational(3, 2)), (Rational(5, 2),), z*exp_polar(I*pi)) + assert RR.almosteq((hyperexpand(f)/g).n(), 1.0, 1e-12) + + +def test_lerchphi(): + from sympy.functions.special.zeta_functions import (lerchphi, polylog) + from sympy.simplify.gammasimp import gammasimp + assert hyperexpand(hyper([1, a], [a + 1], z)/a) == lerchphi(z, 1, a) + assert hyperexpand( + hyper([1, a, a], [a + 1, a + 1], z)/a**2) == lerchphi(z, 2, a) + assert hyperexpand(hyper([1, a, a, a], [a + 1, a + 1, a + 1], z)/a**3) == \ + lerchphi(z, 3, a) + assert hyperexpand(hyper([1] + [a]*10, [a + 1]*10, z)/a**10) == \ + lerchphi(z, 10, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a], [], [0], + [-a], exp_polar(-I*pi)*z))) == lerchphi(z, 1, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a], [], [0], + [-a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 2, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a, 1 - a], [], [0], + [-a, -a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 3, a) + + assert hyperexpand(z*hyper([1, 1], [2], z)) == -log(1 + -z) + assert hyperexpand(z*hyper([1, 1, 1], [2, 2], z)) == polylog(2, z) + assert hyperexpand(z*hyper([1, 1, 1, 1], [2, 2, 2], z)) == polylog(3, z) + + assert hyperexpand(hyper([1, a, 1 + S.Half], [a + 1, S.Half], z)) == \ + -2*a/(z - 1) + (-2*a**2 + a)*lerchphi(z, 1, a) + + # Now numerical tests. These make sure reductions etc are carried out + # correctly + + # a rational function (polylog at negative integer order) + assert can_do([2, 2, 2], [1, 1]) + + # NOTE these contain log(1-x) etc ... better make sure we have |z| < 1 + # reduction of order for polylog + assert can_do([1, 1, 1, b + 5], [2, 2, b], div=10) + + # reduction of order for lerchphi + # XXX lerchphi in mpmath is flaky + assert can_do( + [1, a, a, a, b + 5], [a + 1, a + 1, a + 1, b], numerical=False) + + # test a bug + from sympy.functions.elementary.complexes import Abs + assert hyperexpand(hyper([S.Half, S.Half, S.Half, 1], + [Rational(3, 2), Rational(3, 2), Rational(3, 2)], Rational(1, 4))) == \ + Abs(-polylog(3, exp_polar(I*pi)/2) + polylog(3, S.Half)) + + +def test_partial_simp(): + # First test that hypergeometric function formulae work. + a, b, c, d, e = (randcplx() for _ in range(5)) + for func in [Hyper_Function([a, b, c], [d, e]), + Hyper_Function([], [a, b, c, d, e])]: + f = build_hypergeometric_formula(func) + z = f.z + assert f.closed_form == func(z) + deriv1 = f.B.diff(z)*z + deriv2 = f.M*f.B + for func1, func2 in zip(deriv1, deriv2): + assert tn(func1, func2, z) + + # Now test that formulae are partially simplified. + a, b, z = symbols('a b z') + assert hyperexpand(hyper([3, a], [1, b], z)) == \ + (-a*b/2 + a*z/2 + 2*a)*hyper([a + 1], [b], z) \ + + (a*b/2 - 2*a + 1)*hyper([a], [b], z) + assert tn( + hyperexpand(hyper([3, d], [1, e], z)), hyper([3, d], [1, e], z), z) + assert hyperexpand(hyper([3], [1, a, b], z)) == \ + hyper((), (a, b), z) \ + + z*hyper((), (a + 1, b), z)/(2*a) \ + - z*(b - 4)*hyper((), (a + 1, b + 1), z)/(2*a*b) + assert tn( + hyperexpand(hyper([3], [1, d, e], z)), hyper([3], [1, d, e], z), z) + + +def test_hyperexpand_special(): + assert hyperexpand(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b) + assert hyperexpand(hyper([a, b], [1 + a - b], -1)) == \ + gamma(1 + a/2)*gamma(1 + a - b)/gamma(1 + a)/gamma(1 + a/2 - b) + assert hyperexpand(hyper([a, b], [1 + b - a], -1)) == \ + gamma(1 + b/2)*gamma(1 + b - a)/gamma(1 + b)/gamma(1 + b/2 - a) + assert hyperexpand(meijerg([1 - z - a/2], [1 - z + a/2], [b/2], [-b/2], 1)) == \ + gamma(1 - 2*z)*gamma(z + a/2 + b/2)/gamma(1 - z + a/2 - b/2) \ + /gamma(1 - z - a/2 + b/2)/gamma(1 - z + a/2 + b/2) + assert hyperexpand(hyper([a], [b], 0)) == 1 + assert hyper([a], [b], 0) != 0 + + +def test_Mod1_behavior(): + from sympy.core.symbol import Symbol + from sympy.simplify.simplify import simplify + n = Symbol('n', integer=True) + # Note: this should not hang. + assert simplify(hyperexpand(meijerg([1], [], [n + 1], [0], z))) == \ + lowergamma(n + 1, z) + + +@slow +def test_prudnikov_misc(): + assert can_do([1, (3 + I)/2, (3 - I)/2], [Rational(3, 2), 2]) + assert can_do([S.Half, a - 1], [Rational(3, 2), a + 1], lowerplane=True) + assert can_do([], [b + 1]) + assert can_do([a], [a - 1, b + 1]) + + assert can_do([a], [a - S.Half, 2*a]) + assert can_do([a], [a - S.Half, 2*a + 1]) + assert can_do([a], [a - S.Half, 2*a - 1]) + assert can_do([a], [a + S.Half, 2*a]) + assert can_do([a], [a + S.Half, 2*a + 1]) + assert can_do([a], [a + S.Half, 2*a - 1]) + assert can_do([S.Half], [b, 2 - b]) + assert can_do([S.Half], [b, 3 - b]) + assert can_do([1], [2, b]) + + assert can_do([a, a + S.Half], [2*a, b, 2*a - b + 1]) + assert can_do([a, a + S.Half], [S.Half, 2*a, 2*a + S.Half]) + assert can_do([a], [a + 1], lowerplane=True) # lowergamma + + +def test_prudnikov_1(): + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + + # 7.3.1 + assert can_do([a, -a], [S.Half]) + assert can_do([a, 1 - a], [S.Half]) + assert can_do([a, 1 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [S.Half]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, a + S.Half], [2*a - 1]) + assert can_do([a, a + S.Half], [2*a]) + assert can_do([a, a + S.Half], [2*a + 1]) + assert can_do([a, a + S.Half], [S.Half]) + assert can_do([a, a + S.Half], [Rational(3, 2)]) + assert can_do([a, a/2 + 1], [a/2]) + assert can_do([1, b], [2]) + assert can_do([1, b], [b + 1], numerical=False) # Lerch Phi + # NOTE: branches are complicated for |z| > 1 + + assert can_do([a], [2*a]) + assert can_do([a], [2*a + 1]) + assert can_do([a], [2*a - 1]) + + +@slow +def test_prudnikov_2(): + h = S.Half + assert can_do([-h, -h], [h]) + assert can_do([-h, h], [3*h]) + assert can_do([-h, h], [5*h]) + assert can_do([-h, h], [7*h]) + assert can_do([-h, 1], [h]) + + for p in [-h, h]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [-h, h, 3*h, 5*h, 7*h]: + assert can_do([p, n], [m]) + for n in [1, 2, 3, 4]: + for m in [1, 2, 3, 4]: + assert can_do([p, n], [m]) + + +def test_prudnikov_3(): + h = S.Half + assert can_do([Rational(1, 4), Rational(3, 4)], [h]) + assert can_do([Rational(1, 4), Rational(3, 4)], [3*h]) + assert can_do([Rational(1, 3), Rational(2, 3)], [3*h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [3*h]) + + +@tooslow +def test_prudnikov_3_slow(): + # XXX: This is marked as tooslow and hence skipped in CI. None of the + # individual cases below fails or hangs. Some cases are slow and the loops + # below generate 280 different cases. Is it really necessary to test all + # 280 cases here? + h = S.Half + for p in [1, 2, 3, 4]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4, 9*h]: + for m in [1, 3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_4(): + h = S.Half + for p in [3*h, 5*h, 7*h]: + for n in [-h, h, 3*h, 5*h, 7*h]: + for m in [3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + for n in [1, 2, 3, 4]: + for m in [2, 3, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_5(): + h = S.Half + + for p in [1, 2, 3]: + for q in range(p, 4): + for r in [1, 2, 3]: + for s in range(r, 4): + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [h, 3*h, 5*h]: + for r in [h, 3*h, 5*h]: + for s in [h, 3*h, 5*h]: + if s <= q and s <= r: + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [1, 2, 3]: + for r in [h, 3*h, 5*h]: + for s in [1, 2, 3]: + assert can_do([-h, p, q], [r, s]) + + +@slow +def test_prudnikov_6(): + h = S.Half + + for m in [3*h, 5*h]: + for n in [1, 2, 3]: + for q in [h, 1, 2]: + for p in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + for q in [1, 2, 3]: + for p in [3*h, 5*h]: + assert can_do([h, q, p], [m, n]) + + for q in [1, 2]: + for p in [1, 2, 3]: + for m in [1, 2, 3]: + for n in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + + assert can_do([h, h, 5*h], [3*h, 3*h]) + assert can_do([h, 1, 5*h], [3*h, 3*h]) + assert can_do([h, 2, 2], [1, 3]) + + # pages 435 to 457 contain more PFDD and stuff like this + + +@slow +def test_prudnikov_7(): + assert can_do([3], [6]) + + h = S.Half + for n in [h, 3*h, 5*h, 7*h]: + assert can_do([-h], [n]) + for m in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: # HERE + for n in [-h, h, 3*h, 5*h, 7*h, 1, 2, 3, 4]: + assert can_do([m], [n]) + + +@slow +def test_prudnikov_8(): + h = S.Half + + # 7.12.2 + for ai in [1, 2, 3]: + for bi in [1, 2, 3]: + for ci in range(1, ai + 1): + for di in [h, 1, 3*h, 2, 5*h, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [3*h, 5*h]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + + for ai in [-h, h, 3*h, 5*h]: + for bi in [1, 2, 3]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [h, 3*h, 5*h]: + for ci in [h, 3*h, 5*h, 3]: + for di in [h, 1, 3*h, 2, 5*h, 3]: + if ci <= bi: + assert can_do([ai, bi], [ci, di]) + + +def test_prudnikov_9(): + # 7.13.1 [we have a general formula ... so this is a bit pointless] + for i in range(9): + assert can_do([], [(S(i) + 1)/2]) + for i in range(5): + assert can_do([], [-(2*S(i) + 1)/2]) + + +@slow +def test_prudnikov_10(): + # 7.14.2 + h = S.Half + for p in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [1, 2, 3, 4]: + for n in range(m, 5): + assert can_do([p], [m, n]) + + for p in [1, 2, 3, 4]: + for n in [h, 3*h, 5*h, 7*h]: + for m in [1, 2, 3, 4]: + assert can_do([p], [n, m]) + + for p in [3*h, 5*h, 7*h]: + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([p], [h, m]) + assert can_do([p], [3*h, m]) + + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([7*h], [5*h, m]) + + assert can_do([Rational(-1, 2)], [S.Half, S.Half]) # shine-integral shi + + +def test_prudnikov_11(): + # 7.15 + assert can_do([a, a + S.Half], [2*a, b, 2*a - b]) + assert can_do([a, a + S.Half], [Rational(3, 2), 2*a, 2*a - S.Half]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [S.Half, S.Half, 1]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), S.Half, 2]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), Rational(3, 2), 1]) + assert can_do([Rational(5, 4), Rational(7, 4)], [Rational(3, 2), Rational(5, 2), 2]) + + assert can_do([1, 1], [Rational(3, 2), 2, 2]) # cosh-integral chi + + +def test_prudnikov_12(): + # 7.16 + assert can_do( + [], [a, a + S.Half, 2*a], False) # branches only agree for some z! + assert can_do([], [a, a + S.Half, 2*a + 1], False) # dito + assert can_do([], [S.Half, a, a + S.Half]) + assert can_do([], [Rational(3, 2), a, a + S.Half]) + + assert can_do([], [Rational(1, 4), S.Half, Rational(3, 4)]) + assert can_do([], [S.Half, S.Half, 1]) + assert can_do([], [S.Half, Rational(3, 2), 1]) + assert can_do([], [Rational(3, 4), Rational(3, 2), Rational(5, 4)]) + assert can_do([], [1, 1, Rational(3, 2)]) + assert can_do([], [1, 2, Rational(3, 2)]) + assert can_do([], [1, Rational(3, 2), Rational(3, 2)]) + assert can_do([], [Rational(5, 4), Rational(3, 2), Rational(7, 4)]) + assert can_do([], [2, Rational(3, 2), Rational(3, 2)]) + + +@slow +def test_prudnikov_2F1(): + h = S.Half + # Elliptic integrals + for p in [-h, h]: + for m in [h, 3*h, 5*h, 7*h]: + for n in [1, 2, 3, 4]: + assert can_do([p, m], [n]) + + +@XFAIL +def test_prudnikov_fail_2F1(): + assert can_do([a, b], [b + 1]) # incomplete beta function + assert can_do([-1, b], [c]) # Poly. also -2, -3 etc + + # TODO polys + + # Legendre functions: + assert can_do([a, b], [a + b + S.Half]) + assert can_do([a, b], [a + b - S.Half]) + assert can_do([a, b], [a + b + Rational(3, 2)]) + assert can_do([a, b], [(a + b + 1)/2]) + assert can_do([a, b], [(a + b)/2 + 1]) + assert can_do([a, b], [a - b + 1]) + assert can_do([a, b], [a - b + 2]) + assert can_do([a, b], [2*b]) + assert can_do([a, b], [S.Half]) + assert can_do([a, b], [Rational(3, 2)]) + assert can_do([a, 1 - a], [c]) + assert can_do([a, 2 - a], [c]) + assert can_do([a, 3 - a], [c]) + assert can_do([a, a + S.Half], [c]) + assert can_do([1, b], [c]) + assert can_do([1, b], [Rational(3, 2)]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [1]) + + # PFDD + o = S.One + assert can_do([o/8, 1], [o/8*9]) + assert can_do([o/6, 1], [o/6*7]) + assert can_do([o/6, 1], [o/6*13]) + assert can_do([o/5, 1], [o/5*6]) + assert can_do([o/5, 1], [o/5*11]) + assert can_do([o/4, 1], [o/4*5]) + assert can_do([o/4, 1], [o/4*9]) + assert can_do([o/3, 1], [o/3*4]) + assert can_do([o/3, 1], [o/3*7]) + assert can_do([o/8*3, 1], [o/8*11]) + assert can_do([o/5*2, 1], [o/5*7]) + assert can_do([o/5*2, 1], [o/5*12]) + assert can_do([o/5*3, 1], [o/5*8]) + assert can_do([o/5*3, 1], [o/5*13]) + assert can_do([o/8*5, 1], [o/8*13]) + assert can_do([o/4*3, 1], [o/4*7]) + assert can_do([o/4*3, 1], [o/4*11]) + assert can_do([o/3*2, 1], [o/3*5]) + assert can_do([o/3*2, 1], [o/3*8]) + assert can_do([o/5*4, 1], [o/5*9]) + assert can_do([o/5*4, 1], [o/5*14]) + assert can_do([o/6*5, 1], [o/6*11]) + assert can_do([o/6*5, 1], [o/6*17]) + assert can_do([o/8*7, 1], [o/8*15]) + + +@XFAIL +def test_prudnikov_fail_3F2(): + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(1, 3), Rational(2, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(2, 3), Rational(4, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(4, 3), Rational(5, 3)]) + + # page 421 + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [a*Rational(3, 2), (3*a + 1)/2]) + + # pages 422 ... + assert can_do([Rational(-1, 2), S.Half, S.Half], [1, 1]) # elliptic integrals + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(3, 2)]) + # TODO LOTS more + + # PFDD + assert can_do([Rational(1, 8), Rational(3, 8), 1], [Rational(9, 8), Rational(11, 8)]) + assert can_do([Rational(1, 8), Rational(5, 8), 1], [Rational(9, 8), Rational(13, 8)]) + assert can_do([Rational(1, 8), Rational(7, 8), 1], [Rational(9, 8), Rational(15, 8)]) + assert can_do([Rational(1, 6), Rational(1, 3), 1], [Rational(7, 6), Rational(4, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(7, 6), Rational(5, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(5, 3), Rational(13, 6)]) + assert can_do([S.Half, 1, 1], [Rational(1, 4), Rational(3, 4)]) + # LOTS more + + +@XFAIL +def test_prudnikov_fail_other(): + # 7.11.2 + + # 7.12.1 + assert can_do([1, a], [b, 1 - 2*a + b]) # ??? + + # 7.14.2 + assert can_do([Rational(-1, 2)], [S.Half, 1]) # struve + assert can_do([1], [S.Half, S.Half]) # struve + assert can_do([Rational(1, 4)], [S.Half, Rational(5, 4)]) # PFDD + assert can_do([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)]) # PFDD + assert can_do([1], [Rational(1, 4), Rational(3, 4)]) # PFDD + assert can_do([1], [Rational(3, 4), Rational(5, 4)]) # PFDD + assert can_do([1], [Rational(5, 4), Rational(7, 4)]) # PFDD + # TODO LOTS more + + # 7.15.2 + assert can_do([S.Half, 1], [Rational(3, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + assert can_do([S.Half, 1], [Rational(7, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + + # 7.16.1 + assert can_do([], [Rational(1, 3), S(2/3)]) # PFDD + assert can_do([], [Rational(2, 3), S(4/3)]) # PFDD + assert can_do([], [Rational(5, 3), S(4/3)]) # PFDD + + # XXX this does not *evaluate* right?? + assert can_do([], [a, a + S.Half, 2*a - 1]) + + +def test_bug(): + h = hyper([-1, 1], [z], -1) + assert hyperexpand(h) == (z + 1)/z + + +def test_omgissue_203(): + h = hyper((-5, -3, -4), (-6, -6), 1) + assert hyperexpand(h) == Rational(1, 30) + h = hyper((-6, -7, -5), (-6, -6), 1) + assert hyperexpand(h) == Rational(-1, 6) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_powsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..61bdc93d052baf4b1e80da8f5864cf22b8fa383e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_powsimp.py @@ -0,0 +1,368 @@ +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.simplify.powsimp import (powdenest, powsimp) +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.core.symbol import Str + +from sympy.abc import x, y, z, a, b + + +def test_powsimp(): + x, y, z, n = symbols('x,y,z,n') + f = Function('f') + assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1 + assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1 + + assert powsimp( + f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x)) + assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1) + assert exp(x)*exp(y) == exp(x)*exp(y) + assert powsimp(exp(x)*exp(y)) == exp(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \ + exp(x + y)*2**(x + y) + assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \ + exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y) + assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y)) + assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y)) + assert powsimp(x**2*x**y) == x**(2 + y) + # This should remain factored, because 'exp' with deep=True is supposed + # to act like old automatic exponent combining. + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \ + (1 + E*exp(E))*exp(-E) + x, y = symbols('x,y', nonnegative=True) + n = Symbol('n', real=True) + assert powsimp(y**n * (y/x)**(-n)) == x**n + assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \ + == (x*y)**(x*y)**(x*y) + assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x) + assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z) + assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z + assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \ + exp(x)/(1 + exp(x + y)) + assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y)) + assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x + assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x + p = symbols('p', positive=True) + assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2)) + assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2)) + + # coefficient of exponent can only be simplified for positive bases + assert powsimp(2**(2*x)) == 4**x + assert powsimp((-1)**(2*x)) == (-1)**(2*x) + i = symbols('i', integer=True) + assert powsimp((-1)**(2*i)) == 1 + assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not + # force=True overrides assumptions + assert powsimp((-1)**(2*x), force=True) == 1 + + # rational exponents allow combining of negative terms + w, n, m = symbols('w n m', negative=True) + e = i/a # not a rational exponent if `a` is unknown + ex = w**e*n**e*m**e + assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a) + e = i/3 + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3) + e = (3 + i)/i + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e + + eq = x**(a*Rational(2, 3)) + # eq != (x**a)**(2/3) (try x = -1 and a = 3 to see) + assert powsimp(eq).exp == eq.exp == a*Rational(2, 3) + # powdenest goes the other direction + assert powsimp(2**(2*x)) == 4**x + + assert powsimp(exp(p/2)) == exp(p/2) + + # issue 6368 + eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)]) + assert powsimp(eq) == eq and eq.is_Mul + + assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2))) + + # issue 8836 + assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)' + + # issue 9183 + assert powsimp(-0.1**x) == -0.1**x + + # issue 10095 + assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo + + # PR 13131 + eq = sin(2*x)**2*sin(2.0*x)**2 + assert powsimp(eq) == eq + + # issue 14615 + assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2) + ) == x*y*(x*y**2)**Rational(5, 2) + + #issue 27380 + assert powsimp(1.0**(x+1)/1.0**x) == 1.0 + +def test_powsimp_negated_base(): + assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y) + assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y) + p = symbols('p', positive=True) + reps = {p: 2, a: S.Half} + assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps) + assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps) + n = symbols('n', negative=True) + reps = {p: -2, a: S.Half} + assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half) + assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps) + # if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a + eq = (-x)**a/x**a + assert powsimp(eq) == eq + + +def test_powsimp_nc(): + x, y, z = symbols('x,y,z') + A, B, C = symbols('A B C', commutative=False) + + assert powsimp(A**x*A**y, combine='all') == A**(x + y) + assert powsimp(A**x*A**y, combine='base') == A**x*A**y + assert powsimp(A**x*A**y, combine='exp') == A**(x + y) + + assert powsimp(A**x*B**x, combine='all') == A**x*B**x + assert powsimp(A**x*B**x, combine='base') == A**x*B**x + assert powsimp(A**x*B**x, combine='exp') == A**x*B**x + + assert powsimp(B**x*A**x, combine='all') == B**x*A**x + assert powsimp(B**x*A**x, combine='base') == B**x*A**x + assert powsimp(B**x*A**x, combine='exp') == B**x*A**x + + assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z) + assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z + assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z) + + assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x + + assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x + + +def test_issue_6440(): + assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4) + + +def test_powdenest(): + x, y = symbols('x,y') + p, q = symbols('p q', positive=True) + i, j = symbols('i,j', integer=True) + + assert powdenest(x) == x + assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x)) + assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x) + assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x)) + assert powdenest(exp(3*x*log(2))) == 2**(3*x) + assert powdenest(sqrt(p**2)) == p + eq = p**(2*i)*q**(4*i) + assert powdenest(eq) == (p*q**2)**(2*i) + # -X-> (x**x)**i*(x**x)**j == x**(x*(i + j)) + assert powdenest((x**x)**(i + j)) + assert powdenest(exp(3*y*log(x))) == x**(3*y) + assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y + assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3 + assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x + assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y) + assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \ + (((x**(a*Rational(2, 3)))**(3*y/i))**x) + assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z) + assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j) + e = ((p**(2*a))**(3*y))**x + assert powdenest(e) == e + e = ((x**2*y**4)**a)**(x*y) + assert powdenest(e) == e + e = (((x**2*y**4)**a)**(x*y))**3 + assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \ + (x*y**2)**(2*a*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \ + (x*y**2)**(6*a*x*y) + assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i) + x, y = symbols('x,y', positive=True) + assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i) + + assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2) + assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i + + assert powdenest(4**x) == 2**(2*x) + assert powdenest((4**x)**y) == 2**(2*x*y) + assert powdenest(4**x*y) == 2**(2*x)*y + + +def test_powdenest_polar(): + x, y, z = symbols('x y z', polar=True) + a, b, c = symbols('a b c') + assert powdenest((x*y*z)**a) == x**a*y**a*z**a + assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c) + assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2) + + +def test_issue_5805(): + arg = ((gamma(x)*hyper((), (), x))*pi)**2 + assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2 + assert arg.is_positive is None + + +def test_issue_9324_powsimp_on_matrix_symbol(): + M = MatrixSymbol('M', 10, 10) + expr = powsimp(M, deep=True) + assert expr == M + assert expr.args[0] == Str('M') + + +def test_issue_6367(): + z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half) + assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0 + assert powsimp(z.normal()) == 0 + assert simplify(z) == 0 + assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2 + assert powsimp(z) != 0 + + +def test_powsimp_polar(): + from sympy.functions.elementary.complexes import polar_lift + from sympy.functions.elementary.exponential import exp_polar + x, y, z = symbols('x y z') + p, q, r = symbols('p q r', polar=True) + + assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x) + assert powsimp(p**x * q**x) == (p*q)**x + assert p**x * (1/p)**x == 1 + assert (1/p)**x == p**(-x) + + assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y) + assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \ + (p*exp_polar(1))**(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \ + exp_polar(x + y)*p**(x + y) + assert powsimp( + exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \ + == p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y) + assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \ + sin(exp_polar(x)*exp_polar(y)) + assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \ + sin(exp_polar(x + y)) + + +def test_issue_5728(): + b = x*sqrt(y) + a = sqrt(b) + c = sqrt(sqrt(x)*y) + assert powsimp(a*b) == sqrt(b)**3 + assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5 + assert powsimp(a*x**2*c**3*y) == c**3*a**5 + assert powsimp(a*x*c**3*y**2) == c**7*a + assert powsimp(x*c**3*y**2) == c**7 + assert powsimp(x*c**3*y) == x*y*c**3 + assert powsimp(sqrt(x)*c**3*y) == c**5 + assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3 + assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \ + sqrt(x)*sqrt(y)**3*c**3 + assert powsimp(a**2*a*x**2*y) == a**7 + + # symbolic powers work, too + b = x**y*y + a = b*sqrt(b) + assert a.is_Mul is True + assert powsimp(a) == sqrt(b)**3 + + # as does exp + a = x*exp(y*Rational(2, 3)) + assert powsimp(a*sqrt(a)) == sqrt(a)**3 + assert powsimp(a**2*sqrt(a)) == sqrt(a)**5 + assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) == + -I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3)) + assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) == + -(-1)**Rational(1, 3)* + (-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3)) + + +def test_issue_10195(): + a = Symbol('a', integer=True) + l = Symbol('l', even=True, nonzero=True) + n = Symbol('n', odd=True) + e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half) + assert powsimp((-1)**(l/2)) == I**l + assert powsimp((-1)**(n/2)) == I**n + assert powsimp((-1)**(n*Rational(3, 2))) == -I**n + assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) + + S.Half) + assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a + +def test_issue_15709(): + assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1) + assert powsimp(2*3**x/3) == 2*3**(x-1) + + +def test_issue_11981(): + x, y = symbols('x y', commutative=False) + assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2 + + +def test_issue_17524(): + a = symbols("a", real=True) + e = (-1 - a**2)*sqrt(1 + a**2) + assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2) + + +def test_issue_19627(): + # if you use force the user must verify + assert powdenest(sqrt(sin(x)**2), force=True) == sin(x) + assert powdenest((x**(S.Half/y))**(2*y), force=True) == x + from sympy.core.function import expand_power_base + e = 1 - a + expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e + assert powdenest(expand_power_base(expr, force=True), force=True + ) == x**b*y**(1 - b)*exp(z) + + +def test_issue_22546(): + p1, p2 = symbols('p1, p2', positive=True) + ref = powsimp(p1**z/p2**z) + e = z + 1 + ans = ref.subs(z, e) + assert ans.is_Pow + assert powsimp(p1**e/p2**e) == ans + i = symbols('i', integer=True) + ref = powsimp(x**i/y**i) + e = i + 1 + ans = ref.subs(i, e) + assert ans.is_Pow + assert powsimp(x**e/y**e) == ans diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_radsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ff955e48a34536c1752c565c0864dedae6a214 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_radsimp.py @@ -0,0 +1,498 @@ +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.series.order import O +from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect) + +from sympy.core.expr import unchanged +from sympy.core.mul import _unevaluated_Mul as umul +from sympy.simplify.radsimp import (_unevaluated_Add, + collect_sqrt, fraction_expand, collect_abs) +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z, a, b, c, d + + +def test_radsimp(): + r2 = sqrt(2) + r3 = sqrt(3) + r5 = sqrt(5) + r7 = sqrt(7) + assert fraction(radsimp(1/r2)) == (sqrt(2), 2) + assert radsimp(1/(1 + r2)) == \ + -1 + sqrt(2) + assert radsimp(1/(r2 + r3)) == \ + -sqrt(2) + sqrt(3) + assert fraction(radsimp(1/(1 + r2 + r3))) == \ + (-sqrt(6) + sqrt(2) + 2, 4) + assert fraction(radsimp(1/(r2 + r3 + r5))) == \ + (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12) + assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == ( + (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) + + 93 + 46*sqrt(6) + 53*sqrt(5), 71)) + assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == ( + (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105) + + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215)) + z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7)) + assert len((3616791619821680643598*z).args) == 16 + assert radsimp(1/z) == 1/z + assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7 + assert radsimp(1/(r2*3)) == \ + sqrt(2)/6 + assert radsimp(1/(r2*a + r3 + r5 + r7)) == ( + (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 - + 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5 + - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 + + 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 - + 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 - + 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 - + 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a - + 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 - + 480*a**6 + 3128*a**4 - 6360*a**2 + 3481)) + assert radsimp(1/(r2*a + r2*b + r3 + r7)) == ( + (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a + + b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a + + b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 - + 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8)) + assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \ + sqrt(2)/(2*a + 2*b + 2*c + 2*d) + assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == ( + (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b + + 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1)) + assert radsimp((y**2 - x)/(y - sqrt(x))) == \ + sqrt(x) + y + assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \ + -(sqrt(x) + y) + assert radsimp(1/(1 - I + a*I)) == \ + (-I*a + 1 + I)/(a**2 - 2*a + 2) + assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \ + (-x - sqrt(y))/((x - y)*(x**2 - y)) + e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y)) + assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y)) + assert radsimp(1/e) == ( + (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 - + 9*y))) + assert radsimp(1 + 1/(1 + sqrt(3))) == \ + Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1 + A = symbols("A", commutative=False) + assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \ + x**2 + sqrt(2)*x**2 - sqrt(2)*x*A + assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3) + assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3 + + # issue 6532 + assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x) + assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3) + assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6) + + # issue 5994 + e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/' + '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))') + assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2) + + # issue 5986 (modifications to radimp didn't initially recognize this so + # the test is included here) + assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1 + + # from issue 5934 + eq = ( + (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) - + 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) - + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) + + 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) + + 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 - + 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) + + 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2)) + assert radsimp(eq) is S.NaN # it's 0/0 + + # work with normal form + e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3 + assert radsimp(e) == ( + -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) + + 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15) + - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) + + 8291415*sqrt(21))/1300423175 + 3) + + # obey power rules + base = sqrt(3) - sqrt(2) + assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3 + assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3 + assert radsimp(1/(-base)**x) == (-base)**(-x) + assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x + assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x) + + # recurse + e = cos(1/(1 + sqrt(2))) + assert radsimp(e) == cos(-sqrt(2) + 1) + assert radsimp(e/2) == cos(-sqrt(2) + 1)/2 + assert radsimp(1/e) == 1/cos(-sqrt(2) + 1) + assert radsimp(2/e) == 2/cos(-sqrt(2) + 1) + assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x) + + # test that symbolic denominators are not processed + r = 1 + sqrt(2) + assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1) + assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2)) + assert radsimp(x/(y + r)/r, symbolic=False) == \ + -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2)) + + # issue 7408 + eq = sqrt(x)/sqrt(y) + assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y) + assert radsimp(eq, symbolic=False) == eq + + # issue 7498 + assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3) + + # for coverage + eq = sqrt(x)/y**2 + assert radsimp(eq) == eq + + # handle non-Expr args + from sympy.integrals.integrals import Integral + eq = Integral(x/(sqrt(2) - 1), (x, 0, 1/(sqrt(2) + 1))) + assert radsimp(eq) == Integral((sqrt(2) + 1)*x , (x, 0, sqrt(2) - 1)) + + from sympy.sets import FiniteSet + eq = FiniteSet(x/(sqrt(2) - 1)) + assert radsimp(eq) == FiniteSet((sqrt(2) + 1)*x) + +def test_radsimp_issue_3214(): + c, p = symbols('c p', positive=True) + s = sqrt(c**2 - p**2) + b = (c + I*p - s)/(c + I*p + s) + assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p) + + +def test_collect_1(): + """Collect with respect to Symbol""" + x, y, z, n = symbols('x,y,z,n') + assert collect(1, x) == 1 + assert collect( x + y*x, x ) == x * (1 + y) + assert collect( x + x**2, x ) == x + x**2 + assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y) + assert collect( x**2 + y*x, x ) == x*y + x**2 + assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y + assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x) + + assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \ + x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \ + x**3*(4*(1 + y)).expand() + x**4 + # symbols can be given as any iterable + expr = x + y + assert collect(expr, expr.free_symbols) == expr + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None + ) == x*exp(x) + 3*x + (y + 2)*sin(x) + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x + + y*x*exp(x), x, exact=None + ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x) + + +def test_collect_2(): + """Collect with respect to a sum""" + a, b, x = symbols('a,b,x') + assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)), + sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x)) + + +def test_collect_3(): + """Collect with respect to a product""" + a, b, c = symbols('a,b,c') + f = Function('f') + x, y, z, n = symbols('x,y,z,n') + + assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8)) + + assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2) + assert collect( x*y + a*x*y, x*y) == x*y*(1 + a) + assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a) + assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x) + + assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x) + assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \ + x**2*log(x)**2*(a + b) + + # with respect to a product of three symbols + assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z + + +def test_collect_4(): + """Collect with respect to a power""" + a, b, c, x = symbols('a,b,c,x') + + assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b) + # issue 6096: 2 stays with c (unless c is integer or x is positive0 + assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b) + + +def test_collect_5(): + """Collect with respect to a tuple""" + a, x, y, z, n = symbols('a,x,y,z,n') + assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [ + z*(1 + a + x**2*y**4) + x**2*y**4, + z*(1 + a) + x**2*y**4*(1 + z) ] + assert collect((1 + (x + y) + (x + y)**2).expand(), + [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2 + + +def test_collect_pr19431(): + """Unevaluated collect with respect to a product""" + a = symbols('a') + assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1) + + +def test_collect_D(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fx = D(f(x), x) + fxx = D(f(x), x, x) + + assert collect(a*fx + b*fx, fx) == (a + b)*fx + assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x) + assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x) + # issue 4784 + assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \ + (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x) + e = (1 + x*fx + fx)/f(x) + assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x) + + +def test_collect_func(): + f = ((x + a + 1)**3).expand() + + assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \ + x*(3*a**2 + 6*a + 3) + 1 + assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \ + (a + 1)**3 + + assert collect(f, x, evaluate=False) == { + S.One: a**3 + 3*a**2 + 3*a + 1, + x: 3*a**2 + 6*a + 3, x**2: 3*a + 3, + x**3: 1 + } + + assert collect(f, x, factor, evaluate=False) == { + S.One: (a + 1)**3, x: 3*(a + 1)**2, + x**2: umul(S(3), a + 1), x**3: 1} + + +def test_collect_order(): + a, b, x, t = symbols('a,b,x,t') + + assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3)) + assert collect(t + t*x + x**2 + O(x**3), t) == \ + t*(1 + x + O(x**3)) + x**2 + O(x**3) + + f = a*x + b*x + c*x**2 + d*x**2 + O(x**3) + g = x*(a + b) + x**2*(c + d) + O(x**3) + + assert collect(f, x) == g + assert collect(f, x, distribute_order_term=False) == g + + f = sin(a + b).series(b, 0, 10) + + assert collect(f, [sin(a), cos(a)]) == \ + sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10) + assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \ + sin(a)*cos(b).series(b, 0, 10).removeO() + \ + cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10) + + +def test_rcollect(): + assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \ + (x + y*(1 + x + x**2))/(x + y) + assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1))) + + +def test_collect_D_0(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fxx = D(f(x), x, x) + + assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx + + +def test_collect_Wild(): + """Collect with respect to functions with Wild argument""" + a, b, x, y = symbols('a b x y') + f = Function('f') + w1 = Wild('.1') + w2 = Wild('.2') + assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x) + assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x) + assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \ + a*(x + 1)**y + (x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \ + (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y + + +def test_collect_const(): + # coverage not provided by above tests + assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \ + 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb + assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \ + 2*sqrt(3) + 4*a*sqrt(5) + assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \ + sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3) + + # issue 5290 + assert collect_const(2*x + 2*y + 1, 2) == \ + collect_const(2*x + 2*y + 1) == \ + Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False) + assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, 2) == \ + Mul(2, x - y - z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, -2) == \ + _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False)) + + # this is why the content_primitive is used + eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2 + assert collect_sqrt(eq + 2) == \ + 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2 + + # issue 16296 + assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False) + + +def test_issue_13143(): + f = Function('f') + fx = f(x).diff(x) + e = f(x) + fx + f(x)*fx + # collect function before derivative + assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx + e = f(x) + f(x)*fx + x*fx*f(x) + assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x) + assert collect(e, f(x)) == (x*fx + fx + 1)*f(x) + e = f(x) + fx + f(x)*fx + assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx + assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x) + + +def test_issue_6097(): + assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0 + assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0 + + +def test_fraction_expand(): + eq = (x + y)*y/x + assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x + assert eq.expand() == y + y**2/x + + +def test_fraction(): + x, y, z = map(Symbol, 'xyz') + A = Symbol('A', commutative=False) + + assert fraction(S.Half) == (1, 2) + + assert fraction(x) == (x, 1) + assert fraction(1/x) == (1, x) + assert fraction(x/y) == (x, y) + assert fraction(x/2) == (x, 2) + + assert fraction(x*y/z) == (x*y, z) + assert fraction(x/(y*z)) == (x, y*z) + + assert fraction(1/y**2) == (1, y**2) + assert fraction(x/y**2) == (x, y**2) + + assert fraction((x**2 + 1)/y) == (x**2 + 1, y) + assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7) + + assert fraction(exp(-x), exact=True) == (exp(-x), 1) + assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False)) + + assert fraction(x*A/y) == (x*A, y) + assert fraction(x*A**-1/y) == (x*A**-1, y) + + n = symbols('n', negative=True) + assert fraction(exp(n)) == (1, exp(-n)) + assert fraction(exp(-n)) == (exp(-n), 1) + + p = symbols('p', positive=True) + assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1) + + m = Mul(1, 1, S.Half, evaluate=False) + assert fraction(m) == (1, 2) + assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2) + + m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False) + assert fraction(m) == (1, 4) + assert fraction(m, exact=True) == \ + (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False)) + + +def test_issue_5615(): + aA, Re, a, b, D = symbols('aA Re a b D') + e = ((D**3*a + b*aA**3)/Re).expand() + assert collect(e, [aA**3/Re, a]) == e + + +def test_issue_5933(): + from sympy.geometry.polygon import (Polygon, RegularPolygon) + from sympy.simplify.radsimp import denom + x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x + assert abs(denom(x).n()) > 1e-12 + assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it + + +def test_issue_14608(): + a, b = symbols('a b', commutative=False) + x, y = symbols('x y') + raises(AttributeError, lambda: collect(a*b + b*a, a)) + assert collect(x*y + y*(x+1), a) == x*y + y*(x+1) + assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a + + +def test_collect_abs(): + s = abs(x) + abs(y) + assert collect_abs(s) == s + assert unchanged(Mul, abs(x), abs(y)) + ans = Abs(x*y) + assert isinstance(ans, Abs) + assert collect_abs(abs(x)*abs(y)) == ans + assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans) + + # See https://github.com/sympy/sympy/issues/12910 + p = Symbol('p', positive=True) + assert collect_abs(p/abs(1-p)).is_commutative is True + + +def test_issue_19149(): + eq = exp(3*x/4) + assert collect(eq, exp(x)) == eq + +def test_issue_19719(): + a, b = symbols('a, b') + expr = a**2 * (b + 1) + (7 + 1/b)/a + collected = collect(expr, (a**2, 1/a), evaluate=False) + # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace + assert collected == {a**2: b + 1, 1/a: 7 + 1/b} + + +def test_issue_21355(): + assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2)) + assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_ratsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..14e84fd2b227518baff1bda4e5b27ecc40a8bcdd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_ratsimp.py @@ -0,0 +1,78 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.error_functions import erf +from sympy.polys.domains import GF +from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime) + +from sympy.abc import x, y, z, t, a, b, c, d, e + + +def test_ratsimp(): + f, g = 1/x + 1/y, (x + y)/(x*y) + + assert f != g and ratsimp(f) == g + + f, g = 1/(1 + 1/x), 1 - 1/(x + 1) + + assert f != g and ratsimp(f) == g + + f, g = x/(x + y) + y/(x + y), 1 + + assert f != g and ratsimp(f) == g + + f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y + + assert f != g and ratsimp(f) == g + + f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z + + e*x)/(x*y + z) + G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z), + a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)] + + assert f != g and ratsimp(f) in G + + A = sqrt(pi) + + B = log(erf(x) - 1) + C = log(erf(x) + 1) + + D = 8 - 8*erf(x) + + f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D + + assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4) + + +def test_ratsimpmodprime(): + a = y**5 + x + y + b = x - y + F = [x*y**5 - x - y] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (-x**2 - x*y - x - y) / (-x**2 + x*y) + + a = x + y**2 - 2 + b = x + y**2 - y - 1 + F = [x*y - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + y - x)/(y - x) + + a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15 + b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21 + F = [x**2 + y**2 - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + 5*y - 5*x)/(8*y - 6*x) + + a = x*y - x - 2*y + 4 + b = x + y**2 - 2*y + F = [x - 2, y - 3] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + Rational(2, 5) + + # Test a bug where denominators would be dropped + assert ratsimpmodprime(x, [y - 2*x], order='lex') == \ + y/2 + + a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2)) + assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1 + assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_rewrite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..56d2fb7a85bd959bd4accc2f36127429efbdbe70 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_rewrite.py @@ -0,0 +1,31 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.testing.pytest import _both_exp_pow + +x, y, z, n = symbols('x,y,z,n') + + +@_both_exp_pow +def test_has(): + assert cot(x).has(x) + assert cot(x).has(cot) + assert not cot(x).has(sin) + assert sin(x).has(x) + assert sin(x).has(sin) + assert not sin(x).has(cot) + assert exp(x).has(exp) + + +@_both_exp_pow +def test_sin_exp_rewrite(): + assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x)) + assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x) + assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x) + assert (sin(5*y) - sin( + 2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x) + assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y) + assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y) + # This next test currently passes... not clear whether it should or not? + assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_simplify.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..a5bf469f68adf5c5dfbdf7559414681e2fb28ba7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_simplify.py @@ -0,0 +1,1093 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import unchanged +from sympy.core.function import (count_ops, diff, expand, expand_multinomial, Function, Derivative) +from sympy.core.mul import Mul, _keep_coeff +from sympy.core import GoldenRatio +from sympy.core.numbers import (E, Float, I, oo, pi, Rational, zoo) +from sympy.core.relational import (Eq, Lt, Gt, Ge, Le) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, csch, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, sinc, tan) +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.geometry.polygon import rad +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.polytools import (factor, Poly) +from sympy.simplify.simplify import (besselsimp, hypersimp, inversecombine, logcombine, nsimplify, nthroot, posify, separatevars, signsimp, simplify) +from sympy.solvers.solvers import solve + +from sympy.testing.pytest import XFAIL, slow, _both_exp_pow +from sympy.abc import x, y, z, t, a, b, c, d, e, f, g, h, i, n + + +def test_issue_7263(): + assert abs((simplify(30.8**2 - 82.5**2 * sin(rad(11.6))**2)).evalf() - \ + 673.447451402970) < 1e-12 + + +def test_factorial_simplify(): + # There are more tests in test_factorials.py. + x = Symbol('x') + assert simplify(factorial(x)/x) == gamma(x) + assert simplify(factorial(factorial(x))) == factorial(factorial(x)) + + +def test_simplify_expr(): + x, y, z, k, n, m, w, s, A = symbols('x,y,z,k,n,m,w,s,A') + f = Function('f') + + assert all(simplify(tmp) == tmp for tmp in [I, E, oo, x, -x, -oo, -E, -I]) + + e = 1/x + 1/y + assert e != (x + y)/(x*y) + assert simplify(e) == (x + y)/(x*y) + + e = A**2*s**4/(4*pi*k*m**3) + assert simplify(e) == e + + e = (4 + 4*x - 2*(2 + 2*x))/(2 + 2*x) + assert simplify(e) == 0 + + e = (-4*x*y**2 - 2*y**3 - 2*x**2*y)/(x + y)**2 + assert simplify(e) == -2*y + + e = -x - y - (x + y)**(-1)*y**2 + (x + y)**(-1)*x**2 + assert simplify(e) == -2*y + + e = (x + x*y)/x + assert simplify(e) == 1 + y + + e = (f(x) + y*f(x))/f(x) + assert simplify(e) == 1 + y + + e = (2 * (1/n - cos(n * pi)/n))/pi + assert simplify(e) == (-cos(pi*n) + 1)/(pi*n)*2 + + e = integrate(1/(x**3 + 1), x).diff(x) + assert simplify(e) == 1/(x**3 + 1) + + e = integrate(x/(x**2 + 3*x + 1), x).diff(x) + assert simplify(e) == x/(x**2 + 3*x + 1) + + f = Symbol('f') + A = Matrix([[2*k - m*w**2, -k], [-k, k - m*w**2]]).inv() + assert simplify((A*Matrix([0, f]))[1] - + (-f*(2*k - m*w**2)/(k**2 - (k - m*w**2)*(2*k - m*w**2)))) == 0 + + f = -x + y/(z + t) + z*x/(z + t) + z*a/(z + t) + t*x/(z + t) + assert simplify(f) == (y + a*z)/(z + t) + + # issue 10347 + expr = -x*(y**2 - 1)*(2*y**2*(x**2 - 1)/(a*(x**2 - y**2)**2) + (x**2 - 1) + /(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)* + (y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(x**2 - 1) + sqrt( + (-x**2 + 1)*(y**2 - 1))*(x*(-x*y**2 + x)/sqrt(-x**2*y**2 + x**2 + y**2 - + 1) + sqrt(-x**2*y**2 + x**2 + y**2 - 1))*sin(z))/(a*sqrt((-x**2 + 1)*( + y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a* + (x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2 + *y**2 + x**2 + y**2 - 1)*cos(z)/(x**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - + 1))*(-x*y**2 + x)*cos(z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1) + sqrt((-x**2 + + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z))/(a*sqrt((-x**2 + + 1)*(y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos( + z)/(a*(x**2 - y**2)) - y*sqrt((-x**2 + 1)*(y**2 - 1))*(-x*y*sqrt(-x**2* + y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)*(y**2 - 1)) + 2*x*y*sqrt( + -x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) + (x*y*sqrt(( + -x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(y**2 - + 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*sin(z)/sqrt(-x**2*y**2 + + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2)))*sin( + z)/(a*(x**2 - y**2)) + y*(x**2 - 1)*(-2*x*y*(x**2 - 1)/(a*(x**2 - y**2) + **2) + 2*x*y/(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + y*(x**2 - 1)*(y**2 - + 1)*(-x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2)*(y**2 + - 1)) + 2*x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2) + **2) + (x*y*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - + 1)*cos(z)/(y**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*cos( + z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1) + )*(x**2 - y**2)))*cos(z)/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2) + ) - x*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin( + z)**2/(a**2*(x**2 - 1)*(x**2 - y**2)*(y**2 - 1)) - x*sqrt((-x**2 + 1)*( + y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)**2/(a**2*(x**2 - 1)*( + x**2 - y**2)*(y**2 - 1)) + assert simplify(expr) == 2*x/(a**2*(x**2 - y**2)) + + #issue 17631 + assert simplify('((-1/2)*Boole(True)*Boole(False)-1)*Boole(True)') == \ + Mul(sympify('(2 + Boole(True)*Boole(False))'), sympify('-Boole(True)/2')) + + A, B = symbols('A,B', commutative=False) + + assert simplify(A*B - B*A) == A*B - B*A + assert simplify(A/(1 + y/x)) == x*A/(x + y) + assert simplify(A*(1/x + 1/y)) == A/x + A/y #(x + y)*A/(x*y) + + assert simplify(log(2) + log(3)) == log(6) + assert simplify(log(2*x) - log(2)) == log(x) + + assert simplify(hyper([], [], x)) == exp(x) + + +def test_issue_3557(): + f_1 = x*a + y*b + z*c - 1 + f_2 = x*d + y*e + z*f - 1 + f_3 = x*g + y*h + z*i - 1 + + solutions = solve([f_1, f_2, f_3], x, y, z, simplify=False) + + assert simplify(solutions[y]) == \ + (a*i + c*d + f*g - a*f - c*g - d*i)/ \ + (a*e*i + b*f*g + c*d*h - a*f*h - b*d*i - c*e*g) + + +def test_simplify_other(): + assert simplify(sin(x)**2 + cos(x)**2) == 1 + assert simplify(gamma(x + 1)/gamma(x)) == x + assert simplify(sin(x)**2 + cos(x)**2 + factorial(x)/gamma(x)) == 1 + x + assert simplify( + Eq(sin(x)**2 + cos(x)**2, factorial(x)/gamma(x))) == Eq(x, 1) + nc = symbols('nc', commutative=False) + assert simplify(x + x*nc) == x*(1 + nc) + # issue 6123 + # f = exp(-I*(k*sqrt(t) + x/(2*sqrt(t)))**2) + # ans = integrate(f, (k, -oo, oo), conds='none') + ans = I*(-pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))*erf(x*exp(I*pi*Rational(-3, 4))/ + (2*sqrt(t)))/(2*sqrt(t)) + pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))/ + (2*sqrt(t)))*exp(-I*x**2/(4*t))/(sqrt(pi)*x) - I*sqrt(pi) * \ + (-erf(x*exp(I*pi/4)/(2*sqrt(t))) + 1)*exp(I*pi/4)/(2*sqrt(t)) + assert simplify(ans) == -(-1)**Rational(3, 4)*sqrt(pi)/sqrt(t) + # issue 6370 + assert simplify(2**(2 + x)/4) == 2**x + + +@_both_exp_pow +def test_simplify_complex(): + cosAsExp = cos(x)._eval_rewrite_as_exp(x) + tanAsExp = tan(x)._eval_rewrite_as_exp(x) + assert simplify(cosAsExp*tanAsExp) == sin(x) # issue 4341 + + # issue 10124 + assert simplify(exp(Matrix([[0, -1], [1, 0]]))) == Matrix([[cos(1), + -sin(1)], [sin(1), cos(1)]]) + + +def test_simplify_ratio(): + # roots of x**3-3*x+5 + roots = ['(1/2 - sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3) + 1/((1/2 - ' + 'sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3))', + '1/((1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)) + ' + '(1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)', + '-(sqrt(21)/2 + 5/2)**(1/3) - 1/(sqrt(21)/2 + 5/2)**(1/3)'] + + for r in roots: + r = S(r) + assert count_ops(simplify(r, ratio=1)) <= count_ops(r) + # If ratio=oo, simplify() is always applied: + assert simplify(r, ratio=oo) is not r + + +def test_simplify_measure(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + assert measure1(simplify(expr, measure=measure1)) <= measure1(expr) + assert measure2(simplify(expr, measure=measure2)) <= measure2(expr) + + expr2 = Eq(sin(x)**2 + cos(x)**2, 1) + assert measure1(simplify(expr2, measure=measure1)) <= measure1(expr2) + assert measure2(simplify(expr2, measure=measure2)) <= measure2(expr2) + + +def test_simplify_rational(): + expr = 2**x*2.**y + assert simplify(expr, rational = True) == 2**(x+y) + assert simplify(expr, rational = None) == 2.0**(x+y) + assert simplify(expr, rational = False) == expr + assert simplify('0.9 - 0.8 - 0.1', rational = True) == 0 + + +def test_simplify_issue_1308(): + assert simplify(exp(Rational(-1, 2)) + exp(Rational(-3, 2))) == \ + (1 + E)*exp(Rational(-3, 2)) + + +def test_issue_5652(): + assert simplify(E + exp(-E)) == exp(-E) + E + n = symbols('n', commutative=False) + assert simplify(n + n**(-n)) == n + n**(-n) + +def test_issue_27380(): + assert simplify(1.0**(x+1)/1.0**x) == 1.0 + +def test_simplify_fail1(): + x = Symbol('x') + y = Symbol('y') + e = (x + y)**2/(-4*x*y**2 - 2*y**3 - 2*x**2*y) + assert simplify(e) == 1 / (-2*y) + + +def test_nthroot(): + assert nthroot(90 + 34*sqrt(7), 3) == sqrt(7) + 3 + q = 1 + sqrt(2) - 2*sqrt(3) + sqrt(6) + sqrt(7) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(41 + 29*sqrt(2), 5) == 1 + sqrt(2) + assert nthroot(-41 - 29*sqrt(2), 5) == -1 - sqrt(2) + expr = 1320*sqrt(10) + 4216 + 2576*sqrt(6) + 1640*sqrt(15) + assert nthroot(expr, 5) == 1 + sqrt(6) + sqrt(15) + q = 1 + sqrt(2) + sqrt(3) + sqrt(5) + assert expand_multinomial(nthroot(expand_multinomial(q**5), 5)) == q + q = 1 + sqrt(2) + 7*sqrt(6) + 2*sqrt(10) + assert nthroot(expand_multinomial(q**5), 5, 8) == q + q = 1 + sqrt(2) - 2*sqrt(3) + 1171*sqrt(6) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(expand_multinomial(q**6), 6) == q + + +def test_nthroot1(): + q = 1 + sqrt(2) + sqrt(3) + S.One/10**20 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + q = 1 + sqrt(2) + sqrt(3) + S.One/10**30 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + + +@_both_exp_pow +def test_separatevars(): + x, y, z, n = symbols('x,y,z,n') + assert separatevars(2*n*x*z + 2*x*y*z) == 2*x*z*(n + y) + assert separatevars(x*z + x*y*z) == x*z*(1 + y) + assert separatevars(pi*x*z + pi*x*y*z) == pi*x*z*(1 + y) + assert separatevars(x*y**2*sin(x) + x*sin(x)*sin(y)) == \ + x*(sin(y) + y**2)*sin(x) + assert separatevars(x*exp(x + y) + x*exp(x)) == x*(1 + exp(y))*exp(x) + assert separatevars((x*(y + 1))**z).is_Pow # != x**z*(1 + y)**z + assert separatevars(1 + x + y + x*y) == (x + 1)*(y + 1) + assert separatevars(y/pi*exp(-(z - x)/cos(n))) == \ + y*exp(x/cos(n))*exp(-z/cos(n))/pi + assert separatevars((x + y)*(x - y) + y**2 + 2*x + 1) == (x + 1)**2 + # issue 4858 + p = Symbol('p', positive=True) + assert separatevars(sqrt(p**2 + x*p**2)) == p*sqrt(1 + x) + assert separatevars(sqrt(y*(p**2 + x*p**2))) == p*sqrt(y*(1 + x)) + assert separatevars(sqrt(y*(p**2 + x*p**2)), force=True) == \ + p*sqrt(y)*sqrt(1 + x) + # issue 4865 + assert separatevars(sqrt(x*y)).is_Pow + assert separatevars(sqrt(x*y), force=True) == sqrt(x)*sqrt(y) + # issue 4957 + # any type sequence for symbols is fine + assert separatevars(((2*x + 2)*y), dict=True, symbols=()) == \ + {'coeff': 1, x: 2*x + 2, y: y} + # separable + assert separatevars(((2*x + 2)*y), dict=True, symbols=[x]) == \ + {'coeff': y, x: 2*x + 2} + assert separatevars(((2*x + 2)*y), dict=True, symbols=[]) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True, symbols=None) == \ + {'coeff': y*(2*x + 2)} + # not separable + assert separatevars(3, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=()) is None + assert separatevars(2*x + y, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=None) == {'coeff': 2*x + y} + # issue 4808 + n, m = symbols('n,m', commutative=False) + assert separatevars(m + n*m) == (1 + n)*m + assert separatevars(x + x*n) == x*(1 + n) + # issue 4910 + f = Function('f') + assert separatevars(f(x) + x*f(x)) == f(x) + x*f(x) + # a noncommutable object present + eq = x*(1 + hyper((), (), y*z)) + assert separatevars(eq) == eq + + s = separatevars(abs(x*y)) + assert s == abs(x)*abs(y) and s.is_Mul + z = cos(1)**2 + sin(1)**2 - 1 + a = abs(x*z) + s = separatevars(a) + assert not a.is_Mul and s.is_Mul and s == abs(x)*abs(z) + s = separatevars(abs(x*y*z)) + assert s == abs(x)*abs(y)*abs(z) + + # abs(x+y)/abs(z) would be better but we test this here to + # see that it doesn't raise + assert separatevars(abs((x+y)/z)) == abs((x+y)/z) + + +def test_separatevars_advanced_factor(): + x, y, z = symbols('x,y,z') + assert separatevars(1 + log(x)*log(y) + log(x) + log(y)) == \ + (log(x) + 1)*(log(y) + 1) + assert separatevars(1 + x - log(z) - x*log(z) - exp(y)*log(z) - + x*exp(y)*log(z) + x*exp(y) + exp(y)) == \ + -((x + 1)*(log(z) - 1)*(exp(y) + 1)) + x, y = symbols('x,y', positive=True) + assert separatevars(1 + log(x**log(y)) + log(x*y)) == \ + (log(x) + 1)*(log(y) + 1) + + +def test_hypersimp(): + n, k = symbols('n,k', integer=True) + + assert hypersimp(factorial(k), k) == k + 1 + assert hypersimp(factorial(k**2), k) is None + + assert hypersimp(1/factorial(k), k) == 1/(k + 1) + + assert hypersimp(2**k/factorial(k)**2, k) == 2/(k + 1)**2 + + assert hypersimp(binomial(n, k), k) == (n - k)/(k + 1) + assert hypersimp(binomial(n + 1, k), k) == (n - k + 1)/(k + 1) + + term = (4*k + 1)*factorial(k)/factorial(2*k + 1) + assert hypersimp(term, k) == S.Half*((4*k + 5)/(3 + 14*k + 8*k**2)) + + term = 1/((2*k - 1)*factorial(2*k + 1)) + assert hypersimp(term, k) == (k - S.Half)/((k + 1)*(2*k + 1)*(2*k + 3)) + + term = binomial(n, k)*(-1)**k/factorial(k) + assert hypersimp(term, k) == (k - n)/(k + 1)**2 + + +def test_nsimplify(): + x = Symbol("x") + assert nsimplify(0) == 0 + assert nsimplify(-1) == -1 + assert nsimplify(1) == 1 + assert nsimplify(1 + x) == 1 + x + assert nsimplify(2.7) == Rational(27, 10) + assert nsimplify(1 - GoldenRatio) == (1 - sqrt(5))/2 + assert nsimplify((1 + sqrt(5))/4, [GoldenRatio]) == GoldenRatio/2 + assert nsimplify(2/GoldenRatio, [GoldenRatio]) == 2*GoldenRatio - 2 + assert nsimplify(exp(pi*I*Rational(5, 3), evaluate=False)) == \ + sympify('1/2 - sqrt(3)*I/2') + assert nsimplify(sin(pi*Rational(3, 5), evaluate=False)) == \ + sympify('sqrt(sqrt(5)/8 + 5/8)') + assert nsimplify(sqrt(atan('1', evaluate=False))*(2 + I), [pi]) == \ + sqrt(pi) + sqrt(pi)/2*I + assert nsimplify(2 + exp(2*atan('1/4')*I)) == sympify('49/17 + 8*I/17') + assert nsimplify(pi, tolerance=0.01) == Rational(22, 7) + assert nsimplify(pi, tolerance=0.001) == Rational(355, 113) + assert nsimplify(0.33333, tolerance=1e-4) == Rational(1, 3) + assert nsimplify(2.0**(1/3.), tolerance=0.001) == Rational(635, 504) + assert nsimplify(2.0**(1/3.), tolerance=0.001, full=True) == \ + 2**Rational(1, 3) + assert nsimplify(x + .5, rational=True) == S.Half + x + assert nsimplify(1/.3 + x, rational=True) == Rational(10, 3) + x + assert nsimplify(log(3).n(), rational=True) == \ + sympify('109861228866811/100000000000000') + assert nsimplify(Float(0.272198261287950), [pi, log(2)]) == pi*log(2)/8 + assert nsimplify(Float(0.272198261287950).n(3), [pi, log(2)]) == \ + -pi/4 - log(2) + Rational(7, 4) + assert nsimplify(x/7.0) == x/7 + assert nsimplify(pi/1e2) == pi/100 + assert nsimplify(pi/1e2, rational=False) == pi/100.0 + assert nsimplify(pi/1e-7) == 10000000*pi + assert not nsimplify( + factor(-3.0*z**2*(z**2)**(-2.5) + 3*(z**2)**(-1.5))).atoms(Float) + e = x**0.0 + assert e.is_Pow and nsimplify(x**0.0) == 1 + assert nsimplify(3.333333, tolerance=0.1, rational=True) == Rational(10, 3) + assert nsimplify(3.333333, tolerance=0.01, rational=True) == Rational(10, 3) + assert nsimplify(3.666666, tolerance=0.1, rational=True) == Rational(11, 3) + assert nsimplify(3.666666, tolerance=0.01, rational=True) == Rational(11, 3) + assert nsimplify(33, tolerance=10, rational=True) == Rational(33) + assert nsimplify(33.33, tolerance=10, rational=True) == Rational(30) + assert nsimplify(37.76, tolerance=10, rational=True) == Rational(40) + assert nsimplify(-203.1) == Rational(-2031, 10) + assert nsimplify(.2, tolerance=0) == Rational(1, 5) + assert nsimplify(-.2, tolerance=0) == Rational(-1, 5) + assert nsimplify(.2222, tolerance=0) == Rational(1111, 5000) + assert nsimplify(-.2222, tolerance=0) == Rational(-1111, 5000) + # issue 7211, PR 4112 + assert nsimplify(S(2e-8)) == Rational(1, 50000000) + # issue 7322 direct test + assert nsimplify(1e-42, rational=True) != 0 + # issue 10336 + inf = Float('inf') + infs = (-oo, oo, inf, -inf) + for zi in infs: + ans = sign(zi)*oo + assert nsimplify(zi) == ans + assert nsimplify(zi + x) == x + ans + + assert nsimplify(0.33333333, rational=True, rational_conversion='exact') == Rational(0.33333333) + + # Make sure nsimplify on expressions uses full precision + assert nsimplify(pi.evalf(100)*x, rational_conversion='exact').evalf(100) == pi.evalf(100)*x + + +def test_issue_9448(): + tmp = sympify("1/(1 - (-1)**(2/3) - (-1)**(1/3)) + 1/(1 + (-1)**(2/3) + (-1)**(1/3))") + assert nsimplify(tmp) == S.Half + + +def test_extract_minus_sign(): + x = Symbol("x") + y = Symbol("y") + a = Symbol("a") + b = Symbol("b") + assert simplify(-x/-y) == x/y + assert simplify(-x/y) == -x/y + assert simplify(x/y) == x/y + assert simplify(x/-y) == -x/y + assert simplify(-x/0) == zoo*x + assert simplify(Rational(-5, 0)) is zoo + assert simplify(-a*x/(-y - b)) == a*x/(b + y) + + +def test_diff(): + x = Symbol("x") + y = Symbol("y") + f = Function("f") + g = Function("g") + assert simplify(g(x).diff(x)*f(x).diff(x) - f(x).diff(x)*g(x).diff(x)) == 0 + assert simplify(2*f(x)*f(x).diff(x) - diff(f(x)**2, x)) == 0 + assert simplify(diff(1/f(x), x) + f(x).diff(x)/f(x)**2) == 0 + assert simplify(f(x).diff(x, y) - f(x).diff(y, x)) == 0 + + +def test_logcombine_1(): + x, y = symbols("x,y") + a = Symbol("a") + z, w = symbols("z,w", positive=True) + b = Symbol("b", real=True) + assert logcombine(log(x) + 2*log(y)) == log(x) + 2*log(y) + assert logcombine(log(x) + 2*log(y), force=True) == log(x*y**2) + assert logcombine(a*log(w) + log(z)) == a*log(w) + log(z) + assert logcombine(b*log(z) + b*log(x)) == log(z**b) + b*log(x) + assert logcombine(b*log(z) - log(w)) == log(z**b/w) + assert logcombine(log(x)*log(z)) == log(x)*log(z) + assert logcombine(log(w)*log(x)) == log(w)*log(x) + assert logcombine(cos(-2*log(z) + b*log(w))) in [cos(log(w**b/z**2)), + cos(log(z**2/w**b))] + assert logcombine(log(log(x) - log(y)) - log(z), force=True) == \ + log(log(x/y)/z) + assert logcombine((2 + I)*log(x), force=True) == (2 + I)*log(x) + assert logcombine((x**2 + log(x) - log(y))/(x*y), force=True) == \ + (x**2 + log(x/y))/(x*y) + # the following could also give log(z*x**log(y**2)), what we + # are testing is that a canonical result is obtained + assert logcombine(log(x)*2*log(y) + log(z), force=True) == \ + log(z*y**log(x**2)) + assert logcombine((x*y + sqrt(x**4 + y**4) + log(x) - log(y))/(pi*x**Rational(2, 3)* + sqrt(y)**3), force=True) == ( + x*y + sqrt(x**4 + y**4) + log(x/y))/(pi*x**Rational(2, 3)*y**Rational(3, 2)) + assert logcombine(gamma(-log(x/y))*acos(-log(x/y)), force=True) == \ + acos(-log(x/y))*gamma(-log(x/y)) + + assert logcombine(2*log(z)*log(w)*log(x) + log(z) + log(w)) == \ + log(z**log(w**2))*log(x) + log(w*z) + assert logcombine(3*log(w) + 3*log(z)) == log(w**3*z**3) + assert logcombine(x*(y + 1) + log(2) + log(3)) == x*(y + 1) + log(6) + assert logcombine((x + y)*log(w) + (-x - y)*log(3)) == (x + y)*log(w/3) + # a single unknown can combine + assert logcombine(log(x) + log(2)) == log(2*x) + eq = log(abs(x)) + log(abs(y)) + assert logcombine(eq) == eq + reps = {x: 0, y: 0} + assert log(abs(x)*abs(y)).subs(reps) != eq.subs(reps) + + +def test_logcombine_complex_coeff(): + i = Integral((sin(x**2) + cos(x**3))/x, x) + assert logcombine(i, force=True) == i + assert logcombine(i + 2*log(x), force=True) == \ + i + log(x**2) + + +def test_issue_5950(): + x, y = symbols("x,y", positive=True) + assert logcombine(log(3) - log(2)) == log(Rational(3,2), evaluate=False) + assert logcombine(log(x) - log(y)) == log(x/y) + assert logcombine(log(Rational(3,2), evaluate=False) - log(2)) == \ + log(Rational(3,4), evaluate=False) + + +def test_posify(): + x = symbols('x') + + assert str(posify( + x + + Symbol('p', positive=True) + + Symbol('n', negative=True))) == '(_x + n + p, {_x: x})' + + eq, rep = posify(1/x) + assert log(eq).expand().subs(rep) == -log(x) + assert str(posify([x, 1 + x])) == '([_x, _x + 1], {_x: x})' + + p = symbols('p', positive=True) + n = symbols('n', negative=True) + orig = [x, n, p] + modified, reps = posify(orig) + assert str(modified) == '[_x, n, p]' + assert [w.subs(reps) for w in modified] == orig + + assert str(Integral(posify(1/x + y)[0], (y, 1, 3)).expand()) == \ + 'Integral(1/_x, (y, 1, 3)) + Integral(_y, (y, 1, 3))' + assert str(Sum(posify(1/x**n)[0], (n,1,3)).expand()) == \ + 'Sum(_x**(-n), (n, 1, 3))' + + A = Matrix([[1, 2, 3], [4, 5, 6 * Abs(x)]]) + Ap, rep = posify(A) + assert Ap == A.subs(*reversed(rep.popitem())) + + # issue 16438 + k = Symbol('k', finite=True) + eq, rep = posify(k) + assert eq.assumptions0 == {'positive': True, 'zero': False, 'imaginary': False, + 'nonpositive': False, 'commutative': True, 'hermitian': True, 'real': True, 'nonzero': True, + 'nonnegative': True, 'negative': False, 'complex': True, 'finite': True, + 'infinite': False, 'extended_real':True, 'extended_negative': False, + 'extended_nonnegative': True, 'extended_nonpositive': False, + 'extended_nonzero': True, 'extended_positive': True} + + +def test_issue_4194(): + # simplify should call cancel + f = Function('f') + assert simplify((4*x + 6*f(y))/(2*x + 3*f(y))) == 2 + + +@XFAIL +def test_simplify_float_vs_integer(): + # Test for issue 4473: + # https://github.com/sympy/sympy/issues/4473 + assert simplify(x**2.0 - x**2) == 0 + assert simplify(x**2 - x**2.0) == 0 + + +def test_as_content_primitive(): + assert (x/2 + y).as_content_primitive() == (S.Half, x + 2*y) + assert (x/2 + y).as_content_primitive(clear=False) == (S.One, x/2 + y) + assert (y*(x/2 + y)).as_content_primitive() == (S.Half, y*(x + 2*y)) + assert (y*(x/2 + y)).as_content_primitive(clear=False) == (S.One, y*(x/2 + y)) + + # although the _as_content_primitive methods do not alter the underlying structure, + # the as_content_primitive function will touch up the expression and join + # bases that would otherwise have not been joined. + assert (x*(2 + 2*x)*(3*x + 3)**2).as_content_primitive() == \ + (18, x*(x + 1)**3) + assert (2 + 2*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (2, x + 3*y*(y + 1) + 1) + assert ((2 + 6*x)**2).as_content_primitive() == \ + (4, (3*x + 1)**2) + assert ((2 + 6*x)**(2*y)).as_content_primitive() == \ + (1, (_keep_coeff(S(2), (3*x + 1)))**(2*y)) + assert (5 + 10*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (1, 10*x + 6*y*(y + 1) + 5) + assert (5*(x*(1 + y)) + 2*x*(3 + 3*y)).as_content_primitive() == \ + (11, x*(y + 1)) + assert ((5*(x*(1 + y)) + 2*x*(3 + 3*y))**2).as_content_primitive() == \ + (121, x**2*(y + 1)**2) + assert (y**2).as_content_primitive() == \ + (1, y**2) + assert (S.Infinity).as_content_primitive() == (1, oo) + eq = x**(2 + y) + assert (eq).as_content_primitive() == (1, eq) + assert (S.Half**(2 + x)).as_content_primitive() == (Rational(1, 4), 2**-x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), (Rational(-1, 2))**x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), Rational(-1, 2)**x) + assert (4**((1 + y)/2)).as_content_primitive() == (2, 4**(y/2)) + assert (3**((1 + y)/2)).as_content_primitive() == \ + (1, 3**(Mul(S.Half, 1 + y, evaluate=False))) + assert (5**Rational(3, 4)).as_content_primitive() == (1, 5**Rational(3, 4)) + assert (5**Rational(7, 4)).as_content_primitive() == (5, 5**Rational(3, 4)) + assert Add(z*Rational(5, 7), 0.5*x, y*Rational(3, 2), evaluate=False).as_content_primitive() == \ + (Rational(1, 14), 7.0*x + 21*y + 10*z) + assert (2**Rational(3, 4) + 2**Rational(1, 4)*sqrt(3)).as_content_primitive(radical=True) == \ + (1, 2**Rational(1, 4)*(sqrt(2) + sqrt(3))) + + +def test_signsimp(): + e = x*(-x + 1) + x*(x - 1) + assert signsimp(Eq(e, 0)) is S.true + assert Abs(x - 1) == Abs(1 - x) + assert signsimp(y - x) == y - x + assert signsimp(y - x, evaluate=False) == Mul(-1, x - y, evaluate=False) + + +def test_besselsimp(): + from sympy.functions.special.bessel import (besseli, besselj, bessely) + from sympy.integrals.transforms import cosine_transform + assert besselsimp(exp(-I*pi*y/2)*besseli(y, z*exp_polar(I*pi/2))) == \ + besselj(y, z) + assert besselsimp(exp(-I*pi*a/2)*besseli(a, 2*sqrt(x)*exp_polar(I*pi/2))) == \ + besselj(a, 2*sqrt(x)) + assert besselsimp(sqrt(2)*sqrt(pi)*x**Rational(1, 4)*exp(I*pi/4)*exp(-I*pi*a/2) * + besseli(Rational(-1, 2), sqrt(x)*exp_polar(I*pi/2)) * + besseli(a, sqrt(x)*exp_polar(I*pi/2))/2) == \ + besselj(a, sqrt(x)) * cos(sqrt(x)) + assert besselsimp(besseli(Rational(-1, 2), z)) == \ + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(a, z*exp_polar(-I*pi/2))) == \ + exp(-I*pi*a/2)*besselj(a, z) + assert cosine_transform(1/t*sin(a/t), t, y) == \ + sqrt(2)*sqrt(pi)*besselj(0, 2*sqrt(a)*sqrt(y))/2 + + assert besselsimp(x**2*(a*(-2*besselj(5*I, x) + besselj(-2 + 5*I, x) + + besselj(2 + 5*I, x)) + b*(-2*bessely(5*I, x) + bessely(-2 + 5*I, x) + + bessely(2 + 5*I, x)))/4 + x*(a*(besselj(-1 + 5*I, x)/2 - besselj(1 + 5*I, x)/2) + + b*(bessely(-1 + 5*I, x)/2 - bessely(1 + 5*I, x)/2)) + (x**2 + 25)*(a*besselj(5*I, x) + + b*bessely(5*I, x))) == 0 + + assert besselsimp(81*x**2*(a*(besselj(Rational(-5, 3), 9*x) - 2*besselj(Rational(1, 3), 9*x) + besselj(Rational(7, 3), 9*x)) + + b*(bessely(Rational(-5, 3), 9*x) - 2*bessely(Rational(1, 3), 9*x) + bessely(Rational(7, 3), 9*x)))/4 + x*(a*(9*besselj(Rational(-2, 3), 9*x)/2 + - 9*besselj(Rational(4, 3), 9*x)/2) + b*(9*bessely(Rational(-2, 3), 9*x)/2 - 9*bessely(Rational(4, 3), 9*x)/2)) + + (81*x**2 - Rational(1, 9))*(a*besselj(Rational(1, 3), 9*x) + b*bessely(Rational(1, 3), 9*x))) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) - 2*a*besselj(a, x)/x) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) + besselj(a, x)) == (2*a + x)*besselj(a, x)/x + + assert besselsimp(x**2* besselj(a,x) + x**3*besselj(a+1, x) + besselj(a+2, x)) == \ + 2*a*x*besselj(a + 1, x) + x**3*besselj(a + 1, x) - x**2*besselj(a + 2, x) + 2*x*besselj(a + 1, x) + besselj(a + 2, x) + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + s1 = simplify(e1) + s2 = simplify(e2) + s3 = simplify(e3) + assert simplify(Piecewise((e1, x < e2), (e3, True))) == \ + Piecewise((s1, x < s2), (s3, True)) + + +def test_polymorphism(): + class A(Basic): + def _eval_simplify(x, **kwargs): + return S.One + + a = A(S(5), S(2)) + assert simplify(a) == 1 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert simplify(I*sqrt(n1)) == -sqrt(-n1) + + +def test_issue_6811(): + eq = (x + 2*y)*(2*x + 2) + assert simplify(eq) == (x + 1)*(x + 2*y)*2 + # reject the 2-arg Mul -- these are a headache for test writing + assert simplify(eq.expand()) == \ + 2*x**2 + 4*x*y + 2*x + 4*y + + +def test_issue_6920(): + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + # wrap in f to show that the change happens wherever ei occurs + f = Function('f') + assert [simplify(f(ei)).args[0] for ei in e] == ok + + +def test_issue_7001(): + from sympy.abc import r, R + assert simplify(-(r*Piecewise((pi*Rational(4, 3), r <= R), + (-8*pi*R**3/(3*r**3), True)) + 2*Piecewise((pi*r*Rational(4, 3), r <= R), + (4*pi*R**3/(3*r**2), True)))/(4*pi*r)) == \ + Piecewise((-1, r <= R), (0, True)) + + +def test_inequality_no_auto_simplify(): + # no simplify on creation but can be simplified + lhs = cos(x)**2 + sin(x)**2 + rhs = 2 + e = Lt(lhs, rhs, evaluate=False) + assert e is not S.true + assert simplify(e) + + +def test_issue_9398(): + from sympy.core.numbers import Number + from sympy.polys.polytools import cancel + assert cancel(1e-14) != 0 + assert cancel(1e-14*I) != 0 + + assert simplify(1e-14) != 0 + assert simplify(1e-14*I) != 0 + + assert (I*Number(1.)*Number(10)**Number(-14)).simplify() != 0 + + assert cancel(1e-20) != 0 + assert cancel(1e-20*I) != 0 + + assert simplify(1e-20) != 0 + assert simplify(1e-20*I) != 0 + + assert cancel(1e-100) != 0 + assert cancel(1e-100*I) != 0 + + assert simplify(1e-100) != 0 + assert simplify(1e-100*I) != 0 + + f = Float("1e-1000") + assert cancel(f) != 0 + assert cancel(f*I) != 0 + + assert simplify(f) != 0 + assert simplify(f*I) != 0 + + +def test_issue_9324_simplify(): + M = MatrixSymbol('M', 10, 10) + e = M[0, 0] + M[5, 4] + 1304 + assert simplify(e) == e + + +def test_issue_9817_simplify(): + # simplify on trace of substituted explicit quadratic form of matrix + # expressions (a scalar) should return without errors (AttributeError) + # See issue #9817 and #9190 for the original bug more discussion on this + from sympy.matrices.expressions import Identity, trace + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + assert simplify((trace(quadratic.as_explicit())).xreplace({v:x, A:X})) == 14 + + +def test_issue_13474(): + x = Symbol('x') + assert simplify(x + csch(sinc(1))) == x + csch(sinc(1)) + + +@_both_exp_pow +def test_simplify_function_inverse(): + # "inverse" attribute does not guarantee that f(g(x)) is x + # so this simplification should not happen automatically. + # See issue #12140 + x, y = symbols('x, y') + g = Function('g') + + class f(Function): + def inverse(self, argindex=1): + return g + + assert simplify(f(g(x))) == f(g(x)) + assert inversecombine(f(g(x))) == x + assert simplify(f(g(x)), inverse=True) == x + assert simplify(f(g(sin(x)**2 + cos(x)**2)), inverse=True) == 1 + assert simplify(f(g(x, y)), inverse=True) == f(g(x, y)) + assert unchanged(asin, sin(x)) + assert simplify(asin(sin(x))) == asin(sin(x)) + assert simplify(2*asin(sin(3*x)), inverse=True) == 6*x + assert simplify(log(exp(x))) == log(exp(x)) + assert simplify(log(exp(x)), inverse=True) == x + assert simplify(exp(log(x)), inverse=True) == x + assert simplify(log(exp(x), 2), inverse=True) == x/log(2) + assert simplify(log(exp(x), 2, evaluate=False), inverse=True) == x/log(2) + + +def test_clear_coefficients(): + from sympy.simplify.simplify import clear_coefficients + assert clear_coefficients(4*y*(6*x + 3)) == (y*(2*x + 1), 0) + assert clear_coefficients(4*y*(6*x + 3) - 2) == (y*(2*x + 1), Rational(1, 6)) + assert clear_coefficients(4*y*(6*x + 3) - 2, x) == (y*(2*x + 1), x/12 + Rational(1, 6)) + assert clear_coefficients(sqrt(2) - 2) == (sqrt(2), 2) + assert clear_coefficients(4*sqrt(2) - 2) == (sqrt(2), S.Half) + assert clear_coefficients(S(3), x) == (0, x - 3) + assert clear_coefficients(S.Infinity, x) == (S.Infinity, x) + assert clear_coefficients(-S.Pi, x) == (S.Pi, -x) + assert clear_coefficients(2 - S.Pi/3, x) == (pi, -3*x + 6) + +def test_nc_simplify(): + from sympy.simplify.simplify import nc_simplify + from sympy.matrices.expressions import MatPow, Identity + from sympy.core import Pow + from functools import reduce + + a, b, c, d = symbols('a b c d', commutative = False) + x = Symbol('x') + A = MatrixSymbol("A", x, x) + B = MatrixSymbol("B", x, x) + C = MatrixSymbol("C", x, x) + D = MatrixSymbol("D", x, x) + subst = {a: A, b: B, c: C, d:D} + funcs = {Add: lambda x,y: x+y, Mul: lambda x,y: x*y } + + def _to_matrix(expr): + if expr in subst: + return subst[expr] + if isinstance(expr, Pow): + return MatPow(_to_matrix(expr.args[0]), expr.args[1]) + elif isinstance(expr, (Add, Mul)): + return reduce(funcs[expr.func],[_to_matrix(a) for a in expr.args]) + else: + return expr*Identity(x) + + def _check(expr, simplified, deep=True, matrix=True): + assert nc_simplify(expr, deep=deep) == simplified + assert expand(expr) == expand(simplified) + if matrix: + m_simp = _to_matrix(simplified).doit(inv_expand=False) + assert nc_simplify(_to_matrix(expr), deep=deep) == m_simp + + _check(a*b*a*b*a*b*c*(a*b)**3*c, ((a*b)**3*c)**2) + _check(a*b*(a*b)**-2*a*b, 1) + _check(a**2*b*a*b*a*b*(a*b)**-1, a*(a*b)**2, matrix=False) + _check(b*a*b**2*a*b**2*a*b**2, b*(a*b**2)**3) + _check(a*b*a**2*b*a**2*b*a**3, (a*b*a)**3*a**2) + _check(a**2*b*a**4*b*a**4*b*a**2, (a**2*b*a**2)**3) + _check(a**3*b*a**4*b*a**4*b*a, a**3*(b*a**4)**3*a**-3) + _check(a*b*a*b + a*b*c*x*a*b*c, (a*b)**2 + x*(a*b*c)**2) + _check(a*b*a*b*c*a*b*a*b*c, ((a*b)**2*c)**2) + _check(b**-1*a**-1*(a*b)**2, a*b) + _check(a**-1*b*c**-1, (c*b**-1*a)**-1) + expr = a**3*b*a**4*b*a**4*b*a**2*b*a**2*(b*a**2)**2*b*a**2*b*a**2 + for _ in range(10): + expr *= a*b + _check(expr, a**3*(b*a**4)**2*(b*a**2)**6*(a*b)**10) + _check((a*b*a*b)**2, (a*b*a*b)**2, deep=False) + _check(a*b*(c*d)**2, a*b*(c*d)**2) + expr = b**-1*(a**-1*b**-1 - a**-1*c*b**-1)**-1*a**-1 + assert nc_simplify(expr) == (1-c)**-1 + # commutative expressions should be returned without an error + assert nc_simplify(2*x**2) == 2*x**2 + +def test_issue_15965(): + A = Sum(z*x**y, (x, 1, a)) + anew = z*Sum(x**y, (x, 1, a)) + B = Integral(x*y, x) + bdo = x**2*y/2 + assert simplify(A + B) == anew + bdo + assert simplify(A) == anew + assert simplify(B) == bdo + assert simplify(B, doit=False) == y*Integral(x, x) + + +def test_issue_17137(): + assert simplify(cos(x)**I) == cos(x)**I + assert simplify(cos(x)**(2 + 3*I)) == cos(x)**(2 + 3*I) + + +def test_issue_21869(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + expr = And(Eq(x**2, 4), Le(x, y)) + assert expr.simplify() == expr + + expr = And(Eq(x**2, 4), Eq(x, 2)) + assert expr.simplify() == Eq(x, 2) + + expr = And(Eq(x**3, x**2), Eq(x, 1)) + assert expr.simplify() == Eq(x, 1) + + expr = And(Eq(sin(x), x**2), Eq(x, 0)) + assert expr.simplify() == Eq(x, 0) + + expr = And(Eq(x**3, x**2), Eq(x, 2)) + assert expr.simplify() == S.false + + expr = And(Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 1), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, 2*x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,2), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == S.false + + +def test_issue_7971_21740(): + z = Integral(x, (x, 1, 1)) + assert z != 0 + assert simplify(z) is S.Zero + assert simplify(S.Zero) is S.Zero + z = simplify(Float(0)) + assert z is not S.Zero and z == 0.0 + + +@slow +def test_issue_17141_slow(): + # Should not give RecursionError + assert simplify((2**acos(I+1)**2).rewrite('log')) == 2**((pi + 2*I*log(-1 + + sqrt(1 - 2*I) + I))**2/4) + + +def test_issue_17141(): + # Check that there is no RecursionError + assert simplify(x**(1 / acos(I))) == x**(2/(pi - 2*I*log(1 + sqrt(2)))) + assert simplify(acos(-I)**2*acos(I)**2) == \ + log(1 + sqrt(2))**4 + pi**2*log(1 + sqrt(2))**2/2 + pi**4/16 + assert simplify(2**acos(I)**2) == 2**((pi - 2*I*log(1 + sqrt(2)))**2/4) + p = 2**acos(I+1)**2 + assert simplify(p) == p + + +def test_simplify_kroneckerdelta(): + i, j = symbols("i j") + K = KroneckerDelta + + assert simplify(K(i, j)) == K(i, j) + assert simplify(K(0, j)) == K(0, j) + assert simplify(K(i, 0)) == K(i, 0) + + assert simplify(K(0, j).rewrite(Piecewise) * K(1, j)) == 0 + assert simplify(K(1, i) + Piecewise((1, Eq(j, 2)), (0, True))) == K(1, i) + K(2, j) + + # issue 17214 + assert simplify(K(0, j) * K(1, j)) == 0 + + n = Symbol('n', integer=True) + assert simplify(K(0, n) * K(1, n)) == 0 + + M = Matrix(4, 4, lambda i, j: K(j - i, n) if i <= j else 0) + assert simplify(M**2) == Matrix([[K(0, n), 0, K(1, n), 0], + [0, K(0, n), 0, K(1, n)], + [0, 0, K(0, n), 0], + [0, 0, 0, K(0, n)]]) + assert simplify(eye(1) * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) == Matrix([[0]]) + + assert simplify(S.Infinity * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) is S.NaN + + +def test_issue_17292(): + assert simplify(abs(x)/abs(x**2)) == 1/abs(x) + # this is bigger than the issue: check that deep processing works + assert simplify(5*abs((x**2 - 1)/(x - 1))) == 5*Abs(x + 1) + + +def test_issue_19822(): + expr = And(Gt(n-2, 1), Gt(n, 1)) + assert simplify(expr) == Gt(n, 3) + + +def test_issue_18645(): + expr = And(Ge(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + expr = And(Eq(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + + +@XFAIL +def test_issue_18642(): + i = Symbol("i", integer=True) + n = Symbol("n", integer=True) + expr = And(Eq(i, 2 * n), Le(i, 2*n -1)) + assert simplify(expr) == S.false + + +@XFAIL +def test_issue_18389(): + n = Symbol("n", integer=True) + expr = Eq(n, 0) | (n >= 1) + assert simplify(expr) == Ge(n, 0) + + +def test_issue_8373(): + x = Symbol('x', real=True) + assert simplify(Or(x < 1, x >= 1)) == S.true + + +def test_issue_7950(): + expr = And(Eq(x, 1), Eq(x, 2)) + assert simplify(expr) == S.false + + +def test_issue_22020(): + expr = I*pi/2 -oo + assert simplify(expr) == expr + # Used to throw an error + + +def test_issue_19484(): + assert simplify(sign(x) * Abs(x)) == x + + e = x + sign(x + x**3) + assert simplify(Abs(x + x**3)*e) == x**3 + x*Abs(x**3 + x) + x + + e = x**2 + sign(x**3 + 1) + assert simplify(Abs(x**3 + 1) * e) == x**3 + x**2*Abs(x**3 + 1) + 1 + + f = Function('f') + e = x + sign(x + f(x)**3) + assert simplify(Abs(x + f(x)**3) * e) == x*Abs(x + f(x)**3) + x + f(x)**3 + + +def test_issue_23543(): + # Used to give an error + x, y, z = symbols("x y z", commutative=False) + assert (x*(y + z/2)).simplify() == x*(2*y + z)/2 + + +def test_issue_11004(): + + def f(n): + return sqrt(2*pi*n) * (n/E)**n + + def m(n, k): + return f(n) / (f(n/k)**k) + + def p(n,k): + return m(n, k) / (k**n) + + N, k = symbols('N k') + half = Float('0.5', 4) + z = log(p(n, k) / p(n, k + 1)).expand(force=True) + r = simplify(z.subs(n, N).n(4)) + assert r == ( + half*k*log(k) + - half*k*log(k + 1) + + half*log(N) + - half*log(k + 1) + + Float(0.9189224, 4) + ) + + +def test_issue_19161(): + polynomial = Poly('x**2').simplify() + assert (polynomial-x**2).simplify() == 0 + + +def test_issue_22210(): + d = Symbol('d', integer=True) + expr = 2*Derivative(sin(x), (x, d)) + assert expr.simplify() == expr + + +def test_reduce_inverses_nc_pow(): + x, y = symbols("x y", commutative=True) + Z = symbols("Z", commutative=False) + assert simplify(2**Z * y**Z) == 2**Z * y**Z + assert simplify(x**Z * y**Z) == x**Z * y**Z + x, y = symbols("x y", positive=True) + assert expand((x*y)**Z) == x**Z * y**Z + assert simplify(x**Z * y**Z) == expand((x*y)**Z) + +def test_nc_recursion_coeff(): + X = symbols("X", commutative = False) + assert (2 * cos(pi/3) * X).simplify() == X + assert (2.0 * cos(pi/3) * X).simplify() == X diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_sqrtdenest.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..41c771bb2055a1199d349ae3649f33927d79313a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_sqrtdenest.py @@ -0,0 +1,204 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.sqrtdenest import ( + _subsets as subsets, _sqrt_numeric_denest) + +r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10, + 15, 29)] + + +def test_sqrtdenest(): + d = {sqrt(5 + 2 * r6): r2 + r3, + sqrt(5. + 2 * r6): sqrt(5. + 2 * r6), + sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3), + sqrt(r2): sqrt(r2), + sqrt(5 + r7): sqrt(5 + r7), + sqrt(3 + sqrt(5 + 2*r7)): + 3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) + + r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)), + sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3} + for i in d: + assert sqrtdenest(i) == d[i], i + + +def test_sqrtdenest2(): + assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \ + r5 + sqrt(11 - 2*r29) + e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + assert sqrtdenest(e) == root(-2*r29 + 11, 4) + r = sqrt(1 + r7) + assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r) + e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand()) + assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3)) + + assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \ + sqrt(2)*root(3, 4) + root(3, 4)**3 + + assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \ + 1 + r5 + sqrt(1 + r3) + + assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \ + 1 + sqrt(1 + r3) + r5 + r7 + + e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand()) + assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3) + + e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14) + assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14) + + # check that the result is not more complicated than the input + z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16) + assert sqrtdenest(z) == z + + assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15)) + + z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(z) == z + + +def test_sqrtdenest_rec(): + assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \ + -r2 + r3 + 2*r7 + assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \ + -7 + r5 + 2*r7 + assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \ + sqrt(11)*(r2 + 3 + sqrt(11))/11 + assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \ + 9*r3 + 26 + 56*r6 + z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107) + assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23)) + z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5 + assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \ + sqrt(-1)*(-r10 + 1 + r2 + r5) + assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \ + -r10/3 + r2 + r5 + 3 + assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \ + sqrt(1 + r2 + r3 + r7) + assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15 + + w = 1 + r2 + r3 + r5 + r7 + assert sqrtdenest(sqrt((w**2).expand())) == w + z = sqrt((w**2).expand() + 1) + assert sqrtdenest(z) == z + + z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3) + assert sqrtdenest(z) == z + + +def test_issue_6241(): + z = sqrt( -320 + 32*sqrt(5) + 64*r15) + assert sqrtdenest(z) == z + + +def test_sqrtdenest3(): + z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11)) + assert sqrtdenest(z) == -1 + r2 + r10 + assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10) + z = sqrt(sqrt(r2 + 2) + 2) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \ + sqrt(-2*r10 - 4*r2 + 8*r5 + 20) + assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \ + r10 + 5 + 4*r2 + 3*r5 + z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + r = sqrt(-2*r29 + 11) + assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5) + + n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2) + d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5), + evaluate=False)) + + +def test_sqrtdenest4(): + # see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192 + z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5)) + z1 = sqrtdenest(z) + c = sqrt(-r5 + 5) + z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand() + assert sqrtdenest(z) == z1 + + z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8) + assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2 + + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + z = sqrt((w**2).expand()) + assert sqrtdenest(z) == w.expand() + + +def test_sqrt_symbolic_denest(): + x = Symbol('x') + z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand()) + assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2) + z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand()) + assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3) + z = ((1 + cos(2))**4 + 1).expand() + assert sqrtdenest(z) == z + z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand()) + assert sqrtdenest(z) == z + c = cos(3) + c2 = c**2 + assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \ + -1 - sqrt(1 + r3)*c + ra = sqrt(1 + r3) + z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112) + assert sqrtdenest(z) == z + + +def test_issue_5857(): + from sympy.abc import x, y + z = sqrt(1/(4*r3 + 7) + 1) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + + +def test_subsets(): + assert subsets(1) == [[1]] + assert subsets(4) == [ + [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0], + [0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], + [1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]] + + +def test_issue_5653(): + assert sqrtdenest( + sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2))) + +def test_issue_12420(): + assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I + e = 3 - sqrt(2)*sqrt(4 + I) + 3*I + assert sqrtdenest(e) == e + +def test_sqrt_ratcomb(): + assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0 + +def test_issue_18041(): + e = -sqrt(-2 + 2*sqrt(3)*I) + assert sqrtdenest(e) == -1 - sqrt(3)*I + +def test_issue_19914(): + a = Integer(-8) + b = Integer(-1) + r = Integer(63) + d2 = a*a - b*b*r + + assert _sqrt_numeric_denest(a, b, r, d2) == \ + sqrt(14)*I/2 + 3*sqrt(2)*I/2 + assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_trigsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea091ec8a6c7d654405968e3d035c2bbe02ccdf7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/tests/test_trigsimp.py @@ -0,0 +1,520 @@ +from itertools import product +from sympy.core.function import (Subs, count_ops, diff, expand) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan) +from sympy.functions.elementary.trigonometric import (acos, asin, atan2) +from sympy.functions.elementary.trigonometric import (asec, acsc) +from sympy.functions.elementary.trigonometric import (acot, atan) +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import (exptrigsimp, trigsimp) + +from sympy.testing.pytest import XFAIL + +from sympy.abc import x, y + + + +def test_trigsimp1(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2) == cos(x)**2 + assert trigsimp(1 - cos(x)**2) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2) == 1 + assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1 + assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5 + assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2) + + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x)) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y) + assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \ + sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1) + + assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y) + assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \ + sinh(y)/(sinh(y)*tanh(x) + cosh(y)) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0 + e = 2*sin(x)**2 + 2*cos(x)**2 + assert trigsimp(log(e)) == log(2) + + +def test_trigsimp1a(): + assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2) + assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2) + assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2) + assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2) + assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2) + + +def test_trigsimp2(): + x, y = symbols('x,y') + assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2, + recursive=True) == 1 + assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2, + recursive=True) == 1 + assert trigsimp( + Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1) + + +def test_issue_4373(): + x = Symbol("x") + assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10 + + +def test_trigsimp3(): + x, y = symbols('x,y') + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2 + assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3 + assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10 + + assert trigsimp(cos(x)/sin(x)) == 1/tan(x) + assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2 + assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10 + + assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x)) + + +def test_issue_4661(): + a, x, y = symbols('a x y') + eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2 + assert trigsimp(eq) == -4 + n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6 + d = -sin(x)**2 - 2*cos(x)**2 + assert simplify(n/d) == -1 + assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1 + eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8 + assert trigsimp(eq) == 0 + + +def test_issue_4494(): + a, b = symbols('a b') + eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2 + assert trigsimp(eq) == 1 + + +def test_issue_5948(): + a, x, y = symbols('a x y') + assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \ + cos(x)/sin(x)**7 + + +def test_issue_4775(): + a, x, y = symbols('a x y') + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y) + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3 + + +def test_issue_4280(): + a, x, y = symbols('a x y') + assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1 + assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2 + assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2 + + +def test_issue_3210(): + eqs = (sin(2)*cos(3) + sin(3)*cos(2), + -sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*cos(3) - sin(3)*cos(2), + sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*sin(3) + cos(2)*cos(3) + cos(2), + sinh(2)*cosh(3) + sinh(3)*cosh(2), + sinh(2)*sinh(3) + cosh(2)*cosh(3), + ) + assert [trigsimp(e) for e in eqs] == [ + sin(5), + cos(5), + -sin(1), + cos(1), + cos(1) + cos(2), + sinh(5), + cosh(5), + ] + + +def test_trigsimp_issues(): + a, x, y = symbols('a x y') + + # issue 4625 - factor_terms works, too + assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x) + + # issue 5948 + assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \ + cos(x)/sin(x)**3 + assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \ + sin(x)/cos(x)**3 + + # check integer exponents + e = sin(x)**y/cos(x)**y + assert trigsimp(e) == e + assert trigsimp(e.subs(y, 2)) == tan(x)**2 + assert trigsimp(e.subs(x, 1)) == tan(1)**y + + # check for multiple patterns + assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \ + 1/tan(x)**2/tan(y)**2 + assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \ + 1/(tan(x)*tan(x + y)) + + eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2 + assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2 + assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \ + cos(2)*sin(3)**4 + + # issue 6789; this generates an expression that formerly caused + # trigsimp to hang + assert cot(x).equals(tan(x)) is False + + # nan or the unchanged expression is ok, but not sin(1) + z = cos(x)**2 + sin(x)**2 - 1 + z1 = tan(x)**2 - 1/cot(x)**2 + n = (1 + z1/z) + assert trigsimp(sin(n)) != sin(1) + eq = x*(n - 1) - x*n + assert trigsimp(eq) is S.NaN + assert trigsimp(eq, recursive=True) is S.NaN + assert trigsimp(1).is_Integer + + assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1 + + +def test_trigsimp_issue_2515(): + x = Symbol('x') + assert trigsimp(x*cos(x)*tan(x)) == x*sin(x) + assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0 + + +def test_trigsimp_issue_3826(): + assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x) + + +def test_trigsimp_issue_4032(): + n = Symbol('n', integer=True, positive=True) + assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \ + 2**(n/2)*cos(pi*n/4)/2 + 2**n/4 + + +def test_trigsimp_issue_7761(): + assert trigsimp(cosh(pi/4)) == cosh(pi/4) + + +def test_trigsimp_noncommutative(): + x, y = symbols('x,y') + A, B = symbols('A,B', commutative=False) + + assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2 + assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2 + assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A + assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2 + assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2 + assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A + assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2 + assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2 + assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A + + assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A + + assert trigsimp(A*sin(x)/cos(x)) == A*tan(x) + assert trigsimp(A*tan(x)*cos(x)) == A*sin(x) + assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3 + assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2 + assert trigsimp(A*cot(x)/cos(x)) == A/sin(x) + + assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y) + assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x) + assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y) + assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y) + + assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y) + assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x) + assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y) + assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y) + + assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A + + +def test_hyperbolic_simp(): + x, y = symbols('x,y') + + assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2 + assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2 + assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1 + assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2 + assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2 + assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1 + assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2 + assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2 + assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1 + + assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5 + assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2) + + assert trigsimp(sinh(x)/cosh(x)) == tanh(x) + assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x)) + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x) + assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3 + assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2 + assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x) + + for a in (pi/6*I, pi/4*I, pi/3*I): + assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a) + assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a) + + e = 2*cosh(x)**2 - 2*sinh(x)**2 + assert trigsimp(log(e)) == log(2) + + # issue 19535: + assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2) + + assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2, + recursive=True) == 1 + assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2, + recursive=True) == 1 + + assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10 + + assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2 + assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3 + assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10 + assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3 + + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2 + assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10 + + assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x) + assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0 + + assert tan(x) != 1/cot(x) # cot doesn't auto-simplify + + assert trigsimp(tan(x) - 1/cot(x)) == 0 + assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7 + + +def test_trigsimp_groebner(): + from sympy.simplify.trigsimp import trigsimp_groebner + + c = cos(x) + s = sin(x) + ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/( + -s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21) + resnum = (5*s - 5*c + 1) + resdenom = (8*s - 6*c) + results = [resnum/resdenom, (-resnum)/(-resdenom)] + assert trigsimp_groebner(ex) in results + assert trigsimp_groebner(s/c, hints=[tan]) == tan(x) + assert trigsimp_groebner(c*s) == c*s + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner') == 2/c + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner', polynomial=True) == 2/c + + # Test quick=False works + assert trigsimp_groebner(ex, hints=[2]) in results + assert trigsimp_groebner(ex, hints=[int(2)]) in results + + # test "I" + assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x) + + # test hyperbolic / sums + assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)), + hints=[(tanh, x, y)]) == tanh(x + y) + + +def test_issue_2827_trigsimp_methods(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + ans = Matrix([1]) + M = Matrix([expr]) + assert trigsimp(M, method='fu', measure=measure1) == ans + assert trigsimp(M, method='fu', measure=measure2) != ans + # all methods should work with Basic expressions even if they + # aren't Expr + M = Matrix.eye(1) + assert all(trigsimp(M, method=m) == M for m in + 'fu matching groebner old'.split()) + # watch for E in exptrigsimp, not only exp() + eq = 1/sqrt(E) + E + assert exptrigsimp(eq) == eq + +def test_issue_15129_trigsimp_methods(): + t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0]) + t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0]) + t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0]) + r1 = t1.dot(t2) + r2 = t1.dot(t3) + assert trigsimp(r1) == cos(Rational(1, 50)) + assert trigsimp(r2) == sin(Rational(3, 50)) + +def test_exptrigsimp(): + def valid(a, b): + from sympy.core.random import verify_numerically as tn + if not (tn(a, b) and a == b): + return False + return True + + assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x) + assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x) + assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x) + assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x) + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + assert all(valid(i, j) for i, j in zip( + [exptrigsimp(ei) for ei in e], ok)) + + ue = [cos(x) + sin(x), cos(x) - sin(x), + cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)] + assert [exptrigsimp(ei) == ei for ei in ue] + + res = [] + ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)), + y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)), + y*tanh(1 + I), 1/(y*tanh(1 + I))] + for a in (1, I, x, I*x, 1 + I): + w = exp(a) + eq = y*(w - 1/w)/(w + 1/w) + res.append(simplify(eq)) + res.append(simplify(1/eq)) + assert all(valid(i, j) for i, j in zip(res, ok)) + + for a in range(1, 3): + w = exp(a) + e = w + 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*cosh(a)) + e = w - 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*sinh(a)) + +def test_exptrigsimp_noncommutative(): + a,b = symbols('a b', commutative=False) + x = Symbol('x', commutative=True) + assert exp(a + x) == exptrigsimp(exp(a)*exp(x)) + p = exp(a)*exp(b) - exp(b)*exp(a) + assert p == exptrigsimp(p) != 0 + +def test_powsimp_on_numbers(): + assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4 + + +@XFAIL +def test_issue_6811_fail(): + # from doc/src/modules/physics/mechanics/examples.rst, the current `eq` + # at Line 576 (in different variables) was formerly the equivalent and + # shorter expression given below...it would be nice to get the short one + # back again + xp, y, x, z = symbols('xp, y, x, z') + eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x)) + assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x) + + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + # s1 = simplify(e1) + s2 = simplify(e2) + # s3 = simplify(e3) + + # trigsimp tries not to touch non-trig containing args + assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \ + Piecewise((e1, e3 < s2), (e3, True)) + + +def test_issue_21594(): + assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2 + + +def test_trigsimp_old(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2 + assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1 + assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1 + assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5 + + assert trigsimp(sin(x)/cos(x), old=True) == tan(x) + assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y) + + assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0 + + assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x) + + assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2 + + +def test_trigsimp_inverse(): + alpha = symbols('alpha') + s, c = sin(alpha), cos(alpha) + + for finv in [asin, acos, asec, acsc, atan, acot]: + f = finv.inverse(None) + assert alpha == trigsimp(finv(f(alpha)), inverse=True) + + # test atan2(cos, sin), atan2(sin, cos), etc... + for a, b in [[c, s], [s, c]]: + for i, j in product([-1, 1], repeat=2): + angle = atan2(i*b, j*a) + angle_inverted = trigsimp(angle, inverse=True) + assert angle_inverted != angle # assures simplification happened + assert sin(angle_inverted) == trigsimp(sin(angle)) + assert cos(angle_inverted) == trigsimp(cos(angle)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/traversaltools.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/traversaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..75b0bd0d8fd198cb12640ab8a0fe63a23c81ed8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/traversaltools.py @@ -0,0 +1,15 @@ +from sympy.core.traversal import use as _use +from sympy.utilities.decorator import deprecated + +use = deprecated( + """ + Using use from the sympy.simplify.traversaltools submodule is + deprecated. + + Instead, use use from the top-level sympy namespace, like + + sympy.use + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved" +)(_use) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/trigsimp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5be1444a4625e4b63b339877e441d12cfbe8de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/simplify/trigsimp.py @@ -0,0 +1,1252 @@ +from collections import defaultdict +from functools import reduce + +from sympy.core import (sympify, Basic, S, Expr, factor_terms, + Mul, Add, bottom_up) +from sympy.core.cache import cacheit +from sympy.core.function import (count_ops, _mexpand, FunctionClass, expand, + expand_mul, _coeff_isneg, Derivative) +from sympy.core.numbers import I, Integer +from sympy.core.intfunc import igcd +from sympy.core.sorting import _nodes +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions import sin, cos, exp, cosh, tanh, sinh, tan, cot, coth +from sympy.functions import atan2 +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.polys import Poly, factor, cancel, parallel_poly_from_expr +from sympy.polys.domains import ZZ +from sympy.polys.polyerrors import PolificationFailed +from sympy.polys.polytools import groebner +from sympy.simplify.cse_main import cse +from sympy.strategies.core import identity +from sympy.strategies.tree import greedy +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import debug + +def trigsimp_groebner(expr, hints=[], quick=False, order="grlex", + polynomial=False): + """ + Simplify trigonometric expressions using a groebner basis algorithm. + + Explanation + =========== + + This routine takes a fraction involving trigonometric or hyperbolic + expressions, and tries to simplify it. The primary metric is the + total degree. Some attempts are made to choose the simplest possible + expression of the minimal degree, but this is non-rigorous, and also + very slow (see the ``quick=True`` option). + + If ``polynomial`` is set to True, instead of simplifying numerator and + denominator together, this function just brings numerator and denominator + into a canonical form. This is much faster, but has potentially worse + results. However, if the input is a polynomial, then the result is + guaranteed to be an equivalent polynomial of minimal degree. + + The most important option is hints. Its entries can be any of the + following: + + - a natural number + - a function + - an iterable of the form (func, var1, var2, ...) + - anything else, interpreted as a generator + + A number is used to indicate that the search space should be increased. + A function is used to indicate that said function is likely to occur in a + simplified expression. + An iterable is used indicate that func(var1 + var2 + ...) is likely to + occur in a simplified . + An additional generator also indicates that it is likely to occur. + (See examples below). + + This routine carries out various computationally intensive algorithms. + The option ``quick=True`` can be used to suppress one particularly slow + step (at the expense of potentially more complicated results, but never at + the expense of increased total degree). + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import sin, tan, cos, sinh, cosh, tanh + >>> from sympy.simplify.trigsimp import trigsimp_groebner + + Suppose you want to simplify ``sin(x)*cos(x)``. Naively, nothing happens: + + >>> ex = sin(x)*cos(x) + >>> trigsimp_groebner(ex) + sin(x)*cos(x) + + This is because ``trigsimp_groebner`` only looks for a simplification + involving just ``sin(x)`` and ``cos(x)``. You can tell it to also try + ``2*x`` by passing ``hints=[2]``: + + >>> trigsimp_groebner(ex, hints=[2]) + sin(2*x)/2 + >>> trigsimp_groebner(sin(x)**2 - cos(x)**2, hints=[2]) + -cos(2*x) + + Increasing the search space this way can quickly become expensive. A much + faster way is to give a specific expression that is likely to occur: + + >>> trigsimp_groebner(ex, hints=[sin(2*x)]) + sin(2*x)/2 + + Hyperbolic expressions are similarly supported: + + >>> trigsimp_groebner(sinh(2*x)/sinh(x)) + 2*cosh(x) + + Note how no hints had to be passed, since the expression already involved + ``2*x``. + + The tangent function is also supported. You can either pass ``tan`` in the + hints, to indicate that tan should be tried whenever cosine or sine are, + or you can pass a specific generator: + + >>> trigsimp_groebner(sin(x)/cos(x), hints=[tan]) + tan(x) + >>> trigsimp_groebner(sinh(x)/cosh(x), hints=[tanh(x)]) + tanh(x) + + Finally, you can use the iterable form to suggest that angle sum formulae + should be tried: + + >>> ex = (tan(x) + tan(y))/(1 - tan(x)*tan(y)) + >>> trigsimp_groebner(ex, hints=[(tan, x, y)]) + tan(x + y) + """ + # TODO + # - preprocess by replacing everything by funcs we can handle + # - optionally use cot instead of tan + # - more intelligent hinting. + # For example, if the ideal is small, and we have sin(x), sin(y), + # add sin(x + y) automatically... ? + # - algebraic numbers ... + # - expressions of lowest degree are not distinguished properly + # e.g. 1 - sin(x)**2 + # - we could try to order the generators intelligently, so as to influence + # which monomials appear in the quotient basis + + # THEORY + # ------ + # Ratsimpmodprime above can be used to "simplify" a rational function + # modulo a prime ideal. "Simplify" mainly means finding an equivalent + # expression of lower total degree. + # + # We intend to use this to simplify trigonometric functions. To do that, + # we need to decide (a) which ring to use, and (b) modulo which ideal to + # simplify. In practice, (a) means settling on a list of "generators" + # a, b, c, ..., such that the fraction we want to simplify is a rational + # function in a, b, c, ..., with coefficients in ZZ (integers). + # (2) means that we have to decide what relations to impose on the + # generators. There are two practical problems: + # (1) The ideal has to be *prime* (a technical term). + # (2) The relations have to be polynomials in the generators. + # + # We typically have two kinds of generators: + # - trigonometric expressions, like sin(x), cos(5*x), etc + # - "everything else", like gamma(x), pi, etc. + # + # Since this function is trigsimp, we will concentrate on what to do with + # trigonometric expressions. We can also simplify hyperbolic expressions, + # but the extensions should be clear. + # + # One crucial point is that all *other* generators really should behave + # like indeterminates. In particular if (say) "I" is one of them, then + # in fact I**2 + 1 = 0 and we may and will compute non-sensical + # expressions. However, we can work with a dummy and add the relation + # I**2 + 1 = 0 to our ideal, then substitute back in the end. + # + # Now regarding trigonometric generators. We split them into groups, + # according to the argument of the trigonometric functions. We want to + # organise this in such a way that most trigonometric identities apply in + # the same group. For example, given sin(x), cos(2*x) and cos(y), we would + # group as [sin(x), cos(2*x)] and [cos(y)]. + # + # Our prime ideal will be built in three steps: + # (1) For each group, compute a "geometrically prime" ideal of relations. + # Geometrically prime means that it generates a prime ideal in + # CC[gens], not just ZZ[gens]. + # (2) Take the union of all the generators of the ideals for all groups. + # By the geometric primality condition, this is still prime. + # (3) Add further inter-group relations which preserve primality. + # + # Step (1) works as follows. We will isolate common factors in the + # argument, so that all our generators are of the form sin(n*x), cos(n*x) + # or tan(n*x), with n an integer. Suppose first there are no tan terms. + # The ideal [sin(x)**2 + cos(x)**2 - 1] is geometrically prime, since + # X**2 + Y**2 - 1 is irreducible over CC. + # Now, if we have a generator sin(n*x), than we can, using trig identities, + # express sin(n*x) as a polynomial in sin(x) and cos(x). We can add this + # relation to the ideal, preserving geometric primality, since the quotient + # ring is unchanged. + # Thus we have treated all sin and cos terms. + # For tan(n*x), we add a relation tan(n*x)*cos(n*x) - sin(n*x) = 0. + # (This requires of course that we already have relations for cos(n*x) and + # sin(n*x).) It is not obvious, but it seems that this preserves geometric + # primality. + # XXX A real proof would be nice. HELP! + # Sketch that is a prime ideal of + # CC[S, C, T]: + # - it suffices to show that the projective closure in CP**3 is + # irreducible + # - using the half-angle substitutions, we can express sin(x), tan(x), + # cos(x) as rational functions in tan(x/2) + # - from this, we get a rational map from CP**1 to our curve + # - this is a morphism, hence the curve is prime + # + # Step (2) is trivial. + # + # Step (3) works by adding selected relations of the form + # sin(x + y) - sin(x)*cos(y) - sin(y)*cos(x), etc. Geometric primality is + # preserved by the same argument as before. + + def parse_hints(hints): + """Split hints into (n, funcs, iterables, gens).""" + n = 1 + funcs, iterables, gens = [], [], [] + for e in hints: + if isinstance(e, (SYMPY_INTS, Integer)): + n = e + elif isinstance(e, FunctionClass): + funcs.append(e) + elif iterable(e): + iterables.append((e[0], e[1:])) + # XXX sin(x+2y)? + # Note: we go through polys so e.g. + # sin(-x) -> -sin(x) -> sin(x) + gens.extend(parallel_poly_from_expr( + [e[0](x) for x in e[1:]] + [e[0](Add(*e[1:]))])[1].gens) + else: + gens.append(e) + return n, funcs, iterables, gens + + def build_ideal(x, terms): + """ + Build generators for our ideal. ``Terms`` is an iterable with elements of + the form (fn, coeff), indicating that we have a generator fn(coeff*x). + + If any of the terms is trigonometric, sin(x) and cos(x) are guaranteed + to appear in terms. Similarly for hyperbolic functions. For tan(n*x), + sin(n*x) and cos(n*x) are guaranteed. + """ + I = [] + y = Dummy('y') + for fn, coeff in terms: + for c, s, t, rel in ( + [cos, sin, tan, cos(x)**2 + sin(x)**2 - 1], + [cosh, sinh, tanh, cosh(x)**2 - sinh(x)**2 - 1]): + if coeff == 1 and fn in [c, s]: + I.append(rel) + elif fn == t: + I.append(t(coeff*x)*c(coeff*x) - s(coeff*x)) + elif fn in [c, s]: + cn = fn(coeff*y).expand(trig=True).subs(y, x) + I.append(fn(coeff*x) - cn) + return list(set(I)) + + def analyse_gens(gens, hints): + """ + Analyse the generators ``gens``, using the hints ``hints``. + + The meaning of ``hints`` is described in the main docstring. + Return a new list of generators, and also the ideal we should + work with. + """ + # First parse the hints + n, funcs, iterables, extragens = parse_hints(hints) + debug('n=%s funcs: %s iterables: %s extragens: %s', + (funcs, iterables, extragens)) + + # We just add the extragens to gens and analyse them as before + gens = list(gens) + gens.extend(extragens) + + # remove duplicates + funcs = list(set(funcs)) + iterables = list(set(iterables)) + gens = list(set(gens)) + + # all the functions we can do anything with + allfuncs = {sin, cos, tan, sinh, cosh, tanh} + # sin(3*x) -> ((3, x), sin) + trigterms = [(g.args[0].as_coeff_mul(), g.func) for g in gens + if g.func in allfuncs] + # Our list of new generators - start with anything that we cannot + # work with (i.e. is not a trigonometric term) + freegens = [g for g in gens if g.func not in allfuncs] + newgens = [] + trigdict = {} + for (coeff, var), fn in trigterms: + trigdict.setdefault(var, []).append((coeff, fn)) + res = [] # the ideal + + for key, val in trigdict.items(): + # We have now assembeled a dictionary. Its keys are common + # arguments in trigonometric expressions, and values are lists of + # pairs (fn, coeff). x0, (fn, coeff) in trigdict means that we + # need to deal with fn(coeff*x0). We take the rational gcd of the + # coeffs, call it ``gcd``. We then use x = x0/gcd as "base symbol", + # all other arguments are integral multiples thereof. + # We will build an ideal which works with sin(x), cos(x). + # If hint tan is provided, also work with tan(x). Moreover, if + # n > 1, also work with sin(k*x) for k <= n, and similarly for cos + # (and tan if the hint is provided). Finally, any generators which + # the ideal does not work with but we need to accommodate (either + # because it was in expr or because it was provided as a hint) + # we also build into the ideal. + # This selection process is expressed in the list ``terms``. + # build_ideal then generates the actual relations in our ideal, + # from this list. + fns = [x[1] for x in val] + val = [x[0] for x in val] + gcd = reduce(igcd, val) + terms = [(fn, v/gcd) for (fn, v) in zip(fns, val)] + fs = set(funcs + fns) + for c, s, t in ([cos, sin, tan], [cosh, sinh, tanh]): + if any(x in fs for x in (c, s, t)): + fs.add(c) + fs.add(s) + for fn in fs: + terms.extend((fn, k) for k in range(1, n + 1)) + extra = [] + for fn, v in terms: + if fn == tan: + extra.append((sin, v)) + extra.append((cos, v)) + if fn in [sin, cos] and tan in fs: + extra.append((tan, v)) + if fn == tanh: + extra.append((sinh, v)) + extra.append((cosh, v)) + if fn in [sinh, cosh] and tanh in fs: + extra.append((tanh, v)) + terms.extend(extra) + x = gcd*Mul(*key) + r = build_ideal(x, terms) + res.extend(r) + newgens.extend({fn(v*x) for fn, v in terms}) + + # Add generators for compound expressions from iterables + for fn, args in iterables: + if fn == tan: + # Tan expressions are recovered from sin and cos. + iterables.extend([(sin, args), (cos, args)]) + elif fn == tanh: + # Tanh expressions are recovered from sihn and cosh. + iterables.extend([(sinh, args), (cosh, args)]) + else: + dummys = symbols('d:%i' % len(args), cls=Dummy) + expr = fn( Add(*dummys)).expand(trig=True).subs(list(zip(dummys, args))) + res.append(fn(Add(*args)) - expr) + + if myI in gens: + res.append(myI**2 + 1) + freegens.remove(myI) + newgens.append(myI) + + return res, freegens, newgens + + myI = Dummy('I') + expr = expr.subs(S.ImaginaryUnit, myI) + subs = [(myI, S.ImaginaryUnit)] + + num, denom = cancel(expr).as_numer_denom() + try: + (pnum, pdenom), opt = parallel_poly_from_expr([num, denom]) + except PolificationFailed: + return expr + debug('initial gens:', opt.gens) + ideal, freegens, gens = analyse_gens(opt.gens, hints) + debug('ideal:', ideal) + debug('new gens:', gens, " -- len", len(gens)) + debug('free gens:', freegens, " -- len", len(gens)) + # NOTE we force the domain to be ZZ to stop polys from injecting generators + # (which is usually a sign of a bug in the way we build the ideal) + if not gens: + return expr + G = groebner(ideal, order=order, gens=gens, domain=ZZ) + debug('groebner basis:', list(G), " -- len", len(G)) + + # If our fraction is a polynomial in the free generators, simplify all + # coefficients separately: + + from sympy.simplify.ratsimp import ratsimpmodprime + + if freegens and pdenom.has_only_gens(*set(gens).intersection(pdenom.gens)): + num = Poly(num, gens=gens+freegens).eject(*gens) + res = [] + for monom, coeff in num.terms(): + ourgens = set(parallel_poly_from_expr([coeff, denom])[1].gens) + # We compute the transitive closure of all generators that can + # be reached from our generators through relations in the ideal. + changed = True + while changed: + changed = False + for p in ideal: + p = Poly(p) + if not ourgens.issuperset(p.gens) and \ + not p.has_only_gens(*set(p.gens).difference(ourgens)): + changed = True + ourgens.update(p.exclude().gens) + # NOTE preserve order! + realgens = [x for x in gens if x in ourgens] + # The generators of the ideal have now been (implicitly) split + # into two groups: those involving ourgens and those that don't. + # Since we took the transitive closure above, these two groups + # live in subgrings generated by a *disjoint* set of variables. + # Any sensible groebner basis algorithm will preserve this disjoint + # structure (i.e. the elements of the groebner basis can be split + # similarly), and and the two subsets of the groebner basis then + # form groebner bases by themselves. (For the smaller generating + # sets, of course.) + ourG = [g.as_expr() for g in G.polys if + g.has_only_gens(*ourgens.intersection(g.gens))] + res.append(Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, ourG, order=order, + gens=realgens, quick=quick, domain=ZZ, + polynomial=polynomial).subs(subs)) + return Add(*res) + # NOTE The following is simpler and has less assumptions on the + # groebner basis algorithm. If the above turns out to be broken, + # use this. + return Add(*[Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, list(G), order=order, + gens=gens, quick=quick, domain=ZZ) + for monom, coeff in num.terms()]) + else: + return ratsimpmodprime( + expr, list(G), order=order, gens=freegens+gens, + quick=quick, domain=ZZ, polynomial=polynomial).subs(subs) + + +_trigs = (TrigonometricFunction, HyperbolicFunction) + + +def _trigsimp_inverse(rv): + + def check_args(x, y): + try: + return x.args[0] == y.args[0] + except IndexError: + return False + + def f(rv): + # for simple functions + g = getattr(rv, 'inverse', None) + if (g is not None and isinstance(rv.args[0], g()) and + isinstance(g()(1), TrigonometricFunction)): + return rv.args[0].args[0] + + # for atan2 simplifications, harder because atan2 has 2 args + if isinstance(rv, atan2): + y, x = rv.args + if _coeff_isneg(y): + return -f(atan2(-y, x)) + elif _coeff_isneg(x): + return S.Pi - f(atan2(y, -x)) + + if check_args(x, y): + if isinstance(y, sin) and isinstance(x, cos): + return x.args[0] + if isinstance(y, cos) and isinstance(x, sin): + return S.Pi / 2 - x.args[0] + + return rv + + return bottom_up(rv, f) + + +def trigsimp(expr, inverse=False, **opts): + """Returns a reduced expression by using known trig identities. + + Parameters + ========== + + inverse : bool, optional + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is False. + Default : True + + method : string, optional + Specifies the method to use. Valid choices are: + + - ``'matching'``, default + - ``'groebner'`` + - ``'combined'`` + - ``'fu'`` + - ``'old'`` + + If ``'matching'``, simplify the expression recursively by targeting + common patterns. If ``'groebner'``, apply an experimental groebner + basis algorithm. In this case further options are forwarded to + ``trigsimp_groebner``, please refer to + its docstring. If ``'combined'``, it first runs the groebner basis + algorithm with small default parameters, then runs the ``'matching'`` + algorithm. If ``'fu'``, run the collection of trigonometric + transformations described by Fu, et al. (see the + :py:func:`~sympy.simplify.fu.fu` docstring). If ``'old'``, the original + SymPy trig simplification function is run. + opts : + Optional keyword arguments passed to the method. See each method's + function docstring for details. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e) + 2 + + Simplification occurs wherever trigonometric functions are located. + + >>> trigsimp(log(e)) + log(2) + + Using ``method='groebner'`` (or ``method='combined'``) might lead to + greater simplification. + + The old trigsimp routine can be accessed as with method ``method='old'``. + + >>> from sympy import coth, tanh + >>> t = 3*tanh(x)**7 - 2/coth(x)**7 + >>> trigsimp(t, method='old') == t + True + >>> trigsimp(t) + tanh(x)**7 + + """ + from sympy.simplify.fu import fu + + expr = sympify(expr) + + _eval_trigsimp = getattr(expr, '_eval_trigsimp', None) + if _eval_trigsimp is not None: + return _eval_trigsimp(**opts) + + old = opts.pop('old', False) + if not old: + opts.pop('deep', None) + opts.pop('recursive', None) + method = opts.pop('method', 'matching') + else: + method = 'old' + + def groebnersimp(ex, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + new = traverse(ex) + if not isinstance(new, Expr): + return new + return trigsimp_groebner(new, **opts) + + trigsimpfunc = { + 'fu': (lambda x: fu(x, **opts)), + 'matching': (lambda x: futrig(x)), + 'groebner': (lambda x: groebnersimp(x, **opts)), + 'combined': (lambda x: futrig(groebnersimp(x, + polynomial=True, hints=[2, tan]))), + 'old': lambda x: trigsimp_old(x, **opts), + }[method] + + expr_simplified = trigsimpfunc(expr) + if inverse: + expr_simplified = _trigsimp_inverse(expr_simplified) + + return expr_simplified + + +def exptrigsimp(expr): + """ + Simplifies exponential / trigonometric / hyperbolic functions. + + Examples + ======== + + >>> from sympy import exptrigsimp, exp, cosh, sinh + >>> from sympy.abc import z + + >>> exptrigsimp(exp(z) + exp(-z)) + 2*cosh(z) + >>> exptrigsimp(cosh(z) - sinh(z)) + exp(-z) + """ + from sympy.simplify.fu import hyper_as_trig, TR2i + + def exp_trig(e): + # select the better of e, and e rewritten in terms of exp or trig + # functions + choices = [e] + if e.has(*_trigs): + choices.append(e.rewrite(exp)) + choices.append(e.rewrite(cos)) + return min(*choices, key=count_ops) + newexpr = bottom_up(expr, exp_trig) + + def f(rv): + if not rv.is_Mul: + return rv + commutative_part, noncommutative_part = rv.args_cnc() + # Since as_powers_dict loses order information, + # if there is more than one noncommutative factor, + # it should only be used to simplify the commutative part. + if (len(noncommutative_part) > 1): + return f(Mul(*commutative_part))*Mul(*noncommutative_part) + rvd = rv.as_powers_dict() + newd = rvd.copy() + + def signlog(expr, sign=S.One): + if expr is S.Exp1: + return sign, S.One + elif isinstance(expr, exp) or (expr.is_Pow and expr.base == S.Exp1): + return sign, expr.exp + elif sign is S.One: + return signlog(-expr, sign=-S.One) + else: + return None, None + + ee = rvd[S.Exp1] + for k in rvd: + if k.is_Add and len(k.args) == 2: + # k == c*(1 + sign*E**x) + c = k.args[0] + sign, x = signlog(k.args[1]/c) + if not x: + continue + m = rvd[k] + newd[k] -= m + if ee == -x*m/2: + # sinh and cosh + newd[S.Exp1] -= ee + ee = 0 + if sign == 1: + newd[2*c*cosh(x/2)] += m + else: + newd[-2*c*sinh(x/2)] += m + elif newd[1 - sign*S.Exp1**x] == -m: + # tanh + del newd[1 - sign*S.Exp1**x] + if sign == 1: + newd[-c/tanh(x/2)] += m + else: + newd[-c*tanh(x/2)] += m + else: + newd[1 + sign*S.Exp1**x] += m + newd[c] += m + + return Mul(*[k**newd[k] for k in newd]) + newexpr = bottom_up(newexpr, f) + + # sin/cos and sinh/cosh ratios to tan and tanh, respectively + if newexpr.has(HyperbolicFunction): + e, f = hyper_as_trig(newexpr) + newexpr = f(TR2i(e)) + if newexpr.has(TrigonometricFunction): + newexpr = TR2i(newexpr) + + # can we ever generate an I where there was none previously? + if not (newexpr.has(I) and not expr.has(I)): + expr = newexpr + return expr + +#-------------------- the old trigsimp routines --------------------- + +def trigsimp_old(expr, *, first=True, **opts): + """ + Reduces expression by using known trig identities. + + Notes + ===== + + deep: + - Apply trigsimp inside all objects with arguments + + recursive: + - Use common subexpression elimination (cse()) and apply + trigsimp recursively (this is quite expensive if the + expression is large) + + method: + - Determine the method to use. Valid choices are 'matching' (default), + 'groebner', 'combined', 'fu' and 'futrig'. If 'matching', simplify the + expression recursively by pattern matching. If 'groebner', apply an + experimental groebner basis algorithm. In this case further options + are forwarded to ``trigsimp_groebner``, please refer to its docstring. + If 'combined', first run the groebner basis algorithm with small + default parameters, then run the 'matching' algorithm. 'fu' runs the + collection of trigonometric transformations described by Fu, et al. + (see the `fu` docstring) while `futrig` runs a subset of Fu-transforms + that mimic the behavior of `trigsimp`. + + compare: + - show input and output from `trigsimp` and `futrig` when different, + but returns the `trigsimp` value. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log, cot + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e, old=True) + 2 + >>> trigsimp(log(e), old=True) + log(2*sin(x)**2 + 2*cos(x)**2) + >>> trigsimp(log(e), deep=True, old=True) + log(2) + + Using `method="groebner"` (or `"combined"`) can sometimes lead to a lot + more simplification: + + >>> e = (-sin(x) + 1)/cos(x) + cos(x)/(-sin(x) + 1) + >>> trigsimp(e, old=True) + (1 - sin(x))/cos(x) + cos(x)/(1 - sin(x)) + >>> trigsimp(e, method="groebner", old=True) + 2/cos(x) + + >>> trigsimp(1/cot(x)**2, compare=True, old=True) + futrig: tan(x)**2 + cot(x)**(-2) + + """ + old = expr + if first: + if not expr.has(*_trigs): + return expr + + trigsyms = set().union(*[t.free_symbols for t in expr.atoms(*_trigs)]) + if len(trigsyms) > 1: + from sympy.simplify.simplify import separatevars + + d = separatevars(expr) + if d.is_Mul: + d = separatevars(d, dict=True) or d + if isinstance(d, dict): + expr = 1 + for v in d.values(): + # remove hollow factoring + was = v + v = expand_mul(v) + opts['first'] = False + vnew = trigsimp(v, **opts) + if vnew == v: + vnew = was + expr *= vnew + old = expr + else: + if d.is_Add: + for s in trigsyms: + r, e = expr.as_independent(s) + if r: + opts['first'] = False + expr = r + trigsimp(e, **opts) + if not expr.is_Add: + break + old = expr + + recursive = opts.pop('recursive', False) + deep = opts.pop('deep', False) + method = opts.pop('method', 'matching') + + def groebnersimp(ex, deep, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + if deep: + ex = traverse(ex) + return trigsimp_groebner(ex, **opts) + + trigsimpfunc = { + 'matching': (lambda x, d: _trigsimp(x, d)), + 'groebner': (lambda x, d: groebnersimp(x, d, **opts)), + 'combined': (lambda x, d: _trigsimp(groebnersimp(x, + d, polynomial=True, hints=[2, tan]), + d)) + }[method] + + if recursive: + w, g = cse(expr) + g = trigsimpfunc(g[0], deep) + + for sub in reversed(w): + g = g.subs(sub[0], sub[1]) + g = trigsimpfunc(g, deep) + result = g + else: + result = trigsimpfunc(expr, deep) + + if opts.get('compare', False): + f = futrig(old) + if f != result: + print('\tfutrig:', f) + + return result + + +def _dotrig(a, b): + """Helper to tell whether ``a`` and ``b`` have the same sorts + of symbols in them -- no need to test hyperbolic patterns against + expressions that have no hyperbolics in them.""" + return a.func == b.func and ( + a.has(TrigonometricFunction) and b.has(TrigonometricFunction) or + a.has(HyperbolicFunction) and b.has(HyperbolicFunction)) + + +_trigpat = None +def _trigpats(): + global _trigpat + a, b, c = symbols('a b c', cls=Wild) + d = Wild('d', commutative=False) + + # for the simplifications like sinh/cosh -> tanh: + # DO NOT REORDER THE FIRST 14 since these are assumed to be in this + # order in _match_div_rewrite. + matchers_division = ( + (a*sin(b)**c/cos(b)**c, a*tan(b)**c, sin(b), cos(b)), + (a*tan(b)**c*cos(b)**c, a*sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*sin(b)**c, a*cos(b)**c, sin(b), cos(b)), + (a*tan(b)**c/sin(b)**c, a/cos(b)**c, sin(b), cos(b)), + (a*cot(b)**c/cos(b)**c, a/sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*tan(b)**c, a, sin(b), cos(b)), + (a*(cos(b) + 1)**c*(cos(b) - 1)**c, + a*(-sin(b)**2)**c, cos(b) + 1, cos(b) - 1), + (a*(sin(b) + 1)**c*(sin(b) - 1)**c, + a*(-cos(b)**2)**c, sin(b) + 1, sin(b) - 1), + + (a*sinh(b)**c/cosh(b)**c, a*tanh(b)**c, S.One, S.One), + (a*tanh(b)**c*cosh(b)**c, a*sinh(b)**c, S.One, S.One), + (a*coth(b)**c*sinh(b)**c, a*cosh(b)**c, S.One, S.One), + (a*tanh(b)**c/sinh(b)**c, a/cosh(b)**c, S.One, S.One), + (a*coth(b)**c/cosh(b)**c, a/sinh(b)**c, S.One, S.One), + (a*coth(b)**c*tanh(b)**c, a, S.One, S.One), + + (c*(tanh(a) + tanh(b))/(1 + tanh(a)*tanh(b)), + tanh(a + b)*c, S.One, S.One), + ) + + matchers_add = ( + (c*sin(a)*cos(b) + c*cos(a)*sin(b) + d, sin(a + b)*c + d), + (c*cos(a)*cos(b) - c*sin(a)*sin(b) + d, cos(a + b)*c + d), + (c*sin(a)*cos(b) - c*cos(a)*sin(b) + d, sin(a - b)*c + d), + (c*cos(a)*cos(b) + c*sin(a)*sin(b) + d, cos(a - b)*c + d), + (c*sinh(a)*cosh(b) + c*sinh(b)*cosh(a) + d, sinh(a + b)*c + d), + (c*cosh(a)*cosh(b) + c*sinh(a)*sinh(b) + d, cosh(a + b)*c + d), + ) + + # for cos(x)**2 + sin(x)**2 -> 1 + matchers_identity = ( + (a*sin(b)**2, a - a*cos(b)**2), + (a*tan(b)**2, a*(1/cos(b))**2 - a), + (a*cot(b)**2, a*(1/sin(b))**2 - a), + (a*sin(b + c), a*(sin(b)*cos(c) + sin(c)*cos(b))), + (a*cos(b + c), a*(cos(b)*cos(c) - sin(b)*sin(c))), + (a*tan(b + c), a*((tan(b) + tan(c))/(1 - tan(b)*tan(c)))), + + (a*sinh(b)**2, a*cosh(b)**2 - a), + (a*tanh(b)**2, a - a*(1/cosh(b))**2), + (a*coth(b)**2, a + a*(1/sinh(b))**2), + (a*sinh(b + c), a*(sinh(b)*cosh(c) + sinh(c)*cosh(b))), + (a*cosh(b + c), a*(cosh(b)*cosh(c) + sinh(b)*sinh(c))), + (a*tanh(b + c), a*((tanh(b) + tanh(c))/(1 + tanh(b)*tanh(c)))), + + ) + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1-cos(x)**2 when sin(x)**2 was "simpler" + artifacts = ( + (a - a*cos(b)**2 + c, a*sin(b)**2 + c, cos), + (a - a*(1/cos(b))**2 + c, -a*tan(b)**2 + c, cos), + (a - a*(1/sin(b))**2 + c, -a*cot(b)**2 + c, sin), + + (a - a*cosh(b)**2 + c, -a*sinh(b)**2 + c, cosh), + (a - a*(1/cosh(b))**2 + c, a*tanh(b)**2 + c, cosh), + (a + a*(1/sinh(b))**2 + c, a*coth(b)**2 + c, sinh), + + # same as above but with noncommutative prefactor + (a*d - a*d*cos(b)**2 + c, a*d*sin(b)**2 + c, cos), + (a*d - a*d*(1/cos(b))**2 + c, -a*d*tan(b)**2 + c, cos), + (a*d - a*d*(1/sin(b))**2 + c, -a*d*cot(b)**2 + c, sin), + + (a*d - a*d*cosh(b)**2 + c, -a*d*sinh(b)**2 + c, cosh), + (a*d - a*d*(1/cosh(b))**2 + c, a*d*tanh(b)**2 + c, cosh), + (a*d + a*d*(1/sinh(b))**2 + c, a*d*coth(b)**2 + c, sinh), + ) + + _trigpat = (a, b, c, d, matchers_division, matchers_add, + matchers_identity, artifacts) + return _trigpat + + +def _replace_mul_fpowxgpow(expr, f, g, rexp, h, rexph): + """Helper for _match_div_rewrite. + + Replace f(b_)**c_*g(b_)**(rexp(c_)) with h(b)**rexph(c) if f(b_) + and g(b_) are both positive or if c_ is an integer. + """ + # assert expr.is_Mul and expr.is_commutative and f != g + fargs = defaultdict(int) + gargs = defaultdict(int) + args = [] + for x in expr.args: + if x.is_Pow or x.func in (f, g): + b, e = x.as_base_exp() + if b.is_positive or e.is_integer: + if b.func == f: + fargs[b.args[0]] += e + continue + elif b.func == g: + gargs[b.args[0]] += e + continue + args.append(x) + common = set(fargs) & set(gargs) + hit = False + while common: + key = common.pop() + fe = fargs.pop(key) + ge = gargs.pop(key) + if fe == rexp(ge): + args.append(h(key)**rexph(fe)) + hit = True + else: + fargs[key] = fe + gargs[key] = ge + if not hit: + return expr + while fargs: + key, e = fargs.popitem() + args.append(f(key)**e) + while gargs: + key, e = gargs.popitem() + args.append(g(key)**e) + return Mul(*args) + + +_idn = lambda x: x +_midn = lambda x: -x +_one = lambda x: S.One + +def _match_div_rewrite(expr, i): + """helper for __trigsimp""" + if i == 0: + expr = _replace_mul_fpowxgpow(expr, sin, cos, + _midn, tan, _idn) + elif i == 1: + expr = _replace_mul_fpowxgpow(expr, tan, cos, + _idn, sin, _idn) + elif i == 2: + expr = _replace_mul_fpowxgpow(expr, cot, sin, + _idn, cos, _idn) + elif i == 3: + expr = _replace_mul_fpowxgpow(expr, tan, sin, + _midn, cos, _midn) + elif i == 4: + expr = _replace_mul_fpowxgpow(expr, cot, cos, + _midn, sin, _midn) + elif i == 5: + expr = _replace_mul_fpowxgpow(expr, cot, tan, + _idn, _one, _idn) + # i in (6, 7) is skipped + elif i == 8: + expr = _replace_mul_fpowxgpow(expr, sinh, cosh, + _midn, tanh, _idn) + elif i == 9: + expr = _replace_mul_fpowxgpow(expr, tanh, cosh, + _idn, sinh, _idn) + elif i == 10: + expr = _replace_mul_fpowxgpow(expr, coth, sinh, + _idn, cosh, _idn) + elif i == 11: + expr = _replace_mul_fpowxgpow(expr, tanh, sinh, + _midn, cosh, _midn) + elif i == 12: + expr = _replace_mul_fpowxgpow(expr, coth, cosh, + _midn, sinh, _midn) + elif i == 13: + expr = _replace_mul_fpowxgpow(expr, coth, tanh, + _idn, _one, _idn) + else: + return None + return expr + + +def _trigsimp(expr, deep=False): + # protect the cache from non-trig patterns; we only allow + # trig patterns to enter the cache + if expr.has(*_trigs): + return __trigsimp(expr, deep) + return expr + + +@cacheit +def __trigsimp(expr, deep=False): + """recursive helper for trigsimp""" + from sympy.simplify.fu import TR10i + + if _trigpat is None: + _trigpats() + a, b, c, d, matchers_division, matchers_add, \ + matchers_identity, artifacts = _trigpat + + if expr.is_Mul: + # do some simplifications like sin/cos -> tan: + if not expr.is_commutative: + com, nc = expr.args_cnc() + expr = _trigsimp(Mul._from_args(com), deep)*Mul._from_args(nc) + else: + for i, (pattern, simp, ok1, ok2) in enumerate(matchers_division): + if not _dotrig(expr, pattern): + continue + + newexpr = _match_div_rewrite(expr, i) + if newexpr is not None: + if newexpr != expr: + expr = newexpr + break + else: + continue + + # use SymPy matching instead + res = expr.match(pattern) + if res and res.get(c, 0): + if not res[c].is_integer: + ok = ok1.subs(res) + if not ok.is_positive: + continue + ok = ok2.subs(res) + if not ok.is_positive: + continue + # if "a" contains any of trig or hyperbolic funcs with + # argument "b" then skip the simplification + if any(w.args[0] == res[b] for w in res[a].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + # simplify and finish: + expr = simp.subs(res) + break # process below + + if expr.is_Add: + args = [] + for term in expr.args: + if not term.is_commutative: + com, nc = term.args_cnc() + nc = Mul._from_args(nc) + term = Mul._from_args(com) + else: + nc = S.One + term = _trigsimp(term, deep) + for pattern, result in matchers_identity: + res = term.match(pattern) + if res is not None: + term = result.subs(res) + break + args.append(term*nc) + if args != expr.args: + expr = Add(*args) + expr = min(expr, expand(expr), key=count_ops) + if expr.is_Add: + for pattern, result in matchers_add: + if not _dotrig(expr, pattern): + continue + expr = TR10i(expr) + if expr.has(HyperbolicFunction): + res = expr.match(pattern) + # if "d" contains any trig or hyperbolic funcs with + # argument "a" or "b" then skip the simplification; + # this isn't perfect -- see tests + if res is None or not (a in res and b in res) or any( + w.args[0] in (res[a], res[b]) for w in res[d].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + expr = result.subs(res) + break + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1 - cos(x)**2 when sin(x)**2 was "simpler" + for pattern, result, ex in artifacts: + if not _dotrig(expr, pattern): + continue + # Substitute a new wild that excludes some function(s) + # to help influence a better match. This is because + # sometimes, for example, 'a' would match sec(x)**2 + a_t = Wild('a', exclude=[ex]) + pattern = pattern.subs(a, a_t) + result = result.subs(a, a_t) + + m = expr.match(pattern) + was = None + while m and was != expr: + was = expr + if m[a_t] == 0 or \ + -m[a_t] in m[c].args or m[a_t] + m[c] == 0: + break + if d in m and m[a_t]*m[d] + m[c] == 0: + break + expr = result.subs(m) + m = expr.match(pattern) + m.setdefault(c, S.Zero) + + elif expr.is_Mul or expr.is_Pow or deep and expr.args: + expr = expr.func(*[_trigsimp(a, deep) for a in expr.args]) + + try: + if not expr.has(*_trigs): + raise TypeError + e = expr.atoms(exp) + new = expr.rewrite(exp, deep=deep) + if new == e: + raise TypeError + fnew = factor(new) + if fnew != new: + new = min([new, factor(new)], key=count_ops) + # if all exp that were introduced disappeared then accept it + if not (new.atoms(exp) - e): + expr = new + except TypeError: + pass + + return expr +#------------------- end of old trigsimp routines -------------------- + + +def futrig(e, *, hyper=True, **kwargs): + """Return simplified ``e`` using Fu-like transformations. + This is not the "Fu" algorithm. This is called by default + from ``trigsimp``. By default, hyperbolics subexpressions + will be simplified, but this can be disabled by setting + ``hyper=False``. + + Examples + ======== + + >>> from sympy import trigsimp, tan, sinh, tanh + >>> from sympy.simplify.trigsimp import futrig + >>> from sympy.abc import x + >>> trigsimp(1/tan(x)**2) + tan(x)**(-2) + + >>> futrig(sinh(x)/tanh(x)) + cosh(x) + + """ + from sympy.simplify.fu import hyper_as_trig + + e = sympify(e) + + if not isinstance(e, Basic): + return e + + if not e.args: + return e + + old = e + e = bottom_up(e, _futrig) + + if hyper and e.has(HyperbolicFunction): + e, f = hyper_as_trig(e) + e = f(bottom_up(e, _futrig)) + + if e != old and e.is_Mul and e.args[0].is_Rational: + # redistribute leading coeff on 2-arg Add + e = Mul(*e.as_coeff_Mul()) + return e + + +def _futrig(e): + """Helper for futrig.""" + from sympy.simplify.fu import ( + TR1, TR2, TR3, TR2i, TR10, L, TR10i, + TR8, TR6, TR15, TR16, TR111, TR5, TRmorrie, TR11, _TR11, TR14, TR22, + TR12) + + if not e.has(TrigonometricFunction): + return e + + if e.is_Mul: + coeff, e = e.as_independent(TrigonometricFunction) + else: + coeff = None + + Lops = lambda x: (L(x), x.count_ops(), _nodes(x), len(x.args), x.is_Add) + trigs = lambda x: x.has(TrigonometricFunction) + + tree = [identity, + ( + TR3, # canonical angles + TR1, # sec-csc -> cos-sin + TR12, # expand tan of sum + lambda x: _eapply(factor, x, trigs), + TR2, # tan-cot -> sin-cos + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR2i, # sin-cos ratio -> tan + lambda x: _eapply(lambda i: factor(i.normal()), x, trigs), + TR14, # factored identities + TR5, # sin-pow -> cos_pow + TR10, # sin-cos of sums -> sin-cos prod + TR11, _TR11, TR6, # reduce double angles and rewrite cos pows + lambda x: _eapply(factor, x, trigs), + TR14, # factored powers of identities + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR10i, # sin-cos products > sin-cos of sums + TRmorrie, + [identity, TR8], # sin-cos products -> sin-cos of sums + [identity, lambda x: TR2i(TR2(x))], # tan -> sin-cos -> tan + [ + lambda x: _eapply(expand_mul, TR5(x), trigs), + lambda x: _eapply( + expand_mul, TR15(x), trigs)], # pos/neg powers of sin + [ + lambda x: _eapply(expand_mul, TR6(x), trigs), + lambda x: _eapply( + expand_mul, TR16(x), trigs)], # pos/neg powers of cos + TR111, # tan, sin, cos to neg power -> cot, csc, sec + [identity, TR2i], # sin-cos ratio to tan + [identity, lambda x: _eapply( + expand_mul, TR22(x), trigs)], # tan-cot to sec-csc + TR1, TR2, TR2i, + [identity, lambda x: _eapply( + factor_terms, TR12(x), trigs)], # expand tan of sum + )] + e = greedy(tree, objective=Lops)(e) + + if coeff is not None: + e = coeff * e + + return e + + +def _is_Expr(e): + """_eapply helper to tell whether ``e`` and all its args + are Exprs.""" + if isinstance(e, Derivative): + return _is_Expr(e.expr) + if not isinstance(e, Expr): + return False + return all(_is_Expr(i) for i in e.args) + + +def _eapply(func, e, cond=None): + """Apply ``func`` to ``e`` if all args are Exprs else only + apply it to those args that *are* Exprs.""" + if not isinstance(e, Expr): + return e + if _is_Expr(e) or not e.args: + return func(e) + return e.func(*[ + _eapply(func, ei) if (cond is None or cond(ei)) else ei + for ei in e.args]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a832614b1d48e26bf01e16f040f34dd412e8e32b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__init__.py @@ -0,0 +1,23 @@ +"""A module to manipulate symbolic objects with indices including tensors + +""" +from .indexed import IndexedBase, Idx, Indexed +from .index_methods import get_contraction_structure, get_indices +from .functions import shape +from .array import (MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray, NDimArray, tensorproduct, + tensorcontraction, tensordiagonal, derive_by_array, permutedims, Array, + DenseNDimArray, SparseNDimArray,) + +__all__ = [ + 'IndexedBase', 'Idx', 'Indexed', + + 'get_contraction_structure', 'get_indices', + + 'shape', + + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'NDimArray', + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', 'permutedims', + 'Array', 'DenseNDimArray', 'SparseNDimArray', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21d25f5aca6fba2d68ffe230f9901e4fa2e8ac79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2b990aecf240d723186fcd28fb8ba54ca9f3d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/index_methods.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/index_methods.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2c0dd1791ad35198c819345f81909db834e4b29 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/index_methods.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/indexed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/indexed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9ad0fb848af4f1c8a0c7666502b6cf961197c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/indexed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/toperators.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/toperators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a80b48e14a1bd52a249686888d90683aa1546082 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/__pycache__/toperators.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca2eb4c6c58cb113517b6e41737e9d97abbb84e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__init__.py @@ -0,0 +1,271 @@ +r""" +N-dim array module for SymPy. + +Four classes are provided to handle N-dim arrays, given by the combinations +dense/sparse (i.e. whether to store all elements or only the non-zero ones in +memory) and mutable/immutable (immutable classes are SymPy objects, but cannot +change after they have been created). + +Examples +======== + +The following examples show the usage of ``Array``. This is an abbreviation for +``ImmutableDenseNDimArray``, that is an immutable and dense N-dim array, the +other classes are analogous. For mutable classes it is also possible to change +element values after the object has been constructed. + +Array construction can detect the shape of nested lists and tuples: + +>>> from sympy import Array +>>> a1 = Array([[1, 2], [3, 4], [5, 6]]) +>>> a1 +[[1, 2], [3, 4], [5, 6]] +>>> a1.shape +(3, 2) +>>> a1.rank() +2 +>>> from sympy.abc import x, y, z +>>> a2 = Array([[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]]) +>>> a2 +[[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]] +>>> a2.shape +(2, 2, 2) +>>> a2.rank() +3 + +Otherwise one could pass a 1-dim array followed by a shape tuple: + +>>> m1 = Array(range(12), (3, 4)) +>>> m1 +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] +>>> m2 = Array(range(12), (3, 2, 2)) +>>> m2 +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> m2[1,1,1] +7 +>>> m2.reshape(4, 3) +[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + +Slice support: + +>>> m2[:, 1, 1] +[3, 7, 11] + +Elementwise derivative: + +>>> from sympy.abc import x, y, z +>>> m3 = Array([x**3, x*y, z]) +>>> m3.diff(x) +[3*x**2, y, 0] +>>> m3.diff(z) +[0, 0, 1] + +Multiplication with other SymPy expressions is applied elementwisely: + +>>> (1+x)*m3 +[x**3*(x + 1), x*y*(x + 1), z*(x + 1)] + +To apply a function to each element of the N-dim array, use ``applyfunc``: + +>>> m3.applyfunc(lambda x: x/2) +[x**3/2, x*y/2, z/2] + +N-dim arrays can be converted to nested lists by the ``tolist()`` method: + +>>> m2.tolist() +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> isinstance(m2.tolist(), list) +True + +If the rank is 2, it is possible to convert them to matrices with ``tomatrix()``: + +>>> m1.tomatrix() +Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + +Products and contractions +------------------------- + +Tensor product between arrays `A_{i_1,\ldots,i_n}` and `B_{j_1,\ldots,j_m}` +creates the combined array `P = A \otimes B` defined as + +`P_{i_1,\ldots,i_n,j_1,\ldots,j_m} := A_{i_1,\ldots,i_n}\cdot B_{j_1,\ldots,j_m}.` + +It is available through ``tensorproduct(...)``: + +>>> from sympy import Array, tensorproduct +>>> from sympy.abc import x,y,z,t +>>> A = Array([x, y, z, t]) +>>> B = Array([1, 2, 3, 4]) +>>> tensorproduct(A, B) +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +In case you don't want to evaluate the tensor product immediately, you can use +``ArrayTensorProduct``, which creates an unevaluated tensor product expression: + +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> ArrayTensorProduct(A, B) +ArrayTensorProduct([x, y, z, t], [1, 2, 3, 4]) + +Calling ``.as_explicit()`` on ``ArrayTensorProduct`` is equivalent to just calling +``tensorproduct(...)``: + +>>> ArrayTensorProduct(A, B).as_explicit() +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +Tensor product between a rank-1 array and a matrix creates a rank-3 array: + +>>> from sympy import eye +>>> p1 = tensorproduct(A, eye(4)) +>>> p1 +[[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]], [[y, 0, 0, 0], [0, y, 0, 0], [0, 0, y, 0], [0, 0, 0, y]], [[z, 0, 0, 0], [0, z, 0, 0], [0, 0, z, 0], [0, 0, 0, z]], [[t, 0, 0, 0], [0, t, 0, 0], [0, 0, t, 0], [0, 0, 0, t]]] + +Now, to get back `A_0 \otimes \mathbf{1}` one can access `p_{0,m,n}` by slicing: + +>>> p1[0,:,:] +[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]] + +Tensor contraction sums over the specified axes, for example contracting +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies \sum_k A_{i_1,\ldots,k,\ldots,k,\ldots,i_n}` + +Remember that Python indexing is zero starting, to contract the a-th and b-th +axes it is therefore necessary to specify `a-1` and `b-1` + +>>> from sympy import tensorcontraction +>>> C = Array([[x, y], [z, t]]) + +The matrix trace is equivalent to the contraction of a rank-2 array: + +`A_{m,n} \implies \sum_k A_{k,k}` + +>>> tensorcontraction(C, (0, 1)) +t + x + +To create an expression representing a tensor contraction that does not get +evaluated immediately, use ``ArrayContraction``, which is equivalent to +``tensorcontraction(...)`` if it is followed by ``.as_explicit()``: + +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> ArrayContraction(C, (0, 1)) +ArrayContraction([[x, y], [z, t]], (0, 1)) +>>> ArrayContraction(C, (0, 1)).as_explicit() +t + x + +Matrix product is equivalent to a tensor product of two rank-2 arrays, followed +by a contraction of the 2nd and 3rd axes (in Python indexing axes number 1, 2). + +`A_{m,n}\cdot B_{i,j} \implies \sum_k A_{m, k}\cdot B_{k, j}` + +>>> D = Array([[2, 1], [0, -1]]) +>>> tensorcontraction(tensorproduct(C, D), (1, 2)) +[[2*x, x - y], [2*z, -t + z]] + +One may verify that the matrix product is equivalent: + +>>> from sympy import Matrix +>>> Matrix([[x, y], [z, t]])*Matrix([[2, 1], [0, -1]]) +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +or equivalently + +>>> C.tomatrix()*D.tomatrix() +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +Diagonal operator +----------------- + +The ``tensordiagonal`` function acts in a similar manner as ``tensorcontraction``, +but the joined indices are not summed over, for example diagonalizing +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies A_{i_1,\ldots,k,\ldots,k,\ldots,i_n} +\implies \tilde{A}_{i_1,\ldots,i_{a-1},i_{a+1},\ldots,i_{b-1},i_{b+1},\ldots,i_n,k}` + +where `\tilde{A}` is the array equivalent to the diagonal of `A` at positions +`a` and `b` moved to the last index slot. + +Compare the difference between contraction and diagonal operators: + +>>> from sympy import tensordiagonal +>>> from sympy.abc import a, b, c, d +>>> m = Matrix([[a, b], [c, d]]) +>>> tensorcontraction(m, [0, 1]) +a + d +>>> tensordiagonal(m, [0, 1]) +[a, d] + +In short, no summation occurs with ``tensordiagonal``. + + +Derivatives by array +-------------------- + +The usual derivative operation may be extended to support derivation with +respect to arrays, provided that all elements in the that array are symbols or +expressions suitable for derivations. + +The definition of a derivative by an array is as follows: given the array +`A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` +the derivative of arrays will return a new array `B` defined by + +`B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + +The function ``derive_by_array`` performs such an operation: + +>>> from sympy import derive_by_array +>>> from sympy.abc import x, y, z, t +>>> from sympy import sin, exp + +With scalars, it behaves exactly as the ordinary derivative: + +>>> derive_by_array(sin(x*y), x) +y*cos(x*y) + +Scalar derived by an array basis: + +>>> derive_by_array(sin(x*y), [x, y, z]) +[y*cos(x*y), x*cos(x*y), 0] + +Deriving array by an array basis: `B^{nm} := \frac{\partial A^m}{\partial x^n}` + +>>> basis = [x, y, z] +>>> ax = derive_by_array([exp(x), sin(y*z), t], basis) +>>> ax +[[exp(x), 0, 0], [0, z*cos(y*z), 0], [0, y*cos(y*z), 0]] + +Contraction of the resulting array: `\sum_m \frac{\partial A^m}{\partial x^m}` + +>>> tensorcontraction(ax, (0, 1)) +z*cos(y*z) + exp(x) + +""" + +from .dense_ndim_array import MutableDenseNDimArray, ImmutableDenseNDimArray, DenseNDimArray +from .sparse_ndim_array import MutableSparseNDimArray, ImmutableSparseNDimArray, SparseNDimArray +from .ndim_array import NDimArray, ArrayKind +from .arrayop import tensorproduct, tensorcontraction, tensordiagonal, derive_by_array, permutedims +from .array_comprehension import ArrayComprehension, ArrayComprehensionMap + +Array = ImmutableDenseNDimArray + +__all__ = [ + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', 'DenseNDimArray', + + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'SparseNDimArray', + + 'NDimArray', 'ArrayKind', + + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', + + 'permutedims', 'ArrayComprehension', 'ArrayComprehensionMap', + + 'Array', +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03a320f5c9bc52939a1c8064bab2f4186f9cf812 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6f29836ade33dece18d1274d4e1d40993a99ef4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..966f0d5e2c47378a8d73ac1f2408085b69af491b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..154276ba2eba95a82be9b184675f358024889003 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..243b4ab1b4c0e3d55d90df6bce725bbcc9aab10e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec7dc99fa9f00d4f1d081857cbbac35cb121db0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93acace481b6ce627fc5433f1b126f19b14f30ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001a6650122de0b768ab1c03eb319c00f4b2c3f5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_comprehension.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..95702f499f3e40597fd0144929138ac1329962ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_comprehension.py @@ -0,0 +1,399 @@ +import functools, itertools +from sympy.core.sympify import _sympify, sympify +from sympy.core.expr import Expr +from sympy.core import Basic, Tuple +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer + + +class ArrayComprehension(Basic): + """ + Generate a list comprehension. + + Explanation + =========== + + If there is a symbolic dimension, for example, say [i for i in range(1, N)] where + N is a Symbol, then the expression will not be expanded to an array. Otherwise, + calling the doit() function will launch the expansion. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.doit() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + >>> b.doit() + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + """ + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + arglist = [sympify(function)] + arglist.extend(cls._check_limits_validity(function, symbols)) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args[1:] + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + return obj + + @property + def function(self): + """The function applied across limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.function + 10*i + j + """ + return self._args[0] + + @property + def limits(self): + """ + The list of limits that will be applied while expanding the array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.limits + ((i, 1, 4), (j, 1, 3)) + """ + return self._limits + + @property + def free_symbols(self): + """ + The set of the free_symbols in the array. + Variables appeared in the bounds are supposed to be excluded + from the free symbol set. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.free_symbols + set() + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.free_symbols + {k} + """ + expr_free_sym = self.function.free_symbols + for var, inf, sup in self._limits: + expr_free_sym.discard(var) + curr_free_syms = inf.free_symbols.union(sup.free_symbols) + expr_free_sym = expr_free_sym.union(curr_free_syms) + return expr_free_sym + + @property + def variables(self): + """The tuples of the variables in the limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.variables + [i, j] + """ + return [l[0] for l in self._limits] + + @property + def bound_symbols(self): + """The list of dummy variables. + + Note + ==== + + Note that all variables are dummy variables since a limit without + lower bound or upper bound is not accepted. + """ + return [l[0] for l in self._limits if len(l) != 1] + + @property + def shape(self): + """ + The shape of the expanded array, which may have symbols. + + Note + ==== + + Both the lower and the upper bounds are included while + calculating the shape. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.shape + (4, 3) + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.shape + (4, k + 3) + """ + return self._shape + + @property + def is_shape_numeric(self): + """ + Test if the array is shape-numeric which means there is no symbolic + dimension. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.is_shape_numeric + True + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.is_shape_numeric + False + """ + for _, inf, sup in self._limits: + if Basic(inf, sup).atoms(Symbol): + return False + return True + + def rank(self): + """The rank of the expanded array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.rank() + 2 + """ + return self._rank + + def __len__(self): + """ + The length of the expanded array which means the number + of elements in the array. + + Raises + ====== + + ValueError : When the length of the array is symbolic + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> len(a) + 12 + """ + if self._loop_size.free_symbols: + raise ValueError('Symbolic length is not supported') + return self._loop_size + + @classmethod + def _check_limits_validity(cls, function, limits): + #limits = sympify(limits) + new_limits = [] + for var, inf, sup in limits: + var = _sympify(var) + inf = _sympify(inf) + #since this is stored as an argument, it should be + #a Tuple + if isinstance(sup, list): + sup = Tuple(*sup) + else: + sup = _sympify(sup) + new_limits.append(Tuple(var, inf, sup)) + if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms() + for i in [inf, sup]): + raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)') + if (inf > sup) == True: + raise ValueError('Lower bound should be inferior to upper bound') + if var in inf.free_symbols or var in sup.free_symbols: + raise ValueError('Variable should not be part of its bounds') + return new_limits + + @classmethod + def _calculate_shape_from_limits(cls, limits): + return tuple([sup - inf + 1 for _, inf, sup in limits]) + + @classmethod + def _calculate_loop_size(cls, shape): + if not shape: + return 0 + loop_size = 1 + for l in shape: + loop_size = loop_size * l + + return loop_size + + def doit(self, **hints): + if not self.is_shape_numeric: + return self + + return self._expand_array() + + def _expand_array(self): + res = [] + for values in itertools.product(*[range(inf, sup+1) + for var, inf, sup + in self._limits]): + res.append(self._get_element(values)) + + return ImmutableDenseNDimArray(res, self.shape) + + def _get_element(self, values): + temp = self.function + for var, val in zip(self.variables, values): + temp = temp.subs(var, val) + return temp + + def tolist(self): + """Transform the expanded array to a list. + + Raises + ====== + + ValueError : When there is a symbolic dimension + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tolist() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + """ + if self.is_shape_numeric: + return self._expand_array().tolist() + + raise ValueError("A symbolic array cannot be expanded to a list") + + def tomatrix(self): + """Transform the expanded array to a matrix. + + Raises + ====== + + ValueError : When there is a symbolic dimension + ValueError : When the rank of the expanded array is not equal to 2 + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tomatrix() + Matrix([ + [11, 12, 13], + [21, 22, 23], + [31, 32, 33], + [41, 42, 43]]) + """ + from sympy.matrices import Matrix + + if not self.is_shape_numeric: + raise ValueError("A symbolic array cannot be expanded to a matrix") + if self._rank != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self._expand_array().tomatrix()) + + +def isLambda(v): + LAMBDA = lambda: 0 + return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ + +class ArrayComprehensionMap(ArrayComprehension): + ''' + A subclass of ArrayComprehension dedicated to map external function lambda. + + Notes + ===== + + Only the lambda function is considered. + At most one argument in lambda function is accepted in order to avoid ambiguity + in value assignment. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehensionMap + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4)) + >>> a.doit() + [1, 1, 1, 1] + >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4)) + >>> b.doit() + [2, 3, 4, 5] + + ''' + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + + if not isLambda(function): + raise ValueError('Data type not supported') + + arglist = cls._check_limits_validity(function, symbols) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + obj._lambda = function + return obj + + @property + def func(self): + class _(ArrayComprehensionMap): + def __new__(cls, *args, **kwargs): + return ArrayComprehensionMap(self._lambda, *args, **kwargs) + return _ + + def _get_element(self, values): + temp = self._lambda + if self._lambda.__code__.co_argcount == 0: + temp = temp() + elif self._lambda.__code__.co_argcount == 1: + temp = temp(functools.reduce(lambda a, b: a*b, values)) + return temp diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_derivatives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..a38db6caefe256a8c7e1f3415b78351b3787fee9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/array_derivatives.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from sympy.core.expr import Expr +from sympy.core.function import Derivative +from sympy.core.numbers import Integer +from sympy.matrices.matrixbase import MatrixBase +from .ndim_array import NDimArray +from .arrayop import derive_by_array +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.matrices.expressions.matexpr import _matrix_derivative + + +class ArrayDerivative(Derivative): + + is_scalar = False + + def __new__(cls, expr, *variables, **kwargs): + obj = super().__new__(cls, expr, *variables, **kwargs) + if isinstance(obj, ArrayDerivative): + obj._shape = obj._get_shape() + return obj + + def _get_shape(self): + shape = () + for v, count in self.variable_count: + if hasattr(v, "shape"): + for i in range(count): + shape += v.shape + if hasattr(self.expr, "shape"): + shape += self.expr.shape + return shape + + @property + def shape(self): + return self._shape + + @classmethod + def _get_zero_with_shape_like(cls, expr): + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.zeros(*expr.shape) + elif isinstance(expr, MatrixExpr): + return ZeroMatrix(*expr.shape) + else: + raise RuntimeError("Unable to determine shape of array-derivative.") + + @staticmethod + def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixBase) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return ZeroMatrix(*v.shape) + + @staticmethod + def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_matrix_by_scalar(expr: MatrixBase, v: Expr) -> Expr: + return _matrix_derivative(expr, v) + + @staticmethod + def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr: + return expr._eval_derivative(v) + + @staticmethod + def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr: + return expr.applyfunc(lambda x: x.diff(v)) + + @staticmethod + def _call_derive_default(expr: Expr, v: Expr) -> Expr | None: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return None + + @classmethod + def _dispatch_eval_derivative_n_times(cls, expr, v, count): + # Evaluate the derivative `n` times. If + # `_eval_derivative_n_times` is not overridden by the current + # object, the default in `Basic` will call a loop over + # `_eval_derivative`: + + if not isinstance(count, (int, Integer)) or ((count <= 0) == True): + return None + + # TODO: this could be done with multiple-dispatching: + if expr.is_scalar: + if isinstance(v, MatrixBase): + result = cls._call_derive_scalar_by_matrix(expr, v) + elif isinstance(v, MatrixExpr): + result = cls._call_derive_scalar_by_matexpr(expr, v) + elif isinstance(v, NDimArray): + result = cls._call_derive_scalar_by_array(expr, v) + elif v.is_scalar: + # scalar by scalar has a special + return super()._dispatch_eval_derivative_n_times(expr, v, count) + else: + return None + elif v.is_scalar: + if isinstance(expr, MatrixBase): + result = cls._call_derive_matrix_by_scalar(expr, v) + elif isinstance(expr, MatrixExpr): + result = cls._call_derive_matexpr_by_scalar(expr, v) + elif isinstance(expr, NDimArray): + result = cls._call_derive_array_by_scalar(expr, v) + else: + return None + else: + # Both `expr` and `v` are some array/matrix type: + if isinstance(expr, MatrixBase) or isinstance(v, MatrixBase): + result = derive_by_array(expr, v) + elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr): + result = cls._call_derive_default(expr, v) + elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr): + # if one expression is a symbolic matrix expression while the other isn't, don't evaluate: + return None + else: + result = derive_by_array(expr, v) + if result is None: + return None + if count == 1: + return result + else: + return cls._dispatch_eval_derivative_n_times(result, v, count - 1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/arrayop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..a81e6b381a8a93f0cd585278a4be0259b06406dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/arrayop.py @@ -0,0 +1,528 @@ +import itertools +from collections.abc import Iterable + +from sympy.core._print_helpers import Printable +from sympy.core.containers import Tuple +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.array.dense_ndim_array import DenseNDimArray, ImmutableDenseNDimArray +from sympy.tensor.array.sparse_ndim_array import SparseNDimArray + + +def _arrayfy(a): + from sympy.matrices import MatrixBase + + if isinstance(a, NDimArray): + return a + if isinstance(a, (MatrixBase, list, tuple, Tuple)): + return ImmutableDenseNDimArray(a) + return a + + +def tensorproduct(*args): + """ + Tensor product among scalars or array-like objects. + + The equivalent operator for array expressions is ``ArrayTensorProduct``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, Array + >>> from sympy.abc import x, y, z, t + >>> A = Array([[1, 2], [3, 4]]) + >>> B = Array([x, y]) + >>> tensorproduct(A, B) + [[[x, y], [2*x, 2*y]], [[3*x, 3*y], [4*x, 4*y]]] + >>> tensorproduct(A, x) + [[x, 2*x], [3*x, 4*x]] + >>> tensorproduct(A, B, B) + [[[[x**2, x*y], [x*y, y**2]], [[2*x**2, 2*x*y], [2*x*y, 2*y**2]]], [[[3*x**2, 3*x*y], [3*x*y, 3*y**2]], [[4*x**2, 4*x*y], [4*x*y, 4*y**2]]]] + + Applying this function on two matrices will result in a rank 4 array. + + >>> from sympy import Matrix, eye + >>> m = Matrix([[x, y], [z, t]]) + >>> p = tensorproduct(eye(3), m) + >>> p + [[[[x, y], [z, t]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[x, y], [z, t]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[x, y], [z, t]]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayTensorProduct + + """ + from sympy.tensor.array import SparseNDimArray, ImmutableSparseNDimArray + + if len(args) == 0: + return S.One + if len(args) == 1: + return _arrayfy(args[0]) + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if any(isinstance(arg, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)) for arg in args): + return ArrayTensorProduct(*args) + if len(args) > 2: + return tensorproduct(tensorproduct(args[0], args[1]), *args[2:]) + + # length of args is 2: + a, b = map(_arrayfy, args) + + if not isinstance(a, NDimArray) or not isinstance(b, NDimArray): + return a*b + + if isinstance(a, SparseNDimArray) and isinstance(b, SparseNDimArray): + lp = len(b) + new_array = {k1*lp + k2: v1*v2 for k1, v1 in a._sparse_array.items() for k2, v2 in b._sparse_array.items()} + return ImmutableSparseNDimArray(new_array, a.shape + b.shape) + + product_list = [i*j for i in Flatten(a) for j in Flatten(b)] + return ImmutableDenseNDimArray(product_list, a.shape + b.shape) + + +def _util_contraction_diagonal(array, *contraction_or_diagonal_axes): + array = _arrayfy(array) + + # Verify contraction_axes: + taken_dims = set() + for axes_group in contraction_or_diagonal_axes: + if not isinstance(axes_group, Iterable): + raise ValueError("collections of contraction/diagonal axes expected") + + dim = array.shape[axes_group[0]] + + for d in axes_group: + if d in taken_dims: + raise ValueError("dimension specified more than once") + if dim != array.shape[d]: + raise ValueError("cannot contract or diagonalize between axes of different dimension") + taken_dims.add(d) + + rank = array.rank() + + remaining_shape = [dim for i, dim in enumerate(array.shape) if i not in taken_dims] + cum_shape = [0]*rank + _cumul = 1 + for i in range(rank): + cum_shape[rank - i - 1] = _cumul + _cumul *= int(array.shape[rank - i - 1]) + + # DEFINITION: by absolute position it is meant the position along the one + # dimensional array containing all the tensor components. + + # Possible future work on this module: move computation of absolute + # positions to a class method. + + # Determine absolute positions of the uncontracted indices: + remaining_indices = [[cum_shape[i]*j for j in range(array.shape[i])] + for i in range(rank) if i not in taken_dims] + + # Determine absolute positions of the contracted indices: + summed_deltas = [] + for axes_group in contraction_or_diagonal_axes: + lidx = [] + for js in range(array.shape[axes_group[0]]): + lidx.append(sum(cum_shape[ig] * js for ig in axes_group)) + summed_deltas.append(lidx) + + return array, remaining_indices, remaining_shape, summed_deltas + + +def tensorcontraction(array, *contraction_axes): + """ + Contraction of an array-like object on the specified axes. + + The equivalent operator for array expressions is ``ArrayContraction``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy import Array, tensorcontraction + >>> from sympy import Matrix, eye + >>> tensorcontraction(eye(3), (0, 1)) + 3 + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensorcontraction(A, (0, 2)) + [21, 30] + + Matrix multiplication may be emulated with a proper combination of + ``tensorcontraction`` and ``tensorproduct`` + + >>> from sympy import tensorproduct + >>> from sympy.abc import a,b,c,d,e,f,g,h + >>> m1 = Matrix([[a, b], [c, d]]) + >>> m2 = Matrix([[e, f], [g, h]]) + >>> p = tensorproduct(m1, m2) + >>> p + [[[[a*e, a*f], [a*g, a*h]], [[b*e, b*f], [b*g, b*h]]], [[[c*e, c*f], [c*g, c*h]], [[d*e, d*f], [d*g, d*h]]]] + >>> tensorcontraction(p, (1, 2)) + [[a*e + b*g, a*f + b*h], [c*e + d*g, c*f + d*h]] + >>> m1*m2 + Matrix([ + [a*e + b*g, a*f + b*h], + [c*e + d*g, c*f + d*h]]) + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayContraction + + """ + from sympy.tensor.array.expressions.array_expressions import _array_contraction + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_contraction(array, *contraction_axes) + + array, remaining_indices, remaining_shape, summed_deltas = _util_contraction_diagonal(array, *contraction_axes) + + # Compute the contracted array: + # + # 1. external for loops on all uncontracted indices. + # Uncontracted indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all contracted indices. + # It sums the values of the absolute contracted index and the absolute + # uncontracted index for the external loop. + contracted_array = [] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = S.Zero + for sum_to_index in itertools.product(*summed_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum += array[idx] + + contracted_array.append(isum) + + if len(remaining_indices) == 0: + assert len(contracted_array) == 1 + return contracted_array[0] + + return type(array)(contracted_array, remaining_shape) + + +def tensordiagonal(array, *diagonal_axes): + """ + Diagonalization of an array-like object on the specified axes. + + This is equivalent to multiplying the expression by Kronecker deltas + uniting the axes. + + The diagonal indices are put at the end of the axes. + + The equivalent operator for array expressions is ``ArrayDiagonal``, which + can be used to keep the expression unevaluated. + + Examples + ======== + + ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is + equivalent to the diagonal of the matrix: + + >>> from sympy import Array, tensordiagonal + >>> from sympy import Matrix, eye + >>> tensordiagonal(eye(3), (0, 1)) + [1, 1, 1] + + >>> from sympy.abc import a,b,c,d + >>> m1 = Matrix([[a, b], [c, d]]) + >>> tensordiagonal(m1, [0, 1]) + [a, d] + + In case of higher dimensional arrays, the diagonalized out dimensions + are appended removed and appended as a single dimension at the end: + + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensordiagonal(A, (0, 2)) + [[0, 7, 14], [3, 10, 17]] + >>> from sympy import permutedims + >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0]) + True + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayDiagonal + + """ + if any(len(i) <= 1 for i in diagonal_axes): + raise ValueError("need at least two axes to diagonalize") + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_diagonal(array, *diagonal_axes) + + ArrayDiagonal._validate(array, *diagonal_axes) + + array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(array, *diagonal_axes) + + # Compute the diagonalized array: + # + # 1. external for loops on all undiagonalized indices. + # Undiagonalized indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all diagonal indices. + # It appends the values of the absolute diagonalized index and the absolute + # undiagonalized index for the external loop. + diagonalized_array = [] + diagonal_shape = [len(i) for i in diagonal_deltas] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = [] + for sum_to_index in itertools.product(*diagonal_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum.append(array[idx]) + + isum = type(array)(isum).reshape(*diagonal_shape) + diagonalized_array.append(isum) + + return type(array)(diagonalized_array, remaining_shape + diagonal_shape) + + +def derive_by_array(expr, dx): + r""" + Derivative by arrays. Supports both arrays and scalars. + + The equivalent operator for array expressions is ``array_derive``. + + Explanation + =========== + + Given the array `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` + this function will return a new array `B` defined by + + `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + + Examples + ======== + + >>> from sympy import derive_by_array + >>> from sympy.abc import x, y, z, t + >>> from sympy import cos + >>> derive_by_array(cos(x*t), x) + -t*sin(t*x) + >>> derive_by_array(cos(x*t), [x, y, z, t]) + [-t*sin(t*x), 0, 0, -x*sin(t*x)] + >>> derive_by_array([x, y**2*z], [[x, y], [z, t]]) + [[[1, 0], [0, 2*y*z]], [[0, y**2], [0, 0]]] + + """ + from sympy.matrices import MatrixBase + from sympy.tensor.array import SparseNDimArray + array_types = (Iterable, MatrixBase, NDimArray) + + if isinstance(dx, array_types): + dx = ImmutableDenseNDimArray(dx) + for i in dx: + if not i._diff_wrt: + raise ValueError("cannot derive by this array") + + if isinstance(expr, array_types): + if isinstance(expr, NDimArray): + expr = expr.as_immutable() + else: + expr = ImmutableDenseNDimArray(expr) + + if isinstance(dx, array_types): + if isinstance(expr, SparseNDimArray): + lp = len(expr) + new_array = {k + i*lp: v + for i, x in enumerate(Flatten(dx)) + for k, v in expr.diff(x)._sparse_array.items()} + else: + new_array = [[y.diff(x) for y in Flatten(expr)] for x in Flatten(dx)] + return type(expr)(new_array, dx.shape + expr.shape) + else: + return expr.diff(dx) + else: + expr = _sympify(expr) + if isinstance(dx, array_types): + return ImmutableDenseNDimArray([expr.diff(i) for i in Flatten(dx)], dx.shape) + else: + dx = _sympify(dx) + return diff(expr, dx) + + +def permutedims(expr, perm=None, index_order_old=None, index_order_new=None): + """ + Permutes the indices of an array. + + Parameter specifies the permutation of the indices. + + The equivalent operator for array expressions is ``PermuteDims``, which can + be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.abc import x, y, z, t + >>> from sympy import sin + >>> from sympy import Array, permutedims + >>> a = Array([[x, y, z], [t, sin(x), 0]]) + >>> a + [[x, y, z], [t, sin(x), 0]] + >>> permutedims(a, (1, 0)) + [[x, t], [y, sin(x)], [z, 0]] + + If the array is of second order, ``transpose`` can be used: + + >>> from sympy import transpose + >>> transpose(a) + [[x, t], [y, sin(x)], [z, 0]] + + Examples on higher dimensions: + + >>> b = Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> permutedims(b, (2, 1, 0)) + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, (1, 2, 0)) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + An alternative way to specify the same permutations as in the previous + lines involves passing the *old* and *new* indices, either as a list or as + a string: + + >>> permutedims(b, index_order_old="cba", index_order_new="abc") + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, index_order_old="cab", index_order_new="abc") + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + ``Permutation`` objects are also allowed: + + >>> from sympy.combinatorics import Permutation + >>> permutedims(b, Permutation([1, 2, 0])) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.PermuteDims + + """ + from sympy.tensor.array import SparseNDimArray + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _permute_dims + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.tensor.array.expressions import PermuteDims + from sympy.tensor.array.expressions.array_expressions import get_rank + perm = PermuteDims._get_permutation_from_arguments(perm, index_order_old, index_order_new, get_rank(expr)) + if isinstance(expr, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _permute_dims(expr, perm) + + if not isinstance(expr, NDimArray): + expr = ImmutableDenseNDimArray(expr) + + from sympy.combinatorics import Permutation + if not isinstance(perm, Permutation): + perm = Permutation(list(perm)) + + if perm.size != expr.rank(): + raise ValueError("wrong permutation size") + + # Get the inverse permutation: + iperm = ~perm + new_shape = perm(expr.shape) + + if isinstance(expr, SparseNDimArray): + return type(expr)({tuple(perm(expr._get_tuple_index(k))): v + for k, v in expr._sparse_array.items()}, new_shape) + + indices_span = perm([range(i) for i in expr.shape]) + + new_array = [None]*len(expr) + for i, idx in enumerate(itertools.product(*indices_span)): + t = iperm(idx) + new_array[i] = expr[t] + + return type(expr)(new_array, new_shape) + + +class Flatten(Printable): + """ + Flatten an iterable object to a list in a lazy-evaluation way. + + Notes + ===== + + This class is an iterator with which the memory cost can be economised. + Optimisation has been considered to ameliorate the performance for some + specific data types like DenseNDimArray and SparseNDimArray. + + Examples + ======== + + >>> from sympy.tensor.array.arrayop import Flatten + >>> from sympy.tensor.array import Array + >>> A = Array(range(6)).reshape(2, 3) + >>> Flatten(A) + Flatten([[0, 1, 2], [3, 4, 5]]) + >>> [i for i in Flatten(A)] + [0, 1, 2, 3, 4, 5] + """ + def __init__(self, iterable): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import NDimArray + + if not isinstance(iterable, (Iterable, MatrixBase)): + raise NotImplementedError("Data type not yet supported") + + if isinstance(iterable, list): + iterable = NDimArray(iterable) + + self._iter = iterable + self._idx = 0 + + def __iter__(self): + return self + + def __next__(self): + from sympy.matrices.matrixbase import MatrixBase + + if len(self._iter) > self._idx: + if isinstance(self._iter, DenseNDimArray): + result = self._iter._array[self._idx] + + elif isinstance(self._iter, SparseNDimArray): + if self._idx in self._iter._sparse_array: + result = self._iter._sparse_array[self._idx] + else: + result = 0 + + elif isinstance(self._iter, MatrixBase): + result = self._iter[self._idx] + + elif hasattr(self._iter, '__next__'): + result = next(self._iter) + + else: + result = self._iter[self._idx] + + else: + raise StopIteration + + self._idx += 1 + return result + + def next(self): + return self.__next__() + + def _sympystr(self, printer): + return type(self).__name__ + '(' + printer._print(self._iter) + ')' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/dense_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/dense_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..576e452c55d8d374ca1f72c553f3a64de7227d43 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/dense_ndim_array.py @@ -0,0 +1,206 @@ +import functools +from typing import List + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind +from sympy.utilities.iterables import flatten + + +class DenseNDimArray(NDimArray): + + _array: List[Basic] + + def __new__(self, *args, **kwargs): + return ImmutableDenseNDimArray(*args, **kwargs) + + @property + def kind(self) -> ArrayKind: + return ArrayKind._union(self._array) + + def __getitem__(self, index): + """ + Allows to get items from N-dim array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + + Symbolic index: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(1, 1)`: + + >>> a[i, j].subs({i: 1, j: 1}) + 3 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._array[self._parse_index(i)] for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._array[index] + + @classmethod + def zeros(cls, *shape): + list_length = functools.reduce(lambda x, y: x*y, shape, S.One) + return cls._new(([0]*list_length,), shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + + """ + from sympy.matrices import Matrix + + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self.shape[0], self.shape[1], self._array) + + def reshape(self, *newshape): + """ + Returns MutableDenseNDimArray instance with new shape. Elements number + must be suitable to new shape. The only argument of method sets + new shape. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a.shape + (2, 3) + >>> a + [[1, 2, 3], [4, 5, 6]] + >>> b = a.reshape(3, 2) + >>> b.shape + (3, 2) + >>> b + [[1, 2], [3, 4], [5, 6]] + + """ + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError('Expecting reshape size to %d but got prod(%s) = %d' % ( + self._loop_size, str(newshape), new_total_size)) + + # there is no `.func` as this class does not subtype `Basic`: + return type(self)(self._array, newshape) + + +class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore + def __new__(cls, iterable, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + flat_list = flatten(flat_list) + flat_list = Tuple(*flat_list) + self = Basic.__new__(cls, flat_list, shape, **kwargs) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1) + return self + + def __setitem__(self, index, value): + raise TypeError('immutable N-dim array') + + def as_mutable(self): + return MutableDenseNDimArray(self) + + def _eval_simplify(self, **kwargs): + from sympy.simplify.simplify import simplify + return self.applyfunc(simplify) + +class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + flat_list = flatten(flat_list) + self = object.__new__(cls) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 2) + >>> a[0,0] = 1 + >>> a[1,1] = 1 + >>> a + [[1, 0], [0, 1]] + + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + self._array[self._parse_index(i)] = value[other_i] + else: + index = self._parse_index(index) + self._setter_iterable_check(value) + value = _sympify(value) + self._array[index] = value + + def as_immutable(self): + return ImmutableDenseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._array for i in j.free_symbols} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1658241782cdf0e38a30c43a6d67f9811297f4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__init__.py @@ -0,0 +1,178 @@ +r""" +Array expressions are expressions representing N-dimensional arrays, without +evaluating them. These expressions represent in a certain way abstract syntax +trees of operations on N-dimensional arrays. + +Every N-dimensional array operator has a corresponding array expression object. + +Table of correspondences: + +=============================== ============================= + Array operator Array expression operator +=============================== ============================= + tensorproduct ArrayTensorProduct + tensorcontraction ArrayContraction + tensordiagonal ArrayDiagonal + permutedims PermuteDims +=============================== ============================= + +Examples +======== + +``ArraySymbol`` objects are the N-dimensional equivalent of ``MatrixSymbol`` +objects in the matrix module: + +>>> from sympy.tensor.array.expressions import ArraySymbol +>>> from sympy.abc import i, j, k +>>> A = ArraySymbol("A", (3, 2, 4)) +>>> A.shape +(3, 2, 4) +>>> A[i, j, k] +A[i, j, k] +>>> A.as_explicit() +[[[A[0, 0, 0], A[0, 0, 1], A[0, 0, 2], A[0, 0, 3]], + [A[0, 1, 0], A[0, 1, 1], A[0, 1, 2], A[0, 1, 3]]], + [[A[1, 0, 0], A[1, 0, 1], A[1, 0, 2], A[1, 0, 3]], + [A[1, 1, 0], A[1, 1, 1], A[1, 1, 2], A[1, 1, 3]]], + [[A[2, 0, 0], A[2, 0, 1], A[2, 0, 2], A[2, 0, 3]], + [A[2, 1, 0], A[2, 1, 1], A[2, 1, 2], A[2, 1, 3]]]] + +Component-explicit arrays can be added inside array expressions: + +>>> from sympy import Array +>>> from sympy import tensorproduct +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> a = Array([1, 2, 3]) +>>> b = Array([i, j, k]) +>>> expr = ArrayTensorProduct(a, b, b) +>>> expr +ArrayTensorProduct([1, 2, 3], [i, j, k], [i, j, k]) +>>> expr.as_explicit() == tensorproduct(a, b, b) +True + +Constructing array expressions from index-explicit forms +-------------------------------------------------------- + +Array expressions are index-implicit. This means they do not use any indices to +represent array operations. The function ``convert_indexed_to_array( ... )`` +may be used to convert index-explicit expressions to array expressions. +It takes as input two parameters: the index-explicit expression and the order +of the indices: + +>>> from sympy.tensor.array.expressions import convert_indexed_to_array +>>> from sympy import Sum +>>> A = ArraySymbol("A", (3, 3)) +>>> B = ArraySymbol("B", (3, 3)) +>>> convert_indexed_to_array(A[i, j], [i, j]) +A +>>> convert_indexed_to_array(A[i, j], [j, i]) +PermuteDims(A, (0 1)) +>>> convert_indexed_to_array(A[i, j] + B[j, i], [i, j]) +ArrayAdd(A, PermuteDims(B, (0 1))) +>>> convert_indexed_to_array(Sum(A[i, j]*B[j, k], (j, 0, 2)), [i, k]) +ArrayContraction(ArrayTensorProduct(A, B), (1, 2)) + +The diagonal of a matrix in the array expression form: + +>>> convert_indexed_to_array(A[i, i], [i]) +ArrayDiagonal(A, (0, 1)) + +The trace of a matrix in the array expression form: + +>>> convert_indexed_to_array(Sum(A[i, i], (i, 0, 2)), [i]) +ArrayContraction(A, (0, 1)) + +Compatibility with matrices +--------------------------- + +Array expressions can be mixed with objects from the matrix module: + +>>> from sympy import MatrixSymbol +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> M = MatrixSymbol("M", 3, 3) +>>> N = MatrixSymbol("N", 3, 3) + +Express the matrix product in the array expression form: + +>>> from sympy.tensor.array.expressions import convert_matrix_to_array +>>> expr = convert_matrix_to_array(M*N) +>>> expr +ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + +The expression can be converted back to matrix form: + +>>> from sympy.tensor.array.expressions import convert_array_to_matrix +>>> convert_array_to_matrix(expr) +M*N + +Add a second contraction on the remaining axes in order to get the trace of `M \cdot N`: + +>>> expr_tr = ArrayContraction(expr, (0, 1)) +>>> expr_tr +ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (0, 1)) + +Flatten the expression by calling ``.doit()`` and remove the nested array contraction operations: + +>>> expr_tr.doit() +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Get the explicit form of the array expression: + +>>> expr.as_explicit() +[[M[0, 0]*N[0, 0] + M[0, 1]*N[1, 0] + M[0, 2]*N[2, 0], M[0, 0]*N[0, 1] + M[0, 1]*N[1, 1] + M[0, 2]*N[2, 1], M[0, 0]*N[0, 2] + M[0, 1]*N[1, 2] + M[0, 2]*N[2, 2]], + [M[1, 0]*N[0, 0] + M[1, 1]*N[1, 0] + M[1, 2]*N[2, 0], M[1, 0]*N[0, 1] + M[1, 1]*N[1, 1] + M[1, 2]*N[2, 1], M[1, 0]*N[0, 2] + M[1, 1]*N[1, 2] + M[1, 2]*N[2, 2]], + [M[2, 0]*N[0, 0] + M[2, 1]*N[1, 0] + M[2, 2]*N[2, 0], M[2, 0]*N[0, 1] + M[2, 1]*N[1, 1] + M[2, 2]*N[2, 1], M[2, 0]*N[0, 2] + M[2, 1]*N[1, 2] + M[2, 2]*N[2, 2]]] + +Express the trace of a matrix: + +>>> from sympy import Trace +>>> convert_matrix_to_array(Trace(M)) +ArrayContraction(M, (0, 1)) +>>> convert_matrix_to_array(Trace(M*N)) +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Express the transposition of a matrix (will be expressed as a permutation of the axes: + +>>> convert_matrix_to_array(M.T) +PermuteDims(M, (0 1)) + +Compute the derivative array expressions: + +>>> from sympy.tensor.array.expressions import array_derive +>>> d = array_derive(M, M) +>>> d +PermuteDims(ArrayTensorProduct(I, I), (3)(1 2)) + +Verify that the derivative corresponds to the form computed with explicit matrices: + +>>> d.as_explicit() +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] +>>> Me = M.as_explicit() +>>> Me.diff(Me) +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] + +""" + +__all__ = [ + "ArraySymbol", "ArrayElement", "ZeroArray", "OneArray", + "ArrayTensorProduct", + "ArrayContraction", + "ArrayDiagonal", + "PermuteDims", + "ArrayAdd", + "ArrayElementwiseApplyFunc", + "Reshape", + "convert_array_to_matrix", + "convert_matrix_to_array", + "convert_array_to_indexed", + "convert_indexed_to_array", + "array_derive", +] + +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, PermuteDims, ArrayDiagonal, \ + ArrayContraction, Reshape, ArraySymbol, ArrayElement, ZeroArray, OneArray, ArrayElementwiseApplyFunc +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba8f4ef306c0e545b412fa419fc7c2809845be3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..274a99e8141709b893e6df207b2c2f62bce87ea6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa3efb404b7d4fd158eb58ac0d77f7fa24ce45b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b30950e58faebb953eaa41258d1cf6567757c74d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a7fde46e09d9e0e69915fe76f2caaba748b52f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115e28d3ee061d15a55c7d338c510556709494c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..642e5410ae6928b803a8d0a048ce5beea95923ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba210b1e96928f91dfc3f7a5d4cee77bf3e00d07 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f4783aa493e3b173e4c36ca479519297c73e00a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2a37206f8ba958ff211b823dbe8de19dd761bc3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3498461006c4f5139959692401b7e606c8d7e865 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/array_expressions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..f062e3de4c24987d62ba0b3a19fe474fb4687940 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/array_expressions.py @@ -0,0 +1,1969 @@ +from __future__ import annotations +import collections.abc +import operator +from collections import defaultdict, Counter +from functools import reduce +import itertools +from itertools import accumulate + +import typing + +from sympy.core.numbers import Integer +from sympy.core.relational import Equality +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import (Dummy, Symbol) +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.diagonal import diagonalize_vector +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \ + _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \ + _build_push_indices_down_func_transformation +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.core.sympify import _sympify + + +class _ArrayExpr(Expr): + shape: tuple[Expr, ...] + + def __getitem__(self, item): + if not isinstance(item, collections.abc.Iterable): + item = (item,) + ArrayElement._check_shape(self, item) + return self._get(item) + + def _get(self, item): + return _get_array_element_or_slice(self, item) + + +class ArraySymbol(_ArrayExpr): + """ + Symbol representing an array expression + """ + + _iterable = False + + def __new__(cls, symbol, shape: typing.Iterable) -> "ArraySymbol": + if isinstance(symbol, str): + symbol = Symbol(symbol) + # symbol = _sympify(symbol) + shape = Tuple(*map(_sympify, shape)) + obj = Expr.__new__(cls, symbol, shape) + return obj + + @property + def name(self): + return self._args[0] + + @property + def shape(self): + return self._args[1] + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("cannot express explicit array with symbolic shape") + data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])] + return ImmutableDenseNDimArray(data).reshape(*self.shape) + + +class ArrayElement(Expr): + """ + An element of an array. + """ + + _diff_wrt = True + is_symbol = True + is_commutative = True + + def __new__(cls, name, indices): + if isinstance(name, str): + name = Symbol(name) + name = _sympify(name) + if not isinstance(indices, collections.abc.Iterable): + indices = (indices,) + indices = _sympify(tuple(indices)) + cls._check_shape(name, indices) + obj = Expr.__new__(cls, name, indices) + return obj + + @classmethod + def _check_shape(cls, name, indices): + indices = tuple(indices) + if hasattr(name, "shape"): + index_error = IndexError("number of indices does not match shape of the array") + if len(indices) != len(name.shape): + raise index_error + if any((i >= s) == True for i, s in zip(indices, name.shape)): + raise ValueError("shape is out of bounds") + if any((i < 0) == True for i in indices): + raise ValueError("shape contains negative values") + + @property + def name(self): + return self._args[0] + + @property + def indices(self): + return self._args[1] + + def _eval_derivative(self, s): + if not isinstance(s, ArrayElement): + return S.Zero + + if s == self: + return S.One + + if s.name != self.name: + return S.Zero + + return Mul.fromiter(KroneckerDelta(i, j) for i, j in zip(self.indices, s.indices)) + + +class ZeroArray(_ArrayExpr): + """ + Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.Zero + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray.zeros(*self.shape) + + def _get(self, item): + return S.Zero + + +class OneArray(_ArrayExpr): + """ + Symbolic array of ones. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.One + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape) + + def _get(self, item): + return S.One + + +class _CodegenArrayAbstract(Basic): + + @property + def subranks(self): + """ + Returns the ranks of the objects in the uppermost tensor product inside + the current object. In case no tensor products are contained, return + the atomic ranks. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> P = MatrixSymbol("P", 3, 3) + + Important: do not confuse the rank of the matrix with the rank of an array. + + >>> tp = tensorproduct(M, N, P) + >>> tp.subranks + [2, 2, 2] + + >>> co = tensorcontraction(tp, (1, 2), (3, 4)) + >>> co.subranks + [2, 2, 2] + """ + return self._subranks[:] + + def subrank(self): + """ + The sum of ``subranks``. + """ + return sum(self.subranks) + + @property + def shape(self): + return self._shape + + def doit(self, **hints): + deep = hints.get("deep", True) + if deep: + return self.func(*[arg.doit(**hints) for arg in self.args])._canonicalize() + else: + return self._canonicalize() + +class ArrayTensorProduct(_CodegenArrayAbstract): + r""" + Class to represent the tensor product of array-like objects. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + + canonicalize = kwargs.pop("canonicalize", False) + + ranks = [get_rank(arg) for arg in args] + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + shapes = [get_shape(i) for i in args] + + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = tuple(j for i in shapes for j in i) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + args = self._flatten(args) + + ranks = [get_rank(arg) for arg in args] + + # Check if there are nested permutation and lift them up: + permutation_cycles = [] + for i, arg in enumerate(args): + if not isinstance(arg, PermuteDims): + continue + permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form]) + args[i] = arg.expr + if permutation_cycles: + return _permute_dims(_array_tensor_product(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles)) + + if len(args) == 1: + return args[0] + + # If any object is a ZeroArray, return a ZeroArray: + if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args): + shapes = reduce(operator.add, [get_shape(i) for i in args], ()) + return ZeroArray(*shapes) + + # If there are contraction objects inside, transform the whole + # expression into `ArrayContraction`: + contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)} + if contractions: + ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args]) + contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices] + return _array_contraction(tp, *contraction_indices) + + diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)} + if diagonals: + inverse_permutation = [] + last_perm = [] + ranks = [get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + for i, arg in enumerate(args): + if isinstance(arg, ArrayDiagonal): + i1 = get_rank(arg) - len(arg.diagonal_indices) + i2 = len(arg.diagonal_indices) + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(i1)]) + last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)]) + else: + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))]) + inverse_permutation.extend(last_perm) + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args]) + ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args] + cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1] + diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices] + return _permute_dims(_array_diagonal(tp, *diagonal_indices), _af_invert(inverse_permutation)) + + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten(cls, args): + args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])] + return args + + def as_explicit(self): + return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class ArrayAdd(_CodegenArrayAbstract): + r""" + Class for elementwise array additions. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + ranks = [get_rank(arg) for arg in args] + ranks = list(set(ranks)) + if len(ranks) != 1: + raise ValueError("summing arrays of different ranks") + shapes = [arg.shape for arg in args] + if len({i for i in shapes if i is not None}) > 1: + raise ValueError("mismatching shapes in addition") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = shapes[0] + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + + # Flatten: + args = self._flatten_args(args) + + shapes = [get_shape(arg) for arg in args] + args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))] + if len(args) == 0: + if any(i for i in shapes if i is None): + raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object") + return ZeroArray(*shapes[0]) + elif len(args) == 1: + return args[0] + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten_args(cls, args): + new_args = [] + for arg in args: + if isinstance(arg, ArrayAdd): + new_args.extend(arg.args) + else: + new_args.append(arg) + return new_args + + def as_explicit(self): + return reduce( + operator.add, + [arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class PermuteDims(_CodegenArrayAbstract): + r""" + Class to represent permutation of axes of arrays. + + Examples + ======== + + >>> from sympy.tensor.array import permutedims + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> cg = permutedims(M, [1, 0]) + + The object ``cg`` represents the transposition of ``M``, as the permutation + ``[1, 0]`` will act on its indices by switching them: + + `M_{ij} \Rightarrow M_{ji}` + + This is evident when transforming back to matrix form: + + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> convert_array_to_matrix(cg) + M.T + + >>> N = MatrixSymbol("N", 3, 2) + >>> cg = permutedims(N, [1, 0]) + >>> cg.shape + (2, 3) + + There are optional parameters that can be used as alternative to the permutation: + + >>> from sympy.tensor.array.expressions import ArraySymbol, PermuteDims + >>> M = ArraySymbol("M", (1, 2, 3, 4, 5)) + >>> expr = PermuteDims(M, index_order_old="ijklm", index_order_new="kijml") + >>> expr + PermuteDims(M, (0 2 1)(3 4)) + >>> expr.shape + (3, 1, 2, 5, 4) + + Permutations of tensor products are simplified in order to achieve a + standard form: + + >>> from sympy.tensor.array import tensorproduct + >>> M = MatrixSymbol("M", 4, 5) + >>> tp = tensorproduct(M, N) + >>> tp.shape + (4, 5, 3, 2) + >>> perm1 = permutedims(tp, [2, 3, 1, 0]) + + The args ``(M, N)`` have been sorted and the permutation has been + simplified, the expression is equivalent: + + >>> perm1.expr.args + (N, M) + >>> perm1.shape + (3, 2, 5, 4) + >>> perm1.permutation + (2 3) + + The permutation in its array form has been simplified from + ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor + product `M` and `N` have been switched: + + >>> perm1.permutation.array_form + [0, 1, 3, 2] + + We can nest a second permutation: + + >>> perm2 = permutedims(perm1, [1, 0, 2, 3]) + >>> perm2.shape + (2, 3, 5, 4) + >>> perm2.permutation.array_form + [1, 0, 3, 2] + """ + + def __new__(cls, expr, permutation=None, index_order_old=None, index_order_new=None, **kwargs): + from sympy.combinatorics import Permutation + expr = _sympify(expr) + expr_rank = get_rank(expr) + permutation = cls._get_permutation_from_arguments(permutation, index_order_old, index_order_new, expr_rank) + permutation = Permutation(permutation) + permutation_size = permutation.size + if permutation_size != expr_rank: + raise ValueError("Permutation size must be the length of the shape of expr") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, expr, permutation) + obj._subranks = [get_rank(expr)] + shape = get_shape(expr) + if shape is None: + obj._shape = None + else: + obj._shape = tuple(shape[permutation(i)] for i in range(len(shape))) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + permutation = self.permutation + if isinstance(expr, PermuteDims): + subexpr = expr.expr + subperm = expr.permutation + permutation = permutation * subperm + expr = subexpr + if isinstance(expr, ArrayContraction): + expr, permutation = self._PermuteDims_denestarg_ArrayContraction(expr, permutation) + if isinstance(expr, ArrayTensorProduct): + expr, permutation = self._PermuteDims_denestarg_ArrayTensorProduct(expr, permutation) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return ZeroArray(*[expr.shape[i] for i in permutation.array_form]) + plist = permutation.array_form + if plist == sorted(plist): + return expr + return self.func(expr, permutation, canonicalize=False) + + @property + def expr(self): + return self.args[0] + + @property + def permutation(self): + return self.args[1] + + @classmethod + def _PermuteDims_denestarg_ArrayTensorProduct(cls, expr, permutation): + # Get the permutation in its image-form: + perm_image_form = _af_invert(permutation.array_form) + args = list(expr.args) + # Starting index global position for every arg: + cumul = list(accumulate([0] + expr.subranks)) + # Split `perm_image_form` into a list of list corresponding to the indices + # of every argument: + perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))] + # Create an index, target-position-key array: + ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)] + # Sort the array according to the target-position-key: + # In this way, we define a canonical way to sort the arguments according + # to the permutation. + ps.sort(key=lambda x: x[1]) + # Read the inverse-permutation (i.e. image-form) of the args: + perm_args_image_form = [i[0] for i in ps] + # Apply the args-permutation to the `args`: + args_sorted = [args[i] for i in perm_args_image_form] + # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`): + perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form] + new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i])) + return _array_tensor_product(*args_sorted), new_permutation + + @classmethod + def _PermuteDims_denestarg_ArrayContraction(cls, expr, permutation): + if not isinstance(expr, ArrayContraction): + return expr, permutation + if not isinstance(expr.expr, ArrayTensorProduct): + return expr, permutation + args = expr.expr.args + subranks = [get_rank(arg) for arg in expr.expr.args] + + contraction_indices = expr.contraction_indices + contraction_indices_flat = [j for i in contraction_indices for j in i] + cumul = list(accumulate([0] + subranks)) + + # Spread the permutation in its array form across the args in the corresponding + # tensor-product arguments with free indices: + permutation_array_blocks_up = [] + image_form = _af_invert(permutation.array_form) + counter = 0 + for i in range(len(subranks)): + current = [] + for j in range(cumul[i], cumul[i+1]): + if j in contraction_indices_flat: + continue + current.append(image_form[counter]) + counter += 1 + permutation_array_blocks_up.append(current) + + # Get the map of axis repositioning for every argument of tensor-product: + index_blocks = [list(range(cumul[i], cumul[i+1])) for i, e in enumerate(expr.subranks)] + index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks) + inverse_permutation = permutation**(-1) + index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up] + + # Sorting key is a list of tuple, first element is the index of `args`, second element of + # the tuple is the sorting key to sort `args` of the tensor product: + sorting_keys = list(enumerate(index_blocks_up_permuted)) + sorting_keys.sort(key=lambda x: x[1]) + + # Now we can get the permutation acting on the args in its image-form: + new_perm_image_form = [i[0] for i in sorting_keys] + # Apply the args-level permutation to various elements: + new_index_blocks = [index_blocks[i] for i in new_perm_image_form] + new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i]) + new_args = [args[i] for i in new_perm_image_form] + new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices] + new_expr = _array_contraction(_array_tensor_product(*new_args), *new_contraction_indices) + new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i])) + return new_expr, new_permutation + + @classmethod + def _check_permutation_mapping(cls, expr, permutation): + subranks = expr.subranks + index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])] + permuted_indices = [permutation(i) for i in range(expr.subrank())] + new_args = list(expr.args) + arg_candidate_index = index2arg[permuted_indices[0]] + current_indices = [] + new_permutation = [] + inserted_arg_cand_indices = set() + for i, idx in enumerate(permuted_indices): + if index2arg[idx] != arg_candidate_index: + new_permutation.extend(current_indices) + current_indices = [] + arg_candidate_index = index2arg[idx] + current_indices.append(idx) + arg_candidate_rank = subranks[arg_candidate_index] + if len(current_indices) == arg_candidate_rank: + new_permutation.extend(sorted(current_indices)) + local_current_indices = [j - min(current_indices) for j in current_indices] + i1 = index2arg[i] + new_args[i1] = _permute_dims(new_args[i1], Permutation(local_current_indices)) + inserted_arg_cand_indices.add(arg_candidate_index) + current_indices = [] + new_permutation.extend(current_indices) + + # TODO: swap args positions in order to simplify the expression: + # TODO: this should be in a function + args_positions = list(range(len(new_args))) + # Get possible shifts: + maps = {} + cumulative_subranks = [0] + list(accumulate(subranks)) + for i in range(len(subranks)): + s = {index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])} + if len(s) != 1: + continue + elem = next(iter(s)) + if i != elem: + maps[i] = elem + + # Find cycles in the map: + lines = [] + current_line = [] + while maps: + if len(current_line) == 0: + k, v = maps.popitem() + current_line.append(k) + else: + k = current_line[-1] + if k not in maps: + current_line = [] + continue + v = maps.pop(k) + if v in current_line: + lines.append(current_line) + current_line = [] + continue + current_line.append(v) + for line in lines: + for i, e in enumerate(line): + args_positions[line[(i + 1) % len(line)]] = e + + # TODO: function in order to permute the args: + permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)] + new_args = [new_args[i] for i in args_positions] + new_permutation_blocks = [permutation_blocks[i] for i in args_positions] + new_permutation2 = [j for i in new_permutation_blocks for j in i] + return _array_tensor_product(*new_args), Permutation(new_permutation2) # **(-1) + + @classmethod + def _check_if_there_are_closed_cycles(cls, expr, permutation): + args = list(expr.args) + subranks = expr.subranks + cyclic_form = permutation.cyclic_form + cumulative_subranks = [0] + list(accumulate(subranks)) + cyclic_min = [min(i) for i in cyclic_form] + cyclic_max = [max(i) for i in cyclic_form] + cyclic_keep = [] + for i, cycle in enumerate(cyclic_form): + flag = True + for j in range(len(cumulative_subranks) - 1): + if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]: + # Found a sinkable cycle. + args[j] = _permute_dims(args[j], Permutation([[k - cumulative_subranks[j] for k in cycle]])) + flag = False + break + if flag: + cyclic_keep.append(cycle) + return _array_tensor_product(*args), Permutation(cyclic_keep, size=permutation.size) + + def nest_permutation(self): + r""" + DEPRECATED. + """ + ret = self._nest_permutation(self.expr, self.permutation) + if ret is None: + return self + return ret + + @classmethod + def _nest_permutation(cls, expr, permutation): + if isinstance(expr, ArrayTensorProduct): + return _permute_dims(*cls._check_if_there_are_closed_cycles(expr, permutation)) + elif isinstance(expr, ArrayContraction): + # Invert tree hierarchy: put the contraction above. + cycles = permutation.cyclic_form + newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles) + newpermutation = Permutation(newcycles) + new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices] + return _array_contraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices) + elif isinstance(expr, ArrayAdd): + return _array_add(*[PermuteDims(arg, permutation) for arg in expr.args]) + return None + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return permutedims(expr, self.permutation) + + @classmethod + def _get_permutation_from_arguments(cls, permutation, index_order_old, index_order_new, dim): + if permutation is None: + if index_order_new is None or index_order_old is None: + raise ValueError("Permutation not defined") + return PermuteDims._get_permutation_from_index_orders(index_order_old, index_order_new, dim) + else: + if index_order_new is not None: + raise ValueError("index_order_new cannot be defined with permutation") + if index_order_old is not None: + raise ValueError("index_order_old cannot be defined with permutation") + return permutation + + @classmethod + def _get_permutation_from_index_orders(cls, index_order_old, index_order_new, dim): + if len(set(index_order_new)) != dim: + raise ValueError("wrong number of indices in index_order_new") + if len(set(index_order_old)) != dim: + raise ValueError("wrong number of indices in index_order_old") + if len(set.symmetric_difference(set(index_order_new), set(index_order_old))) > 0: + raise ValueError("index_order_new and index_order_old must have the same indices") + permutation = [index_order_old.index(i) for i in index_order_new] + return permutation + + +class ArrayDiagonal(_CodegenArrayAbstract): + r""" + Class to represent the diagonal operator. + + Explanation + =========== + + In a 2-dimensional array it returns the diagonal, this looks like the + operation: + + `A_{ij} \rightarrow A_{ii}` + + The diagonal over axes 1 and 2 (the second and third) of the tensor product + of two 2-dimensional arrays `A \otimes B` is + + `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}` + + In this last example the array expression has been reduced from + 4-dimensional to 3-dimensional. Notice that no contraction has occurred, + rather there is a new index `i` for the diagonal, contraction would have + reduced the array to 2 dimensions. + + Notice that the diagonalized out dimensions are added as new dimensions at + the end of the indices. + """ + + def __new__(cls, expr, *diagonal_indices, **kwargs): + expr = _sympify(expr) + diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices] + canonicalize = kwargs.get("canonicalize", False) + + shape = get_shape(expr) + if shape is not None: + cls._validate(expr, *diagonal_indices, **kwargs) + # Get new shape: + positions, shape = cls._get_positions_shape(shape, diagonal_indices) + else: + positions = None + if len(diagonal_indices) == 0: + return expr + obj = Basic.__new__(cls, expr, *diagonal_indices) + obj._positions = positions + obj._subranks = _get_subranks(expr) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + diagonal_indices = self.diagonal_indices + trivial_diags = [i for i in diagonal_indices if len(i) == 1] + if len(trivial_diags) > 0: + trivial_pos = {e[0]: i for i, e in enumerate(diagonal_indices) if len(e) == 1} + diag_pos = {e: i for i, e in enumerate(diagonal_indices) if len(e) > 1} + diagonal_indices_short = [i for i in diagonal_indices if len(i) > 1] + rank1 = get_rank(self) + rank2 = len(diagonal_indices) + rank3 = rank1 - rank2 + inv_permutation = [] + counter1 = 0 + indices_down = ArrayDiagonal._push_indices_down(diagonal_indices_short, list(range(rank1)), get_rank(expr)) + for i in indices_down: + if i in trivial_pos: + inv_permutation.append(rank3 + trivial_pos[i]) + elif isinstance(i, (Integer, int)): + inv_permutation.append(counter1) + counter1 += 1 + else: + inv_permutation.append(rank3 + diag_pos[i]) + permutation = _af_invert(inv_permutation) + if len(diagonal_indices_short) > 0: + return _permute_dims(_array_diagonal(expr, *diagonal_indices_short), permutation) + else: + return _permute_dims(expr, permutation) + if isinstance(expr, ArrayAdd): + return self._ArrayDiagonal_denest_ArrayAdd(expr, *diagonal_indices) + if isinstance(expr, ArrayDiagonal): + return self._ArrayDiagonal_denest_ArrayDiagonal(expr, *diagonal_indices) + if isinstance(expr, PermuteDims): + return self._ArrayDiagonal_denest_PermuteDims(expr, *diagonal_indices) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + positions, shape = self._get_positions_shape(expr.shape, diagonal_indices) + return ZeroArray(*shape) + return self.func(expr, *diagonal_indices, canonicalize=False) + + @staticmethod + def _validate(expr, *diagonal_indices, **kwargs): + # Check that no diagonalization happens on indices with mismatched + # dimensions: + shape = get_shape(expr) + for i in diagonal_indices: + if any(j >= len(shape) for j in i): + raise ValueError("index is larger than expression shape") + if len({shape[j] for j in i}) != 1: + raise ValueError("diagonalizing indices of different dimensions") + if not kwargs.get("allow_trivial_diags", False) and len(i) <= 1: + raise ValueError("need at least two axes to diagonalize") + if len(set(i)) != len(i): + raise ValueError("axis index cannot be repeated") + + @staticmethod + def _remove_trivial_dimensions(shape, *diagonal_indices): + return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1] + + @property + def expr(self): + return self.args[0] + + @property + def diagonal_indices(self): + return self.args[1:] + + @staticmethod + def _flatten(expr, *outer_diagonal_indices): + inner_diagonal_indices = expr.diagonal_indices + all_inner = [j for i in inner_diagonal_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices) + diagonal_indices = inner_diagonal_indices + outer_diagonal_indices + return _array_diagonal(expr.expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_ArrayAdd(cls, expr, *diagonal_indices): + return _array_add(*[_array_diagonal(arg, *diagonal_indices) for arg in expr.args]) + + @classmethod + def _ArrayDiagonal_denest_ArrayDiagonal(cls, expr, *diagonal_indices): + return cls._flatten(expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_PermuteDims(cls, expr: PermuteDims, *diagonal_indices): + back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices] + nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)] + back_nondiag = [expr.permutation(i) for i in nondiag] + remap = {e: i for i, e in enumerate(sorted(back_nondiag))} + new_permutation1 = [remap[i] for i in back_nondiag] + shift = len(new_permutation1) + diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))] + new_permutation = new_permutation1 + diag_block_perm + return _permute_dims( + _array_diagonal( + expr.expr, + *back_diagonal_indices + ), + new_permutation + ) + + def _push_indices_down_nonstatic(self, indices): + transform = lambda x: self._positions[x] if x < len(self._positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + def _push_indices_up_nonstatic(self, indices): + + def transform(x): + for i, e in enumerate(self._positions): + if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_down(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + transform = lambda x: positions[x] if x < len(positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + + def transform(x): + for i, e in enumerate(positions): + if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _get_positions_shape(cls, shape, diagonal_indices): + data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices)) + pos1, shp1 = zip(*data1) if data1 else ((), ()) + data2 = tuple((i, shape[i[0]]) for i in diagonal_indices) + pos2, shp2 = zip(*data2) if data2 else ((), ()) + positions = pos1 + pos2 + shape = shp1 + shp2 + return positions, shape + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensordiagonal(expr, *self.diagonal_indices) + + +class ArrayElementwiseApplyFunc(_CodegenArrayAbstract): + + def __new__(cls, function, element): + + if not isinstance(function, Lambda): + d = Dummy('d') + function = Lambda(d, function(d)) + + obj = _CodegenArrayAbstract.__new__(cls, function, element) + obj._subranks = _get_subranks(element) + return obj + + @property + def function(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + @property + def shape(self): + return self.expr.shape + + def _get_function_fdiff(self): + d = Dummy("d") + function = self.function(d) + fdiff = function.diff(d) + if isinstance(fdiff, Function): + fdiff = type(fdiff) + else: + fdiff = Lambda(d, fdiff) + return fdiff + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return expr.applyfunc(self.function) + + +class ArrayContraction(_CodegenArrayAbstract): + r""" + This class is meant to represent contractions of arrays in a form easily + processable by the code printers. + """ + + def __new__(cls, expr, *contraction_indices, **kwargs): + contraction_indices = _sort_contraction_indices(contraction_indices) + expr = _sympify(expr) + + canonicalize = kwargs.get("canonicalize", False) + + obj = Basic.__new__(cls, expr, *contraction_indices) + obj._subranks = _get_subranks(expr) + obj._mapping = _get_mapping_from_subranks(obj._subranks) + + free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all(i not in cind for cind in contraction_indices)} + obj._free_indices_to_position = free_indices_to_position + + shape = get_shape(expr) + cls._validate(expr, *contraction_indices) + if shape: + shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices)) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + contraction_indices = self.contraction_indices + + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayContraction): + return self._ArrayContraction_denest_ArrayContraction(expr, *contraction_indices) + + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return self._ArrayContraction_denest_ZeroArray(expr, *contraction_indices) + + if isinstance(expr, PermuteDims): + return self._ArrayContraction_denest_PermuteDims(expr, *contraction_indices) + + if isinstance(expr, ArrayTensorProduct): + expr, contraction_indices = self._sort_fully_contracted_args(expr, contraction_indices) + expr, contraction_indices = self._lower_contraction_to_addends(expr, contraction_indices) + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayDiagonal): + return self._ArrayContraction_denest_ArrayDiagonal(expr, *contraction_indices) + + if isinstance(expr, ArrayAdd): + return self._ArrayContraction_denest_ArrayAdd(expr, *contraction_indices) + + # Check single index contractions on 1-dimensional axes: + contraction_indices = [i for i in contraction_indices if len(i) > 1 or get_shape(expr)[i[0]] != 1] + if len(contraction_indices) == 0: + return expr + + return self.func(expr, *contraction_indices, canonicalize=False) + + def __mul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + def __rmul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + @staticmethod + def _validate(expr, *contraction_indices): + shape = get_shape(expr) + if shape is None: + return + + # Check that no contraction happens when the shape is mismatched: + for i in contraction_indices: + if len({shape[j] for j in i if shape[j] != -1}) != 1: + raise ValueError("contracting indices of different dimensions") + + @classmethod + def _push_indices_down(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_down_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_up_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _lower_contraction_to_addends(cls, expr, contraction_indices): + if isinstance(expr, ArrayAdd): + raise NotImplementedError() + if not isinstance(expr, ArrayTensorProduct): + return expr, contraction_indices + subranks = expr.subranks + cumranks = list(accumulate([0] + subranks)) + contraction_indices_remaining = [] + contraction_indices_args = [[] for i in expr.args] + backshift = set() + for contraction_group in contraction_indices: + for j in range(len(expr.args)): + if not isinstance(expr.args[j], ArrayAdd): + continue + if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group): + contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group]) + backshift.update(contraction_group) + break + else: + contraction_indices_remaining.append(contraction_group) + if len(contraction_indices_remaining) == len(contraction_indices): + return expr, contraction_indices + total_rank = get_rank(expr) + shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)])) + contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining] + ret = _array_tensor_product(*[ + _array_contraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args) + ]) + return ret, contraction_indices_remaining + + def split_multiple_contractions(self): + """ + Recognize multiple contractions and attempt at rewriting them as paired-contractions. + + This allows some contractions involving more than two indices to be + rewritten as multiple contractions involving two indices, thus allowing + the expression to be rewritten as a matrix multiplication line. + + Examples: + + * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C` + + Care for: + - matrix being diagonalized (i.e. `A_ii`) + - vectors being diagonalized (i.e. `a_i0`) + + Multiple contractions can be split into matrix multiplications if + not more than two arguments are non-diagonals or non-vectors. + Vectors get diagonalized while diagonal matrices remain diagonal. + The non-diagonal matrices can be at the beginning or at the end + of the final matrix multiplication line. + """ + + editor = _EditArrayContraction(self) + + contraction_indices = self.contraction_indices + + onearray_insert = [] + + for indl, links in enumerate(contraction_indices): + if len(links) <= 2: + continue + + # Check multiple contractions: + # + # Examples: + # + # * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C \otimes OneArray(1)` with permutation (1 2) + # + # Care for: + # - matrix being diagonalized (i.e. `A_ii`) + # - vectors being diagonalized (i.e. `a_i0`) + + # Multiple contractions can be split into matrix multiplications if + # not more than three arguments are non-diagonals or non-vectors. + # + # Vectors get diagonalized while diagonal matrices remain diagonal. + # The non-diagonal matrices can be at the beginning or at the end + # of the final matrix multiplication line. + + positions = editor.get_mapping_for_index(indl) + + # Also consider the case of diagonal matrices being contracted: + current_dimension = self.expr.shape[links[0]] + + not_vectors = [] + vectors = [] + for arg_ind, rel_ind in positions: + arg = editor.args_with_ind[arg_ind] + mat = arg.element + abs_arg_start, abs_arg_end = editor.get_absolute_range(arg) + other_arg_pos = 1-rel_ind + other_arg_abs = abs_arg_start + other_arg_pos + if ((1 not in mat.shape) or + ((current_dimension == 1) is True and mat.shape != (1, 1)) or + any(other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl) + ): + not_vectors.append((arg, rel_ind)) + else: + vectors.append((arg, rel_ind)) + if len(not_vectors) > 2: + # If more than two arguments in the multiple contraction are + # non-vectors and non-diagonal matrices, we cannot find a way + # to split this contraction into a matrix multiplication line: + continue + # Three cases to handle: + # - zero non-vectors + # - one non-vector + # - two non-vectors + for v, rel_ind in vectors: + v.element = diagonalize_vector(v.element) + vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:] + first_not_vector, rel_ind = vectors_to_loop[0] + new_index = first_not_vector.indices[rel_ind] + + for v, rel_ind in vectors_to_loop[1:-1]: + v.indices[rel_ind] = new_index + new_index = editor.get_new_contraction_index() + assert v.indices.index(None) == 1 - rel_ind + v.indices[v.indices.index(None)] = new_index + onearray_insert.append(v) + + last_vec, rel_ind = vectors_to_loop[-1] + last_vec.indices[rel_ind] = new_index + + for v in onearray_insert: + editor.insert_after(v, _ArgE(OneArray(1), [None])) + + return editor.to_array_contraction() + + def flatten_contraction_of_diagonal(self): + if not isinstance(self.expr, ArrayDiagonal): + return self + contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices) + new_contraction_indices = [] + diagonal_indices = self.expr.diagonal_indices[:] + for i in contraction_down: + contraction_group = list(i) + for j in i: + diagonal_with = [k for k in diagonal_indices if j in k] + contraction_group.extend([l for k in diagonal_with for l in k]) + diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with] + new_contraction_indices.append(sorted(set(contraction_group))) + + new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices) + return _array_contraction( + _array_diagonal( + self.expr.expr, + *diagonal_indices + ), + *new_contraction_indices + ) + + @staticmethod + def _get_free_indices_to_position_map(free_indices, contraction_indices): + free_indices_to_position = {} + flattened_contraction_indices = [j for i in contraction_indices for j in i] + counter = 0 + for ind in free_indices: + while counter in flattened_contraction_indices: + counter += 1 + free_indices_to_position[ind] = counter + counter += 1 + return free_indices_to_position + + @staticmethod + def _get_index_shifts(expr): + """ + Get the mapping of indices at the positions before the contraction + occurs. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> cg = tensorcontraction(tensorproduct(M, N), [1, 2]) + >>> cg._get_index_shifts(cg) + [0, 2] + + Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They + need to be shifted by 0 and 2 to get the corresponding positions before + the contraction (that is, 0 and 3). + """ + inner_contraction_indices = expr.contraction_indices + all_inner = [j for i in inner_contraction_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + return shifts + + @staticmethod + def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices): + shifts = ArrayContraction._get_index_shifts(expr) + outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices) + return outer_contraction_indices + + @staticmethod + def _flatten(expr, *outer_contraction_indices): + inner_contraction_indices = expr.contraction_indices + outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices) + contraction_indices = inner_contraction_indices + outer_contraction_indices + return _array_contraction(expr.expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ArrayContraction(cls, expr, *contraction_indices): + return cls._flatten(expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ZeroArray(cls, expr, *contraction_indices): + contraction_indices_flat = [j for i in contraction_indices for j in i] + shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat] + return ZeroArray(*shape) + + @classmethod + def _ArrayContraction_denest_ArrayAdd(cls, expr, *contraction_indices): + return _array_add(*[_array_contraction(i, *contraction_indices) for i in expr.args]) + + @classmethod + def _ArrayContraction_denest_PermuteDims(cls, expr, *contraction_indices): + permutation = expr.permutation + plist = permutation.array_form + new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices] + new_plist = [i for i in plist if not any(i in j for j in new_contraction_indices)] + new_plist = cls._push_indices_up(new_contraction_indices, new_plist) + return _permute_dims( + _array_contraction(expr.expr, *new_contraction_indices), + Permutation(new_plist) + ) + + @classmethod + def _ArrayContraction_denest_ArrayDiagonal(cls, expr: 'ArrayDiagonal', *contraction_indices): + diagonal_indices = list(expr.diagonal_indices) + down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr)) + # Flatten diagonally contracted indices: + down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices] + new_contraction_indices = [] + for contr_indgrp in down_contraction_indices: + ind = contr_indgrp[:] + for j, diag_indgrp in enumerate(diagonal_indices): + if diag_indgrp is None: + continue + if any(i in diag_indgrp for i in contr_indgrp): + ind.extend(diag_indgrp) + diagonal_indices[j] = None + new_contraction_indices.append(sorted(set(ind))) + + new_diagonal_indices_down = [i for i in diagonal_indices if i is not None] + new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down) + return _array_diagonal( + _array_contraction(expr.expr, *new_contraction_indices), + *new_diagonal_indices + ) + + @classmethod + def _sort_fully_contracted_args(cls, expr, contraction_indices): + if expr.shape is None: + return expr, contraction_indices + cumul = list(accumulate([0] + expr.subranks)) + index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))] + contraction_indices_flat = {j for i in contraction_indices for j in i} + fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)] + new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,)) + new_args = [expr.args[i] for i in new_pos] + new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]] + index_permutation_array_form = _af_invert(new_index_blocks_flat) + new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices] + new_contraction_indices = _sort_contraction_indices(new_contraction_indices) + return _array_tensor_product(*new_args), new_contraction_indices + + def _get_contraction_tuples(self): + r""" + Return tuples containing the argument index and position within the + argument of the index position. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + + >>> cg = tensorcontraction(tensorproduct(A, B), (1, 2)) + >>> cg._get_contraction_tuples() + [[(0, 1), (1, 0)]] + + Notes + ===== + + Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices + of the tensor product `A\otimes B` are contracted, has been transformed + into `(0, 1)` and `(1, 0)`, identifying the same indices in a different + notation. `(0, 1)` is the second index (1) of the first argument (i.e. + 0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second + argument (i.e. 1 or `B`). + """ + mapping = self._mapping + return [[mapping[j] for j in i] for i in self.contraction_indices] + + @staticmethod + def _contraction_tuples_to_contraction_indices(expr, contraction_tuples): + # TODO: check that `expr` has `.subranks`: + ranks = expr.subranks + cumulative_ranks = [0] + list(accumulate(ranks)) + return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples] + + @property + def free_indices(self): + return self._free_indices[:] + + @property + def free_indices_to_position(self): + return dict(self._free_indices_to_position) + + @property + def expr(self): + return self.args[0] + + @property + def contraction_indices(self): + return self.args[1:] + + def _contraction_indices_to_components(self): + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + raise NotImplementedError("only for contractions of tensor products") + ranks = expr.subranks + mapping = {} + counter = 0 + for i, rank in enumerate(ranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + def sort_args_by_name(self): + """ + Sort arguments in the tensor product so that their order is lexicographical. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> cg = convert_matrix_to_array(C*D*A*B) + >>> cg + ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5)) + >>> cg.sort_args_by_name() + ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7)) + """ + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + return self + args = expr.args + sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1])) + pos_sorted, args_sorted = zip(*sorted_data) + reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)} + contraction_tuples = self._get_contraction_tuples() + contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples] + c_tp = _array_tensor_product(*args_sorted) + new_contr_indices = self._contraction_tuples_to_contraction_indices( + c_tp, + contraction_tuples + ) + return _array_contraction(c_tp, *new_contr_indices) + + def _get_contraction_links(self): + r""" + Returns a dictionary of links between arguments in the tensor product + being contracted. + + See the example for an explanation of the values. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + Matrix multiplications are pairwise contractions between neighboring + matrices: + + `A_{ij} B_{jk} C_{kl} D_{lm}` + + >>> cg = convert_matrix_to_array(A*B*C*D) + >>> cg + ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6)) + + >>> cg._get_contraction_links() + {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}} + + This dictionary is interpreted as follows: argument in position 0 (i.e. + matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that + is argument in position 1 (matrix `B`) on the first index slot of `B`, + this is the contraction provided by the index `j` from `A`. + + The argument in position 1 (that is, matrix `B`) has two contractions, + the ones provided by the indices `j` and `k`, respectively the first + and second indices (0 and 1 in the sub-dict). The link `(0, 1)` and + `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of + argument in position 0 (that is, `A_{\ldot j}`), and so on. + """ + args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices) + return dlinks + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensorcontraction(expr, *self.contraction_indices) + + +class Reshape(_CodegenArrayAbstract): + """ + Reshape the dimensions of an array expression. + + Examples + ======== + + >>> from sympy.tensor.array.expressions import ArraySymbol, Reshape + >>> A = ArraySymbol("A", (6,)) + >>> A.shape + (6,) + >>> Reshape(A, (3, 2)).shape + (3, 2) + + Check the component-explicit forms: + + >>> A.as_explicit() + [A[0], A[1], A[2], A[3], A[4], A[5]] + >>> Reshape(A, (3, 2)).as_explicit() + [[A[0], A[1]], [A[2], A[3]], [A[4], A[5]]] + + """ + + def __new__(cls, expr, shape): + expr = _sympify(expr) + if not isinstance(shape, Tuple): + shape = Tuple(*shape) + if Equality(Mul.fromiter(expr.shape), Mul.fromiter(shape)) == False: + raise ValueError("shape mismatch") + obj = Expr.__new__(cls, expr, shape) + obj._shape = tuple(shape) + obj._expr = expr + return obj + + @property + def shape(self): + return self._shape + + @property + def expr(self): + return self._expr + + def doit(self, *args, **kwargs): + if kwargs.get("deep", True): + expr = self.expr.doit(*args, **kwargs) + else: + expr = self.expr + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.reshape(*self.shape) + return Reshape(expr, self.shape) + + def as_explicit(self): + ee = self.expr + if hasattr(ee, "as_explicit"): + ee = ee.as_explicit() + if isinstance(ee, MatrixBase): + from sympy import Array + ee = Array(ee) + elif isinstance(ee, MatrixExpr): + return self + return ee.reshape(*self.shape) + + +class _ArgE: + """ + The ``_ArgE`` object contains references to the array expression + (``.element``) and a list containing the information about index + contractions (``.indices``). + + Index contractions are numbered and contracted indices show the number of + the contraction. Uncontracted indices have ``None`` value. + + For example: + ``_ArgE(M, [None, 3])`` + This object means that expression ``M`` is part of an array contraction + and has two indices, the first is not contracted (value ``None``), + the second index is contracted to the 4th (i.e. number ``3``) group of the + array contraction object. + """ + indices: list[int | None] + + def __init__(self, element, indices: list[int | None] | None = None): + self.element = element + if indices is None: + self.indices = [None for i in range(get_rank(element))] + else: + self.indices = indices + + def __str__(self): + return "_ArgE(%s, %s)" % (self.element, self.indices) + + __repr__ = __str__ + + +class _IndPos: + """ + Index position, requiring two integers in the constructor: + + - arg: the position of the argument in the tensor product, + - rel: the relative position of the index inside the argument. + """ + def __init__(self, arg: int, rel: int): + self.arg = arg + self.rel = rel + + def __str__(self): + return "_IndPos(%i, %i)" % (self.arg, self.rel) + + __repr__ = __str__ + + def __iter__(self): + yield from [self.arg, self.rel] + + +class _EditArrayContraction: + """ + Utility class to help manipulate array contraction objects. + + This class takes as input an ``ArrayContraction`` object and turns it into + an editable object. + + The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects + which can be used to easily edit the contraction structure of the + expression. + + Once editing is finished, the ``ArrayContraction`` object may be recreated + by calling the ``.to_array_contraction()`` method. + """ + + def __init__(self, base_array: typing.Union[ArrayContraction, ArrayDiagonal, ArrayTensorProduct]): + + expr: Basic + diagonalized: tuple[tuple[int, ...], ...] + contraction_indices: list[tuple[int]] + if isinstance(base_array, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.subranks) + expr = base_array.expr + contraction_indices = base_array.contraction_indices + diagonalized = () + elif isinstance(base_array, ArrayDiagonal): + + if isinstance(base_array.expr, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.expr.subranks) + expr = base_array.expr.expr + diagonalized = ArrayContraction._push_indices_down(base_array.expr.contraction_indices, base_array.diagonal_indices) + contraction_indices = base_array.expr.contraction_indices + elif isinstance(base_array.expr, ArrayTensorProduct): + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + else: + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + + elif isinstance(base_array, ArrayTensorProduct): + expr = base_array + contraction_indices = [] + diagonalized = () + else: + raise NotImplementedError() + + if isinstance(expr, ArrayTensorProduct): + args = list(expr.args) + else: + args = [expr] + + args_with_ind: list[_ArgE] = [_ArgE(arg) for arg in args] + for i, contraction_tuple in enumerate(contraction_indices): + for j in contraction_tuple: + arg_pos, rel_pos = mapping[j] + args_with_ind[arg_pos].indices[rel_pos] = i + self.args_with_ind: list[_ArgE] = args_with_ind + self.number_of_contraction_indices: int = len(contraction_indices) + self._track_permutation: list[list[int]] | None = None + + mapping = _get_mapping_from_subranks(base_array.subranks) + + # Trick: add diagonalized indices as negative indices into the editor object: + for i, e in enumerate(diagonalized): + for j in e: + arg_pos, rel_pos = mapping[j] + self.args_with_ind[arg_pos].indices[rel_pos] = -1 - i + + def insert_after(self, arg: _ArgE, new_arg: _ArgE): + pos = self.args_with_ind.index(arg) + self.args_with_ind.insert(pos + 1, new_arg) + + def get_new_contraction_index(self): + self.number_of_contraction_indices += 1 + return self.number_of_contraction_indices - 1 + + def refresh_indices(self): + updates = {} + for arg_with_ind in self.args_with_ind: + updates.update({i: -1 for i in arg_with_ind.indices if i is not None}) + for i, e in enumerate(sorted(updates)): + updates[e] = i + self.number_of_contraction_indices = len(updates) + for arg_with_ind in self.args_with_ind: + arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices] + + def merge_scalars(self): + scalars = [] + for arg_with_ind in self.args_with_ind: + if len(arg_with_ind.indices) == 0: + scalars.append(arg_with_ind) + for i in scalars: + self.args_with_ind.remove(i) + scalar = Mul.fromiter([i.element for i in scalars]) + if len(self.args_with_ind) == 0: + self.args_with_ind.append(_ArgE(scalar)) + else: + from sympy.tensor.array.expressions.from_array_to_matrix import _a2m_tensor_product + self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element) + + def to_array_contraction(self): + + # Count the ranks of the arguments: + counter = 0 + # Create a collector for the new diagonal indices: + diag_indices = defaultdict(list) + + count_index_freq = Counter() + for arg_with_ind in self.args_with_ind: + count_index_freq.update(Counter(arg_with_ind.indices)) + + free_index_count = count_index_freq[None] + + # Construct the inverse permutation: + inv_perm1 = [] + inv_perm2 = [] + # Keep track of which diagonal indices have already been processed: + done = set() + + # Counter for the diagonal indices: + counter4 = 0 + + for arg_with_ind in self.args_with_ind: + # If some diagonalization axes have been removed, they should be + # permuted in order to keep the permutation. + # Add permutation here + counter2 = 0 # counter for the indices + for i in arg_with_ind.indices: + if i is None: + inv_perm1.append(counter4) + counter2 += 1 + counter4 += 1 + continue + if i >= 0: + continue + # Reconstruct the diagonal indices: + diag_indices[-1 - i].append(counter + counter2) + if count_index_freq[i] == 1 and i not in done: + inv_perm1.append(free_index_count - 1 - i) + done.add(i) + elif i not in done: + inv_perm2.append(free_index_count - 1 - i) + done.add(i) + counter2 += 1 + # Remove negative indices to restore a proper editor object: + arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices] + counter += len([i for i in arg_with_ind.indices if i is None or i < 0]) + + inverse_permutation = inv_perm1 + inv_perm2 + permutation = _af_invert(inverse_permutation) + + # Get the diagonal indices after the detection of HadamardProduct in the expression: + diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1] + + self.merge_scalars() + self.refresh_indices() + args = [arg.element for arg in self.args_with_ind] + contraction_indices = self.get_contraction_indices() + expr = _array_contraction(_array_tensor_product(*args), *contraction_indices) + expr2 = _array_diagonal(expr, *diag_indices_filtered) + if self._track_permutation is not None: + permutation2 = _af_invert([j for i in self._track_permutation for j in i]) + expr2 = _permute_dims(expr2, permutation2) + + expr3 = _permute_dims(expr2, permutation) + return expr3 + + def get_contraction_indices(self) -> list[list[int]]: + contraction_indices: list[list[int]] = [[] for i in range(self.number_of_contraction_indices)] + current_position: int = 0 + for arg_with_ind in self.args_with_ind: + for j in arg_with_ind.indices: + if j is not None: + contraction_indices[j].append(current_position) + current_position += 1 + return contraction_indices + + def get_mapping_for_index(self, ind) -> list[_IndPos]: + if ind >= self.number_of_contraction_indices: + raise ValueError("index value exceeding the index range") + positions: list[_IndPos] = [] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, arg_ind in enumerate(arg_with_ind.indices): + if ind == arg_ind: + positions.append(_IndPos(i, j)) + return positions + + def get_contraction_indices_to_ind_rel_pos(self) -> list[list[_IndPos]]: + contraction_indices: list[list[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, ind in enumerate(arg_with_ind.indices): + if ind is not None: + contraction_indices[ind].append(_IndPos(i, j)) + return contraction_indices + + def count_args_with_index(self, index: int) -> int: + """ + Count the number of arguments that have the given index. + """ + counter: int = 0 + for arg_with_ind in self.args_with_ind: + if index in arg_with_ind.indices: + counter += 1 + return counter + + def get_args_with_index(self, index: int) -> list[_ArgE]: + """ + Get a list of arguments having the given index. + """ + ret: list[_ArgE] = [i for i in self.args_with_ind if index in i.indices] + return ret + + @property + def number_of_diagonal_indices(self): + data = set() + for arg in self.args_with_ind: + data.update({i for i in arg.indices if i is not None and i < 0}) + return len(data) + + def track_permutation_start(self): + permutation = [] + perm_diag = [] + counter = 0 + counter2 = -1 + for arg_with_ind in self.args_with_ind: + perm = [] + for i in arg_with_ind.indices: + if i is not None: + if i < 0: + perm_diag.append(counter2) + counter2 -= 1 + continue + perm.append(counter) + counter += 1 + permutation.append(perm) + max_ind = max(max(i) if i else -1 for i in permutation) if permutation else -1 + perm_diag = [max_ind - i for i in perm_diag] + self._track_permutation = permutation + [perm_diag] + + def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE): + index_destination = self.args_with_ind.index(destination) + index_element = self.args_with_ind.index(from_element) + self._track_permutation[index_destination].extend(self._track_permutation[index_element]) # type: ignore + self._track_permutation.pop(index_element) # type: ignore + + def get_absolute_free_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the range of the free indices of the arg as absolute positions + among all free indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_free_indices = len([i for i in arg_with_ind.indices if i is None]) + if arg_with_ind == arg: + return counter, counter + number_free_indices + counter += number_free_indices + raise IndexError("argument not found") + + def get_absolute_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the absolute range of indices for arg, disregarding dummy + indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_indices = len(arg_with_ind.indices) + if arg_with_ind == arg: + return counter, counter + number_indices + counter += number_indices + raise IndexError("argument not found") + + +def get_rank(expr): + if isinstance(expr, (MatrixExpr, MatrixElement)): + return 2 + if isinstance(expr, _CodegenArrayAbstract): + return len(expr.shape) + if isinstance(expr, NDimArray): + return expr.rank() + if isinstance(expr, Indexed): + return expr.rank + if isinstance(expr, IndexedBase): + shape = expr.shape + if shape is None: + return -1 + else: + return len(shape) + if hasattr(expr, "shape"): + return len(expr.shape) + return 0 + + +def _get_subrank(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subrank() + return get_rank(expr) + + +def _get_subranks(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subranks + else: + return [get_rank(expr)] + + +def get_shape(expr): + if hasattr(expr, "shape"): + return expr.shape + return () + + +def nest_permutation(expr): + if isinstance(expr, PermuteDims): + return expr.nest_permutation() + else: + return expr + + +def _array_tensor_product(*args, **kwargs): + return ArrayTensorProduct(*args, canonicalize=True, **kwargs) + + +def _array_contraction(expr, *contraction_indices, **kwargs): + return ArrayContraction(expr, *contraction_indices, canonicalize=True, **kwargs) + + +def _array_diagonal(expr, *diagonal_indices, **kwargs): + return ArrayDiagonal(expr, *diagonal_indices, canonicalize=True, **kwargs) + + +def _permute_dims(expr, permutation, **kwargs): + return PermuteDims(expr, permutation, canonicalize=True, **kwargs) + + +def _array_add(*args, **kwargs): + return ArrayAdd(*args, canonicalize=True, **kwargs) + + +def _get_array_element_or_slice(expr, indices): + return ArrayElement(expr, indices) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..ab44a6fbf715ac7f2b8c287dcc84a49289f2dd76 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py @@ -0,0 +1,194 @@ +import operator +from functools import reduce, singledispatch + +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.matrices.expressions.hadamard import HadamardProduct +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol) +from sympy.matrices.expressions.special import Identity, OneMatrix +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ( + _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank, + get_shape, ArrayContraction, _array_tensor_product, _array_contraction, + _array_diagonal, _array_add, _permute_dims, Reshape) +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + + +@singledispatch +def array_derive(expr, x): + """ + Derivatives (gradients) for array expressions. + """ + raise NotImplementedError(f"not implemented for type {type(expr)}") + + +@array_derive.register(Expr) +def _(expr: Expr, x: _ArrayExpr): + return ZeroArray(*x.shape) + + +@array_derive.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct, x: Expr): + args = expr.args + addend_list = [] + for i, arg in enumerate(expr.args): + darg = array_derive(arg, x) + if darg == 0: + continue + args_prev = args[:i] + args_succ = args[i+1:] + shape_prev = reduce(operator.add, map(get_shape, args_prev), ()) + shape_succ = reduce(operator.add, map(get_shape, args_succ), ()) + addend = _array_tensor_product(*args_prev, darg, *args_succ) + tot1 = len(get_shape(x)) + tot2 = tot1 + len(shape_prev) + tot3 = tot2 + len(get_shape(arg)) + tot4 = tot3 + len(shape_succ) + perm = list(range(tot1, tot2)) + \ + list(range(tot1)) + list(range(tot2, tot3)) + \ + list(range(tot3, tot4)) + addend = _permute_dims(addend, _af_invert(perm)) + addend_list.append(addend) + if len(addend_list) == 1: + return addend_list[0] + elif len(addend_list) == 0: + return S.Zero + else: + return _array_add(*addend_list) + + +@array_derive.register(ArraySymbol) +def _(expr: ArraySymbol, x: _ArrayExpr): + if expr == x: + return _permute_dims( + ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape), + [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(MatrixSymbol) +def _(expr: MatrixSymbol, x: _ArrayExpr): + m, n = expr.shape + if expr == x: + return _permute_dims( + _array_tensor_product(Identity(m), Identity(n)), + [0, 2, 1, 3] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Identity) +def _(expr: Identity, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(OneMatrix) +def _(expr: OneMatrix, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Transpose) +def _(expr: Transpose, x: Expr): + # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni + # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn) + fd = array_derive(expr.arg, x) + return _permute_dims(fd, [0, 1, 3, 2]) + + +@array_derive.register(Inverse) +def _(expr: Inverse, x: Expr): + mat = expr.I + dexpr = array_derive(mat, x) + tp = _array_tensor_product(-expr, dexpr, expr) + mp = _array_contraction(tp, (1, 4), (5, 6)) + pp = _permute_dims(mp, [1, 2, 0, 3]) + return pp + + +@array_derive.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction, x: Expr): + assert get_rank(expr) == 2 + assert get_rank(x) == 2 + fdiff = expr._get_function_fdiff() + dexpr = array_derive(expr.expr, x) + tp = _array_tensor_product( + ElementwiseApplyFunction(fdiff, expr.expr), + dexpr + ) + td = _array_diagonal( + tp, (0, 4), (1, 5) + ) + return td + + +@array_derive.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc, x: Expr): + fdiff = expr._get_function_fdiff() + subexpr = expr.expr + dsubexpr = array_derive(subexpr, x) + tp = _array_tensor_product( + dsubexpr, + ArrayElementwiseApplyFunc(fdiff, subexpr) + ) + b = get_rank(x) + c = get_rank(expr) + diag_indices = [(b + i, b + c + i) for i in range(c)] + return _array_diagonal(tp, *diag_indices) + + +@array_derive.register(MatrixExpr) +def _(expr: MatrixExpr, x: Expr): + cg = convert_matrix_to_array(expr) + return array_derive(cg, x) + + +@array_derive.register(HadamardProduct) +def _(expr: HadamardProduct, x: Expr): + raise NotImplementedError() + + +@array_derive.register(ArrayContraction) +def _(expr: ArrayContraction, x: Expr): + fd = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + contraction_indices = expr.contraction_indices + new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices] + return _array_contraction(fd, *new_contraction_indices) + + +@array_derive.register(ArrayDiagonal) +def _(expr: ArrayDiagonal, x: Expr): + dsubexpr = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices] + return _array_diagonal(dsubexpr, *diag_indices) + + +@array_derive.register(ArrayAdd) +def _(expr: ArrayAdd, x: Expr): + return _array_add(*[array_derive(arg, x) for arg in expr.args]) + + +@array_derive.register(PermuteDims) +def _(expr: PermuteDims, x: Expr): + de = array_derive(expr.expr, x) + perm = [0, 1] + [i + 2 for i in expr.permutation.array_form] + return _permute_dims(de, perm) + + +@array_derive.register(Reshape) +def _(expr: Reshape, x: Expr): + de = array_derive(expr.expr, x) + return Reshape(de, get_shape(x) + expr.shape) + + +def matrix_derive(expr, x): + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + ce = convert_matrix_to_array(expr) + dce = array_derive(ce, x) + return convert_array_to_matrix(dce).doit() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..1929c3401e131cca0a83080131ead9198b37bcbb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py @@ -0,0 +1,12 @@ +from sympy.tensor.array.expressions import from_array_to_indexed +from sympy.utilities.decorator import deprecated + + +_conv_to_from_decorator = deprecated( + "module has been renamed by replacing 'conv_' with 'from_' in its name", + deprecated_since_version="1.11", + active_deprecations_target="deprecated-conv-array-expr-module-names", +) + + +convert_array_to_indexed = _conv_to_from_decorator(from_array_to_indexed.convert_array_to_indexed) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2708e74aaa98d6ee38eae46d97d4483a546e0776 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py @@ -0,0 +1,6 @@ +from sympy.tensor.array.expressions import from_array_to_matrix +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_array_to_matrix = _conv_to_from_decorator(from_array_to_matrix.convert_array_to_matrix) +_array2matrix = _conv_to_from_decorator(from_array_to_matrix._array2matrix) +_remove_trivial_dims = _conv_to_from_decorator(from_array_to_matrix._remove_trivial_dims) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..6058b31f20778834ea23a01553d594b7965eb6bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_indexed_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_indexed_to_array = _conv_to_from_decorator(from_indexed_to_array.convert_indexed_to_array) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..46469df60703c237527c0b2834235309640afe7c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_matrix_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_matrix_to_array = _conv_to_from_decorator(from_matrix_to_array.convert_matrix_to_array) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb86e7cfbe31ebfe7c9649803d9cb5e34b98276 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py @@ -0,0 +1,84 @@ +import collections.abc +import operator +from itertools import accumulate + +from sympy import Mul, Sum, Dummy, Add +from sympy.tensor.array.expressions import PermuteDims, ArrayAdd, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, get_rank, ArrayContraction, \ + ArrayDiagonal, get_shape, _get_array_element_or_slice, _ArrayExpr +from sympy.tensor.array.expressions.utils import _apply_permutation_to_list + + +def convert_array_to_indexed(expr, indices): + return _ConvertArrayToIndexed().do_convert(expr, indices) + + +class _ConvertArrayToIndexed: + + def __init__(self): + self.count_dummies = 0 + + def do_convert(self, expr, indices): + if isinstance(expr, ArrayTensorProduct): + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + indices_grp = [indices[cumul[i]:cumul[i+1]] for i in range(len(expr.args))] + return Mul.fromiter(self.do_convert(arg, ind) for arg, ind in zip(expr.args, indices_grp)) + if isinstance(expr, ArrayContraction): + new_indices = [None for i in range(get_rank(expr.expr))] + limits = [] + bottom_shape = get_shape(expr.expr) + for contraction_index_grp in expr.contraction_indices: + d = Dummy(f"d{self.count_dummies}") + self.count_dummies += 1 + dim = bottom_shape[contraction_index_grp[0]] + limits.append((d, 0, dim-1)) + for i in contraction_index_grp: + new_indices[i] = d + j = 0 + for i in range(len(new_indices)): + if new_indices[i] is None: + new_indices[i] = indices[j] + j += 1 + newexpr = self.do_convert(expr.expr, new_indices) + return Sum(newexpr, *limits) + if isinstance(expr, ArrayDiagonal): + new_indices = [None for i in range(get_rank(expr.expr))] + ind_pos = expr._push_indices_down(expr.diagonal_indices, list(range(len(indices))), get_rank(expr)) + for i, index in zip(ind_pos, indices): + if isinstance(i, collections.abc.Iterable): + for j in i: + new_indices[j] = index + else: + new_indices[i] = index + newexpr = self.do_convert(expr.expr, new_indices) + return newexpr + if isinstance(expr, PermuteDims): + permuted_indices = _apply_permutation_to_list(expr.permutation, indices) + return self.do_convert(expr.expr, permuted_indices) + if isinstance(expr, ArrayAdd): + return Add.fromiter(self.do_convert(arg, indices) for arg in expr.args) + if isinstance(expr, _ArrayExpr): + return expr.__getitem__(tuple(indices)) + if isinstance(expr, ArrayElementwiseApplyFunc): + return expr.function(self.do_convert(expr.expr, indices)) + if isinstance(expr, Reshape): + shape_up = expr.shape + shape_down = get_shape(expr.expr) + cumul = list(accumulate([1] + list(reversed(shape_up)), operator.mul)) + one_index = Add.fromiter(i*s for i, s in zip(reversed(indices), cumul)) + dest_indices = [None for _ in shape_down] + c = 1 + for i, e in enumerate(reversed(shape_down)): + if c == 1: + if i == len(shape_down) - 1: + dest_indices[i] = one_index + else: + dest_indices[i] = one_index % e + elif i == len(shape_down) - 1: + dest_indices[i] = one_index // c + else: + dest_indices[i] = one_index // c % e + c *= e + dest_indices.reverse() + return self.do_convert(expr.expr, dest_indices) + return _get_array_element_or_slice(expr, indices) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..debfdd7eb5c4533996b3d72b55d679be3daf3afe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py @@ -0,0 +1,1004 @@ +from __future__ import annotations +import itertools +from collections import defaultdict +from typing import FrozenSet +from functools import singledispatch +from itertools import accumulate + +from sympy import MatMul, Basic, Wild, KroneckerProduct +from sympy.assumptions.ask import (Q, ask) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.hadamard import hadamard_product, HadamardPower +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import (Identity, ZeroMatrix, OneMatrix) +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert, Permutation +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \ + ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \ + ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE, \ + ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, _array_add, _permute_dims +from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks + + +def _get_candidate_for_matmul_from_contraction(scan_indices: list[int | None], remaining_args: list[_ArgE]) -> tuple[_ArgE | None, bool, int]: + + scan_indices_int: list[int] = [i for i in scan_indices if i is not None] + if len(scan_indices_int) == 0: + return None, False, -1 + + transpose: bool = False + candidate: _ArgE | None = None + candidate_index: int = -1 + for arg_with_ind2 in remaining_args: + if not isinstance(arg_with_ind2.element, MatrixExpr): + continue + for index in scan_indices_int: + if candidate_index != -1 and candidate_index != index: + # A candidate index has already been selected, check + # repetitions only for that index: + continue + if index in arg_with_ind2.indices: + if set(arg_with_ind2.indices) == {index}: + # Index repeated twice in arg_with_ind2 + candidate = None + break + if candidate is None: + candidate = arg_with_ind2 + candidate_index = index + transpose = (index == arg_with_ind2.indices[1]) + else: + # Index repeated more than twice, break + candidate = None + break + return candidate, transpose, candidate_index + + +def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool): + other = candidate.element + other_index: int | None + if transpose2: + other = Transpose(other) + other_index = candidate.indices[0] + else: + other_index = candidate.indices[1] + new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other + editor.args_with_ind.remove(candidate) + new_arge = _ArgE(new_element) + return new_arge, other_index + + +def _support_function_tp1_recognize(contraction_indices, args): + if len(contraction_indices) == 0: + return _a2m_tensor_product(*args) + + ac = _array_contraction(_array_tensor_product(*args), *contraction_indices) + editor = _EditArrayContraction(ac) + editor.track_permutation_start() + + while True: + flag_stop = True + for i, arg_with_ind in enumerate(editor.args_with_ind): + if not isinstance(arg_with_ind.element, MatrixExpr): + continue + + first_index = arg_with_ind.indices[0] + second_index = arg_with_ind.indices[1] + + first_frequency = editor.count_args_with_index(first_index) + second_frequency = editor.count_args_with_index(second_index) + + if first_index is not None and first_frequency == 1 and first_index == second_index: + flag_stop = False + arg_with_ind.element = Trace(arg_with_ind.element)._normalize() + arg_with_ind.indices = [] + break + + scan_indices = [] + if first_frequency == 2: + scan_indices.append(first_index) + if second_frequency == 2: + scan_indices.append(second_index) + + candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:]) + if candidate is not None: + flag_stop = False + editor.track_permutation_merge(arg_with_ind, candidate) + transpose1 = found_index == first_index + new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose) + if found_index == first_index: + new_arge.indices = [second_index, other_index] + else: + new_arge.indices = [first_index, other_index] + set_indices = set(new_arge.indices) + if len(set_indices) == 1 and set_indices != {None}: + # This is a trace: + new_arge.element = Trace(new_arge.element)._normalize() + new_arge.indices = [] + editor.args_with_ind[i] = new_arge + # TODO: is this break necessary? + break + + if flag_stop: + break + + editor.refresh_indices() + return editor.to_array_contraction() + + +def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct): + # If there are matrices of trivial shape in the tensor product (i.e. shape + # (1, 1)), try to check if there is a suitable non-trivial MatMul where the + # expression can be inserted. + + # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the + # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as + # "b*a*b.T" + + trivial_matrices = [] + pos: int | None = None # must be initialized else causes UnboundLocalError + first: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + second: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + removed: list[int] = [] + counter: int = 0 + args: list[Basic | None] = list(expr.args) + for i, arg in enumerate(expr.args): + if isinstance(arg, MatrixExpr): + if arg.shape == (1, 1): + trivial_matrices.append(arg) + args[i] = None + removed.extend([counter, counter+1]) + elif pos is None and isinstance(arg, MatMul): + margs = arg.args + for j, e in enumerate(margs): + if isinstance(e, MatrixExpr) and e.shape[1] == 1: + pos = i + first = MatMul.fromiter(margs[:j+1]) + second = MatMul.fromiter(margs[j+1:]) + break + counter += get_rank(arg) + if pos is None: + return expr, [] + args[pos] = (first*MatMul.fromiter(i for i in trivial_matrices)*second).doit() + return _array_tensor_product(*[i for i in args if i is not None]), removed + + +def _find_trivial_kronecker_products_broadcast(expr: ArrayTensorProduct): + newargs: list[Basic] = [] + removed = [] + count_dims = 0 + for arg in expr.args: + count_dims += get_rank(arg) + shape = get_shape(arg) + current_range = [count_dims-i for i in range(len(shape), 0, -1)] + if (shape == (1, 1) and len(newargs) > 0 and 1 not in get_shape(newargs[-1]) and + isinstance(newargs[-1], MatrixExpr) and isinstance(arg, MatrixExpr)): + # KroneckerProduct object allows the trick of broadcasting: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + removed.extend(current_range) + elif 1 not in shape and len(newargs) > 0 and get_shape(newargs[-1]) == (1, 1): + # Broadcast: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + prev_range = [i for i in range(min(current_range)) if i not in removed] + removed.extend(prev_range[-2:]) + else: + newargs.append(arg) + return _array_tensor_product(*newargs), removed + + +@singledispatch +def _array2matrix(expr): + return expr + + +@_array2matrix.register(ZeroArray) +def _(expr: ZeroArray): + if get_rank(expr) == 2: + return ZeroMatrix(*expr.shape) + else: + return expr + + +@_array2matrix.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args]) + + +@_array2matrix.register(ArrayContraction) +def _(expr: ArrayContraction): + expr = expr.flatten_contraction_of_diagonal() + expr = identify_removable_identity_matrices(expr) + expr = expr.split_multiple_contractions() + expr = identify_hadamard_products(expr) + if not isinstance(expr, ArrayContraction): + return _array2matrix(expr) + subexpr = expr.expr + contraction_indices: tuple[tuple[int]] = expr.contraction_indices + if contraction_indices == ((0,), (1,)) or ( + contraction_indices == ((0,),) and subexpr.shape[1] == 1 + ) or ( + contraction_indices == ((1,),) and subexpr.shape[0] == 1 + ): + shape = subexpr.shape + subexpr = _array2matrix(subexpr) + if isinstance(subexpr, MatrixExpr): + return OneMatrix(1, shape[0])*subexpr*OneMatrix(shape[1], 1) + if isinstance(subexpr, ArrayTensorProduct): + newexpr = _array_contraction(_array2matrix(subexpr), *contraction_indices) + contraction_indices = newexpr.contraction_indices + if any(i > 2 for i in newexpr.subranks): + addends = _array_add(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i, + ArrayAdd) else [i] for i in expr.expr.args])]) + newexpr = _array_contraction(addends, *contraction_indices) + if isinstance(newexpr, ArrayAdd): + ret = _array2matrix(newexpr) + return ret + assert isinstance(newexpr, ArrayContraction) + ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args)) + return ret + elif not isinstance(subexpr, _CodegenArrayAbstract): + ret = _array2matrix(subexpr) + if isinstance(ret, MatrixExpr): + assert expr.contraction_indices == ((0, 1),) + return _a2m_trace(ret) + else: + return _array_contraction(ret, *expr.contraction_indices) + + +@_array2matrix.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + pexpr = _array_diagonal(_array2matrix(expr.expr), *expr.diagonal_indices) + pexpr = identify_hadamard_products(pexpr) + if isinstance(pexpr, ArrayDiagonal): + pexpr = _array_diag2contr_diagmatrix(pexpr) + if expr == pexpr: + return expr + return _array2matrix(pexpr) + + +@_array2matrix.register(PermuteDims) +def _(expr: PermuteDims): + if expr.permutation.array_form == [1, 0]: + return _a2m_transpose(_array2matrix(expr.expr)) + elif isinstance(expr.expr, ArrayTensorProduct): + ranks = expr.expr.subranks + inv_permutation = expr.permutation**(-1) + newrange = [inv_permutation(i) for i in range(sum(ranks))] + newpos = [] + counter = 0 + for rank in ranks: + newpos.append(newrange[counter:counter+rank]) + counter += rank + newargs = [] + newperm = [] + scalars = [] + for pos, arg in zip(newpos, expr.expr.args): + if len(pos) == 0: + scalars.append(_array2matrix(arg)) + elif pos == sorted(pos): + newargs.append((_array2matrix(arg), pos[0])) + newperm.extend(pos) + elif len(pos) == 2: + newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0])) + newperm.extend(reversed(pos)) + else: + raise NotImplementedError() + newargs = [i[0] for i in newargs] + return _permute_dims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm)) + elif isinstance(expr.expr, ArrayContraction): + mat_mul_lines = _array2matrix(expr.expr) + if not isinstance(mat_mul_lines, ArrayTensorProduct): + return _permute_dims(mat_mul_lines, expr.permutation) + # TODO: this assumes that all arguments are matrices, it may not be the case: + permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation + permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))] + args_array = [None for i in mat_mul_lines.args] + for i in range(len(mat_mul_lines.args)): + p1 = permuted[2*i] + p2 = permuted[2*i+1] + if p1 // 2 != p2 // 2: + return _permute_dims(mat_mul_lines, permutation) + if p1 > p2: + args_array[i] = _a2m_transpose(mat_mul_lines.args[p1 // 2]) + else: + args_array[i] = mat_mul_lines.args[p1 // 2] + return _a2m_tensor_product(*args_array) + else: + return expr + + +@_array2matrix.register(ArrayAdd) +def _(expr: ArrayAdd): + addends = [_array2matrix(arg) for arg in expr.args] + return _a2m_add(*addends) + + +@_array2matrix.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr = _array2matrix(expr.expr) + if isinstance(subexpr, MatrixExpr): + if subexpr.shape != (1, 1): + d = expr.function.bound_symbols[0] + w = Wild("w", exclude=[d]) + p = Wild("p", exclude=[d]) + m = expr.function.expr.match(w*d**p) + if m is not None: + return m[w]*HadamardPower(subexpr, m[p]) + return ElementwiseApplyFunction(expr.function, subexpr) + else: + return ArrayElementwiseApplyFunc(expr.function, subexpr) + + +@_array2matrix.register(ArrayElement) +def _(expr: ArrayElement): + ret = _array2matrix(expr.name) + if isinstance(ret, MatrixExpr): + return MatrixElement(ret, *expr.indices) + return ArrayElement(ret, expr.indices) + + +@singledispatch +def _remove_trivial_dims(expr): + return expr, [] + + +@_remove_trivial_dims.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`. + # The matrix expression has to be equivalent to the tensor product of the + # matrices, with trivial dimensions (i.e. dim=1) dropped. + # That is, add contractions over trivial dimensions: + + removed = [] + newargs = [] + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + pending = None + prev_i = None + for i, arg in enumerate(expr.args): + current_range = list(range(cumul[i], cumul[i+1])) + if isinstance(arg, OneArray): + removed.extend(current_range) + continue + if not isinstance(arg, (MatrixExpr, MatrixBase)): + rarg, rem = _remove_trivial_dims(arg) + removed.extend(rem) + newargs.append(rarg) + continue + elif getattr(arg, "is_Identity", False) and arg.shape == (1, 1): + if arg.shape == (1, 1): + # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1. + removed.extend(current_range) + continue + elif arg.shape == (1, 1): + arg, _ = _remove_trivial_dims(arg) + # Matrix is equivalent to scalar: + if len(newargs) == 0: + newargs.append(arg) + elif 1 in get_shape(newargs[-1]): + if newargs[-1].shape[1] == 1: + newargs[-1] = newargs[-1]*arg + else: + newargs[-1] = arg*newargs[-1] + removed.extend(current_range) + else: + newargs.append(arg) + elif 1 in arg.shape: + k = [i for i in arg.shape if i != 1][0] + if pending is None: + pending = k + prev_i = i + newargs.append(arg) + elif pending == k: + prev = newargs[-1] + if prev.shape[0] == 1: + d1 = cumul[prev_i] # type: ignore + prev = _a2m_transpose(prev) + else: + d1 = cumul[prev_i] + 1 # type: ignore + if arg.shape[1] == 1: + d2 = cumul[i] + 1 + arg = _a2m_transpose(arg) + else: + d2 = cumul[i] + newargs[-1] = prev*arg + pending = None + removed.extend([d1, d2]) + else: + newargs.append(arg) + pending = k + prev_i = i + else: + newargs.append(arg) + pending = None + newexpr, newremoved = _a2m_tensor_product(*newargs), sorted(removed) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_matrices_rewrite(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_kronecker_products_broadcast(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + return newexpr, newremoved + + +@_remove_trivial_dims.register(ArrayAdd) +def _(expr: ArrayAdd): + rec = [_remove_trivial_dims(arg) for arg in expr.args] + newargs, removed = zip(*rec) + if len({get_shape(i) for i in newargs}) > 1: + return expr, [] + if len(removed) == 0: + return expr, removed + removed1 = removed[0] + return _a2m_add(*newargs), removed1 + + +@_remove_trivial_dims.register(PermuteDims) +def _(expr: PermuteDims): + subexpr, subremoved = _remove_trivial_dims(expr.expr) + p = expr.permutation.array_form + pinv = _af_invert(expr.permutation.array_form) + shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))])) + premoved = [pinv[i] for i in subremoved] + p2 = [e - shift[e] for e in p if e not in subremoved] + # TODO: check if subremoved should be permuted as well... + newexpr = _permute_dims(subexpr, p2) + premoved = sorted(premoved) + if newexpr != expr: + newexpr, removed2 = _remove_trivial_dims(_array2matrix(newexpr)) + premoved = _combine_removed(-1, premoved, removed2) + return newexpr, premoved + + +@_remove_trivial_dims.register(ArrayContraction) +def _(expr: ArrayContraction): + new_expr, removed0 = _array_contraction_to_diagonal_multiple_identity(expr) + if new_expr != expr: + new_expr2, removed1 = _remove_trivial_dims(_array2matrix(new_expr)) + removed = _combine_removed(-1, removed0, removed1) + return new_expr2, removed + rank1 = get_rank(expr) + expr, removed1 = remove_identity_matrices(expr) + if not isinstance(expr, ArrayContraction): + expr2, removed2 = _remove_trivial_dims(expr) + return expr2, _combine_removed(rank1, removed1, removed2) + newexpr, removed2 = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([1 if i in removed2 else 0 for i in range(get_rank(expr.expr))])) + new_contraction_indices = [tuple(j for j in i if j not in removed2) for i in expr.contraction_indices] + # Remove possible empty tuples "()": + new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0] + contraction_indices_flat = [j for i in expr.contraction_indices for j in i] + removed2 = [i for i in removed2 if i not in contraction_indices_flat] + new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices] + # Shift removed2: + removed2 = ArrayContraction._push_indices_up(expr.contraction_indices, removed2) + removed = _combine_removed(rank1, removed1, removed2) + return _array_contraction(newexpr, *new_contraction_indices), list(removed) + + +def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal): + assert isinstance(expr, ArrayDiagonal) + editor = _EditArrayContraction(expr) + mapping = {i: {j for j in editor.args_with_ind if i in j.indices} for i in range(-1, -1-editor.number_of_diagonal_indices, -1)} + removed = [] + counter: int = 0 + for i, arg_with_ind in enumerate(editor.args_with_ind): + counter += len(arg_with_ind.indices) + if isinstance(arg_with_ind.element, Identity): + if None in arg_with_ind.indices and any(i is not None and (i < 0) == True for i in arg_with_ind.indices): + diag_ind = [j for j in arg_with_ind.indices if j is not None][0] + other = [j for j in mapping[diag_ind] if j != arg_with_ind][0] + if not isinstance(other.element, MatrixExpr): + continue + if 1 not in other.element.shape: + continue + if None not in other.indices: + continue + editor.args_with_ind[i].element = None + none_index = other.indices.index(None) + other.element = DiagMatrix(other.element) + other_range = editor.get_absolute_range(other) + removed.extend([other_range[0] + none_index]) + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, get_rank(expr.expr)) + return editor.to_array_contraction(), removed + + +@_remove_trivial_dims.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + newexpr, removed = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))])) + new_diag_indices_map = {i: tuple(j for j in i if j not in removed) for i in expr.diagonal_indices} + for old_diag_tuple, new_diag_tuple in new_diag_indices_map.items(): + if len(new_diag_tuple) == 1: + removed = [i for i in removed if i not in old_diag_tuple] + new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices_map.values()] + rank = get_rank(expr.expr) + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank) + removed = sorted(set(removed)) + # If there are single axes to diagonalize remaining, it means that their + # corresponding dimension has been removed, they no longer need diagonalization: + new_diag_indices = [i for i in new_diag_indices if len(i) > 0] + if len(new_diag_indices) > 0: + newexpr2 = _array_diagonal(newexpr, *new_diag_indices, allow_trivial_diags=True) + else: + newexpr2 = newexpr + if isinstance(newexpr2, ArrayDiagonal): + newexpr3, removed2 = _remove_diagonalized_identity_matrices(newexpr2) + removed = _combine_removed(-1, removed, removed2) + return newexpr3, removed + else: + return newexpr2, removed + + +@_remove_trivial_dims.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction): + subexpr, removed = _remove_trivial_dims(expr.expr) + if subexpr.shape == (1, 1): + # TODO: move this to ElementwiseApplyFunction + return expr.function(subexpr), removed + [0, 1] + return ElementwiseApplyFunction(expr.function, subexpr), [] + + +@_remove_trivial_dims.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr, removed = _remove_trivial_dims(expr.expr) + return ArrayElementwiseApplyFunc(expr.function, subexpr), removed + + +def convert_array_to_matrix(expr): + r""" + Recognize matrix expressions in codegen objects. + + If more than one matrix multiplication line have been detected, return a + list with the matrix expressions. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy.tensor.array import tensorcontraction, tensorproduct + >>> from sympy import MatrixSymbol, Sum + >>> from sympy.abc import i, j, k, l, N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A.T + + Transposition is detected: + + >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A.T*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A + + Detect the trace: + + >>> expr = Sum(A[i, i], (i, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A) + + Recognize some more complex traces: + + >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A*B) + + More complicated expressions: + + >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B.T*A.T + + Expressions constructed from matrix expressions do not contain literal + indices, the positions of free indices are returned instead: + + >>> expr = A*B + >>> cg = convert_matrix_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + + If more than one line of matrix multiplications is detected, return + separate matrix multiplication factors embedded in a tensor product object: + + >>> cg = tensorcontraction(tensorproduct(A, B, C, D), (1, 2), (5, 6)) + >>> convert_array_to_matrix(cg) + ArrayTensorProduct(A*B, C*D) + + The two lines have free indices at axes 0, 3 and 4, 7, respectively. + """ + rec = _array2matrix(expr) + rec, removed = _remove_trivial_dims(rec) + return rec + + +def _array_diag2contr_diagmatrix(expr: ArrayDiagonal): + if isinstance(expr.expr, ArrayTensorProduct): + args = list(expr.expr.args) + diag_indices = list(expr.diagonal_indices) + mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args]) + tuple_links = [[mapping[j] for j in i] for i in diag_indices] + contr_indices = [] + total_rank = get_rank(expr) + replaced = [False for arg in args] + for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)): + if len(abs_pos) != 2: + continue + (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos + arg1 = args[pos1_outer] + arg2 = args[pos2_outer] + if get_rank(arg1) != 2 or get_rank(arg2) != 2: + if replaced[pos1_outer]: + diag_indices[i] = None + if replaced[pos2_outer]: + diag_indices[i] = None + continue + pos1_in2 = 1 - pos1_inner + pos2_in2 = 1 - pos2_inner + if arg1.shape[pos1_in2] == 1: + if arg1.shape[pos1_inner] != 1: + darg1 = DiagMatrix(arg1) + else: + darg1 = arg1 + args.append(darg1) + contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos1_outer] = OneArray(arg1.shape[pos1_in2]) + replaced[pos1_outer] = True + elif arg2.shape[pos2_in2] == 1: + if arg2.shape[pos2_inner] != 1: + darg2 = DiagMatrix(arg2) + else: + darg2 = arg2 + args.append(darg2) + contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos2_outer] = OneArray(arg2.shape[pos2_in2]) + replaced[pos2_outer] = True + diag_indices_new = [i for i in diag_indices if i is not None] + cumul = list(accumulate([0] + [get_rank(arg) for arg in args])) + contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices] + tc = _array_contraction( + _array_tensor_product(*args), *contr_indices2 + ) + td = _array_diagonal(tc, *diag_indices_new) + return td + return expr + + +def _a2m_mul(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matmul import MatMul + return MatMul(*args).doit() + else: + return _array_contraction( + _array_tensor_product(*args), + *[(2*i-1, 2*i) for i in range(1, len(args))] + ) + + +def _a2m_tensor_product(*args): + scalars = [] + arrays = [] + for arg in args: + if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)): + arrays.append(arg) + else: + scalars.append(arg) + scalar = Mul.fromiter(scalars) + if len(arrays) == 0: + return scalar + if scalar != 1: + if isinstance(arrays[0], _CodegenArrayAbstract): + arrays = [scalar] + arrays + else: + arrays[0] *= scalar + return _array_tensor_product(*arrays) + + +def _a2m_add(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matadd import MatAdd + return MatAdd(*args).doit() + else: + return _array_add(*args) + + +def _a2m_trace(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _array_contraction(arg, (0, 1)) + else: + from sympy.matrices.expressions.trace import Trace + return Trace(arg) + + +def _a2m_transpose(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _permute_dims(arg, [1, 0]) + else: + from sympy.matrices.expressions.transpose import Transpose + return Transpose(arg).doit() + + +def identify_hadamard_products(expr: ArrayContraction | ArrayDiagonal): + + editor: _EditArrayContraction = _EditArrayContraction(expr) + + map_contr_to_args: dict[FrozenSet, list[_ArgE]] = defaultdict(list) + map_ind_to_inds: dict[int | None, int] = defaultdict(int) + for arg_with_ind in editor.args_with_ind: + for ind in arg_with_ind.indices: + map_ind_to_inds[ind] += 1 + if None in arg_with_ind.indices: + continue + map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) + + k: FrozenSet[int] + v: list[_ArgE] + for k, v in map_contr_to_args.items(): + make_trace: bool = False + if len(k) == 1 and next(iter(k)) >= 0 and sum(next(iter(k)) in i for i in map_contr_to_args) == 1: + # This is a trace: the arguments are fully contracted with only one + # index, and the index isn't used anywhere else: + make_trace = True + first_element = S.One + elif len(k) != 2: + # Hadamard product only defined for matrices: + continue + if len(v) == 1: + # Hadamard product with a single argument makes no sense: + continue + for ind in k: + if map_ind_to_inds[ind] <= 2: + # There is no other contraction, skip: + continue + + def check_transpose(x): + x = [i if i >= 0 else -1-i for i in x] + return x == sorted(x) + + # Check if expression is a trace: + if all(map_ind_to_inds[j] == len(v) and j >= 0 for j in k) and all(j >= 0 for j in k): + # This is a trace + make_trace = True + first_element = v[0].element + if not check_transpose(v[0].indices): + first_element = first_element.T # type: ignore + hadamard_factors = v[1:] + else: + hadamard_factors = v + + # This is a Hadamard product: + + hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors]) + hp_indices = v[0].indices + if not check_transpose(hadamard_factors[0].indices): + hp_indices = list(reversed(hp_indices)) + if make_trace: + hp = Trace(first_element*hp.T)._normalize() + hp_indices = [] + editor.insert_after(v[0], _ArgE(hp, hp_indices)) + for i in v: + editor.args_with_ind.remove(i) + + return editor.to_array_contraction() + + +def identify_removable_identity_matrices(expr): + editor = _EditArrayContraction(expr) + + flag = True + while flag: + flag = False + for arg_with_ind in editor.args_with_ind: + if isinstance(arg_with_ind.element, Identity): + k = arg_with_ind.element.shape[0] + # Candidate for removal: + if arg_with_ind.indices == [None, None]: + # Free identity matrix, will be cleared by _remove_trivial_dims: + continue + elif None in arg_with_ind.indices: + ind = [j for j in arg_with_ind.indices if j is not None][0] + counted = editor.count_args_with_index(ind) + if counted == 1: + # Identity matrix contracted only on one index with itself, + # transform to a OneArray(k) element: + editor.insert_after(arg_with_ind, OneArray(k)) + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + elif counted > 2: + # Case counted = 2 is a matrix multiplication by identity matrix, skip it. + # Case counted > 2 is a multiple contraction, + # this is a case where the contraction becomes a diagonalization if the + # identity matrix is dropped. + continue + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted > 1: + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + else: + # This is a trace, skip it as it will be recognized somewhere else: + pass + elif ask(Q.diagonal(arg_with_ind.element)): + if arg_with_ind.indices == [None, None]: + continue + elif None in arg_with_ind.indices: + pass + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted == 3: + # A_ai B_bi D_ii ==> A_ai D_ij B_bj + ind_new = editor.get_new_contraction_index() + other_args = [j for j in editor.args_with_ind if j != arg_with_ind] + other_args[1].indices = [ind_new if j == ind else j for j in other_args[1].indices] + arg_with_ind.indices = [ind, ind_new] + flag = True + break + + return editor.to_array_contraction() + + +def remove_identity_matrices(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + removed: list[int] = [] + + permutation_map = {} + + free_indices = list(accumulate([0] + [sum(i is None for i in arg.indices) for arg in editor.args_with_ind])) + free_map = dict(zip(editor.args_with_ind, free_indices[:-1])) + + update_pairs = {} + + for ind in range(editor.number_of_contraction_indices): + args = editor.get_args_with_index(ind) + identity_matrices = [i for i in args if isinstance(i.element, Identity)] + number_identity_matrices = len(identity_matrices) + # If the contraction involves a non-identity matrix and multiple identity matrices: + if number_identity_matrices != len(args) - 1 or number_identity_matrices == 0: + continue + # Get the non-identity element: + non_identity = [i for i in args if not isinstance(i.element, Identity)][0] + # Check that all identity matrices have at least one free index + # (otherwise they would be contractions to some other elements) + if any(None not in i.indices for i in identity_matrices): + continue + # Mark the identity matrices for removal: + for i in identity_matrices: + i.element = None + removed.extend(range(free_map[i], free_map[i] + len([j for j in i.indices if j is None]))) + last_removed = removed.pop(-1) + update_pairs[last_removed, ind] = non_identity.indices[:] + # Remove the indices from the non-identity matrix, as the contraction + # no longer exists: + non_identity.indices = [None if i == ind else i for i in non_identity.indices] + + removed.sort() + + shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr))])) + for (last_removed, ind), non_identity_indices in update_pairs.items(): + pos = [free_map[non_identity] + i for i, e in enumerate(non_identity_indices) if e == ind] + assert len(pos) == 1 + for j in pos: + permutation_map[j] = last_removed + + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + ret_expr = editor.to_array_contraction() + permutation = [] + counter = 0 + counter2 = 0 + for j in range(get_rank(expr)): + if j in removed: + continue + if counter2 in permutation_map: + target = permutation_map[counter2] + permutation.append(target - shifts[target]) + counter2 += 1 + else: + while counter in permutation_map.values(): + counter += 1 + permutation.append(counter) + counter += 1 + counter2 += 1 + ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation)) + return ret_expr2, removed + + +def _combine_removed(dim: int, removed1: list[int], removed2: list[int]) -> list[int]: + # Concatenate two axis removal operations as performed by + # _remove_trivial_dims, + removed1 = sorted(removed1) + removed2 = sorted(removed2) + i = 0 + j = 0 + removed = [] + while True: + if j >= len(removed2): + while i < len(removed1): + removed.append(removed1[i]) + i += 1 + break + elif i < len(removed1) and removed1[i] <= i + removed2[j]: + removed.append(removed1[i]) + i += 1 + else: + removed.append(i + removed2[j]) + j += 1 + return removed + + +def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + editor.track_permutation_start() + removed: list[int] = [] + diag_index_counter: int = 0 + for i in range(editor.number_of_contraction_indices): + identities = [] + args = [] + for j, arg in enumerate(editor.args_with_ind): + if i not in arg.indices: + continue + if isinstance(arg.element, Identity): + identities.append(arg) + else: + args.append(arg) + if len(identities) == 0: + continue + if len(args) + len(identities) < 3: + continue + new_diag_ind = -1 - diag_index_counter + diag_index_counter += 1 + # Variable "flag" to control whether to skip this contraction set: + flag: bool = True + for i1, id1 in enumerate(identities): + if None not in id1.indices: + flag = True + break + free_pos = list(range(*editor.get_absolute_free_range(id1)))[0] + editor._track_permutation[-1].append(free_pos) # type: ignore + id1.element = None + flag = False + break + if flag: + continue + for arg in identities[:i1] + identities[i1+1:]: + arg.element = None + removed.extend(range(*editor.get_absolute_free_range(arg))) + for arg in args: + arg.indices = [new_diag_ind if j == i else j for j in arg.indices] + for j, e in enumerate(editor.args_with_ind): + if e.element is None: + editor._track_permutation[j] = None # type: ignore + editor._track_permutation = [i for i in editor._track_permutation if i is not None] # type: ignore + # Renumber permutation array form in order to deal with deleted positions: + remap = {e: i for i, e in enumerate(sorted({k for j in editor._track_permutation for k in j}))} + editor._track_permutation = [[remap[j] for j in i] for i in editor._track_permutation] + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + new_expr = editor.to_array_contraction() + return new_expr, removed diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c219a205c4305bd7070e5117978146224521c58c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py @@ -0,0 +1,257 @@ +from collections import defaultdict + +from sympy import Function +from sympy.combinatorics.permutations import _af_invert +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, \ + get_shape, ArrayElement, _array_tensor_product, _array_diagonal, _array_contraction, _array_add, \ + _permute_dims, OneArray, ArrayAdd +from sympy.tensor.array.expressions.utils import _get_argindex, _get_diagonal_indices + + +def convert_indexed_to_array(expr, first_indices=None): + r""" + Parse indexed expression into a form useful for code generation. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy import MatrixSymbol, Sum, symbols + + >>> i, j, k, d = symbols("i j k d") + >>> M = MatrixSymbol("M", d, d) + >>> N = MatrixSymbol("N", d, d) + + Recognize the trace in summation form: + + >>> expr = Sum(M[i, i], (i, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(M, (0, 1)) + + Recognize the extraction of the diagonal by using the same index `i` on + both axes of the matrix: + + >>> expr = M[i, i] + >>> convert_indexed_to_array(expr) + ArrayDiagonal(M, (0, 1)) + + This function can help perform the transformation expressed in two + different mathematical notations as: + + `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` + + Recognize the matrix multiplication in summation form: + + >>> expr = Sum(M[i, j]*N[j, k], (j, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + + Specify that ``k`` has to be the starting index: + + >>> convert_indexed_to_array(expr, first_indices=[k]) + ArrayContraction(ArrayTensorProduct(N, M), (0, 3)) + """ + + result, indices = _convert_indexed_to_array(expr) + + if any(isinstance(i, (int, Integer)) for i in indices): + result = ArrayElement(result, indices) + indices = [] + + if not first_indices: + return result + + def _check_is_in(elem, indices): + if elem in indices: + return True + if any(elem in i for i in indices if isinstance(i, frozenset)): + return True + return False + + repl = {j: i for i in indices if isinstance(i, frozenset) for j in i} + first_indices = [repl.get(i, i) for i in first_indices] + for i in first_indices: + if not _check_is_in(i, indices): + first_indices.remove(i) + first_indices.extend([i for i in indices if not _check_is_in(i, first_indices)]) + + def _get_pos(elem, indices): + if elem in indices: + return indices.index(elem) + for i, e in enumerate(indices): + if not isinstance(e, frozenset): + continue + if elem in e: + return i + raise ValueError("not found") + + permutation = _af_invert([_get_pos(i, first_indices) for i in indices]) + if isinstance(result, ArrayAdd): + return _array_add(*[_permute_dims(arg, permutation) for arg in result.args]) + else: + return _permute_dims(result, permutation) + + +def _convert_indexed_to_array(expr): + if isinstance(expr, Sum): + function = expr.function + summation_indices = expr.variables + subexpr, subindices = _convert_indexed_to_array(function) + subindicessets = {j: i for i in subindices if isinstance(i, frozenset) for j in i} + summation_indices = sorted({subindicessets.get(i, i) for i in summation_indices}, key=default_sort_key) + # TODO: check that Kronecker delta is only contracted to one other element: + kronecker_indices = set() + if isinstance(function, Mul): + for arg in function.args: + if not isinstance(arg, KroneckerDelta): + continue + arg_indices = sorted(set(arg.indices), key=default_sort_key) + if len(arg_indices) == 2: + kronecker_indices.update(arg_indices) + kronecker_indices = sorted(kronecker_indices, key=default_sort_key) + # Check dimensional consistency: + shape = get_shape(subexpr) + if shape: + for ind, istart, iend in expr.limits: + i = _get_argindex(subindices, ind) + if istart != 0 or iend+1 != shape[i]: + raise ValueError("summation index and array dimension mismatch: %s" % ind) + contraction_indices = [] + subindices = list(subindices) + if isinstance(subexpr, ArrayDiagonal): + diagonal_indices = list(subexpr.diagonal_indices) + dindices = subindices[-len(diagonal_indices):] + subindices = subindices[:-len(diagonal_indices)] + for index in summation_indices: + if index in dindices: + position = dindices.index(index) + contraction_indices.append(diagonal_indices[position]) + diagonal_indices[position] = None + diagonal_indices = [i for i in diagonal_indices if i is not None] + for i, ind in enumerate(subindices): + if ind in summation_indices: + pass + if diagonal_indices: + subexpr = _array_diagonal(subexpr.expr, *diagonal_indices) + else: + subexpr = subexpr.expr + + axes_contraction = defaultdict(list) + for i, ind in enumerate(subindices): + include = all(j not in kronecker_indices for j in ind) if isinstance(ind, frozenset) else ind not in kronecker_indices + if ind in summation_indices and include: + axes_contraction[ind].append(i) + subindices[i] = None + for k, v in axes_contraction.items(): + if any(i in kronecker_indices for i in k) if isinstance(k, frozenset) else k in kronecker_indices: + continue + contraction_indices.append(tuple(v)) + free_indices = [i for i in subindices if i is not None] + indices_ret = list(free_indices) + indices_ret.sort(key=lambda x: free_indices.index(x)) + return _array_contraction( + subexpr, + *contraction_indices, + free_indices=free_indices + ), tuple(indices_ret) + if isinstance(expr, Mul): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + # Check if there are KroneckerDelta objects: + kronecker_delta_repl = {} + for arg in args: + if not isinstance(arg, KroneckerDelta): + continue + # Diagonalize two indices: + i, j = arg.indices + kindices = set(arg.indices) + if i in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[i]) + if j in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[j]) + kindices = frozenset(kindices) + for index in kindices: + kronecker_delta_repl[index] = kindices + # Remove KroneckerDelta objects, their relations should be handled by + # ArrayDiagonal: + newargs = [] + newindices = [] + for arg, loc_indices in zip(args, indices): + if isinstance(arg, KroneckerDelta): + continue + newargs.append(arg) + newindices.append(loc_indices) + flattened_indices = [kronecker_delta_repl.get(j, j) for i in newindices for j in i] + diagonal_indices, ret_indices = _get_diagonal_indices(flattened_indices) + tp = _array_tensor_product(*newargs) + if diagonal_indices: + return _array_diagonal(tp, *diagonal_indices), ret_indices + else: + return tp, ret_indices + if isinstance(expr, MatrixElement): + indices = expr.args[1:] + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.args[0], *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, ArrayElement): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.name, *diagonal_indices), ret_indices + else: + return expr.name, ret_indices + if isinstance(expr, Indexed): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.base, *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, IndexedBase): + raise NotImplementedError + if isinstance(expr, KroneckerDelta): + return expr, expr.indices + if isinstance(expr, Add): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + args = list(args) + # Check if all indices are compatible. Otherwise expand the dimensions: + index0 = [] + shape0 = [] + for arg, arg_indices in zip(args, indices): + arg_indices_set = set(arg_indices) + arg_indices_missing = arg_indices_set.difference(index0) + index0.extend([i for i in arg_indices if i in arg_indices_missing]) + arg_shape = get_shape(arg) + shape0.extend([arg_shape[i] for i, e in enumerate(arg_indices) if e in arg_indices_missing]) + for i, (arg, arg_indices) in enumerate(zip(args, indices)): + if len(arg_indices) < len(index0): + missing_indices_pos = [i for i, e in enumerate(index0) if e not in arg_indices] + missing_shape = [shape0[i] for i in missing_indices_pos] + arg_indices = tuple(index0[j] for j in missing_indices_pos) + arg_indices + args[i] = _array_tensor_product(OneArray(*missing_shape), args[i]) + permutation = Permutation([arg_indices.index(j) for j in index0]) + # Perform index permutations: + args[i] = _permute_dims(args[i], permutation) + return _array_add(*args), tuple(index0) + if isinstance(expr, Pow): + subexpr, subindices = _convert_indexed_to_array(expr.base) + if isinstance(expr.exp, (int, Integer)): + diags = zip(*[(2*i, 2*i + 1) for i in range(expr.exp)]) + arr = _array_diagonal(_array_tensor_product(*[subexpr for i in range(expr.exp)]), *diags) + return arr, subindices + if isinstance(expr, Function): + subexpr, subindices = _convert_indexed_to_array(expr.args[0]) + return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices + return expr, () diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66961727f6338318d65876a7768802773e4f2d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py @@ -0,0 +1,87 @@ +from sympy import KroneckerProduct +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.tensor.array.expressions.array_expressions import \ + ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _array_add, _permute_dims, Reshape + + +def convert_matrix_to_array(expr: Basic) -> Basic: + if isinstance(expr, MatMul): + args_nonmat = [] + args = [] + for arg in expr.args: + if isinstance(arg, MatrixExpr): + args.append(arg) + else: + args_nonmat.append(convert_matrix_to_array(arg)) + contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)] + scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One + if scalar == 1: + tprod = _array_tensor_product( + *[convert_matrix_to_array(arg) for arg in args]) + else: + tprod = _array_tensor_product( + scalar, + *[convert_matrix_to_array(arg) for arg in args]) + return _array_contraction( + tprod, + *contractions + ) + elif isinstance(expr, MatAdd): + return _array_add( + *[convert_matrix_to_array(arg) for arg in expr.args] + ) + elif isinstance(expr, Transpose): + return _permute_dims( + convert_matrix_to_array(expr.args[0]), [1, 0] + ) + elif isinstance(expr, Trace): + inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore + return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1)) + elif isinstance(expr, Mul): + return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args]) + elif isinstance(expr, Pow): + base = convert_matrix_to_array(expr.base) + if (expr.exp > 0) == True: + return _array_tensor_product(*[base for i in range(expr.exp)]) + else: + return expr + elif isinstance(expr, MatPow): + base = convert_matrix_to_array(expr.base) + if expr.exp.is_Integer != True: + b = symbols("b", cls=Dummy) + return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base)) + elif (expr.exp > 0) == True: + return convert_matrix_to_array(MatMul.fromiter(base for i in range(expr.exp))) + else: + return expr + elif isinstance(expr, HadamardProduct): + tp = _array_tensor_product(*[convert_matrix_to_array(arg) for arg in expr.args]) + diag = [[2*i for i in range(len(expr.args))], [2*i+1 for i in range(len(expr.args))]] + return _array_diagonal(tp, *diag) + elif isinstance(expr, HadamardPower): + base, exp = expr.args + if isinstance(exp, Integer) and exp > 0: + return convert_matrix_to_array(HadamardProduct.fromiter(base for i in range(exp))) + else: + d = Dummy("d") + return ArrayElementwiseApplyFunc(Lambda(d, d**exp), base) + elif isinstance(expr, KroneckerProduct): + kp_args = [convert_matrix_to_array(arg) for arg in expr.args] + permutation = [2*i for i in range(len(kp_args))] + [2*i + 1 for i in range(len(kp_args))] + return Reshape(_permute_dims(_array_tensor_product(*kp_args), permutation), expr.shape) + else: + return expr diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96a49f7643eee1fcb0b9debe9775df77a20e686 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7498104d77906294212bff0f4adcd9656270043d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa0d3d6d67093d7c5f7a900c6a4963db98e1d30 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc10c447a0f05cebc13df2eb2b60915e62b91d7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a6e427908d40bcb13b316603b9d62c55410667 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0648f9f2a4a9acd96ba29897a866ae004d5df1d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fff55bcaa0149d747cde7807e2c300c195158506 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b1b2702ffc95295dbb323bf11b7cd8f957ce164 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..183446abdceb48d306833cc46add15a40e7c0c45 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..63fb79ab7ced7bff5ecb55b1764f43e29f98609d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py @@ -0,0 +1,808 @@ +import random + +from sympy import tensordiagonal, eye, KroneckerDelta, Array +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, ArrayElement, \ + PermuteDims, ArrayContraction, ArrayTensorProduct, ArrayDiagonal, \ + ArrayAdd, nest_permutation, ArrayElementwiseApplyFunc, _EditArrayContraction, _ArgE, _array_tensor_product, \ + _array_contraction, _array_diagonal, _array_add, _permute_dims, Reshape +from sympy.testing.pytest import raises + +i, j, k, l, m, n = symbols("i j k l m n") + + +M = ArraySymbol("M", (k, k)) +N = ArraySymbol("N", (k, k)) +P = ArraySymbol("P", (k, k)) +Q = ArraySymbol("Q", (k, k)) + +A = ArraySymbol("A", (k, k)) +B = ArraySymbol("B", (k, k)) +C = ArraySymbol("C", (k, k)) +D = ArraySymbol("D", (k, k)) + +X = ArraySymbol("X", (k, k)) +Y = ArraySymbol("Y", (k, k)) + +a = ArraySymbol("a", (k, 1)) +b = ArraySymbol("b", (k, 1)) +c = ArraySymbol("c", (k, 1)) +d = ArraySymbol("d", (k, 1)) + + +def test_array_symbol_and_element(): + A = ArraySymbol("A", (2,)) + A0 = ArrayElement(A, (0,)) + A1 = ArrayElement(A, (1,)) + assert A[0] == A0 + assert A[1] != A0 + assert A.as_explicit() == ImmutableDenseNDimArray([A0, A1]) + + A2 = tensorproduct(A, A) + assert A2.shape == (2, 2) + # TODO: not yet supported: + # assert A2.as_explicit() == Array([[A[0]*A[0], A[1]*A[0]], [A[0]*A[1], A[1]*A[1]]]) + A3 = tensorcontraction(A2, (0, 1)) + assert A3.shape == () + # TODO: not yet supported: + # assert A3.as_explicit() == Array([]) + + A = ArraySymbol("A", (2, 3, 4)) + Ae = A.as_explicit() + assert Ae == ImmutableDenseNDimArray( + [[[ArrayElement(A, (i, j, k)) for k in range(4)] for j in range(3)] for i in range(2)]) + + p = _permute_dims(A, Permutation(0, 2, 1)) + assert isinstance(p, PermuteDims) + + A = ArraySymbol("A", (2,)) + raises(IndexError, lambda: A[()]) + raises(IndexError, lambda: A[0, 1]) + raises(ValueError, lambda: A[-1]) + raises(ValueError, lambda: A[2]) + + O = OneArray(3, 4) + Z = ZeroArray(m, n) + + raises(IndexError, lambda: O[()]) + raises(IndexError, lambda: O[1, 2, 3]) + raises(ValueError, lambda: O[3, 0]) + raises(ValueError, lambda: O[0, 4]) + + assert O[1, 2] == 1 + assert Z[1, 2] == 0 + + +def test_zero_array(): + assert ZeroArray() == 0 + assert ZeroArray().is_Integer + + za = ZeroArray(3, 2, 4) + assert za.shape == (3, 2, 4) + za_e = za.as_explicit() + assert za_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + za = ZeroArray(m, n, k, 2) + assert za.shape == (m, n, k, 2) + raises(ValueError, lambda: za.as_explicit()) + + +def test_one_array(): + assert OneArray() == 1 + assert OneArray().is_Integer + + oa = OneArray(3, 2, 4) + assert oa.shape == (3, 2, 4) + oa_e = oa.as_explicit() + assert oa_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + oa = OneArray(m, n, k, 2) + assert oa.shape == (m, n, k, 2) + raises(ValueError, lambda: oa.as_explicit()) + + +def test_arrayexpr_contraction_construction(): + + cg = _array_contraction(A) + assert cg == A + + cg = _array_contraction(_array_tensor_product(A, B), (1, 0)) + assert cg == _array_contraction(_array_tensor_product(A, B), (0, 1)) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (0, 1)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 1)] + + cg = _array_contraction(_array_tensor_product(M, N), (1, 2)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 1), (1, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(1, 2)] + + cg = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (1, 1)], [(0, 1), (2, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 3), (1, 4)] + + # Test removal of trivial contraction: + assert _array_contraction(a, (1,)) == a + assert _array_contraction( + _array_tensor_product(a, b), (0, 2), (1,), (3,)) == _array_contraction( + _array_tensor_product(a, b), (0, 2)) + + +def test_arrayexpr_array_flatten(): + + # Flatten nested ArrayTensorProduct objects: + expr1 = _array_tensor_product(M, N) + expr2 = _array_tensor_product(P, Q) + expr = _array_tensor_product(expr1, expr2) + assert expr == _array_tensor_product(M, N, P, Q) + assert expr.args == (M, N, P, Q) + + # Flatten mixed ArrayTensorProduct and ArrayContraction objects: + cg1 = _array_contraction(expr1, (1, 2)) + cg2 = _array_contraction(expr2, (0, 3)) + + expr = _array_tensor_product(cg1, cg2) + assert expr == _array_contraction(_array_tensor_product(M, N, P, Q), (1, 2), (4, 7)) + + expr = _array_tensor_product(M, cg1) + assert expr == _array_contraction(_array_tensor_product(M, M, N), (3, 4)) + + # Flatten nested ArrayContraction objects: + cgnested = _array_contraction(cg1, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + + cgnested = _array_contraction(_array_tensor_product(cg1, cg2), (0, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 6), (1, 2), (4, 7)) + + cg3 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cgnested = _array_contraction(cg3, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 5), (1, 3), (2, 4)) + + cgnested = _array_contraction(cg3, (0, 3), (1, 2)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 7), (1, 3), (2, 4), (5, 6)) + + cg4 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + cgnested = _array_contraction(cg4, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7)) + + cgnested = _array_contraction(cg4, (0, 1), (2, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7), (4, 6)) + + cg = _array_diagonal(cg4) + assert cg == cg4 + assert isinstance(cg, type(cg4)) + + # Flatten nested ArrayDiagonal objects: + cg1 = _array_diagonal(expr1, (1, 2)) + cg2 = _array_diagonal(expr2, (0, 3)) + cg3 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cg4 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + + cgnested = _array_diagonal(cg1, (0, 1)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N), (1, 2), (0, 3)) + + cgnested = _array_diagonal(cg3, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4), (5, 6)) + + cgnested = _array_diagonal(cg4, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7), (2, 4)) + + cg = _array_add(M, N) + cg2 = _array_add(cg, P) + assert isinstance(cg2, ArrayAdd) + assert cg2.args == (M, N, P) + assert cg2.shape == (k, k) + + expr = _array_tensor_product(_array_diagonal(X, (0, 1)), _array_diagonal(A, (0, 1))) + assert expr == _array_diagonal(_array_tensor_product(X, A), (0, 1), (2, 3)) + + expr1 = _array_diagonal(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert expr2 == _permute_dims(_array_diagonal(_array_tensor_product(X, A, a), (1, 2)), [0, 1, 4, 2, 3]) + + expr1 = _array_contraction(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert isinstance(expr2, ArrayContraction) + assert isinstance(expr2.expr, ArrayTensorProduct) + + cg = _array_tensor_product(_array_diagonal(_array_tensor_product(A, X, Y), (0, 3), (1, 5)), a, b) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(A, X, Y, a, b), (0, 3), (1, 5)), [0, 1, 6, 7, 2, 3, 4, 5]) + + +def test_arrayexpr_array_diagonal(): + cg = _array_diagonal(M, (1, 0)) + assert cg == _array_diagonal(M, (0, 1)) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (4, 1), (2, 0)) + assert cg == _array_diagonal(_array_tensor_product(M, N, P), (1, 4), (0, 2)) + + cg = _array_diagonal(_array_tensor_product(M, N), (1, 2), (3,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(M, N), (1, 2)), [0, 2, 1]) + + Ax = ArraySymbol("Ax", shape=(1, 2, 3, 4, 3, 5, 6, 2, 7)) + cg = _array_diagonal(Ax, (1, 7), (3,), (2, 4), (6,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(Ax, (1, 7), (2, 4)), [0, 2, 4, 5, 1, 6, 3]) + + cg = _array_diagonal(M, (0,), allow_trivial_diags=True) + assert cg == _permute_dims(M, [1, 0]) + + raises(ValueError, lambda: _array_diagonal(M, (0, 0))) + + +def test_arrayexpr_array_shape(): + expr = _array_tensor_product(M, N, P, Q) + assert expr.shape == (k, k, k, k, k, k, k, k) + Z = MatrixSymbol("Z", m, n) + expr = _array_tensor_product(M, Z) + assert expr.shape == (k, k, m, n) + expr2 = _array_contraction(expr, (0, 1)) + assert expr2.shape == (m, n) + expr2 = _array_diagonal(expr, (0, 1)) + assert expr2.shape == (m, n, k) + exprp = _permute_dims(expr, [2, 1, 3, 0]) + assert exprp.shape == (m, k, n, k) + expr3 = _array_tensor_product(N, Z) + expr2 = _array_add(expr, expr3) + assert expr2.shape == (k, k, m, n) + + # Contraction along axes with discordant dimensions: + raises(ValueError, lambda: _array_contraction(expr, (1, 2))) + # Also diagonal needs the same dimensions: + raises(ValueError, lambda: _array_diagonal(expr, (1, 2))) + # Diagonal requires at least to axes to compute the diagonal: + raises(ValueError, lambda: _array_diagonal(expr, (1,))) + + +def test_arrayexpr_permutedims_sink(): + + cg = _permute_dims(_array_tensor_product(M, N), [0, 1, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(M, _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [3, 2, 1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(N, [1, 0]), _permute_dims(M, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N), (1, 2)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N), [[0, 3]]), (1, 2)) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N, P), [[0, 5]]), (1, 2), (3, 4)) + + +def test_arrayexpr_push_indices_up_and_down(): + + indices = list(range(12)) + + contr_diag_indices = [(0, 6), (2, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (1, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (None, 0, None, 1, 2, 3, None, 4, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (1, 3, 4, 5, 7, 9, (0, 6), (2, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (6, 0, 7, 1, 2, 3, 6, 4, 7, 5, None, None) + + contr_diag_indices = [(1, 2), (7, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (0, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (0, None, None, 1, 2, 3, 4, None, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (0, 3, 4, 5, 6, 9, (1, 2), (7, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (0, 6, 6, 1, 2, 3, 4, 7, 7, 5, None, None) + + +def test_arrayexpr_split_multiple_contractions(): + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + A = MatrixSymbol("A", k, k) + B = MatrixSymbol("B", k, k) + C = MatrixSymbol("C", k, k) + X = MatrixSymbol("X", k, k) + + cg = _array_contraction(_array_tensor_product(A.T, a, b, b.T, (A*X*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(A.T, DiagMatrix(a), OneArray(1), b, b.T, (A*X*b).applyfunc(cos)), (1, 3), (2, 9), (6, 7, 10)) + assert cg.split_multiple_contractions().dummy_eq(expected) + + # Check no overlap of lines: + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8), (3, 7)) + assert cg.split_multiple_contractions() == cg + + cg = _array_contraction(_array_tensor_product(a, b, A), (0, 2, 4), (1, 3)) + assert cg.split_multiple_contractions() == cg + + +def test_arrayexpr_nested_permutations(): + + cg = _permute_dims(_permute_dims(M, (1, 0)), (1, 0)) + assert cg == M + + times = 3 + plist1 = [list(range(6)) for i in range(times)] + plist2 = [list(range(6)) for i in range(times)] + + for i in range(times): + random.shuffle(plist1[i]) + random.shuffle(plist2[i]) + + plist1.append([2, 5, 4, 1, 0, 3]) + plist2.append([3, 5, 0, 4, 1, 2]) + + plist1.append([2, 5, 4, 0, 3, 1]) + plist2.append([3, 0, 5, 1, 2, 4]) + + plist1.append([5, 4, 2, 0, 3, 1]) + plist2.append([4, 5, 0, 2, 3, 1]) + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + Pe = P.subs(k, 3).as_explicit() + cge = tensorproduct(Me, Ne, Pe) + + for permutation_array1, permutation_array2 in zip(plist1, plist2): + p1 = Permutation(permutation_array1) + p2 = Permutation(permutation_array2) + + cg = _permute_dims( + _permute_dims( + _array_tensor_product(M, N, P), + p1), + p2 + ) + result = _permute_dims( + _array_tensor_product(M, N, P), + p2*p1 + ) + assert cg == result + + # Check that `permutedims` behaves the same way with explicit-component arrays: + result1 = _permute_dims(_permute_dims(cge, p1), p2) + result2 = _permute_dims(cge, p2*p1) + assert result1 == result2 + + +def test_arrayexpr_contraction_permutation_mix(): + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + + cg1 = _array_contraction(PermuteDims(_array_tensor_product(M, N), Permutation([0, 2, 1, 3])), (2, 3)) + cg2 = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert cg1 == cg2 + cge1 = tensorcontraction(permutedims(tensorproduct(Me, Ne), Permutation([0, 2, 1, 3])), (2, 3)) + cge2 = tensorcontraction(tensorproduct(Me, Ne), (1, 3)) + assert cge1 == cge2 + + cg1 = _permute_dims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + cg2 = _array_tensor_product(M, _permute_dims(N, Permutation([1, 0]))) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + cg2 = _array_contraction( + _array_tensor_product(M, N, P, _permute_dims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 2, 7, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(M, P, Q, N), + (0, 1), (2, 3), (4, 7)), + [1, 0] + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 7, 2, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(_permute_dims(M, [1, 0]), N, P, Q), + (0, 1), (3, 6), (4, 5) + ), + Permutation([1, 0]) + ) + assert cg1 == cg2 + + +def test_arrayexpr_permute_tensor_product(): + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 1, 0, 5, 4, 6, 7])) + cg2 = _array_tensor_product(N, _permute_dims(M, [1, 0]), + _permute_dims(P, [1, 0]), Q) + assert cg1 == cg2 + + # TODO: reverse operation starting with `PermuteDims` and getting down to `bb`... + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 5, 0, 1, 6, 7])) + cg2 = _array_tensor_product(N, P, M, Q) + assert cg1 == cg2 + + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 6, 5, 7, 0, 1])) + assert cg1.expr == _array_tensor_product(N, P, Q, M) + assert cg1.permutation == Permutation([0, 1, 2, 4, 3, 5, 6, 7]) + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]) + cg2 = _permute_dims(_array_contraction(_array_tensor_product(Q, Q, N, M), (3, 5, 6)), [0, 2, 3, 1, 4]) + assert cg1 == cg2 + + cg1 = _array_contraction( + _array_contraction( + _array_contraction( + _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]), + [1, 3, 4]), + [1]), + [0]) + cg2 = _array_contraction(_array_tensor_product(M, N, Q, Q), (0, 3, 5), (1, 4, 7), (2,), (6,)) + assert cg1 == cg2 + + +def test_arrayexpr_canonicalize_diagonal__permute_dims(): + tp = _array_tensor_product(M, Q, N, P) + expr = _array_diagonal( + _permute_dims(tp, [0, 1, 2, 4, 7, 6, 3, 5]), (2, 4, 5), (6, 7), + (0, 3)) + result = _array_diagonal(tp, (2, 6, 7), (3, 5), (0, 4)) + assert expr == result + + tp = _array_tensor_product(M, N, P, Q) + expr = _array_diagonal(_permute_dims(tp, [0, 5, 2, 4, 1, 6, 3, 7]), (1, 2, 6), (3, 4)) + result = _array_diagonal(_array_tensor_product(M, P, N, Q), (3, 4, 5), (1, 2)) + assert expr == result + + +def test_arrayexpr_canonicalize_diagonal_contraction(): + tp = _array_tensor_product(M, N, P, Q) + expr = _array_contraction(_array_diagonal(tp, (1, 3, 4)), (0, 3)) + result = _array_diagonal(_array_contraction(_array_tensor_product(M, N, P, Q), (0, 6)), (0, 2, 3)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 1, 2, 3, 7)), (1, 2, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 2, 3, 5, 6, 7)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 2, 6, 7)), (1, 2, 3)) + result = _array_diagonal(_array_contraction(tp, (3, 4, 5)), (0, 2, 3, 4)) + assert expr == result + + td = _array_diagonal(_array_tensor_product(M, N, P, Q), (0, 3)) + expr = _array_contraction(td, (2, 1), (0, 4, 6, 5, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 3, 5, 6, 7), (2, 4)) + assert expr == result + + +def test_arrayexpr_array_wrong_permutation_size(): + cg = _array_tensor_product(M, N) + raises(ValueError, lambda: _permute_dims(cg, [1, 0])) + raises(ValueError, lambda: _permute_dims(cg, [1, 0, 2, 3, 5, 4])) + + +def test_arrayexpr_nested_array_elementwise_add(): + cg = _array_contraction(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + cg = _array_diagonal(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_diagonal(_array_tensor_product(M, N), (1, 2)), + _array_diagonal(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + +def test_arrayexpr_array_expr_zero_array(): + za1 = ZeroArray(k, l, m, n) + zm1 = ZeroMatrix(m, n) + + za2 = ZeroArray(k, m, m, n) + zm2 = ZeroMatrix(m, m) + zm3 = ZeroMatrix(k, k) + + assert _array_tensor_product(M, N, za1) == ZeroArray(k, k, k, k, k, l, m, n) + assert _array_tensor_product(M, N, zm1) == ZeroArray(k, k, k, k, m, n) + + assert _array_contraction(za1, (3,)) == ZeroArray(k, l, m) + assert _array_contraction(zm1, (1,)) == ZeroArray(m) + assert _array_contraction(za2, (1, 2)) == ZeroArray(k, n) + assert _array_contraction(zm2, (0, 1)) == 0 + + assert _array_diagonal(za2, (1, 2)) == ZeroArray(k, n, m) + assert _array_diagonal(zm2, (0, 1)) == ZeroArray(m) + + assert _permute_dims(za1, [2, 1, 3, 0]) == ZeroArray(m, l, n, k) + assert _permute_dims(zm1, [1, 0]) == ZeroArray(n, m) + + assert _array_add(za1) == za1 + assert _array_add(zm1) == ZeroArray(m, n) + tp1 = _array_tensor_product(MatrixSymbol("A", k, l), MatrixSymbol("B", m, n)) + assert _array_add(tp1, za1) == tp1 + tp2 = _array_tensor_product(MatrixSymbol("C", k, l), MatrixSymbol("D", m, n)) + assert _array_add(tp1, za1, tp2) == _array_add(tp1, tp2) + assert _array_add(M, zm3) == M + assert _array_add(M, N, zm3) == _array_add(M, N) + + +def test_arrayexpr_array_expr_applyfunc(): + + A = ArraySymbol("A", (3, k, 2)) + aaf = ArrayElementwiseApplyFunc(sin, A) + assert aaf.shape == (3, k, 2) + + +def test_edit_array_contraction(): + cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5)) + ecg = _EditArrayContraction(cg) + assert ecg.to_array_contraction() == cg + + ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4)) + + ci = ecg.get_new_contraction_index() + new_arg = _ArgE(X) + new_arg.indices = [ci, ci] + ecg.args_with_ind.insert(2, new_arg) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5)) + + assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]] + assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]] + assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]] + assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]] + raises(ValueError, lambda: ecg.get_mapping_for_index(2)) + + ecg.args_with_ind.pop(1) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3)) + + ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0] + ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4)) + + ecg.insert_after(ecg.args_with_ind[1], _ArgE(C)) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6)) + + +def test_array_expressions_no_canonicalization(): + + tp = _array_tensor_product(M, N, P) + + # ArrayTensorProduct: + + expr = ArrayTensorProduct(tp, N) + assert str(expr) == "ArrayTensorProduct(ArrayTensorProduct(M, N, P), N)" + assert expr.doit() == ArrayTensorProduct(M, N, P, N) + + expr = ArrayTensorProduct(ArrayContraction(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayContraction(M, (0, 1)), N)" + assert expr.doit() == ArrayContraction(ArrayTensorProduct(M, N), (0, 1)) + + expr = ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N)" + assert expr.doit() == PermuteDims(ArrayDiagonal(ArrayTensorProduct(M, N), (0, 1)), [2, 0, 1]) + + expr = ArrayTensorProduct(PermuteDims(M, [1, 0]), N) + assert str(expr) == "ArrayTensorProduct(PermuteDims(M, (0 1)), N)" + assert expr.doit() == PermuteDims(ArrayTensorProduct(M, N), [1, 0, 2, 3]) + + # ArrayContraction: + + expr = ArrayContraction(_array_contraction(tp, (0, 2)), (0, 1)) + assert isinstance(expr, ArrayContraction) + assert isinstance(expr.expr, ArrayContraction) + assert str(expr) == "ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayContraction(tp, (0, 2), (1, 3)) + + expr = ArrayContraction(ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayContraction(tp, (0, 1), (2, 3), (4, 5)) + # assert expr._canonicalize() == ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1), (2, 3)) + + expr = ArrayContraction(ArrayDiagonal(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayContraction(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(N, M, P), (0, 1)), (0, 1)) + + expr = ArrayContraction(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayContraction(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayContraction(M, (0, 1)) + + # ArrayDiagonal: + + expr = ArrayDiagonal(ArrayDiagonal(tp, (0, 2)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayDiagonal(tp, (0, 2), (1, 3)) + + expr = ArrayDiagonal(ArrayDiagonal(ArrayDiagonal(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayDiagonal(tp, (0, 1), (2, 3), (4, 5)) + assert expr._canonicalize() == expr.doit() + + expr = ArrayDiagonal(ArrayContraction(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == expr + + expr = ArrayDiagonal(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayDiagonal(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(M, (0, 1)) + + # ArrayAdd: + + expr = ArrayAdd(M) + assert isinstance(expr, ArrayAdd) + assert expr.doit() == M + + expr = ArrayAdd(ArrayAdd(M, N), P) + assert str(expr) == "ArrayAdd(ArrayAdd(M, N), P)" + assert expr.doit() == ArrayAdd(M, N, P) + + expr = ArrayAdd(M, ArrayAdd(N, ArrayAdd(P, M))) + assert expr.doit() == ArrayAdd(M, N, P, M) + assert expr._canonicalize() == ArrayAdd(M, N, ArrayAdd(P, M)) + + expr = ArrayAdd(M, ZeroArray(k, k), N) + assert str(expr) == "ArrayAdd(M, ZeroArray(k, k), N)" + assert expr.doit() == ArrayAdd(M, N) + + # PermuteDims: + + expr = PermuteDims(PermuteDims(M, [1, 0]), [1, 0]) + assert str(expr) == "PermuteDims(PermuteDims(M, (0 1)), (0 1))" + assert expr.doit() == M + + expr = PermuteDims(PermuteDims(PermuteDims(M, [1, 0]), [1, 0]), [1, 0]) + assert expr.doit() == PermuteDims(M, [1, 0]) + assert expr._canonicalize() == expr.doit() + + # Reshape + + expr = Reshape(A, (k**2,)) + assert expr.shape == (k**2,) + assert isinstance(expr, Reshape) + + +def test_array_expr_construction_with_functions(): + + tp = tensorproduct(M, N) + assert tp == ArrayTensorProduct(M, N) + + expr = tensorproduct(A, eye(2)) + assert expr == ArrayTensorProduct(A, eye(2)) + + # Contraction: + + expr = tensorcontraction(M, (0, 1)) + assert expr == ArrayContraction(M, (0, 1)) + + expr = tensorcontraction(tp, (1, 2)) + assert expr == ArrayContraction(tp, (1, 2)) + + expr = tensorcontraction(tensorcontraction(tp, (1, 2)), (0, 1)) + assert expr == ArrayContraction(tp, (0, 3), (1, 2)) + + # Diagonalization: + + expr = tensordiagonal(M, (0, 1)) + assert expr == ArrayDiagonal(M, (0, 1)) + + expr = tensordiagonal(tensordiagonal(tp, (0, 1)), (0, 1)) + assert expr == ArrayDiagonal(tp, (0, 1), (2, 3)) + + # Permutation of dimensions: + + expr = permutedims(M, [1, 0]) + assert expr == PermuteDims(M, [1, 0]) + + expr = permutedims(PermuteDims(tp, [1, 0, 2, 3]), [0, 1, 3, 2]) + assert expr == PermuteDims(tp, [1, 0, 3, 2]) + + expr = PermuteDims(tp, index_order_new=["a", "b", "c", "d"], index_order_old=["d", "c", "b", "a"]) + assert expr == PermuteDims(tp, [3, 2, 1, 0]) + + arr = Array(range(32)).reshape(2, 2, 2, 2, 2) + expr = PermuteDims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + assert expr == PermuteDims(arr, [2, 0, 4, 3, 1]) + assert expr.as_explicit() == permutedims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + + +def test_array_element_expressions(): + # Check commutative property: + assert M[0, 0]*N[0, 0] == N[0, 0]*M[0, 0] + + # Check derivatives: + assert M[0, 0].diff(M[0, 0]) == 1 + assert M[0, 0].diff(M[1, 0]) == 0 + assert M[0, 0].diff(N[0, 0]) == 0 + assert M[0, 1].diff(M[i, j]) == KroneckerDelta(i, 0)*KroneckerDelta(j, 1) + assert M[0, 1].diff(N[i, j]) == 0 + + K4 = ArraySymbol("K4", shape=(k, k, k, k)) + + assert K4[i, j, k, l].diff(K4[1, 2, 3, 4]) == ( + KroneckerDelta(i, 1)*KroneckerDelta(j, 2)*KroneckerDelta(k, 3)*KroneckerDelta(l, 4) + ) + + +def test_array_expr_reshape(): + + A = MatrixSymbol("A", 2, 2) + B = ArraySymbol("B", (2, 2, 2)) + C = Array([1, 2, 3, 4]) + + expr = Reshape(A, (4,)) + assert expr.expr == A + assert expr.shape == (4,) + assert expr.as_explicit() == Array([A[0, 0], A[0, 1], A[1, 0], A[1, 1]]) + + expr = Reshape(B, (2, 4)) + assert expr.expr == B + assert expr.shape == (2, 4) + ee = expr.as_explicit() + assert isinstance(ee, ImmutableDenseNDimArray) + assert ee.shape == (2, 4) + assert ee == Array([[B[0, 0, 0], B[0, 0, 1], B[0, 1, 0], B[0, 1, 1]], [B[1, 0, 0], B[1, 0, 1], B[1, 1, 0], B[1, 1, 1]]]) + + expr = Reshape(A, (k, 2)) + assert expr.shape == (k, 2) + + raises(ValueError, lambda: Reshape(A, (2, 3))) + raises(ValueError, lambda: Reshape(A, (3,))) + + expr = Reshape(C, (2, 2)) + assert expr.expr == C + assert expr.shape == (2, 2) + assert expr.doit() == Array([[1, 2], [3, 4]]) + + +def test_array_expr_as_explicit_with_explicit_component_arrays(): + # Test if .as_explicit() works with explicit-component arrays + # nested in array expressions: + from sympy.abc import x, y, z, t + A = Array([[x, y], [z, t]]) + assert ArrayTensorProduct(A, A).as_explicit() == tensorproduct(A, A) + assert ArrayDiagonal(A, (0, 1)).as_explicit() == tensordiagonal(A, (0, 1)) + assert ArrayContraction(A, (0, 1)).as_explicit() == tensorcontraction(A, (0, 1)) + assert ArrayAdd(A, A).as_explicit() == A + A + assert ArrayElementwiseApplyFunc(sin, A).as_explicit() == A.applyfunc(sin) + assert PermuteDims(A, [1, 0]).as_explicit() == permutedims(A, [1, 0]) + assert Reshape(A, [4]).as_explicit() == A.reshape(4) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0fcf63f2607b23feb38758e4f0994de4f0384b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py @@ -0,0 +1,78 @@ +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayTensorProduct, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, ArrayContraction, _permute_dims, Reshape +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive + +k = symbols("k") + +I = Identity(k) +X = MatrixSymbol("X", k, k) +x = MatrixSymbol("x", k, 1) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +A1 = ArraySymbol("A", (3, 2, k)) + + +def test_arrayexpr_derivatives1(): + + res = array_derive(X, X) + assert res == PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]) + + cg = ArrayTensorProduct(A, X, B) + res = array_derive(cg, X) + assert res == _permute_dims( + ArrayTensorProduct(I, A, I, B), + [0, 4, 2, 3, 1, 5, 6, 7]) + + cg = ArrayContraction(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayContraction(ArrayTensorProduct(I, I), (1, 3)) + + cg = ArrayDiagonal(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayDiagonal(ArrayTensorProduct(I, I), (1, 3)) + + cg = ElementwiseApplyFunction(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + ElementwiseApplyFunction(cos, X), + I, + I + ), (0, 3), (1, 5))) + + cg = ArrayElementwiseApplyFunc(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + I, + I, + ArrayElementwiseApplyFunc(cos, X) + ), (1, 4), (3, 5))) + + res = array_derive(A1, A1) + assert res == PermuteDims( + ArrayTensorProduct(Identity(3), Identity(2), Identity(k)), + [0, 2, 4, 1, 3, 5] + ) + + cg = ArrayElementwiseApplyFunc(sin, A1) + res = array_derive(cg, A1) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + Identity(3), Identity(2), Identity(k), + ArrayElementwiseApplyFunc(cos, A1) + ), (1, 6), (3, 7), (5, 8) + )) + + cg = Reshape(A, (k**2,)) + res = array_derive(cg, A) + assert res == Reshape(PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]), (k, k, k**2)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py new file mode 100644 index 0000000000000000000000000000000000000000..30cc61b1ee651ca032e165cd67926fa33c71354f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py @@ -0,0 +1,63 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, \ + ArrayTensorProduct, PermuteDims, ArrayDiagonal, ArrayContraction, ArrayAdd +from sympy.testing.pytest import raises + + +def test_array_as_explicit_call(): + + assert ZeroArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray.zeros(3, 2, 4) + assert OneArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray([1 for i in range(3*2*4)]).reshape(3, 2, 4) + + k = Symbol("k") + X = ArraySymbol("X", (k, 3, 2)) + raises(ValueError, lambda: X.as_explicit()) + raises(ValueError, lambda: ZeroArray(k, 2, 3).as_explicit()) + raises(ValueError, lambda: OneArray(2, k, 2).as_explicit()) + + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + +def test_array_as_explicit_matrix_symbol(): + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + expr = ArrayAdd(ArrayTensorProduct(A, B), ArrayTensorProduct(B, A)) + assert expr.as_explicit() == expr.args[0].as_explicit() + expr.args[1].as_explicit() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b713fbec94ab7808c5a8a778b3313402d9d0c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py @@ -0,0 +1,61 @@ +from sympy import Sum, Dummy, sin +from sympy.tensor.array.expressions import ArraySymbol, ArrayTensorProduct, ArrayContraction, PermuteDims, \ + ArrayDiagonal, ArrayAdd, OneArray, ZeroArray, convert_indexed_to_array, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed + +from sympy.abc import i, j, k, l, m, n, o + + +def test_convert_array_to_indexed_main(): + A = ArraySymbol("A", (3, 3, 3)) + B = ArraySymbol("B", (3, 3)) + C = ArraySymbol("C", (3, 3)) + + d_ = Dummy("d_") + + assert convert_array_to_indexed(A, [i, j, k]) == A[i, j, k] + + expr = ArrayTensorProduct(A, B, C) + conv = convert_array_to_indexed(expr, [i,j,k,l,m,n,o]) + assert conv == A[i,j,k]*B[l,m]*C[n,o] + assert convert_indexed_to_array(conv, [i,j,k,l,m,n,o]) == expr + + expr = ArrayContraction(A, (0, 2)) + assert convert_array_to_indexed(expr, [i]).dummy_eq(Sum(A[d_, i, d_], (d_, 0, 2))) + + expr = ArrayDiagonal(A, (0, 2)) + assert convert_array_to_indexed(expr, [i, j]) == A[j, i, j] + + expr = PermuteDims(A, [1, 2, 0]) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == A[k, i, j] + assert convert_indexed_to_array(conv, [i, j, k]) == expr + + expr = ArrayAdd(B, C, PermuteDims(C, [1, 0])) + conv = convert_array_to_indexed(expr, [i, j]) + assert conv == B[i, j] + C[i, j] + C[j, i] + assert convert_indexed_to_array(conv, [i, j]) == expr + + expr = ArrayElementwiseApplyFunc(sin, A) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == sin(A[i, j, k]) + assert convert_indexed_to_array(conv, [i, j, k]).dummy_eq(expr) + + assert convert_array_to_indexed(OneArray(3, 3), [i, j]) == 1 + assert convert_array_to_indexed(ZeroArray(3, 3), [i, j]) == 0 + + expr = Reshape(A, (27,)) + assert convert_array_to_indexed(expr, [i]) == A[i // 9, i // 3 % 3, i % 3] + + X = ArraySymbol("X", (2, 3, 4, 5, 6)) + expr = Reshape(X, (2*3*4*5*6,)) + assert convert_array_to_indexed(expr, [i]) == X[i // 360, i // 120 % 3, i // 30 % 4, i // 6 % 5, i % 6] + + expr = Reshape(X, (4, 9, 2, 2, 5)) + one_index = 180*i + 20*j + 10*k + 5*l + m + expected = X[one_index // (3*4*5*6), one_index // (4*5*6) % 3, one_index // (5*6) % 4, one_index // 6 % 5, one_index % 6] + assert convert_array_to_indexed(expr, [i, j, k, l, m]) == expected + + X = ArraySymbol("X", (2*3*5,)) + expr = Reshape(X, (2, 3, 5)) + assert convert_array_to_indexed(expr, [i, j, k]) == X[15*i + 5*j + k] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..26839d5e7cec0554948c6b726482f9d8ca250b1c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py @@ -0,0 +1,689 @@ +from sympy import Lambda, S, Dummy, KroneckerProduct +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.expressions.hadamard import HadamardProduct, HadamardPower +from sympy.matrices.expressions.special import (Identity, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.tensor.array.expressions.from_array_to_matrix import _support_function_tp1_recognize, \ + _array_diag2contr_diagmatrix, convert_array_to_matrix, _remove_trivial_dims, _array2matrix, \ + _combine_removed, identify_removable_identity_matrices, _array_contraction_to_diagonal_multiple_identity +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.diagonal import DiagMatrix, DiagonalMatrix +from sympy.matrices import Trace, MatMul, Transpose +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, \ + ArrayElement, ArraySymbol, ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _permute_dims, PermuteDims, ArrayAdd, ArrayDiagonal, ArrayContraction, ArrayTensorProduct +from sympy.testing.pytest import raises + + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) +I1 = Identity(1) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + +x = MatrixSymbol("x", k, 1) +y = MatrixSymbol("y", k, 1) + + +def test_arrayexpr_convert_array_to_matrix(): + + cg = _array_contraction(_array_tensor_product(M), (0, 1)) + assert convert_array_to_matrix(cg) == Trace(M) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1), (2, 3)) + assert convert_array_to_matrix(cg) == Trace(M) * Trace(N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(M * N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 2), (1, 3)) + assert convert_array_to_matrix(cg) == Trace(M * N.T) + + cg = convert_matrix_to_array(M * N * P) + assert convert_array_to_matrix(cg) == M * N * P + + cg = convert_matrix_to_array(M * N.T * P) + assert convert_array_to_matrix(cg) == M * N.T * P + + cg = _array_contraction(_array_tensor_product(M,N,P,Q), (1, 2), (5, 6)) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * N, P * Q) + + cg = _array_contraction(_array_tensor_product(-2, M, N), (1, 2)) + assert convert_array_to_matrix(cg) == -2 * M * N + + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + c = MatrixSymbol("c", k, 1) + cg = PermuteDims( + _array_contraction( + _array_tensor_product( + a, + ArrayAdd( + _array_tensor_product(b, c), + _array_tensor_product(c, b), + ) + ), (2, 4)), [0, 1, 3, 2]) + assert convert_array_to_matrix(cg) == a * (b.T * c + c.T * b) + + za = ZeroArray(m, n) + assert convert_array_to_matrix(za) == ZeroMatrix(m, n) + + cg = _array_tensor_product(3, M) + assert convert_array_to_matrix(cg) == 3 * M + + # Partial conversion to matrix multiplication: + expr = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 4, 6)) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(M.T*N, P, Q), (0, 2, 4)) + + x = MatrixSymbol("x", k, 1) + cg = PermuteDims( + _array_contraction(_array_tensor_product(OneArray(1), x, OneArray(1), DiagMatrix(Identity(1))), + (0, 5)), Permutation(1, 2, 3)) + assert convert_array_to_matrix(cg) == x + + expr = ArrayAdd(M, PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(expr) == M + Transpose(M) + + +def test_arrayexpr_convert_array_to_matrix2(): + cg = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert convert_array_to_matrix(cg) == M * N.T + + cg = PermuteDims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(M, PermuteDims(N, Permutation([1, 0]))) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_contraction( + PermuteDims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_contraction( + _array_tensor_product(M, N, P, PermuteDims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_tensor_product(M, PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(PermuteDims(M, [1, 0]), PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M.T, N.T) + + cg = _array_tensor_product(PermuteDims(N, [1, 0]), PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(N.T, M.T) + + cg = _array_contraction(M, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*M*OneMatrix(k, 1) + + cg = _array_contraction(x, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*x + + Xm = MatrixSymbol("Xm", m, n) + cg = _array_contraction(Xm, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, m)*Xm*OneMatrix(n, 1) + + +def test_arrayexpr_convert_array_to_diagonalized_vector(): + + # Check matrix recognition over trivial dimensions: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(I1, a, b) + assert convert_array_to_matrix(cg) == a * b.T + + # Recognize trace inside a tensor product: + + cg = _array_contraction(_array_tensor_product(A, B, C), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(A * B) * C + + # Transform diagonal operator to contraction: + + cg = _array_diagonal(_array_tensor_product(A, a), (1, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (1, 3)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(a, b), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _permute_dims( + _array_contraction(_array_tensor_product(DiagMatrix(a), OneArray(1), b), (0, 3)), [1, 2, 0] + ) + assert convert_array_to_matrix(cg) == b.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(A, a), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (0, 3)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(I, x, I1), (0, 2), (3, 5)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(I, OneArray(1), I1, DiagMatrix(x)), (0, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + cg = _array_diagonal(_array_tensor_product(I, x, A, B), (1, 2), (5, 6)) + assert _array_diag2contr_diagmatrix(cg) == _array_diagonal(_array_contraction(_array_tensor_product(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6)) + # TODO: this is returning a wrong result: + # convert_array_to_matrix(cg) + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3, 5)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), a, b, I1), (2, 6)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(x, I1), (1, 2)) + assert isinstance(cg, ArrayDiagonal) + assert cg.diagonal_indices == ((1, 2),) + assert convert_array_to_matrix(cg) == x + + cg = _array_diagonal(_array_tensor_product(x, I), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), I, DiagMatrix(x)), (1, 3)) + assert convert_array_to_matrix(cg).doit() == DiagMatrix(x) + + raises(ValueError, lambda: _array_diagonal(x, (1,))) + + # Ignore identity matrices with contractions: + + cg = _array_contraction(_array_tensor_product(I, A, I, I), (0, 2), (1, 3), (5, 7)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == Trace(A) * I + + cg = _array_contraction(_array_tensor_product(Trace(A) * I, I, I), (1, 5), (3, 4)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg).doit() == Trace(A) * I + + # Add DiagMatrix when required: + + cg = _array_contraction(_array_tensor_product(A, a), (1, 2)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == A * a + + cg = _array_contraction(_array_tensor_product(A, a, B), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, B), (0, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, b, a.T, B), (0, 2, 4, 7, 9)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), + DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B), + (0, 2), (3, 5), (6, 9), (8, 12)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T + + cg = _array_contraction(_array_tensor_product(I1, I1, I1), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(I1, I1, OneArray(1), I1), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == 1 + + cg = _array_contraction(_array_tensor_product(I, I, I, I, A), (1, 2, 8), (5, 6, 9)) + assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8)) + expected = _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10)) + assert cg.split_multiple_contractions() == expected + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)), + (1, 3), (2, 10), (6, 8), (7, 11)) + assert cg.split_multiple_contractions().dummy_eq(expected) + assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T)) + + +def test_arrayexpr_convert_array_contraction_tp_additions(): + a = ArrayAdd( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ) + tp = _array_tensor_product(P, a, Q) + expr = _array_contraction(tp, (3, 4)) + expected = _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ) + assert expr == expected + assert convert_array_to_matrix(expr) == _array_tensor_product(P, M * N + N * M, Q) + + expr = _array_contraction(tp, (1, 2), (3, 4), (5, 6)) + result = _array_contraction( + _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ), (1, 2), (3, 4)) + assert expr == result + assert convert_array_to_matrix(expr) == P * (M * N + N * M) * Q + + +def test_arrayexpr_convert_array_to_implicit_matmul(): + # Trivial dimensions are suppressed, so the result can be expressed in matrix form: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(a, b, I) + assert convert_array_to_matrix(cg) == _array_tensor_product(a*b.T, I) + + cg = _array_tensor_product(I, a, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(I, a*b.T) + + cg = _array_tensor_product(a, I, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(a, I, b) + + cg = _array_contraction(_array_tensor_product(I, I), (1, 2)) + assert convert_array_to_matrix(cg) == I + + cg = PermuteDims(_array_tensor_product(I, Identity(1)), [0, 2, 1, 3]) + assert convert_array_to_matrix(cg) == I + + +def test_arrayexpr_convert_array_to_matrix_remove_trivial_dims(): + + # Tensor Product: + assert _remove_trivial_dims(_array_tensor_product(a, b)) == (a * b.T, [1, 3]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b)) == (a * b.T, [0, 3]) + assert _remove_trivial_dims(_array_tensor_product(a, b.T)) == (a * b.T, [1, 2]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T)) == (a * b.T, [0, 2]) + + assert _remove_trivial_dims(_array_tensor_product(I, a.T, b.T)) == (_array_tensor_product(I, a * b.T), [2, 4]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T)) == (_array_tensor_product(a.T, I, b.T), []) + + assert _remove_trivial_dims(_array_tensor_product(a, I)) == (_array_tensor_product(a, I), []) + assert _remove_trivial_dims(_array_tensor_product(I, a)) == (_array_tensor_product(I, a), []) + + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T, c, d)) == ( + _array_tensor_product(a * b.T, c * d.T), [0, 2, 5, 7]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T, c, d, I)) == ( + _array_tensor_product(a.T, I, b*c.T, d, I), [4, 7]) + + # Addition: + + cg = ArrayAdd(_array_tensor_product(a, b), _array_tensor_product(c, d)) + assert _remove_trivial_dims(cg) == (a * b.T + c * d.T, [1, 3]) + + # Permute Dims: + + cg = PermuteDims(_array_tensor_product(a, b), Permutation(3)(1, 2)) + assert _remove_trivial_dims(cg) == (a * b.T, [2, 3]) + + cg = PermuteDims(_array_tensor_product(a, I, b), Permutation(5)(1, 2, 3, 4)) + assert _remove_trivial_dims(cg) == (cg, []) + + cg = PermuteDims(_array_tensor_product(I, b, a), Permutation(5)(1, 2, 4, 5, 3)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_tensor_product(I, b * a.T), [0, 2, 3, 1]), [4, 5]) + + # Diagonal: + + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # Contraction: + + cg = _array_contraction(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # A few more cases to test the removal and shift of nested removed axes + # with array contractions and array diagonals: + tp = _array_tensor_product( + OneMatrix(1, 1), + M, + x, + OneMatrix(1, 1), + Identity(1), + ) + + expr = _array_contraction(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7] + + expr = _array_contraction(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5] + + expr = _array_diagonal(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7, 8] + + expr = _array_diagonal(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5, 6] + + expr = _array_diagonal(_array_contraction(_array_tensor_product(A, x, I, I1), (1, 2, 5)), (1, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [2, 3] + + cg = _array_diagonal(_array_tensor_product(PermuteDims(_array_tensor_product(x, I1), Permutation(1, 2, 3)), (x.T*x).applyfunc(sqrt)), (2, 4), (3, 5)) + rexpr, removed = _remove_trivial_dims(cg) + assert removed == [1, 2] + + # Contractions with identity matrices need to be followed by a permutation + # in order + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_tensor_product(A, B, C, M), [0, 2, 3, 4, 5, 6, 7, 1]) + assert removed == [] + + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8), (3, 4)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_contraction(_array_tensor_product(A, B, C, M), (3, 4)), [0, 2, 3, 4, 5, 1]) + assert removed == [] + + # Trivial matrices are sometimes inserted into MatMul expressions: + + cg = _array_tensor_product(b*b.T, a.T*a) + ret, removed = _remove_trivial_dims(cg) + assert ret == b*a.T*a*b.T + assert removed == [2, 3] + + Xs = ArraySymbol("X", (3, 2, k)) + cg = _array_tensor_product(M, Xs, b.T*c, a*a.T, b*b.T, c.T*d) + ret, removed = _remove_trivial_dims(cg) + assert ret == _array_tensor_product(M, Xs, a*b.T*c*c.T*d*a.T, b*b.T) + assert removed == [5, 6, 11, 12] + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_diagonal(_array_tensor_product(I, x), (1, 2)), Permutation(1, 2)), [1]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2)) + assert _remove_trivial_dims(expr) == (PermuteDims(_array_tensor_product(DiagMatrix(x), y), [1, 2, 3, 0]), [0]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2), (3, 4)) + assert _remove_trivial_dims(expr) == (expr, []) + + +def test_arrayexpr_convert_array_to_matrix_diag2contraction_diagmatrix(): + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(M, OneArray(1), DiagMatrix(a)), (1, 3)) + + raises(ValueError, lambda: _array_diagonal(_array_tensor_product(a, M), (1, 2))) + + cg = _array_diagonal(_array_tensor_product(a.T, M), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(OneArray(1), M, DiagMatrix(a.T)), (1, 4)) + + cg = _array_diagonal(_array_tensor_product(a.T, M, N, b.T), (1, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a.T), DiagMatrix(b.T)), (1, 7), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (1, 6), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 4), (3, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (3, 6), (2, 9)) + + I1 = Identity(1) + x = MatrixSymbol("x", k, 1) + A = MatrixSymbol("A", k, k) + cg = _array_diagonal(_array_tensor_product(x, A.T, I1), (0, 2)) + assert _array_diag2contr_diagmatrix(cg).shape == cg.shape + assert _array2matrix(cg).shape == cg.shape + + +def test_arrayexpr_convert_array_to_matrix_support_function(): + + assert _support_function_tp1_recognize([], [2 * k]) == 2 * k + + assert _support_function_tp1_recognize([(1, 2)], [A, 2 * k, B, 3]) == 6 * k * A * B + + assert _support_function_tp1_recognize([(0, 3), (1, 2)], [A, B]) == Trace(A * B) + + assert _support_function_tp1_recognize([(1, 2)], [A, B]) == A * B + assert _support_function_tp1_recognize([(0, 2)], [A, B]) == A.T * B + assert _support_function_tp1_recognize([(1, 3)], [A, B]) == A * B.T + assert _support_function_tp1_recognize([(0, 3)], [A, B]) == A.T * B.T + + assert _support_function_tp1_recognize([(1, 2), (5, 6)], [A, B, C, D]) == _array_tensor_product(A * B, C * D) + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 3), (1, 4)], [A, B, C]) == B * A * C + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4), (7, 8)], + [X, Y, A, B, C, D]) == X * Y * A * B * C * D + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4)], + [X, Y, A, B, C, D]) == _array_tensor_product(X * Y * A * B, C * D) + + assert _support_function_tp1_recognize([(1, 7), (3, 8), (4, 11)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B.T, Y * C, A.T * D.T), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(0, 1), (3, 6), (5, 8)], [X, A, B, C, D]) == PermuteDims( + _array_tensor_product(Trace(X) * A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [A, A, B, C, D]) == A ** 2 * B * C * D + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [X, A, B, C, D]) == X * A * B * C * D + + assert _support_function_tp1_recognize([(1, 6), (3, 8), (5, 10)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B, Y * C, A * D), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + +def test_convert_array_to_hadamard_products(): + + expr = HadamardProduct(M, N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N.T)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + expr = P.T*HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + # ArrayDiagonal should be converted + cg = _array_diagonal(_array_tensor_product(M, N, Q), (1, 3), (0, 2, 4)) + ret = convert_array_to_matrix(cg) + expected = PermuteDims(_array_diagonal(_array_tensor_product(HadamardProduct(M.T, N.T), Q), (1, 2)), [1, 0, 2]) + assert expected == ret + + # Special case that should return the same expression: + cg = _array_diagonal(_array_tensor_product(HadamardProduct(M, N), Q), (0, 2)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + # Hadamard products with traces: + + expr = Trace(HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N.T)) + + expr = Trace(A*HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M, N)*A) + + expr = Trace(HadamardProduct(A, M)*N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N)*A) + + # These should not be converted into Hadamard products: + + cg = _array_diagonal(_array_tensor_product(M, N), (0, 1, 2, 3)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(A), (0, 1)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 2, 4), (1, 3, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, N, P) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 3, 4), (1, 2, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, P, N.T) + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + +def test_identify_removable_identity_matrices(): + + D = DiagonalMatrix(MatrixSymbol("D", k, k)) + + cg = _array_contraction(_array_tensor_product(A, B, I), (1, 2, 4, 5)) + expected = _array_contraction(_array_tensor_product(A, B), (1, 2)) + assert identify_removable_identity_matrices(cg) == expected + + cg = _array_contraction(_array_tensor_product(A, B, C, I), (1, 3, 5, 6, 7)) + expected = _array_contraction(_array_tensor_product(A, B, C), (1, 3, 5)) + assert identify_removable_identity_matrices(cg) == expected + + # Tests with diagonal matrices: + + cg = _array_contraction(_array_tensor_product(A, B, D), (1, 2, 4, 5)) + ret = identify_removable_identity_matrices(cg) + expected = _array_contraction(_array_tensor_product(A, B, D), (1, 4), (2, 5)) + assert ret == expected + + cg = _array_contraction(_array_tensor_product(A, B, D, M, N), (1, 2, 4, 5, 6, 8)) + ret = identify_removable_identity_matrices(cg) + assert ret == cg + + +def test_combine_removed(): + + assert _combine_removed(6, [0, 1, 2], [0, 1, 2]) == [0, 1, 2, 3, 4, 5] + assert _combine_removed(8, [2, 5], [1, 3, 4]) == [1, 2, 4, 5, 6] + assert _combine_removed(8, [7], []) == [7] + + +def test_array_contraction_to_diagonal_multiple_identities(): + + expr = _array_contraction(_array_tensor_product(A, B, I, C), (1, 2, 4), (5, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(A, B, C), (1, 2, 4)) + + expr = _array_contraction(_array_tensor_product(A, I, I), (1, 2, 4)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (A, [2]) + assert convert_array_to_matrix(expr) == A + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 4), (3, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 3, 4, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + +def test_convert_array_element_to_matrix(): + + expr = ArrayElement(M, (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M, i, j) + + expr = ArrayElement(_array_contraction(_array_tensor_product(M, N), (1, 3)), (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M*N.T, i, j) + + expr = ArrayElement(_array_tensor_product(M, N), (i, j, m, n)) + assert convert_array_to_matrix(expr) == expr + + +def test_convert_array_elementwise_function_to_matrix(): + + d = Dummy("d") + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x.T*y) + assert convert_array_to_matrix(expr) == sin(x.T*y) + + expr = ArrayElementwiseApplyFunc(Lambda(d, d**2), x.T*y) + assert convert_array_to_matrix(expr) == (x.T*y)**2 + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x) + assert convert_array_to_matrix(expr).dummy_eq(x.applyfunc(sin)) + + expr = ArrayElementwiseApplyFunc(Lambda(d, 1 / (2 * sqrt(d))), x) + assert convert_array_to_matrix(expr) == S.Half * HadamardPower(x, -S.Half) + + +def test_array2matrix(): + # See issue https://github.com/sympy/sympy/pull/22877 + expr = PermuteDims(ArrayContraction(ArrayTensorProduct(x, I, I1, x), (0, 3), (1, 7)), Permutation(2, 3)) + expected = PermuteDims(ArrayTensorProduct(x*x.T, I1), Permutation(3)(1, 2)) + assert _array2matrix(expr) == expected + + +def test_recognize_broadcasting(): + expr = ArrayTensorProduct(x.T*x, A) + assert _remove_trivial_dims(expr) == (KroneckerProduct(x.T*x, A), [0, 1]) + + expr = ArrayTensorProduct(A, x.T*x) + assert _remove_trivial_dims(expr) == (KroneckerProduct(A, x.T*x), [2, 3]) + + expr = ArrayTensorProduct(A, B, x.T*x, C) + assert _remove_trivial_dims(expr) == (ArrayTensorProduct(A, KroneckerProduct(B, x.T*x), C), [4, 5]) + + # Always prefer matrix multiplication to Kronecker product, if possible: + expr = ArrayTensorProduct(a, b, x.T*x) + assert _remove_trivial_dims(expr) == (a*x.T*x*b.T, [1, 3, 4, 5]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..258062eadeca041ae3c864dabeefd5165f1cef11 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py @@ -0,0 +1,205 @@ +from sympy import tanh +from sympy.concrete.summations import Sum +from sympy.core.symbol import symbols +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import IndexedBase +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ArrayContraction, ArrayTensorProduct, \ + ArrayDiagonal, ArrayAdd, PermuteDims, ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, \ + _array_add, _permute_dims, ArraySymbol, OneArray +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array, _convert_indexed_to_array +from sympy.testing.pytest import raises + + +A, B = symbols("A B", cls=IndexedBase) +i, j, k, l, m, n = symbols("i j k l m n") +d0, d1, d2, d3 = symbols("d0:4") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_index_to_array_support_function(): + expr = M[i, j] + assert _convert_indexed_to_array(expr) == (M, (i, j)) + expr = M[i, j]*N[k, l] + assert _convert_indexed_to_array(expr) == (ArrayTensorProduct(M, N), (i, j, k, l)) + expr = M[i, j]*N[j, k] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)), (i, k, j)) + expr = Sum(M[i, j]*N[j, k], (j, 0, k-1)) + assert _convert_indexed_to_array(expr) == (ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (i, k)) + expr = M[i, j] + N[i, j] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, N), (i, j)) + expr = M[i, j] + N[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(N, Permutation([1, 0]))), (i, j)) + expr = M[i, j] + M[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(M, Permutation([1, 0]))), (i, j)) + expr = (M*N*P)[i, j] + assert _convert_indexed_to_array(expr) == (_array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)), (i, j)) + expr = expr.function # Disregard summation in previous expression + ret1, ret2 = _convert_indexed_to_array(expr) + assert ret1 == ArrayDiagonal(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + assert str(ret2) == "(i, j, _i_1, _i_2)" + expr = KroneckerDelta(i, j)*M[i, k] + assert _convert_indexed_to_array(expr) == (M, ({i, j}, k)) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*M[i, l] + assert _convert_indexed_to_array(expr) == (M, ({i, j, k}, l)) + expr = KroneckerDelta(j, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, k}))) + expr = KroneckerDelta(j, m)*KroneckerDelta(m, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, m, k}))) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*KroneckerDelta(k,m)*M[i, 0]*KroneckerDelta(m, n) + assert _convert_indexed_to_array(expr) == (M, ({i, j, k, m, n}, 0)) + expr = M[i, i] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(M, (0, 1)), (i,)) + + +def test_arrayexpr_convert_indexed_to_array_expression(): + + s = Sum(A[i]*B[i], (i, 0, 3)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + elem = expr[i, j] + assert convert_indexed_to_array(elem) == result + + expr = M*N*M + elem = expr[i, j] + result = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + cg = convert_indexed_to_array(elem) + assert cg == result + + cg = convert_indexed_to_array((M * N * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + + cg = convert_indexed_to_array((M * N.T * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 3), (2, 4)) + + expr = -2*M*N + elem = expr[i, j] + cg = convert_indexed_to_array(elem) + assert cg == ArrayContraction(ArrayTensorProduct(-2, M, N), (1, 2)) + + +def test_arrayexpr_convert_array_element_to_array_expression(): + A = ArraySymbol("A", (k,)) + B = ArraySymbol("B", (k,)) + + s = Sum(A[i]*B[i], (i, 0, k-1)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[i] + cg = convert_indexed_to_array(s) + assert cg == ArrayDiagonal(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[j] + cg = convert_indexed_to_array(s, [i, j]) + assert cg == ArrayTensorProduct(A, B) + cg = convert_indexed_to_array(s, [j, i]) + assert cg == ArrayTensorProduct(B, A) + + s = tanh(A[i]*B[j]) + cg = convert_indexed_to_array(s, [i, j]) + assert cg.dummy_eq(ArrayElementwiseApplyFunc(tanh, ArrayTensorProduct(A, B))) + + +def test_arrayexpr_convert_indexed_to_array_and_back_to_matrix(): + + expr = a.T*b + elem = expr[0, 0] + cg = convert_indexed_to_array(elem) + assert cg == ArrayElement(ArrayContraction(ArrayTensorProduct(a, b), (0, 2)), [0, 0]) + + expr = M[i,j] + N[i,j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N + + expr = M[i,j] + N[j,i] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N.T + + expr = M[i,j]*N[k,l] + N[i,j]*M[k,l] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == ArrayAdd( + ArrayTensorProduct(M, N), + ArrayTensorProduct(N, M)) + + expr = (M*N*P)[i, j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum(M[i,j]*(N*P)[j,m], (j, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum((P[j, m] + P[m, j])*(M[i,j]*N[m,n] + N[i,j]*M[m,n]), (j, 0, k-1), (m, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * P * N + M * P.T * N + N * P * M + N * P.T * M + + +def test_arrayexpr_convert_indexed_to_array_out_of_bounds(): + + expr = Sum(M[i, i], (i, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + expr = Sum(M[i, j]*N[j,m], (j, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + +def test_arrayexpr_convert_indexed_to_array_broadcast(): + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + expr = A[i, j] + B[k, l] + O2 = OneArray(3, 3) + expected = ArrayAdd(ArrayTensorProduct(A, O2), ArrayTensorProduct(O2, B)) + assert convert_indexed_to_array(expr) == expected + assert convert_indexed_to_array(expr, [i, j, k, l]) == expected + assert convert_indexed_to_array(expr, [l, k, i, j]) == ArrayAdd(PermuteDims(ArrayTensorProduct(O2, A), [1, 0, 2, 3]), PermuteDims(ArrayTensorProduct(B, O2), [1, 0, 2, 3])) + + expr = A[i, j] + B[j, k] + O1 = OneArray(3) + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(A, O1), ArrayTensorProduct(O1, B)) + + C = ArraySymbol("C", (d0, d1)) + D = ArraySymbol("D", (d3, d1)) + + expr = C[i, j] + D[k, j] + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(C, OneArray(d3)), PermuteDims(ArrayTensorProduct(OneArray(d0), D), [0, 2, 1])) + + X = ArraySymbol("X", (5, 3)) + + expr = X[i, n] - X[j, n] + assert convert_indexed_to_array(expr, [i, j, n]) == ArrayAdd(ArrayTensorProduct(-1, OneArray(5), X), PermuteDims(ArrayTensorProduct(X, OneArray(5)), [0, 2, 1])) + + raises(ValueError, lambda: convert_indexed_to_array(C[i, j] + D[i, j])) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..142585882588df6aa0e4648d9d8881ea755f42a0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py @@ -0,0 +1,128 @@ +from sympy import Lambda, KroneckerProduct +from sympy.core.symbol import symbols, Dummy +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, _array_contraction, _array_tensor_product, Reshape +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_matrix_to_array(): + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + assert convert_matrix_to_array(expr) == result + + expr = M*N*M + result = _array_contraction(ArrayTensorProduct(M, N, M), (1, 2), (3, 4)) + assert convert_matrix_to_array(expr) == result + + expr = Transpose(M) + assert convert_matrix_to_array(expr) == PermuteDims(M, [1, 0]) + + expr = M*Transpose(N) + assert convert_matrix_to_array(expr) == _array_contraction(_array_tensor_product(M, PermuteDims(N, [1, 0])), (1, 2)) + + expr = 3*M*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = 3*M + N*M.T*M + 4*k*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Inverse(M)*N + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M**2 + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M*(2*N + 3*M) + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Trace(M) + result = ArrayContraction(M, (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M) + result = ArrayContraction(ArrayTensorProduct(3, M), (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(Trace(M) * M) + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M)**2 + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M, N) + result = ArrayDiagonal(ArrayTensorProduct(M, N), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M*N, N*M) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, N, M), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, 2) + result = ArrayDiagonal(ArrayTensorProduct(M, M), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M*N, 2) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, M, N), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, n) + d0 = Dummy("d0") + result = ArrayElementwiseApplyFunc(Lambda(d0, d0**n), M) + assert convert_matrix_to_array(expr).dummy_eq(result) + + expr = M**2 + assert isinstance(expr, MatPow) + assert convert_matrix_to_array(expr) == ArrayContraction(ArrayTensorProduct(M, M), (1, 2)) + + expr = a.T*b + cg = convert_matrix_to_array(expr) + assert cg == ArrayContraction(ArrayTensorProduct(a, b), (0, 2)) + + expr = KroneckerProduct(A, B) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B), [0, 2, 1, 3]), (k**2, k**2)) + + expr = KroneckerProduct(A, B, C, D) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B, C, D), [0, 2, 4, 6, 1, 3, 5, 7]), (k**4, k**4)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b41b6105410a308e7774fce760b235497d0303bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py @@ -0,0 +1,22 @@ +from sympy import MatrixSymbol, symbols, Sum +from sympy.tensor.array.expressions import conv_array_to_indexed, from_array_to_indexed, ArrayTensorProduct, \ + ArrayContraction, conv_array_to_matrix, from_array_to_matrix, conv_matrix_to_array, from_matrix_to_array, \ + conv_indexed_to_array, from_indexed_to_array +from sympy.testing.pytest import warns +from sympy.utilities.exceptions import SymPyDeprecationWarning + + +def test_deprecated_conv_module_results(): + + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + i, j, d = symbols("i j d") + + x = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + y = Sum(M[i, d]*N[d, j], (d, 0, 2)) + + with warns(SymPyDeprecationWarning, test_stacklevel=False): + assert conv_array_to_indexed.convert_array_to_indexed(x, [i, j]).dummy_eq(from_array_to_indexed.convert_array_to_indexed(x, [i, j])) + assert conv_array_to_matrix.convert_array_to_matrix(x) == from_array_to_matrix.convert_array_to_matrix(x) + assert conv_matrix_to_array.convert_matrix_to_array(M*N) == from_matrix_to_array.convert_matrix_to_array(M*N) + assert conv_indexed_to_array.convert_indexed_to_array(y) == from_indexed_to_array.convert_indexed_to_array(y) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e55c0e6ed47cdc9ff1c24cc92f006998aeb86822 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/expressions/utils.py @@ -0,0 +1,123 @@ +import bisect +from collections import defaultdict + +from sympy.combinatorics import Permutation +from sympy.core.containers import Tuple +from sympy.core.numbers import Integer + + +def _get_mapping_from_subranks(subranks): + mapping = {} + counter = 0 + for i, rank in enumerate(subranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + +def _get_contraction_links(args, subranks, *contraction_indices): + mapping = _get_mapping_from_subranks(subranks) + contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices] + dlinks = defaultdict(dict) + for links in contraction_tuples: + if len(links) == 2: + (arg1, pos1), (arg2, pos2) = links + dlinks[arg1][pos1] = (arg2, pos2) + dlinks[arg2][pos2] = (arg1, pos1) + continue + + return args, dict(dlinks) + + +def _sort_contraction_indices(pairing_indices): + pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices] + pairing_indices.sort(key=lambda x: min(x)) + return pairing_indices + + +def _get_diagonal_indices(flattened_indices): + axes_contraction = defaultdict(list) + for i, ind in enumerate(flattened_indices): + if isinstance(ind, (int, Integer)): + # If the indices is a number, there can be no diagonal operation: + continue + axes_contraction[ind].append(i) + axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1} + # Put the diagonalized indices at the end: + ret_indices = [i for i in flattened_indices if i not in axes_contraction] + diag_indices = list(axes_contraction) + diag_indices.sort(key=lambda x: flattened_indices.index(x)) + diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices] + ret_indices += diag_indices + ret_indices = tuple(ret_indices) + return diagonal_indices, ret_indices + + +def _get_argindex(subindices, ind): + for i, sind in enumerate(subindices): + if ind == sind: + return i + if isinstance(sind, (set, frozenset)) and ind in sind: + return i + raise IndexError("%s not found in %s" % (ind, subindices)) + + +def _apply_recursively_over_nested_lists(func, arr): + if isinstance(arr, (tuple, list, Tuple)): + return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr) + elif isinstance(arr, Tuple): + return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr) + else: + return func(arr) + + +def _build_push_indices_up_func_transformation(flattened_contraction_indices): + shifts = {0: 0} + i = 0 + cumulative = 0 + while i < len(flattened_contraction_indices): + j = 1 + while i+j < len(flattened_contraction_indices): + if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]: + break + j += 1 + cumulative += j + shifts[flattened_contraction_indices[i]] = cumulative + i += j + shift_keys = sorted(shifts.keys()) + + def func(idx): + return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]] + + def transform(j): + if j in flattened_contraction_indices: + return None + else: + return j - func(j) + + return transform + + +def _build_push_indices_down_func_transformation(flattened_contraction_indices): + N = flattened_contraction_indices[-1]+2 + + shifts = [i for i in range(N) if i not in flattened_contraction_indices] + + def transform(j): + if j < len(shifts): + return shifts[j] + else: + return j + shifts[-1] - len(shifts) + 1 + + return transform + + +def _apply_permutation_to_list(perm: Permutation, target_list: list): + """ + Permute a list according to the given permutation. + """ + new_list = [None for i in range(perm.size)] + for i, e in enumerate(target_list): + new_list[perm(i)] = e + return new_list diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/mutable_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eaaf7241bc3b4a48234178d18da3aa5736e189 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/mutable_ndim_array.py @@ -0,0 +1,13 @@ +from sympy.tensor.array.ndim_array import NDimArray + + +class MutableNDimArray(NDimArray): + + def as_immutable(self): + raise NotImplementedError("abstract method") + + def as_mutable(self): + return self + + def _sympy_(self): + return self.as_immutable() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9a857b8cfd9ee46646c46f274636d6b9962b6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/ndim_array.py @@ -0,0 +1,601 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import Expr +from sympy.core.kind import Kind, NumberKind, UndefinedKind +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.printing.defaults import Printable + +import itertools +from collections.abc import Iterable + + +class ArrayKind(Kind): + """ + Kind for N-dimensional array in SymPy. + + This kind represents the multidimensional array that algebraic + operations are defined. Basic class for this kind is ``NDimArray``, + but any expression representing the array can have this. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is :obj:NumberKind ``, + which means that the array contains only numbers. + + Examples + ======== + + Any instance of array class has ``ArrayKind``. + + >>> from sympy import NDimArray + >>> NDimArray([1,2,3]).kind + ArrayKind(NumberKind) + + Although expressions representing an array may be not instance of + array class, it will have ``ArrayKind`` as well. + + >>> from sympy import Integral + >>> from sympy.tensor.array import NDimArray + >>> from sympy.abc import x + >>> intA = Integral(NDimArray([1,2,3]), x) + >>> isinstance(intA, NDimArray) + False + >>> intA.kind + ArrayKind(NumberKind) + + Use ``isinstance()`` to check for ``ArrayKind` without specifying + the element kind. Use ``is`` with specifying the element kind. + + >>> from sympy.tensor.array import ArrayKind + >>> from sympy.core import NumberKind + >>> boolA = NDimArray([True, False]) + >>> isinstance(boolA.kind, ArrayKind) + True + >>> boolA.kind is ArrayKind(NumberKind) + False + + See Also + ======== + + shape : Function to return the shape of objects with ``MatrixKind``. + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "ArrayKind(%s)" % self.element_kind + + @classmethod + def _union(cls, kinds) -> 'ArrayKind': + elem_kinds = {e.kind for e in kinds} + if len(elem_kinds) == 1: + elemkind, = elem_kinds + else: + elemkind = UndefinedKind + return ArrayKind(elemkind) + + +class NDimArray(Printable): + """N-dimensional array. + + Examples + ======== + + Create an N-dim array of zeros: + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3, 4) + >>> a + [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + + Create an N-dim array from a list; + + >>> a = MutableDenseNDimArray([[2, 3], [4, 5]]) + >>> a + [[2, 3], [4, 5]] + + >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]) + >>> b + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + + Create an N-dim array from a flat list with dimension shape: + + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a + [[1, 2, 3], [4, 5, 6]] + + Create an N-dim array from a matrix: + + >>> from sympy import Matrix + >>> a = Matrix([[1,2],[3,4]]) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> b = MutableDenseNDimArray(a) + >>> b + [[1, 2], [3, 4]] + + Arithmetic operations on N-dim arrays + + >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2)) + >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2)) + >>> c = a + b + >>> c + [[5, 5], [5, 5]] + >>> a - b + [[-3, -3], [-3, -3]] + + """ + + _diff_wrt = True + is_scalar = False + + def __new__(cls, iterable, shape=None, **kwargs): + from sympy.tensor.array import ImmutableDenseNDimArray + return ImmutableDenseNDimArray(iterable, shape, **kwargs) + + def __getitem__(self, index): + raise NotImplementedError("A subclass of NDimArray should implement __getitem__") + + def _parse_index(self, index): + if isinstance(index, (SYMPY_INTS, Integer)): + if index >= self._loop_size: + raise ValueError("Only a tuple index is accepted") + return index + + if self._loop_size == 0: + raise ValueError("Index not valid with an empty array") + + if len(index) != self._rank: + raise ValueError('Wrong number of array axes') + + real_index = 0 + # check if input index can exist in current indexing + for i in range(self._rank): + if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]): + raise ValueError('Index ' + str(index) + ' out of border') + if index[i] < 0: + real_index += 1 + real_index = real_index*self.shape[i] + index[i] + + return real_index + + def _get_tuple_index(self, integer_index): + index = [] + for sh in reversed(self.shape): + index.append(integer_index % sh) + integer_index //= sh + index.reverse() + return tuple(index) + + def _check_symbolic_index(self, index): + # Check if any index is symbolic: + tuple_index = (index if isinstance(index, tuple) else (index,)) + if any((isinstance(i, Expr) and (not i.is_number)) for i in tuple_index): + for i, nth_dim in zip(tuple_index, self.shape): + if ((i < 0) == True) or ((i >= nth_dim) == True): + raise ValueError("index out of range") + from sympy.tensor import Indexed + return Indexed(self, *tuple_index) + return None + + def _setter_iterable_check(self, value): + from sympy.matrices.matrixbase import MatrixBase + if isinstance(value, (Iterable, MatrixBase, NDimArray)): + raise NotImplementedError + + @classmethod + def _scan_iterable_shape(cls, iterable): + def f(pointer): + if not isinstance(pointer, Iterable): + return [pointer], () + + if len(pointer) == 0: + return [], (0,) + + result = [] + elems, shapes = zip(*[f(i) for i in pointer]) + if len(set(shapes)) != 1: + raise ValueError("could not determine shape unambiguously") + for i in elems: + result.extend(i) + return result, (len(shapes),)+shapes[0] + + return f(iterable) + + @classmethod + def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + + if shape is None: + if iterable is None: + shape = () + iterable = () + # Construction of a sparse array from a sparse array + elif isinstance(iterable, SparseNDimArray): + return iterable._shape, iterable._sparse_array + + # Construct N-dim array from another N-dim array: + elif isinstance(iterable, NDimArray): + shape = iterable.shape + + # Construct N-dim array from an iterable (numpy arrays included): + elif isinstance(iterable, Iterable): + iterable, shape = cls._scan_iterable_shape(iterable) + + # Construct N-dim array from a Matrix: + elif isinstance(iterable, MatrixBase): + shape = iterable.shape + + else: + shape = () + iterable = (iterable,) + + if isinstance(iterable, (Dict, dict)) and shape is not None: + new_dict = iterable.copy() + for k in new_dict: + if isinstance(k, (tuple, Tuple)): + new_key = 0 + for i, idx in enumerate(k): + new_key = new_key * shape[i] + idx + iterable[new_key] = iterable[k] + del iterable[k] + + if isinstance(shape, (SYMPY_INTS, Integer)): + shape = (shape,) + + if not all(isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape): + raise TypeError("Shape should contain integers only.") + + return tuple(shape), iterable + + def __len__(self): + """Overload common function len(). Returns number of elements in array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + >>> len(a) + 9 + + """ + return self._loop_size + + @property + def shape(self): + """ + Returns array shape (dimension). + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a.shape + (3, 3) + + """ + return self._shape + + def rank(self): + """ + Returns rank of array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3) + >>> a.rank() + 5 + + """ + return self._rank + + def diff(self, *args, **kwargs): + """ + Calculate the derivative of each element in the array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> from sympy.abc import x, y + >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]]) + >>> M.diff(x) + [[1, 0], [0, y]] + + """ + from sympy.tensor.array.array_derivatives import ArrayDerivative + kwargs.setdefault('evaluate', True) + return ArrayDerivative(self.as_immutable(), *args, **kwargs) + + def _eval_derivative(self, base): + # Types are (base: scalar, self: array) + return self.applyfunc(lambda x: base.diff(x)) + + def _eval_derivative_n_times(self, s, n): + return Basic._eval_derivative_n_times(self, s, n) + + def applyfunc(self, f): + """Apply a function to each element of the N-dim array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2)) + >>> m + [[0, 1], [2, 3]] + >>> m.applyfunc(lambda i: 2*i) + [[0, 2], [4, 6]] + """ + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray) and f(S.Zero) == 0: + return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape) + + return type(self)(map(f, Flatten(self)), self.shape) + + def _sympystr(self, printer): + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]" + + sh //= shape_left[0] + return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left) + + if self.rank() == 0: + return printer._print(self[()]) + if 0 in self.shape: + return f"{self.__class__.__name__}([], {self.shape})" + return f(self._loop_size, self.shape, 0, self._loop_size) + + def tolist(self): + """ + Converting MutableDenseNDimArray to one-dim list + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + >>> a + [[1, 2], [3, 4]] + >>> b = a.tolist() + >>> b + [[1, 2], [3, 4]] + """ + + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return [self[self._get_tuple_index(e)] for e in range(i, j)] + result = [] + sh //= shape_left[0] + for e in range(shape_left[0]): + result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh)) + return result + + return f(self._loop_size, self.shape, 0, self._loop_size) + + def __add__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __sub__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __mul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i*other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rmul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [other*i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __truediv__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected") + + other = sympify(other) + if isinstance(self, SparseNDimArray) and other != S.Zero: + return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i/other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rtruediv__(self, other): + raise NotImplementedError('unsupported operation on NDimArray') + + def __neg__(self): + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray): + return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [-i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __iter__(self): + def iterator(): + if self._shape: + for i in range(self._shape[0]): + yield self[i] + else: + yield self[()] + + return iterator() + + def __eq__(self, other): + """ + NDimArray instances can be compared to each other. + Instances equal if they have same shape and data. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3) + >>> b = MutableDenseNDimArray.zeros(2, 3) + >>> a == b + True + >>> c = a.reshape(3, 2) + >>> c == b + False + >>> a[0,0] = 1 + >>> b[0,0] = 2 + >>> a == b + False + """ + from sympy.tensor.array import SparseNDimArray + if not isinstance(other, NDimArray): + return False + + if not self.shape == other.shape: + return False + + if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray): + return dict(self._sparse_array) == dict(other._sparse_array) + + return list(self) == list(other) + + def __ne__(self, other): + return not self == other + + def _eval_transpose(self): + if self.rank() != 2: + raise ValueError("array rank not 2") + from .arrayop import permutedims + return permutedims(self, (1, 0)) + + def transpose(self): + return self._eval_transpose() + + def _eval_conjugate(self): + from sympy.tensor.array.arrayop import Flatten + + return self.func([i.conjugate() for i in Flatten(self)], self.shape) + + def conjugate(self): + return self._eval_conjugate() + + def _eval_adjoint(self): + return self.transpose().conjugate() + + def adjoint(self): + return self._eval_adjoint() + + def _slice_expand(self, s, dim): + if not isinstance(s, slice): + return (s,) + start, stop, step = s.indices(dim) + return [start + i*step for i in range((stop-start)//step)] + + def _get_slice_data_for_array_access(self, index): + sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)] + eindices = itertools.product(*sl_factors) + return sl_factors, eindices + + def _get_slice_data_for_array_assignment(self, index, value): + if not isinstance(value, NDimArray): + value = type(self)(value) + sl_factors, eindices = self._get_slice_data_for_array_access(index) + slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors] + # TODO: add checks for dimensions for `value`? + return value, eindices, slice_offsets + + @classmethod + def _check_special_bounds(cls, flat_list, shape): + if shape == () and len(flat_list) != 1: + raise ValueError("arrays without shape need one scalar value") + if shape == (0,) and len(flat_list) > 0: + raise ValueError("if array shape is (0,) there cannot be elements") + + def _check_index_for_getitem(self, index): + if isinstance(index, (SYMPY_INTS, Integer, slice)): + index = (index,) + + if len(index) < self.rank(): + index = tuple(index) + \ + tuple(slice(None) for i in range(len(index), self.rank())) + + if len(index) > self.rank(): + raise ValueError('Dimension of index greater than rank of array') + + return index + + +class ImmutableNDimArray(NDimArray, Basic): + _op_priority = 11.0 + + def __hash__(self): + return Basic.__hash__(self) + + def as_immutable(self): + return self + + def as_mutable(self): + raise NotImplementedError("abstract method") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/sparse_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/sparse_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..f11aa95be8ec9d10a9104d48fb28f406fe43845e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/sparse_ndim_array.py @@ -0,0 +1,196 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray +from sympy.utilities.iterables import flatten + +import functools + +class SparseNDimArray(NDimArray): + + def __new__(self, *args, **kwargs): + return ImmutableSparseNDimArray(*args, **kwargs) + + def __getitem__(self, index): + """ + Get an element from a sparse N-dim array. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray(range(4), (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + Symbolic indexing: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(0, 0)`: + + >>> a[i, j].subs({i: 0, j: 0}) + 0 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + # `index` is a tuple with one or more slices: + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._sparse_array.get(self._parse_index(i), S.Zero) for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._sparse_array.get(index, S.Zero) + + @classmethod + def zeros(cls, *shape): + """ + Return a sparse N-dim array of zeros. + """ + return cls({}, shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + """ + from sympy.matrices import SparseMatrix + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + mat_sparse = {} + for key, value in self._sparse_array.items(): + mat_sparse[self._get_tuple_index(key)] = value + + return SparseMatrix(self.shape[0], self.shape[1], mat_sparse) + + def reshape(self, *newshape): + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError("Invalid reshape parameters " + newshape) + + return type(self)(self._sparse_array, newshape) + +class ImmutableSparseNDimArray(SparseNDimArray, ImmutableNDimArray): # type: ignore + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + sparse_array = Dict(flat_list) + else: + sparse_array = {} + for i, el in enumerate(flatten(flat_list)): + if el != 0: + sparse_array[i] = _sympify(el) + + sparse_array = Dict(sparse_array) + + self = Basic.__new__(cls, sparse_array, shape, **kwargs) + self._shape = shape + self._rank = len(shape) + self._loop_size = loop_size + self._sparse_array = sparse_array + + return self + + def __setitem__(self, index, value): + raise TypeError("immutable N-dim array") + + def as_mutable(self): + return MutableSparseNDimArray(self) + + +class MutableSparseNDimArray(MutableNDimArray, SparseNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + self = object.__new__(cls) + self._shape = shape + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + self._sparse_array = dict(flat_list) + return self + + self._sparse_array = {} + + for i, el in enumerate(flatten(flat_list)): + if el != 0: + self._sparse_array[i] = _sympify(el) + + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray.zeros(2, 2) + >>> a[0, 0] = 1 + >>> a[1, 1] = 1 + >>> a + [[1, 0], [0, 1]] + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + other_value = value[other_i] + complete_index = self._parse_index(i) + if other_value != 0: + self._sparse_array[complete_index] = other_value + elif complete_index in self._sparse_array: + self._sparse_array.pop(complete_index) + else: + index = self._parse_index(index) + value = _sympify(value) + if value == 0 and index in self._sparse_array: + self._sparse_array.pop(index) + else: + self._sparse_array[index] = value + + def as_immutable(self): + return ImmutableSparseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._sparse_array.values() for i in j.free_symbols} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb964e7ceae559ef04255f330f9fc3cd1bf54ea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0420cad49e50fa771ed82d49776cbfd22a74d9ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eee96290b40f368844cafa0231cedb5845d2225 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c83b2851059c0a9e6b1ea15bfda1d7659f4c76b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb41a2defe9ffbe72758e012a414cd6a4b3b0c5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f9d290ae41023c237129e2abb79654db397adfb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6703f0f201a6eda219698bc860e51251857becff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223beb0c8c3dd4edd72d4d84905932bcaf223041 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_comprehension.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..510e068f287fa04419712e5e9a16a314e522a62d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_comprehension.py @@ -0,0 +1,78 @@ +from sympy.tensor.array.array_comprehension import ArrayComprehension, ArrayComprehensionMap +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.abc import i, j, k, l +from sympy.testing.pytest import raises +from sympy.matrices import Matrix + + +def test_array_comprehension(): + a = ArrayComprehension(i*j, (i, 1, 3), (j, 2, 4)) + b = ArrayComprehension(i, (i, 1, j+1)) + c = ArrayComprehension(i+j+k+l, (i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + d = ArrayComprehension(k, (i, 1, 5)) + e = ArrayComprehension(i, (j, k+1, k+5)) + assert a.doit().tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.shape == (3, 3) + assert a.is_shape_numeric == True + assert a.tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.tomatrix() == Matrix([ + [2, 3, 4], + [4, 6, 8], + [6, 9, 12]]) + assert len(a) == 9 + assert isinstance(b.doit(), ArrayComprehension) + assert isinstance(a.doit(), ImmutableDenseNDimArray) + assert b.subs(j, 3) == ArrayComprehension(i, (i, 1, 4)) + assert b.free_symbols == {j} + assert b.shape == (j + 1,) + assert b.rank() == 1 + assert b.is_shape_numeric == False + assert c.free_symbols == set() + assert c.function == i + j + k + l + assert c.limits == ((i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + assert c.doit().tolist() == [[[[4, 5, 6, 7, 8], [5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11]], + [[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]]], + [[[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]], + [[7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13], [10, 11, 12, 13, 14]]]] + assert c.free_symbols == set() + assert c.variables == [i, j, k, l] + assert c.bound_symbols == [i, j, k, l] + assert d.doit().tolist() == [k, k, k, k, k] + assert len(e) == 5 + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, [1, 3, 2]))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, 1))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+1))) + raises(ValueError, lambda: len(ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+4)))) + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 0, i + 1.5), (j, 0, 2))) + raises(ValueError, lambda: b.tolist()) + raises(ValueError, lambda: b.tomatrix()) + raises(ValueError, lambda: c.tomatrix()) + +def test_arraycomprehensionmap(): + a = ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)) + assert a.doit().tolist() == [2, 3, 4, 5, 6] + assert a.shape == (5,) + assert a.is_shape_numeric + assert a.tolist() == [2, 3, 4, 5, 6] + assert len(a) == 5 + assert isinstance(a.doit(), ImmutableDenseNDimArray) + expr = ArrayComprehensionMap(lambda i: i+1, (i, 1, k)) + assert expr.doit() == expr + assert expr.subs(k, 4) == ArrayComprehensionMap(lambda i: i+1, (i, 1, 4)) + assert expr.subs(k, 4).doit() == ImmutableDenseNDimArray([2, 3, 4, 5]) + b = ArrayComprehensionMap(lambda i: i+1, (i, 1, 2), (i, 1, 3), (i, 1, 4), (i, 1, 5)) + assert b.doit().tolist() == [[[[2, 3, 4, 5, 6], [3, 5, 7, 9, 11], [4, 7, 10, 13, 16], [5, 9, 13, 17, 21]], + [[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[4, 7, 10, 13, 16], [7, 13, 19, 25, 31], [10, 19, 28, 37, 46], [13, 25, 37, 49, 61]]], + [[[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[5, 9, 13, 17, 21], [9, 17, 25, 33, 41], [13, 25, 37, 49, 61], [17, 33, 49, 65, 81]], + [[7, 13, 19, 25, 31], [13, 25, 37, 49, 61], [19, 37, 55, 73, 91], [25, 49, 73, 97, 121]]]] + + # tests about lambda expression + assert ArrayComprehensionMap(lambda: 3, (i, 1, 5)).doit().tolist() == [3, 3, 3, 3, 3] + assert ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)).doit().tolist() == [2, 3, 4, 5, 6] + raises(ValueError, lambda: ArrayComprehensionMap(i*j, (i, 1, 3), (j, 2, 4))) + a = ArrayComprehensionMap(lambda i, j: i+j, (i, 1, 5)) + raises(ValueError, lambda: a.doit()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_derivatives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6c777c55a9170704f309bf74387d140bf2ec32 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_array_derivatives.py @@ -0,0 +1,52 @@ +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.ndim_array import NDimArray +from sympy.matrices.matrixbase import MatrixBase +from sympy.tensor.array.array_derivatives import ArrayDerivative + +x, y, z, t = symbols("x y z t") + +m = Matrix([[x, y], [z, t]]) + +M = MatrixSymbol("M", 3, 2) +N = MatrixSymbol("N", 4, 3) + + +def test_array_derivative_construction(): + + d = ArrayDerivative(x, m, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(m, m, evaluate=False) + assert d.shape == (2, 2, 2, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (2, 2, 2, 2) + + d = ArrayDerivative(m, x, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(M, N, evaluate=False) + assert d.shape == (4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 3, 2) + + d = ArrayDerivative(M, (N, 2), evaluate=False) + assert d.shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 4, 3, 3, 2) + + d = ArrayDerivative(M.as_explicit(), (N.as_explicit(), 2), evaluate=False) + assert d.doit().shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (4, 3, 4, 3, 3, 2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_arrayop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..de56e81e0064f1e303a7a58e41932d15f2d0b41e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_arrayop.py @@ -0,0 +1,361 @@ +import itertools +import random + +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.testing.pytest import raises + +from sympy.core.function import diff +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import (adjoint, conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.array import Array, ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableSparseNDimArray + +from sympy.tensor.array.arrayop import tensorproduct, tensorcontraction, derive_by_array, permutedims, Flatten, \ + tensordiagonal + + +def test_import_NDimArray(): + from sympy.tensor.array import NDimArray + del NDimArray + + +def test_tensorproduct(): + x,y,z,t = symbols('x y z t') + from sympy.abc import a,b,c,d + assert tensorproduct() == 1 + assert tensorproduct([x]) == Array([x]) + assert tensorproduct([x], [y]) == Array([[x*y]]) + assert tensorproduct([x], [y], [z]) == Array([[[x*y*z]]]) + assert tensorproduct([x], [y], [z], [t]) == Array([[[[x*y*z*t]]]]) + + assert tensorproduct(x) == x + assert tensorproduct(x, y) == x*y + assert tensorproduct(x, y, z) == x*y*z + assert tensorproduct(x, y, z, t) == x*y*z*t + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + A = ArrayType([x, y]) + B = ArrayType([1, 2, 3]) + C = ArrayType([a, b, c, d]) + + assert tensorproduct(A, B, C) == ArrayType([[[a*x, b*x, c*x, d*x], [2*a*x, 2*b*x, 2*c*x, 2*d*x], [3*a*x, 3*b*x, 3*c*x, 3*d*x]], + [[a*y, b*y, c*y, d*y], [2*a*y, 2*b*y, 2*c*y, 2*d*y], [3*a*y, 3*b*y, 3*c*y, 3*d*y]]]) + + assert tensorproduct([x, y], [1, 2, 3]) == tensorproduct(A, B) + + assert tensorproduct(A, 2) == ArrayType([2*x, 2*y]) + assert tensorproduct(A, [2]) == ArrayType([[2*x], [2*y]]) + assert tensorproduct([2], A) == ArrayType([[2*x, 2*y]]) + assert tensorproduct(a, A) == ArrayType([a*x, a*y]) + assert tensorproduct(a, A, B) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(A, B, a) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(B, a, A) == ArrayType([[a*x, a*y], [2*a*x, 2*a*y], [3*a*x, 3*a*y]]) + + # tests for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + a = SparseArrayType({1:2, 3:4},(1000, 2000)) + b = SparseArrayType({1:2, 3:4},(1000, 2000)) + assert tensorproduct(a, b) == ImmutableSparseNDimArray({2000001: 4, 2000003: 8, 6000001: 8, 6000003: 16}, (1000, 2000, 1000, 2000)) + + +def test_tensorcontraction(): + from sympy.abc import a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x + B = Array(range(18), (2, 3, 3)) + assert tensorcontraction(B, (1, 2)) == Array([12, 39]) + C1 = Array([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x], (2, 3, 2, 2)) + + assert tensorcontraction(C1, (0, 2)) == Array([[a + o, b + p], [e + s, f + t], [i + w, j + x]]) + assert tensorcontraction(C1, (0, 2, 3)) == Array([a + p, e + t, i + x]) + assert tensorcontraction(C1, (2, 3)) == Array([[a + d, e + h, i + l], [m + p, q + t, u + x]]) + + +def test_derivative_by_array(): + from sympy.abc import i, j, t, x, y, z + + bexpr = x*y**2*exp(z)*log(t) + sexpr = sin(bexpr) + cexpr = cos(bexpr) + + a = Array([sexpr]) + + assert derive_by_array(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert derive_by_array(sexpr, [x, y, z]) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert derive_by_array(a, [x, y, z]) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert derive_by_array(sexpr, [[x, y], [z, t]]) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert derive_by_array(a, [[x, y], [z, t]]) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert derive_by_array([[x, y], [z, t]], [x, y]) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert derive_by_array([[x, y], [z, t]], [[x, y], [z, t]]) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + assert diff(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert diff(sexpr, Array([x, y, z])) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert diff(a, Array([x, y, z])) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert diff(sexpr, Array([[x, y], [z, t]])) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert diff(a, Array([[x, y], [z, t]])) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert diff(Array([[x, y], [z, t]]), Array([x, y])) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert diff(Array([[x, y], [z, t]]), Array([[x, y], [z, t]])) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + b = MutableSparseNDimArray({0:i, 1:j}, (10000, 20000)) + assert derive_by_array(b, i) == ImmutableSparseNDimArray({0: 1}, (10000, 20000)) + assert derive_by_array(b, (i, j)) == ImmutableSparseNDimArray({0: 1, 200000001: 1}, (2, 10000, 20000)) + + #https://github.com/sympy/sympy/issues/20655 + U = Array([x, y, z]) + E = 2 + assert derive_by_array(E, U) == ImmutableDenseNDimArray([0, 0, 0]) + + +def test_issue_emerged_while_discussing_10972(): + ua = Array([-1,0]) + Fa = Array([[0, 1], [-1, 0]]) + po = tensorproduct(Fa, ua, Fa, ua) + assert tensorcontraction(po, (1, 2), (4, 5)) == Array([[0, 0], [0, 1]]) + + sa = symbols('a0:144') + po = Array(sa, [2, 2, 3, 3, 2, 2]) + assert tensorcontraction(po, (0, 1), (2, 3), (4, 5)) == sa[0] + sa[108] + sa[111] + sa[124] + sa[127] + sa[140] + sa[143] + sa[16] + sa[19] + sa[3] + sa[32] + sa[35] + assert tensorcontraction(po, (0, 1, 4, 5), (2, 3)) == sa[0] + sa[111] + sa[127] + sa[143] + sa[16] + sa[32] + assert tensorcontraction(po, (0, 1), (4, 5)) == Array([[sa[0] + sa[108] + sa[111] + sa[3], sa[112] + sa[115] + sa[4] + sa[7], + sa[11] + sa[116] + sa[119] + sa[8]], [sa[12] + sa[120] + sa[123] + sa[15], + sa[124] + sa[127] + sa[16] + sa[19], sa[128] + sa[131] + sa[20] + sa[23]], + [sa[132] + sa[135] + sa[24] + sa[27], sa[136] + sa[139] + sa[28] + sa[31], + sa[140] + sa[143] + sa[32] + sa[35]]]) + assert tensorcontraction(po, (0, 1), (2, 3)) == Array([[sa[0] + sa[108] + sa[124] + sa[140] + sa[16] + sa[32], sa[1] + sa[109] + sa[125] + sa[141] + sa[17] + sa[33]], + [sa[110] + sa[126] + sa[142] + sa[18] + sa[2] + sa[34], sa[111] + sa[127] + sa[143] + sa[19] + sa[3] + sa[35]]]) + + +def test_array_permutedims(): + sa = symbols('a0:144') + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + m1 = ArrayType(sa[:6], (2, 3)) + assert permutedims(m1, (1, 0)) == transpose(m1) + assert m1.tomatrix().T == permutedims(m1, (1, 0)).tomatrix() + + assert m1.tomatrix().T == transpose(m1).tomatrix() + assert m1.tomatrix().C == conjugate(m1).tomatrix() + assert m1.tomatrix().H == adjoint(m1).tomatrix() + + assert m1.tomatrix().T == m1.transpose().tomatrix() + assert m1.tomatrix().C == m1.conjugate().tomatrix() + assert m1.tomatrix().H == m1.adjoint().tomatrix() + + raises(ValueError, lambda: permutedims(m1, (0,))) + raises(ValueError, lambda: permutedims(m1, (0, 0))) + raises(ValueError, lambda: permutedims(m1, (1, 2, 0))) + + # Some tests with random arrays: + dims = 6 + shape = [random.randint(1,5) for i in range(dims)] + elems = [random.random() for i in range(tensorproduct(*shape))] + ra = ArrayType(elems, shape) + perm = list(range(dims)) + # Randomize the permutation: + random.shuffle(perm) + # Test inverse permutation: + assert permutedims(permutedims(ra, perm), _af_invert(perm)) == ra + # Test that permuted shape corresponds to action by `Permutation`: + assert permutedims(ra, perm).shape == tuple(Permutation(perm)(shape)) + + z = ArrayType.zeros(4,5,6,7) + + assert permutedims(z, (2, 3, 1, 0)).shape == (6, 7, 5, 4) + assert permutedims(z, [2, 3, 1, 0]).shape == (6, 7, 5, 4) + assert permutedims(z, Permutation([2, 3, 1, 0])).shape == (6, 7, 5, 4) + + po = ArrayType(sa, [2, 2, 3, 3, 2, 2]) + + raises(ValueError, lambda: permutedims(po, (1, 1))) + raises(ValueError, lambda: po.transpose()) + raises(ValueError, lambda: po.adjoint()) + + assert permutedims(po, reversed(range(po.rank()))) == ArrayType( + [[[[[[sa[0], sa[72]], [sa[36], sa[108]]], [[sa[12], sa[84]], [sa[48], sa[120]]], [[sa[24], + sa[96]], [sa[60], sa[132]]]], + [[[sa[4], sa[76]], [sa[40], sa[112]]], [[sa[16], + sa[88]], [sa[52], sa[124]]], + [[sa[28], sa[100]], [sa[64], sa[136]]]], + [[[sa[8], + sa[80]], [sa[44], sa[116]]], [[sa[20], sa[92]], [sa[56], sa[128]]], [[sa[32], + sa[104]], [sa[68], sa[140]]]]], + [[[[sa[2], sa[74]], [sa[38], sa[110]]], [[sa[14], + sa[86]], [sa[50], sa[122]]], [[sa[26], sa[98]], [sa[62], sa[134]]]], + [[[sa[6], + sa[78]], [sa[42], sa[114]]], [[sa[18], sa[90]], [sa[54], sa[126]]], [[sa[30], + sa[102]], [sa[66], sa[138]]]], + [[[sa[10], sa[82]], [sa[46], sa[118]]], [[sa[22], + sa[94]], [sa[58], sa[130]]], + [[sa[34], sa[106]], [sa[70], sa[142]]]]]], + [[[[[sa[1], + sa[73]], [sa[37], sa[109]]], [[sa[13], sa[85]], [sa[49], sa[121]]], [[sa[25], + sa[97]], [sa[61], sa[133]]]], + [[[sa[5], sa[77]], [sa[41], sa[113]]], [[sa[17], + sa[89]], [sa[53], sa[125]]], + [[sa[29], sa[101]], [sa[65], sa[137]]]], + [[[sa[9], + sa[81]], [sa[45], sa[117]]], [[sa[21], sa[93]], [sa[57], sa[129]]], [[sa[33], + sa[105]], [sa[69], sa[141]]]]], + [[[[sa[3], sa[75]], [sa[39], sa[111]]], [[sa[15], + sa[87]], [sa[51], sa[123]]], [[sa[27], sa[99]], [sa[63], sa[135]]]], + [[[sa[7], + sa[79]], [sa[43], sa[115]]], [[sa[19], sa[91]], [sa[55], sa[127]]], [[sa[31], + sa[103]], [sa[67], sa[139]]]], + [[[sa[11], sa[83]], [sa[47], sa[119]]], [[sa[23], + sa[95]], [sa[59], sa[131]]], + [[sa[35], sa[107]], [sa[71], sa[143]]]]]]]) + + assert permutedims(po, (1, 0, 2, 3, 4, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], + sa[11]]]], + [[[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], + sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]]], + [[[sa[24], sa[25]], [sa[26], + sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], + sa[35]]]]], + [[[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], + sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]]], + [[[sa[84], sa[85]], [sa[86], + sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], + sa[95]]]], + [[[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], + sa[103]]], + [[sa[104], sa[105]], [sa[106], sa[107]]]]]], [[[[[sa[36], sa[37]], [sa[38], + sa[39]]], + [[sa[40], sa[41]], [sa[42], sa[43]]], + [[sa[44], sa[45]], [sa[46], + sa[47]]]], + [[[sa[48], sa[49]], [sa[50], sa[51]]], + [[sa[52], sa[53]], [sa[54], + sa[55]]], + [[sa[56], sa[57]], [sa[58], sa[59]]]], + [[[sa[60], sa[61]], [sa[62], + sa[63]]], + [[sa[64], sa[65]], [sa[66], sa[67]]], + [[sa[68], sa[69]], [sa[70], + sa[71]]]]], [ + [[[sa[108], sa[109]], [sa[110], sa[111]]], + [[sa[112], sa[113]], [sa[114], + sa[115]]], + [[sa[116], sa[117]], [sa[118], sa[119]]]], + [[[sa[120], sa[121]], [sa[122], + sa[123]]], + [[sa[124], sa[125]], [sa[126], sa[127]]], + [[sa[128], sa[129]], [sa[130], + sa[131]]]], + [[[sa[132], sa[133]], [sa[134], sa[135]]], + [[sa[136], sa[137]], [sa[138], + sa[139]]], + [[sa[140], sa[141]], [sa[142], sa[143]]]]]]]) + + assert permutedims(po, (0, 2, 1, 4, 3, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[4], sa[5]], [sa[8], sa[9]]], [[sa[2], sa[3]], [sa[6], sa[7]], [sa[10], + sa[11]]]], + [[[sa[36], sa[37]], [sa[40], sa[41]], [sa[44], sa[45]]], [[sa[38], + sa[39]], [sa[42], sa[43]], [sa[46], sa[47]]]]], + [[[[sa[12], sa[13]], [sa[16], + sa[17]], [sa[20], sa[21]]], [[sa[14], sa[15]], [sa[18], sa[19]], [sa[22], + sa[23]]]], + [[[sa[48], sa[49]], [sa[52], sa[53]], [sa[56], sa[57]]], [[sa[50], + sa[51]], [sa[54], sa[55]], [sa[58], sa[59]]]]], + [[[[sa[24], sa[25]], [sa[28], + sa[29]], [sa[32], sa[33]]], [[sa[26], sa[27]], [sa[30], sa[31]], [sa[34], + sa[35]]]], + [[[sa[60], sa[61]], [sa[64], sa[65]], [sa[68], sa[69]]], [[sa[62], + sa[63]], [sa[66], sa[67]], [sa[70], sa[71]]]]]], + [[[[[sa[72], sa[73]], [sa[76], + sa[77]], [sa[80], sa[81]]], [[sa[74], sa[75]], [sa[78], sa[79]], [sa[82], + sa[83]]]], + [[[sa[108], sa[109]], [sa[112], sa[113]], [sa[116], sa[117]]], [[sa[110], + sa[111]], [sa[114], sa[115]], + [sa[118], sa[119]]]]], + [[[[sa[84], sa[85]], [sa[88], + sa[89]], [sa[92], sa[93]]], [[sa[86], sa[87]], [sa[90], sa[91]], [sa[94], + sa[95]]]], + [[[sa[120], sa[121]], [sa[124], sa[125]], [sa[128], sa[129]]], [[sa[122], + sa[123]], [sa[126], sa[127]], + [sa[130], sa[131]]]]], + [[[[sa[96], sa[97]], [sa[100], + sa[101]], [sa[104], sa[105]]], [[sa[98], sa[99]], [sa[102], sa[103]], [sa[106], + sa[107]]]], + [[[sa[132], sa[133]], [sa[136], sa[137]], [sa[140], sa[141]]], [[sa[134], + sa[135]], [sa[138], sa[139]], + [sa[142], sa[143]]]]]]]) + + po2 = po.reshape(4, 9, 2, 2) + assert po2 == ArrayType([[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], sa[11]]], [[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]], [[sa[24], sa[25]], [sa[26], sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], sa[35]]]], [[[sa[36], sa[37]], [sa[38], sa[39]]], [[sa[40], sa[41]], [sa[42], sa[43]]], [[sa[44], sa[45]], [sa[46], sa[47]]], [[sa[48], sa[49]], [sa[50], sa[51]]], [[sa[52], sa[53]], [sa[54], sa[55]]], [[sa[56], sa[57]], [sa[58], sa[59]]], [[sa[60], sa[61]], [sa[62], sa[63]]], [[sa[64], sa[65]], [sa[66], sa[67]]], [[sa[68], sa[69]], [sa[70], sa[71]]]], [[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]], [[sa[84], sa[85]], [sa[86], sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], sa[95]]], [[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], sa[103]]], [[sa[104], sa[105]], [sa[106], sa[107]]]], [[[sa[108], sa[109]], [sa[110], sa[111]]], [[sa[112], sa[113]], [sa[114], sa[115]]], [[sa[116], sa[117]], [sa[118], sa[119]]], [[sa[120], sa[121]], [sa[122], sa[123]]], [[sa[124], sa[125]], [sa[126], sa[127]]], [[sa[128], sa[129]], [sa[130], sa[131]]], [[sa[132], sa[133]], [sa[134], sa[135]]], [[sa[136], sa[137]], [sa[138], sa[139]]], [[sa[140], sa[141]], [sa[142], sa[143]]]]]) + + assert permutedims(po2, (3, 2, 0, 1)) == ArrayType([[[[sa[0], sa[4], sa[8], sa[12], sa[16], sa[20], sa[24], sa[28], sa[32]], [sa[36], sa[40], sa[44], sa[48], sa[52], sa[56], sa[60], sa[64], sa[68]], [sa[72], sa[76], sa[80], sa[84], sa[88], sa[92], sa[96], sa[100], sa[104]], [sa[108], sa[112], sa[116], sa[120], sa[124], sa[128], sa[132], sa[136], sa[140]]], [[sa[2], sa[6], sa[10], sa[14], sa[18], sa[22], sa[26], sa[30], sa[34]], [sa[38], sa[42], sa[46], sa[50], sa[54], sa[58], sa[62], sa[66], sa[70]], [sa[74], sa[78], sa[82], sa[86], sa[90], sa[94], sa[98], sa[102], sa[106]], [sa[110], sa[114], sa[118], sa[122], sa[126], sa[130], sa[134], sa[138], sa[142]]]], [[[sa[1], sa[5], sa[9], sa[13], sa[17], sa[21], sa[25], sa[29], sa[33]], [sa[37], sa[41], sa[45], sa[49], sa[53], sa[57], sa[61], sa[65], sa[69]], [sa[73], sa[77], sa[81], sa[85], sa[89], sa[93], sa[97], sa[101], sa[105]], [sa[109], sa[113], sa[117], sa[121], sa[125], sa[129], sa[133], sa[137], sa[141]]], [[sa[3], sa[7], sa[11], sa[15], sa[19], sa[23], sa[27], sa[31], sa[35]], [sa[39], sa[43], sa[47], sa[51], sa[55], sa[59], sa[63], sa[67], sa[71]], [sa[75], sa[79], sa[83], sa[87], sa[91], sa[95], sa[99], sa[103], sa[107]], [sa[111], sa[115], sa[119], sa[123], sa[127], sa[131], sa[135], sa[139], sa[143]]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + A = SparseArrayType({1:1, 10000:2}, (10000, 20000, 10000)) + assert permutedims(A, (0, 1, 2)) == A + assert permutedims(A, (1, 0, 2)) == SparseArrayType({1: 1, 100000000: 2}, (20000, 10000, 10000)) + B = SparseArrayType({1:1, 20000:2}, (10000, 20000)) + assert B.transpose() == SparseArrayType({10000: 1, 1: 2}, (20000, 10000)) + + +def test_permutedims_with_indices(): + A = Array(range(32)).reshape(2, 2, 2, 2, 2) + indices_new = list("abcde") + indices_old = list("ebdac") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[e, b, d, a, c] + indices_old = list("cabed") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[c, a, b, e, d] + raises(ValueError, lambda: permutedims(A, index_order_old=list("aacde"), index_order_new=list("abcde"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abcce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abce"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_old=list("abcde"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_new=list("abcde"))) + + +def test_flatten(): + from sympy.matrices.dense import Matrix + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray, Matrix]: + A = ArrayType(range(24)).reshape(4, 6) + assert list(Flatten(A)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + + for i, v in enumerate(Flatten(A)): + assert i == v + + +def test_tensordiagonal(): + from sympy.matrices.dense import eye + expr = Array(range(9)).reshape(3, 3) + raises(ValueError, lambda: tensordiagonal(expr, [0], [1])) + raises(ValueError, lambda: tensordiagonal(expr, [0, 0])) + assert tensordiagonal(eye(3), [0, 1]) == Array([1, 1, 1]) + assert tensordiagonal(expr, [0, 1]) == Array([0, 4, 8]) + x, y, z = symbols("x y z") + expr2 = tensorproduct([x, y, z], expr) + assert tensordiagonal(expr2, [1, 2]) == Array([[0, 4*x, 8*x], [0, 4*y, 8*y], [0, 4*z, 8*z]]) + assert tensordiagonal(expr2, [0, 1]) == Array([[0, 3*y, 6*z], [x, 4*y, 7*z], [2*x, 5*y, 8*z]]) + assert tensordiagonal(expr2, [0, 1, 2]) == Array([0, 4*y, 8*z]) + # assert tensordiagonal(expr2, [0]) == permutedims(expr2, [1, 2, 0]) + # assert tensordiagonal(expr2, [1]) == permutedims(expr2, [0, 2, 1]) + # assert tensordiagonal(expr2, [2]) == expr2 + # assert tensordiagonal(expr2, [1], [2]) == expr2 + # assert tensordiagonal(expr2, [0], [1]) == permutedims(expr2, [2, 0, 1]) + + a, b, c, X, Y, Z = symbols("a b c X Y Z") + expr3 = tensorproduct([x, y, z], [1, 2, 3], [a, b, c], [X, Y, Z]) + assert tensordiagonal(expr3, [0, 1, 2, 3]) == Array([x*a*X, 2*y*b*Y, 3*z*c*Z]) + assert tensordiagonal(expr3, [0, 1], [2, 3]) == tensorproduct([x, 2*y, 3*z], [a*X, b*Y, c*Z]) + + # assert tensordiagonal(expr3, [0], [1, 2], [3]) == tensorproduct([x, y, z], [a, 2*b, 3*c], [X, Y, Z]) + assert tensordiagonal(tensordiagonal(expr3, [2, 3]), [0, 1]) == tensorproduct([a*X, b*Y, c*Z], [x, 2*y, 3*z]) + + raises(ValueError, lambda: tensordiagonal([[1, 2, 3], [4, 5, 6]], [0, 1])) + raises(ValueError, lambda: tensordiagonal(expr3.reshape(3, 3, 9), [1, 2])) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c6bed4b605c424284b4752592b03b13a9178aac8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py @@ -0,0 +1,452 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.core.containers import Dict +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.matrices import SparseMatrix +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import ImmutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_no_elements = ImmutableDenseNDimArray([], shape=(0,)) + assert len(arr_with_no_elements) == 0 + assert arr_with_no_elements.rank() == 1 + + raises(ValueError, lambda: ImmutableDenseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([], shape=())) + + raises(ValueError, lambda: ImmutableSparseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([], shape=())) + + arr_with_one_element = ImmutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element[:] == ImmutableDenseNDimArray([23]) + assert arr_with_one_element.rank() == 1 + + arr_with_symbol_element = ImmutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element[:] == ImmutableDenseNDimArray([Symbol('x')]) + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = ImmutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + + vector = ImmutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == Dict() + assert vector.rank() == 1 + + n_dim_array = ImmutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + + array_shape = (3, 3, 3, 3) + sparse_array = ImmutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = ImmutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = ImmutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + + +def test_reshape(): + array = ImmutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_getitem(): + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_iterator(): + array = ImmutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == ImmutableDenseNDimArray([0, 1]) + assert array[1] == ImmutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_sparse(): + sparse_array = ImmutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == ImmutableSparseNDimArray(j) + + def sparse_assignment(): + sparse_array[0, 0] = 123 + + assert len(sparse_array._sparse_array) == 1 + raises(TypeError, sparse_assignment) + assert len(sparse_array._sparse_array) == 1 + assert sparse_array[0, 0] == 0 + assert sparse_array/0 == ImmutableSparseNDimArray([[S.NaN, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # test for large scale sparse array + # equality test + assert ImmutableSparseNDimArray.zeros(100000, 200000) == ImmutableSparseNDimArray.zeros(100000, 200000) + + # __mul__ and __rmul__ + a = ImmutableSparseNDimArray({200001: 1}, (100000, 200000)) + assert a * 3 == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == ImmutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == ImmutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == ImmutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == ImmutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = ImmutableDenseNDimArray([1]*9, (3, 3)) + b = ImmutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == ImmutableDenseNDimArray([10, 10, 10]) + + assert c == ImmutableDenseNDimArray([10]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == ImmutableDenseNDimArray([8, 8, 8]) + + assert c == ImmutableDenseNDimArray([8]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = ImmutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert ImmutableDenseNDimArray(matrix) == dense_array + assert ImmutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert ImmutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = ImmutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert ImmutableSparseNDimArray(matrix) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = ImmutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = ImmutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = ImmutableDenseNDimArray(second_list, (2, 2)) + fourth_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + + def assignment_attempt(a): + a[0, 0] = 0 + + raises(TypeError, lambda: assignment_attempt(second_ndim_array)) + assert first_ndim_array == second_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = ImmutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = ImmutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == ImmutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_rebuild_immutable_arrays(): + sparr = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + densarr = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert sparr == sparr.func(*sparr.args) + assert densarr == densarr.func(*densarr.args) + + +def test_slices(): + md = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == ImmutableSparseNDimArray(md) + + assert sd[:] == ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_diff_and_applyfunc(): + from sympy.abc import x, y, z + md = ImmutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = ImmutableSparseNDimArray(md) + assert sd == ImmutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + + mdn = md.applyfunc(lambda x: x*3) + assert mdn == ImmutableDenseNDimArray([[3*x, 3*y], [3*x*z, 3*x*y*z]]) + assert md != mdn + + sdn = sd.applyfunc(lambda x: x/2) + assert sdn == ImmutableSparseNDimArray([[x/2, y/2], [x*z/2, x*y*z/2]]) + assert sd != sdn + + sdp = sd.applyfunc(lambda x: x+1) + assert sdp == ImmutableSparseNDimArray([[x + 1, y + 1], [x*z + 1, x*y*z + 1]]) + assert sd != sdp + + +def test_op_priority(): + from sympy.abc import x + md = ImmutableDenseNDimArray([1, 2, 3]) + e1 = (1+x)*md + e2 = md*(1+x) + assert e1 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e1 == e2 + + sd = ImmutableSparseNDimArray([1, 2, 3]) + e3 = (1+x)*sd + e4 = sd*(1+x) + assert e3 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e3 == e4 + + +def test_symbolic_indexing(): + x, y, z, w = symbols("x y z w") + M = ImmutableDenseNDimArray([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, Indexed) + Ms = ImmutableSparseNDimArray([[2, 3*x], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, Indexed) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = IndexedBase("A", (0, 2)) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3 * i - 2, j], Indexed) + assert M[3 * i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], Indexed) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == ImmutableDenseNDimArray([[1, 0], [0, 0]])[i, j] + assert Ms[i, j].diff(x) == ImmutableSparseNDimArray([[0, 3], [0, 0]])[i, j] + + Mo = ImmutableDenseNDimArray([1, 2, 3]) + assert Mo[i].subs(i, 1) == 2 + Mos = ImmutableSparseNDimArray([1, 2, 3]) + assert Mos[i].subs(i, 1) == 2 + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + raises(ValueError, lambda: Ms[i, 2]) + raises(ValueError, lambda: Ms[i, -1]) + raises(ValueError, lambda: Ms[2, i]) + raises(ValueError, lambda: Ms[-1, i]) + + +def test_issue_12665(): + # Testing Python 3 hash of immutable arrays: + arr = ImmutableDenseNDimArray([1, 2, 3]) + # This should NOT raise an exception: + hash(arr) + + +def test_zeros_without_shape(): + arr = ImmutableDenseNDimArray.zeros() + assert arr == ImmutableDenseNDimArray(0) + +def test_issue_21870(): + a0 = ImmutableDenseNDimArray(0) + assert a0.rank() == 0 + a1 = ImmutableDenseNDimArray(a0) + assert a1.rank() == 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..9a232f399bbc0639d326217975fb0a12e645a984 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py @@ -0,0 +1,374 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import MutableDenseNDimArray +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices import SparseMatrix +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import MutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_one_element = MutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[1]) + + arr_with_symbol_element = MutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = MutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[5]) + + vector = MutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == {} + assert vector.rank() == 1 + + n_dim_array = MutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + raises(ValueError, lambda: n_dim_array[0, 0, 0, 3]) + raises(ValueError, lambda: n_dim_array[3, 0, 0, 0]) + raises(ValueError, lambda: n_dim_array[3**4]) + + array_shape = (3, 3, 3, 3) + sparse_array = MutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = MutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = MutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + +def test_sympify(): + from sympy.abc import x, y, z, t + arr = MutableDenseNDimArray([[x, y], [1, z*t]]) + arr_other = sympify(arr) + assert arr_other.shape == (2, 2) + assert arr_other == arr + + +def test_reshape(): + array = MutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_iterator(): + array = MutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == MutableDenseNDimArray([0, 1]) + assert array[1] == MutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_getitem(): + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_sparse(): + sparse_array = MutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == MutableSparseNDimArray(j) + + sparse_array[0, 0] = 123 + assert len(sparse_array._sparse_array) == 2 + assert sparse_array[0, 0] == 123 + assert sparse_array/0 == MutableSparseNDimArray([[S.ComplexInfinity, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # when element in sparse array become zero it will disappear from + # dictionary + sparse_array[0, 0] = 0 + assert len(sparse_array._sparse_array) == 1 + sparse_array[1, 1] = 0 + assert len(sparse_array._sparse_array) == 0 + assert sparse_array[0, 0] == 0 + + # test for large scale sparse array + # equality test + a = MutableSparseNDimArray.zeros(100000, 200000) + b = MutableSparseNDimArray.zeros(100000, 200000) + assert a == b + a[1, 1] = 1 + b[1, 1] = 2 + assert a != b + + # __mul__ and __rmul__ + assert a * 3 == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == MutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == MutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == MutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == MutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = MutableDenseNDimArray([1]*9, (3, 3)) + b = MutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == MutableDenseNDimArray([10, 10, 10]) + + assert c == MutableDenseNDimArray([10]*9, (3, 3)) + assert c == MutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == MutableSparseNDimArray([8, 8, 8]) + + assert c == MutableDenseNDimArray([8]*9, (3, 3)) + assert c == MutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert MutableDenseNDimArray(matrix) == dense_array + assert MutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert MutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = MutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert MutableSparseNDimArray(matrix) == sparse_array + assert MutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert MutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = MutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = MutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = MutableDenseNDimArray(second_list, (2, 2)) + third_ndim_array = MutableDenseNDimArray(third_list, (2, 2)) + fourth_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + second_ndim_array[0, 0] = 0 + assert first_ndim_array != second_ndim_array + assert first_ndim_array != third_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = MutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = MutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == MutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_slices(): + md = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == MutableSparseNDimArray(md) + + assert sd[:] == MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_slices_assign(): + a = MutableDenseNDimArray(range(12), shape=(4, 3)) + b = MutableSparseNDimArray(range(12), shape=(4, 3)) + + for i in [a, b]: + assert i.tolist() == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, :] = [2, 2, 2] + assert i.tolist() == [[2, 2, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, 1:] = [8, 8] + assert i.tolist() == [[2, 8, 8], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[1:3, 1] = [20, 44] + assert i.tolist() == [[2, 8, 8], [3, 20, 5], [6, 44, 8], [9, 10, 11]] + + +def test_diff(): + from sympy.abc import x, y, z + md = MutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = MutableSparseNDimArray(md) + assert sd == MutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff9b032631c01272c00478e4cdf0dcbc6997990 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array.py @@ -0,0 +1,73 @@ +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.matrices.dense import Matrix +from sympy.simplify import simplify +from sympy.tensor.array import Array +from sympy.tensor.array.dense_ndim_array import ( + ImmutableDenseNDimArray, MutableDenseNDimArray) +from sympy.tensor.array.sparse_ndim_array import ( + ImmutableSparseNDimArray, MutableSparseNDimArray) + +from sympy.abc import x, y + +mutable_array_types = [ + MutableDenseNDimArray, + MutableSparseNDimArray +] + +array_types = [ + ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableDenseNDimArray, + MutableSparseNDimArray +] + + +def test_array_negative_indices(): + for ArrayType in array_types: + test_array = ArrayType([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + assert test_array[:, -1] == Array([5, 10]) + assert test_array[:, -2] == Array([4, 9]) + assert test_array[:, -3] == Array([3, 8]) + assert test_array[:, -4] == Array([2, 7]) + assert test_array[:, -5] == Array([1, 6]) + assert test_array[:, 0] == Array([1, 6]) + assert test_array[:, 1] == Array([2, 7]) + assert test_array[:, 2] == Array([3, 8]) + assert test_array[:, 3] == Array([4, 9]) + assert test_array[:, 4] == Array([5, 10]) + + raises(ValueError, lambda: test_array[:, -6]) + raises(ValueError, lambda: test_array[-3, :]) + + assert test_array[-1, -1] == 10 + + +def test_issue_18361(): + A = Array([sin(2 * x) - 2 * sin(x) * cos(x)]) + B = Array([sin(x)**2 + cos(x)**2, 0]) + C = Array([(x + x**2)/(x*sin(y)**2 + x*cos(y)**2), 2*sin(x)*cos(x)]) + assert simplify(A) == Array([0]) + assert simplify(B) == Array([1, 0]) + assert simplify(C) == Array([x + 1, sin(2*x)]) + + +def test_issue_20222(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: A - B) + + +def test_issue_17851(): + for array_type in array_types: + A = array_type([]) + assert isinstance(A, array_type) + assert A.shape == (0,) + assert list(A) == [] + + +def test_issue_and_18715(): + for array_type in mutable_array_types: + A = array_type([0, 1, 2]) + A[0] += 5 + assert A[0] == 5 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f43260ccc636ac461ba0c06dbfcf3fe3a8d5338d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py @@ -0,0 +1,22 @@ +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray) +from sympy.abc import x, y, z + + +def test_NDim_array_conv(): + MD = MutableDenseNDimArray([x, y, z]) + MS = MutableSparseNDimArray([x, y, z]) + ID = ImmutableDenseNDimArray([x, y, z]) + IS = ImmutableSparseNDimArray([x, y, z]) + + assert MD.as_immutable() == ID + assert MD.as_mutable() == MD + + assert MS.as_immutable() == IS + assert MS.as_mutable() == MS + + assert ID.as_immutable() == ID + assert ID.as_mutable() == MD + + assert IS.as_immutable() == IS + assert IS.as_mutable() == MS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f14599d69152db1713f21c9dd785683901c5eeb9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/functions.py @@ -0,0 +1,154 @@ +from collections.abc import Iterable +from functools import singledispatch + +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.core.parameters import global_parameters + + +class TensorProduct(Expr): + """ + Generic class for tensor products. + """ + is_number = False + + def __new__(cls, *args, **kwargs): + from sympy.tensor.array import NDimArray, tensorproduct, Array + from sympy.matrices.expressions.matexpr import MatrixExpr + from sympy.matrices.matrixbase import MatrixBase + from sympy.strategies import flatten + + args = [sympify(arg) for arg in args] + evaluate = kwargs.get("evaluate", global_parameters.evaluate) + + if not evaluate: + obj = Expr.__new__(cls, *args) + return obj + + arrays = [] + other = [] + scalar = S.One + for arg in args: + if isinstance(arg, (Iterable, MatrixBase, NDimArray)): + arrays.append(Array(arg)) + elif isinstance(arg, (MatrixExpr,)): + other.append(arg) + else: + scalar *= arg + + coeff = scalar*tensorproduct(*arrays) + if len(other) == 0: + return coeff + if coeff != 1: + newargs = [coeff] + other + else: + newargs = other + obj = Expr.__new__(cls, *newargs, **kwargs) + return flatten(obj) + + def rank(self): + return len(self.shape) + + def _get_args_shapes(self): + from sympy.tensor.array import Array + return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] + + @property + def shape(self): + shape_list = self._get_args_shapes() + return sum(shape_list, ()) + + def __getitem__(self, index): + index = iter(index) + return Mul.fromiter( + arg.__getitem__(tuple(next(index) for i in shp)) + for arg, shp in zip(self.args, self._get_args_shapes()) + ) + + +@singledispatch +def shape(expr): + """ + Return the shape of the *expr* as a tuple. *expr* should represent + suitable object such as matrix or array. + + Parameters + ========== + + expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. + + Raises + ====== + + NoShapeError : Raised when object with wrong kind is passed. + + Examples + ======== + + This function returns the shape of any object representing matrix or array. + + >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral + >>> from sympy.abc import x + >>> A = Array([1, 2]) + >>> shape(A) + (2,) + >>> shape(Integral(A, x)) + (2,) + >>> M = ImmutableDenseMatrix([1, 2]) + >>> shape(M) + (2, 1) + >>> shape(Integral(M, x)) + (2, 1) + + You can support new type by dispatching. + + >>> from sympy import Expr + >>> class NewExpr(Expr): + ... pass + >>> @shape.register(NewExpr) + ... def _(expr): + ... return shape(expr.args[0]) + >>> shape(NewExpr(M)) + (2, 1) + + If unsuitable expression is passed, ``NoShapeError()`` will be raised. + + >>> shape(Integral(x, x)) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) + + Notes + ===== + + Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` + property which returns its shape, but it cannot be used for non-array + classes containing array. This function returns the shape of any + registered object representing array. + + """ + if hasattr(expr, "shape"): + return expr.shape + raise NoShapeError( + "%s does not have shape, or its type is not registered to shape()." % expr) + + +class NoShapeError(Exception): + """ + Raised when ``shape()`` is called on non-array object. + + This error can be imported from ``sympy.tensor.functions``. + + Examples + ======== + + >>> from sympy import shape + >>> from sympy.abc import x + >>> shape(x) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: x + """ + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/index_methods.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..12f707b60b4ad0bcadc35a222d9abe0cc5e033fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/index_methods.py @@ -0,0 +1,469 @@ +"""Module with functions operating on IndexedBase, Indexed and Idx objects + + - Check shape conformance + - Determine indices in resulting expression + + etc. + + Methods in this module could be implemented by calling methods on Expr + objects instead. When things stabilize this could be a useful + refactoring. +""" + +from functools import reduce + +from sympy.core.function import Function +from sympy.functions import exp, Piecewise +from sympy.tensor.indexed import Idx, Indexed +from sympy.utilities import sift + +from collections import OrderedDict + +class IndexConformanceException(Exception): + pass + +def _unique_and_repeated(inds): + """ + Returns the unique and repeated indices. Also note, from the examples given below + that the order of indices is maintained as given in the input. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _unique_and_repeated + >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0]) + ([2, 1, 4], [3, 0]) + """ + uniq = OrderedDict() + for i in inds: + if i in uniq: + uniq[i] = 0 + else: + uniq[i] = 1 + return sift(uniq, lambda x: uniq[x], binary=True) + +def _remove_repeated(inds): + """ + Removes repeated objects from sequences + + Returns a set of the unique objects and a tuple of all that have been + removed. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _remove_repeated + >>> l1 = [1, 2, 3, 2] + >>> _remove_repeated(l1) + ({1, 3}, (2,)) + + """ + u, r = _unique_and_repeated(inds) + return set(u), tuple(r) + + +def _get_indices_Mul(expr, return_dummies=False): + """Determine the outer indices of a Mul object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Mul + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Mul(x[i, k]*y[j, k]) + ({i, j}, {}) + >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True) + ({i, j}, {}, (k,)) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + inds = list(map(list, inds)) + inds = list(reduce(lambda x, y: x + y, inds)) + inds, dummies = _remove_repeated(inds) + + symmetry = {} + for s in syms: + for pair in s: + if pair in symmetry: + symmetry[pair] *= s[pair] + else: + symmetry[pair] = s[pair] + + if return_dummies: + return inds, symmetry, dummies + else: + return inds, symmetry + + +def _get_indices_Pow(expr): + """Determine outer indices of a power or an exponential. + + A power is considered a universal function, so that the indices of a Pow is + just the collection of indices present in the expression. This may be + viewed as a bit inconsistent in the special case: + + x[i]**2 = x[i]*x[i] (1) + + The above expression could have been interpreted as the contraction of x[i] + with itself, but we choose instead to interpret it as a function + + lambda y: y**2 + + applied to each element of x (a universal function in numpy terms). In + order to allow an interpretation of (1) as a contraction, we need + contravariant and covariant Idx subclasses. (FIXME: this is not yet + implemented) + + Expressions in the base or exponent are subject to contraction as usual, + but an index that is present in the exponent, will not be considered + contractable with its own base. Note however, that indices in the same + exponent can be contracted with each other. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Pow + >>> from sympy import Pow, exp, IndexedBase, Idx + >>> A = IndexedBase('A') + >>> x = IndexedBase('x') + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> _get_indices_Pow(exp(A[i, j]*x[j])) + ({i}, {}) + >>> _get_indices_Pow(Pow(x[i], x[i])) + ({i}, {}) + >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i])) + ({i}, {}) + + """ + base, exp = expr.as_base_exp() + binds, bsyms = get_indices(base) + einds, esyms = get_indices(exp) + + inds = binds | einds + + # FIXME: symmetries from power needs to check special cases, else nothing + symmetries = {} + + return inds, symmetries + + +def _get_indices_Add(expr): + """Determine outer indices of an Add object. + + In a sum, each term must have the same set of outer indices. A valid + expression could be + + x(i)*y(j) - x(j)*y(i) + + But we do not allow expressions like: + + x(i)*y(j) - z(j)*z(j) + + FIXME: Add support for Numpy broadcasting + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Add + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Add(x[i] + x[k]*y[i, k]) + ({i}, {}) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + # allow broadcast of scalars + non_scalars = [x for x in inds if x != set()] + if not non_scalars: + return set(), {} + + if not all(x == non_scalars[0] for x in non_scalars[1:]): + raise IndexConformanceException("Indices are not consistent: %s" % expr) + if not reduce(lambda x, y: x != y or y, syms): + symmetries = syms[0] + else: + # FIXME: search for symmetries + symmetries = {} + + return non_scalars[0], symmetries + + +def get_indices(expr): + """Determine the outer indices of expression ``expr`` + + By *outer* we mean indices that are not summation indices. Returns a set + and a dict. The set contains outer indices and the dict contains + information about index symmetries. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_indices + >>> from sympy import symbols + >>> from sympy.tensor import IndexedBase + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, a, z = symbols('i j a z', integer=True) + + The indices of the total expression is determined, Repeated indices imply a + summation, for instance the trace of a matrix A: + + >>> get_indices(A[i, i]) + (set(), {}) + + In the case of many terms, the terms are required to have identical + outer indices. Else an IndexConformanceException is raised. + + >>> get_indices(x[i] + A[i, j]*y[j]) + ({i}, {}) + + :Exceptions: + + An IndexConformanceException means that the terms ar not compatible, e.g. + + >>> get_indices(x[i] + y[j]) #doctest: +SKIP + (...) + IndexConformanceException: Indices are not consistent: x(i) + y(j) + + .. warning:: + The concept of *outer* indices applies recursively, starting on the deepest + level. This implies that dummies inside parenthesis are assumed to be + summed first, so that the following expression is handled gracefully: + + >>> get_indices((x[i] + A[i, j]*y[j])*x[j]) + ({i, j}, {}) + + This is correct and may appear convenient, but you need to be careful + with this as SymPy will happily .expand() the product, if requested. The + resulting expression would mix the outer ``j`` with the dummies inside + the parenthesis, which makes it a different expression. To be on the + safe side, it is best to avoid such ambiguities by using unique indices + for all contractions that should be held separate. + + """ + # We call ourself recursively to determine indices of sub expressions. + + # break recursion + if isinstance(expr, Indexed): + c = expr.indices + inds, dummies = _remove_repeated(c) + return inds, {} + elif expr is None: + return set(), {} + elif isinstance(expr, Idx): + return {expr}, {} + elif expr.is_Atom: + return set(), {} + + + # recurse via specialized functions + else: + if expr.is_Mul: + return _get_indices_Mul(expr) + elif expr.is_Add: + return _get_indices_Add(expr) + elif expr.is_Pow or isinstance(expr, exp): + return _get_indices_Pow(expr) + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return set(), {} + elif isinstance(expr, Function): + # Support ufunc like behaviour by returning indices from arguments. + # Functions do not interpret repeated indices across arguments + # as summation + ind0 = set() + for arg in expr.args: + ind, sym = get_indices(arg) + ind0 |= ind + return ind0, sym + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return set(), {} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) + + +def get_contraction_structure(expr): + """Determine dummy indices of ``expr`` and describe its structure + + By *dummy* we mean indices that are summation indices. + + The structure of the expression is determined and described as follows: + + 1) A conforming summation of Indexed objects is described with a dict where + the keys are summation indices and the corresponding values are sets + containing all terms for which the summation applies. All Add objects + in the SymPy expression tree are described like this. + + 2) For all nodes in the SymPy expression tree that are *not* of type Add, the + following applies: + + If a node discovers contractions in one of its arguments, the node + itself will be stored as a key in the dict. For that key, the + corresponding value is a list of dicts, each of which is the result of a + recursive call to get_contraction_structure(). The list contains only + dicts for the non-trivial deeper contractions, omitting dicts with None + as the one and only key. + + .. Note:: The presence of expressions among the dictionary keys indicates + multiple levels of index contractions. A nested dict displays nested + contractions and may itself contain dicts from a deeper level. In + practical calculations the summation in the deepest nested level must be + calculated first so that the outer expression can access the resulting + indexed object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_contraction_structure + >>> from sympy import default_sort_key + >>> from sympy.tensor import IndexedBase, Idx + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l']) + >>> get_contraction_structure(x[i]*y[i] + A[j, j]) + {(i,): {x[i]*y[i]}, (j,): {A[j, j]}} + >>> get_contraction_structure(x[i]*y[j]) + {None: {x[i]*y[j]}} + + A multiplication of contracted factors results in nested dicts representing + the internal contractions. + + >>> d = get_contraction_structure(x[i, i]*y[j, j]) + >>> sorted(d.keys(), key=default_sort_key) + [None, x[i, i]*y[j, j]] + + In this case, the product has no contractions: + + >>> d[None] + {x[i, i]*y[j, j]} + + Factors are contracted "first": + + >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key) + [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}] + + A parenthesized Add object is also returned as a nested dictionary. The + term containing the parenthesis is a Mul with a contraction among the + arguments, so it will be found as a key in the result. It stores the + dictionary resulting from a recursive call on the Add expression. + + >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j])) + >>> sorted(d.keys(), key=default_sort_key) + [(A[i, j]*x[j] + y[i])*x[i], (i,)] + >>> d[(i,)] + {(A[i, j]*x[j] + y[i])*x[i]} + >>> d[x[i]*(A[i, j]*x[j] + y[i])] + [{None: {y[i]}, (j,): {A[i, j]*x[j]}}] + + Powers with contractions in either base or exponent will also be found as + keys in the dictionary, mapping to a list of results from recursive calls: + + >>> d = get_contraction_structure(A[j, j]**A[i, i]) + >>> d[None] + {A[j, j]**A[i, i]} + >>> nested_contractions = d[A[j, j]**A[i, i]] + >>> nested_contractions[0] + {(j,): {A[j, j]}} + >>> nested_contractions[1] + {(i,): {A[i, i]}} + + The description of the contraction structure may appear complicated when + represented with a string in the above examples, but it is easy to iterate + over: + + >>> from sympy import Expr + >>> for key in d: + ... if isinstance(key, Expr): + ... continue + ... for term in d[key]: + ... if term in d: + ... # treat deepest contraction first + ... pass + ... # treat outermost contactions here + + """ + + # We call ourself recursively to inspect sub expressions. + + if isinstance(expr, Indexed): + junk, key = _remove_repeated(expr.indices) + return {key or None: {expr}} + elif expr.is_Atom: + return {None: {expr}} + elif expr.is_Mul: + junk, junk, key = _get_indices_Mul(expr, return_dummies=True) + result = {key or None: {expr}} + # recurse on every factor + nested = [] + for fac in expr.args: + facd = get_contraction_structure(fac) + if not (None in facd and len(facd) == 1): + nested.append(facd) + if nested: + result[expr] = nested + return result + elif expr.is_Pow or isinstance(expr, exp): + # recurse in base and exp separately. If either has internal + # contractions we must include ourselves as a key in the returned dict + b, e = expr.as_base_exp() + dbase = get_contraction_structure(b) + dexp = get_contraction_structure(e) + + dicts = [] + for d in dbase, dexp: + if not (None in d and len(d) == 1): + dicts.append(d) + result = {None: {expr}} + if dicts: + result[expr] = dicts + return result + elif expr.is_Add: + # Note: we just collect all terms with identical summation indices, We + # do nothing to identify equivalent terms here, as this would require + # substitutions or pattern matching in expressions of unknown + # complexity. + result = {} + for term in expr.args: + # recurse on every term + d = get_contraction_structure(term) + for key in d: + if key in result: + result[key] |= d[key] + else: + result[key] = d[key] + return result + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return {None: expr} + elif isinstance(expr, Function): + # Collect non-trivial contraction structures in each argument + # We do not report repeated indices in separate arguments as a + # contraction + deeplist = [] + for arg in expr.args: + deep = get_contraction_structure(arg) + if not (None in deep and len(deep) == 1): + deeplist.append(deep) + d = {None: {expr}} + if deeplist: + d[expr] = deeplist + return d + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return {None: {expr}} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/indexed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..feddad21e52bbab2e1243beafdb11f30b2eded4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/indexed.py @@ -0,0 +1,793 @@ +r"""Module that defines indexed objects. + +The classes ``IndexedBase``, ``Indexed``, and ``Idx`` represent a +matrix element ``M[i, j]`` as in the following diagram:: + + 1) The Indexed class represents the entire indexed object. + | + ___|___ + ' ' + M[i, j] + / \__\______ + | | + | | + | 2) The Idx class represents indices; each Idx can + | optionally contain information about its range. + | + 3) IndexedBase represents the 'stem' of an indexed object, here `M`. + The stem used by itself is usually taken to represent the entire + array. + +There can be any number of indices on an Indexed object. No +transformation properties are implemented in these Base objects, but +implicit contraction of repeated indices is supported. + +Note that the support for complicated (i.e. non-atomic) integer +expressions as indices is limited. (This should be improved in +future releases.) + +Examples +======== + +To express the above matrix element example you would write: + +>>> from sympy import symbols, IndexedBase, Idx +>>> M = IndexedBase('M') +>>> i, j = symbols('i j', cls=Idx) +>>> M[i, j] +M[i, j] + +Repeated indices in a product implies a summation, so to express a +matrix-vector product in terms of Indexed objects: + +>>> x = IndexedBase('x') +>>> M[i, j]*x[j] +M[i, j]*x[j] + +If the indexed objects will be converted to component based arrays, e.g. +with the code printers or the autowrap framework, you also need to provide +(symbolic or numerical) dimensions. This can be done by passing an +optional shape parameter to IndexedBase upon construction: + +>>> dim1, dim2 = symbols('dim1 dim2', integer=True) +>>> A = IndexedBase('A', shape=(dim1, 2*dim1, dim2)) +>>> A.shape +(dim1, 2*dim1, dim2) +>>> A[i, j, 3].shape +(dim1, 2*dim1, dim2) + +If an IndexedBase object has no shape information, it is assumed that the +array is as large as the ranges of its indices: + +>>> n, m = symbols('n m', integer=True) +>>> i = Idx('i', m) +>>> j = Idx('j', n) +>>> M[i, j].shape +(m, n) +>>> M[i, j].ranges +[(0, m - 1), (0, n - 1)] + +The above can be compared with the following: + +>>> A[i, 2, j].shape +(dim1, 2*dim1, dim2) +>>> A[i, 2, j].ranges +[(0, m - 1), None, (0, n - 1)] + +To analyze the structure of indexed expressions, you can use the methods +get_indices() and get_contraction_structure(): + +>>> from sympy.tensor import get_indices, get_contraction_structure +>>> get_indices(A[i, j, j]) +({i}, {}) +>>> get_contraction_structure(A[i, j, j]) +{(j,): {A[i, j, j]}} + +See the appropriate docstrings for a detailed explanation of the output. +""" + +# TODO: (some ideas for improvement) +# +# o test and guarantee numpy compatibility +# - implement full support for broadcasting +# - strided arrays +# +# o more functions to analyze indexed expressions +# - identify standard constructs, e.g matrix-vector product in a subexpression +# +# o functions to generate component based arrays (numpy and sympy.Matrix) +# - generate a single array directly from Indexed +# - convert simple sub-expressions +# +# o sophisticated indexing (possibly in subclasses to preserve simplicity) +# - Idx with range smaller than dimension of Indexed +# - Idx with stepsize != 1 +# - Idx with step determined by function call +from collections.abc import Iterable + +from sympy.core.numbers import Number +from sympy.core.assumptions import StdFactKB +from sympy.core import Expr, Tuple, sympify, S +from sympy.core.symbol import _filter_assumptions, Symbol +from sympy.core.logic import fuzzy_bool, fuzzy_not +from sympy.core.sympify import _sympify +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.multipledispatch import dispatch +from sympy.utilities.iterables import is_sequence, NotIterable +from sympy.utilities.misc import filldedent + + +class IndexException(Exception): + pass + + +class Indexed(Expr): + """Represents a mathematical object with indices. + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j) + A[i, j] + + It is recommended that ``Indexed`` objects be created by indexing ``IndexedBase``: + ``IndexedBase('A')[i, j]`` instead of ``Indexed(IndexedBase('A'), i, j)``. + + >>> A = IndexedBase('A') + >>> a_ij = A[i, j] # Prefer this, + >>> b_ij = Indexed(A, i, j) # over this. + >>> a_ij == b_ij + True + + """ + is_Indexed = True + is_symbol = True + is_Atom = True + + def __new__(cls, base, *args, **kw_args): + from sympy.tensor.array.ndim_array import NDimArray + from sympy.matrices.matrixbase import MatrixBase + + if not args: + raise IndexException("Indexed needs at least one index.") + if isinstance(base, (str, Symbol)): + base = IndexedBase(base) + elif not hasattr(base, '__getitem__') and not isinstance(base, IndexedBase): + raise TypeError(filldedent(""" + The base can only be replaced with a string, Symbol, + IndexedBase or an object with a method for getting + items (i.e. an object with a `__getitem__` method). + """)) + args = list(map(sympify, args)) + if isinstance(base, (NDimArray, Iterable, Tuple, MatrixBase)) and all(i.is_number for i in args): + if len(args) == 1: + return base[args[0]] + else: + return base[args] + + base = _sympify(base) + + obj = Expr.__new__(cls, base, *args, **kw_args) + + IndexedBase._set_assumptions(obj, base.assumptions0) + + return obj + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def name(self): + return str(self) + + @property + def _diff_wrt(self): + """Allow derivatives with respect to an ``Indexed`` object.""" + return True + + def _eval_derivative(self, wrt): + from sympy.tensor.array.ndim_array import NDimArray + + if isinstance(wrt, Indexed) and wrt.base == self.base: + if len(self.indices) != len(wrt.indices): + msg = "Different # of indices: d({!s})/d({!s})".format(self, + wrt) + raise IndexException(msg) + result = S.One + for index1, index2 in zip(self.indices, wrt.indices): + result *= KroneckerDelta(index1, index2) + return result + elif isinstance(self.base, NDimArray): + from sympy.tensor.array import derive_by_array + return Indexed(derive_by_array(self.base, wrt), *self.args[1:]) + else: + if Tuple(self.indices).has(wrt): + return S.NaN + return S.Zero + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + @property + def base(self): + """Returns the ``IndexedBase`` of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).base + A + >>> B = IndexedBase('B') + >>> B == B[i, j].base + True + + """ + return self.args[0] + + @property + def indices(self): + """ + Returns the indices of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).indices + (i, j) + + """ + return self.args[1:] + + @property + def rank(self): + """ + Returns the rank of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j, k, l, m = symbols('i:m', cls=Idx) + >>> Indexed('A', i, j).rank + 2 + >>> q = Indexed('A', i, j, k, l, m) + >>> q.rank + 5 + >>> q.rank == len(q.indices) + True + + """ + return len(self.args) - 1 + + @property + def shape(self): + """Returns a list with dimensions of each index. + + Dimensions is a property of the array, not of the indices. Still, if + the ``IndexedBase`` does not define a shape attribute, it is assumed + that the ranges of the indices correspond to the shape of the array. + + >>> from sympy import IndexedBase, Idx, symbols + >>> n, m = symbols('n m', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', m) + >>> A = IndexedBase('A', shape=(n, n)) + >>> B = IndexedBase('B') + >>> A[i, j].shape + (n, n) + >>> B[i, j].shape + (m, m) + """ + + if self.base.shape: + return self.base.shape + sizes = [] + for i in self.indices: + upper = getattr(i, 'upper', None) + lower = getattr(i, 'lower', None) + if None in (upper, lower): + raise IndexException(filldedent(""" + Range is not defined for all indices in: %s""" % self)) + try: + size = upper - lower + 1 + except TypeError: + raise IndexException(filldedent(""" + Shape cannot be inferred from Idx with + undefined range: %s""" % self)) + sizes.append(size) + return Tuple(*sizes) + + @property + def ranges(self): + """Returns a list of tuples with lower and upper range of each index. + + If an index does not define the data members upper and lower, the + corresponding slot in the list contains ``None`` instead of a tuple. + + Examples + ======== + + >>> from sympy import Indexed,Idx, symbols + >>> Indexed('A', Idx('i', 2), Idx('j', 4), Idx('k', 8)).ranges + [(0, 1), (0, 3), (0, 7)] + >>> Indexed('A', Idx('i', 3), Idx('j', 3), Idx('k', 3)).ranges + [(0, 2), (0, 2), (0, 2)] + >>> x, y, z = symbols('x y z', integer=True) + >>> Indexed('A', x, y, z).ranges + [None, None, None] + + """ + ranges = [] + sentinel = object() + for i in self.indices: + upper = getattr(i, 'upper', sentinel) + lower = getattr(i, 'lower', sentinel) + if sentinel not in (upper, lower): + ranges.append((lower, upper)) + else: + ranges.append(None) + return ranges + + def _sympystr(self, p): + indices = list(map(p.doprint, self.indices)) + return "%s[%s]" % (p.doprint(self.base), ", ".join(indices)) + + @property + def free_symbols(self): + base_free_symbols = self.base.free_symbols + indices_free_symbols = { + fs for i in self.indices for fs in i.free_symbols} + if base_free_symbols: + return {self} | base_free_symbols | indices_free_symbols + else: + return indices_free_symbols + + @property + def expr_free_symbols(self): + from sympy.utilities.exceptions import sympy_deprecation_warning + sympy_deprecation_warning(""" + The expr_free_symbols property is deprecated. Use free_symbols to get + the free symbols of an expression. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-expr-free-symbols") + + return {self} + + +class IndexedBase(Expr, NotIterable): + """Represent the base or stem of an indexed object + + The IndexedBase class represent an array that contains elements. The main purpose + of this class is to allow the convenient creation of objects of the Indexed + class. The __getitem__ method of IndexedBase returns an instance of + Indexed. Alone, without indices, the IndexedBase class can be used as a + notation for e.g. matrix equations, resembling what you could do with the + Symbol class. But, the IndexedBase class adds functionality that is not + available for Symbol instances: + + - An IndexedBase object can optionally store shape information. This can + be used in to check array conformance and conditions for numpy + broadcasting. (TODO) + - An IndexedBase object implements syntactic sugar that allows easy symbolic + representation of array operations, using implicit summation of + repeated indices. + - The IndexedBase object symbolizes a mathematical structure equivalent + to arrays, and is recognized as such for code generation and automatic + compilation and wrapping. + + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> A = IndexedBase('A'); A + A + >>> type(A) + + + When an IndexedBase object receives indices, it returns an array with named + axes, represented by an Indexed object: + + >>> i, j = symbols('i j', integer=True) + >>> A[i, j, 2] + A[i, j, 2] + >>> type(A[i, j, 2]) + + + The IndexedBase constructor takes an optional shape argument. If given, + it overrides any shape information in the indices. (But not the index + ranges!) + + >>> m, n, o, p = symbols('m n o p', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', n) + >>> A[i, j].shape + (m, n) + >>> B = IndexedBase('B', shape=(o, p)) + >>> B[i, j].shape + (o, p) + + Assumptions can be specified with keyword arguments the same way as for Symbol: + + >>> A_real = IndexedBase('A', real=True) + >>> A_real.is_real + True + >>> A != A_real + True + + Assumptions can also be inherited if a Symbol is used to initialize the IndexedBase: + + >>> I = symbols('I', integer=True) + >>> C_inherit = IndexedBase(I) + >>> C_explicit = IndexedBase('I', integer=True) + >>> C_inherit == C_explicit + True + """ + is_symbol = True + is_Atom = True + + @staticmethod + def _set_assumptions(obj, assumptions): + """Set assumptions on obj, making sure to apply consistent values.""" + tmp_asm_copy = assumptions.copy() + is_commutative = fuzzy_bool(assumptions.get('commutative', True)) + assumptions['commutative'] = is_commutative + obj._assumptions = StdFactKB(assumptions) + obj._assumptions._generator = tmp_asm_copy # Issue #8873 + + def __new__(cls, label, shape=None, *, offset=S.Zero, strides=None, **kw_args): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array.ndim_array import NDimArray + + assumptions, kw_args = _filter_assumptions(kw_args) + if isinstance(label, str): + label = Symbol(label, **assumptions) + elif isinstance(label, Symbol): + assumptions = label._merge(assumptions) + elif isinstance(label, (MatrixBase, NDimArray)): + return label + elif isinstance(label, Iterable): + return _sympify(label) + else: + label = _sympify(label) + + if is_sequence(shape): + shape = Tuple(*shape) + elif shape is not None: + shape = Tuple(shape) + + if shape is not None: + obj = Expr.__new__(cls, label, shape) + else: + obj = Expr.__new__(cls, label) + obj._shape = shape + obj._offset = offset + obj._strides = strides + obj._name = str(label) + + IndexedBase._set_assumptions(obj, assumptions) + return obj + + @property + def name(self): + return self._name + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + def __getitem__(self, indices, **kw_args): + if is_sequence(indices): + # Special case needed because M[*my_tuple] is a syntax error. + if self.shape and len(self.shape) != len(indices): + raise IndexException("Rank mismatch.") + return Indexed(self, *indices, **kw_args) + else: + if self.shape and len(self.shape) != 1: + raise IndexException("Rank mismatch.") + return Indexed(self, indices, **kw_args) + + @property + def shape(self): + """Returns the shape of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase, Idx + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).shape + (x, y) + + Note: If the shape of the ``IndexedBase`` is specified, it will override + any shape information given by the indices. + + >>> A = IndexedBase('A', shape=(x, y)) + >>> B = IndexedBase('B') + >>> i = Idx('i', 2) + >>> j = Idx('j', 1) + >>> A[i, j].shape + (x, y) + >>> B[i, j].shape + (2, 1) + + """ + return self._shape + + @property + def strides(self): + """Returns the strided scheme for the ``IndexedBase`` object. + + Normally this is a tuple denoting the number of + steps to take in the respective dimension when traversing + an array. For code generation purposes strides='C' and + strides='F' can also be used. + + strides='C' would mean that code printer would unroll + in row-major order and 'F' means unroll in column major + order. + + """ + + return self._strides + + @property + def offset(self): + """Returns the offset for the ``IndexedBase`` object. + + This is the value added to the resulting index when the + 2D Indexed object is unrolled to a 1D form. Used in code + generation. + + Examples + ========== + >>> from sympy.printing import ccode + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> l, m, n, o = symbols('l m n o', integer=True) + >>> A = IndexedBase('A', strides=(l, m, n), offset=o) + >>> i, j, k = map(Idx, 'ijk') + >>> ccode(A[i, j, k]) + 'A[l*i + m*j + n*k + o]' + + """ + return self._offset + + @property + def label(self): + """Returns the label of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).label + A + + """ + return self.args[0] + + def _sympystr(self, p): + return p.doprint(self.label) + + +class Idx(Expr): + """Represents an integer index as an ``Integer`` or integer expression. + + There are a number of ways to create an ``Idx`` object. The constructor + takes two arguments: + + ``label`` + An integer or a symbol that labels the index. + ``range`` + Optionally you can specify a range as either + + * ``Symbol`` or integer: This is interpreted as a dimension. Lower and + upper bounds are set to ``0`` and ``range - 1``, respectively. + * ``tuple``: The two elements are interpreted as the lower and upper + bounds of the range, respectively. + + Note: bounds of the range are assumed to be either integer or infinite (oo + and -oo are allowed to specify an unbounded range). If ``n`` is given as a + bound, then ``n.is_integer`` must not return false. + + For convenience, if the label is given as a string it is automatically + converted to an integer symbol. (Note: this conversion is not done for + range or dimension arguments.) + + Examples + ======== + + >>> from sympy import Idx, symbols, oo + >>> n, i, L, U = symbols('n i L U', integer=True) + + If a string is given for the label an integer ``Symbol`` is created and the + bounds are both ``None``: + + >>> idx = Idx('qwerty'); idx + qwerty + >>> idx.lower, idx.upper + (None, None) + + Both upper and lower bounds can be specified: + + >>> idx = Idx(i, (L, U)); idx + i + >>> idx.lower, idx.upper + (L, U) + + When only a single bound is given it is interpreted as the dimension + and the lower bound defaults to 0: + + >>> idx = Idx(i, n); idx.lower, idx.upper + (0, n - 1) + >>> idx = Idx(i, 4); idx.lower, idx.upper + (0, 3) + >>> idx = Idx(i, oo); idx.lower, idx.upper + (0, oo) + + """ + + is_integer = True + is_finite = True + is_real = True + is_symbol = True + is_Atom = True + _diff_wrt = True + + def __new__(cls, label, range=None, **kw_args): + + if isinstance(label, str): + label = Symbol(label, integer=True) + label, range = list(map(sympify, (label, range))) + + if label.is_Number: + if not label.is_integer: + raise TypeError("Index is not an integer number.") + return label + + if not label.is_integer: + raise TypeError("Idx object requires an integer label.") + + elif is_sequence(range): + if len(range) != 2: + raise ValueError(filldedent(""" + Idx range tuple must have length 2, but got %s""" % len(range))) + for bound in range: + if (bound.is_integer is False and bound is not S.Infinity + and bound is not S.NegativeInfinity): + raise TypeError("Idx object requires integer bounds.") + args = label, Tuple(*range) + elif isinstance(range, Expr): + if range is not S.Infinity and fuzzy_not(range.is_integer): + raise TypeError("Idx object requires an integer dimension.") + args = label, Tuple(0, range - 1) + elif range: + raise TypeError(filldedent(""" + The range must be an ordered iterable or + integer SymPy expression.""")) + else: + args = label, + + obj = Expr.__new__(cls, *args, **kw_args) + obj._assumptions["finite"] = True + obj._assumptions["real"] = True + return obj + + @property + def label(self): + """Returns the label (Integer or integer expression) of the Idx object. + + Examples + ======== + + >>> from sympy import Idx, Symbol + >>> x = Symbol('x', integer=True) + >>> Idx(x).label + x + >>> j = Symbol('j', integer=True) + >>> Idx(j).label + j + >>> Idx(j + 1).label + j + 1 + + """ + return self.args[0] + + @property + def lower(self): + """Returns the lower bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).lower + 0 + >>> Idx('j', 5).lower + 0 + >>> Idx('j').lower is None + True + + """ + try: + return self.args[1][0] + except IndexError: + return + + @property + def upper(self): + """Returns the upper bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).upper + 1 + >>> Idx('j', 5).upper + 4 + >>> Idx('j').upper is None + True + + """ + try: + return self.args[1][1] + except IndexError: + return + + def _sympystr(self, p): + return p.doprint(self.label) + + @property + def name(self): + return self.label.name if self.label.is_Symbol else str(self.label) + + @property + def free_symbols(self): + return {self} + + +@dispatch(Idx, Idx) +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs if rhs.upper is None else rhs.upper + other_lower = rhs if rhs.lower is None else rhs.lower + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Idx, Number) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs + other_lower = rhs + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Number, Idx) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = lhs + other_lower = lhs + + if rhs.upper is not None and (rhs.upper <= other_lower) == True: + return True + if rhs.lower is not None and (rhs.lower > other_upper) == True: + return False + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..579e7c7a86c2a1f18ab889af32ce0053a729ff5f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tensor.py @@ -0,0 +1,5265 @@ +""" +This module defines tensors with abstract index notation. + +The abstract index notation has been first formalized by Penrose. + +Tensor indices are formal objects, with a tensor type; there is no +notion of index range, it is only possible to assign the dimension, +used to trace the Kronecker delta; the dimension can be a Symbol. + +The Einstein summation convention is used. +The covariant indices are indicated with a minus sign in front of the index. + +For instance the tensor ``t = p(a)*A(b,c)*q(-c)`` has the index ``c`` +contracted. + +A tensor expression ``t`` can be called; called with its +indices in sorted order it is equal to itself: +in the above example ``t(a, b) == t``; +one can call ``t`` with different indices; ``t(c, d) == p(c)*A(d,a)*q(-a)``. + +The contracted indices are dummy indices, internally they have no name, +the indices being represented by a graph-like structure. + +Tensors are put in canonical form using ``canon_bp``, which uses +the Butler-Portugal algorithm for canonicalization using the monoterm +symmetries of the tensors. + +If there is a (anti)symmetric metric, the indices can be raised and +lowered when the tensor is put in canonical form. +""" + +from __future__ import annotations +from typing import Any +from functools import reduce +from math import prod + +from abc import abstractmethod, ABC +from collections import defaultdict +import operator +import itertools + +from sympy.core.numbers import (Integer, Rational) +from sympy.combinatorics import Permutation +from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \ + bsgs_direct_product, canonicalize, riemann_bsgs +from sympy.core import Basic, Expr, sympify, Add, Mul, S +from sympy.core.cache import clear_cache +from sympy.core.containers import Tuple, Dict +from sympy.core.function import WildFunction +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol, symbols, Wild +from sympy.core.sympify import CantSympify, _sympify +from sympy.core.operations import AssocOp +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import eye +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) +from sympy.utilities.decorator import memoize_property, deprecated +from sympy.utilities.iterables import sift + + +def deprecate_data(): + sympy_deprecation_warning( + """ + The data attribute of TensorIndexType is deprecated. Use The + replace_with_arrays() method instead. + """, + deprecated_since_version="1.4", + active_deprecations_target="deprecated-tensorindextype-attrs", + stacklevel=4, + ) + +def deprecate_fun_eval(): + sympy_deprecation_warning( + """ + The Tensor.fun_eval() method is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +def deprecate_call(): + sympy_deprecation_warning( + """ + Calling a tensor like Tensor(*indices) is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +class _IndexStructure(CantSympify): + """ + This class handles the indices (free and dummy ones). It contains the + algorithms to manage the dummy indices replacements and contractions of + free indices under multiplications of tensor expressions, as well as stuff + related to canonicalization sorting, getting the permutation of the + expression and so on. It also includes tools to get the ``TensorIndex`` + objects corresponding to the given index structure. + """ + + def __init__(self, free, dum, index_types, indices, canon_bp=False): + self.free = free + self.dum = dum + self.index_types = index_types + self.indices = indices + self._ext_rank = len(self.free) + 2*len(self.dum) + self.dum.sort(key=lambda x: x[0]) + + @staticmethod + def from_indices(*indices): + """ + Create a new ``_IndexStructure`` object from a list of ``indices``. + + Explanation + =========== + + ``indices`` ``TensorIndex`` objects, the indices. Contractions are + detected upon construction. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure.from_indices(m0, m1, -m1, m3) + _IndexStructure([(m0, 0), (m3, 3)], [(1, 2)], [Lorentz, Lorentz, Lorentz, Lorentz]) + """ + + free, dum = _IndexStructure._free_dum_from_indices(*indices) + index_types = [i.tensor_index_type for i in indices] + indices = _IndexStructure._replace_dummy_names(indices, free, dum) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def from_components_free_dum(components, free, dum): + index_types = [] + for component in components: + index_types.extend(component.index_types) + indices = _IndexStructure.generate_indices_from_free_dum_index_types(free, dum, index_types) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def _free_dum_from_indices(*indices): + """ + Convert ``indices`` into ``free``, ``dum`` for single component tensor. + + Explanation + =========== + + ``free`` list of tuples ``(index, pos, 0)``, + where ``pos`` is the position of index in + the list of indices formed by the component tensors + + ``dum`` list of tuples ``(pos_contr, pos_cov, 0, 0)`` + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, \ + _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure._free_dum_from_indices(m0, m1, -m1, m3) + ([(m0, 0), (m3, 3)], [(1, 2)]) + """ + n = len(indices) + if n == 1: + return [(indices[0], 0)], [] + + # find the positions of the free indices and of the dummy indices + free = [True]*len(indices) + index_dict = {} + dum = [] + for i, index in enumerate(indices): + name = index.name + typ = index.tensor_index_type + contr = index.is_up + if (name, typ) in index_dict: + # found a pair of dummy indices + is_contr, pos = index_dict[(name, typ)] + # check consistency and update free + if is_contr: + if contr: + raise ValueError('two equal contravariant indices in slots %d and %d' %(pos, i)) + else: + free[pos] = False + free[i] = False + else: + if contr: + free[pos] = False + free[i] = False + else: + raise ValueError('two equal covariant indices in slots %d and %d' %(pos, i)) + if contr: + dum.append((i, pos)) + else: + dum.append((pos, i)) + else: + index_dict[(name, typ)] = index.is_up, i + + free = [(index, i) for i, index in enumerate(indices) if free[i]] + free.sort() + return free, dum + + def get_indices(self): + """ + Get a list of indices, creating new tensor indices to complete dummy indices. + """ + return self.indices[:] + + @staticmethod + def generate_indices_from_free_dum_index_types(free, dum, index_types): + indices = [None]*(len(free)+2*len(dum)) + for idx, pos in free: + indices[pos] = idx + + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for pos1, pos2 in dum: + typ1 = index_types[pos1] + indname = generate_dummy_name(typ1) + indices[pos1] = TensorIndex(indname, typ1, True) + indices[pos2] = TensorIndex(indname, typ1, False) + + return _IndexStructure._replace_dummy_names(indices, free, dum) + + @staticmethod + def _get_generator_for_dummy_indices(free): + cdt = defaultdict(int) + # if the free indices have names with dummy_name, start with an + # index higher than those for the dummy indices + # to avoid name collisions + for indx, ipos in free: + if indx.name.split('_')[0] == indx.tensor_index_type.dummy_name: + cdt[indx.tensor_index_type] = max(cdt[indx.tensor_index_type], int(indx.name.split('_')[1]) + 1) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + return dummy_name_gen + + @staticmethod + def _replace_dummy_names(indices, free, dum): + dum.sort(key=lambda x: x[0]) + new_indices = list(indices) + assert len(indices) == len(free) + 2*len(dum) + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for ipos1, ipos2 in dum: + typ1 = new_indices[ipos1].tensor_index_type + indname = generate_dummy_name(typ1) + new_indices[ipos1] = TensorIndex(indname, typ1, True) + new_indices[ipos2] = TensorIndex(indname, typ1, False) + return new_indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices. + """ + # get sorted indices according to their position: + free = sorted(self.free, key=lambda x: x[1]) + return [i[0] for i in free] + + def __str__(self): + return "_IndexStructure({}, {}, {})".format(self.free, self.dum, self.index_types) + + def __repr__(self): + return self.__str__() + + def _get_sorted_free_indices_for_canon(self): + sorted_free = self.free[:] + sorted_free.sort(key=lambda x: x[0]) + return sorted_free + + def _get_sorted_dum_indices_for_canon(self): + return sorted(self.dum, key=lambda x: x[0]) + + def _get_lexicographically_sorted_index_types(self): + permutation = self.indices_canon_args()[0] + index_types = [None]*self._ext_rank + for i, it in enumerate(self.index_types): + index_types[permutation(i)] = it + return index_types + + def _get_lexicographically_sorted_indices(self): + permutation = self.indices_canon_args()[0] + indices = [None]*self._ext_rank + for i, it in enumerate(self.indices): + indices[permutation(i)] = it + return indices + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns a ``_IndexStructure`` instance corresponding to the permutation ``g``. + + Explanation + =========== + + ``g`` permutation corresponding to the tensor in the representation + used in canonicalization + + ``is_canon_bp`` if True, then ``g`` is the permutation + corresponding to the canonical form of the tensor + """ + sorted_free = [i[0] for i in self._get_sorted_free_indices_for_canon()] + lex_index_types = self._get_lexicographically_sorted_index_types() + lex_indices = self._get_lexicographically_sorted_indices() + nfree = len(sorted_free) + rank = self._ext_rank + dum = [[None]*2 for i in range((rank - nfree)//2)] + free = [] + + index_types = [None]*rank + indices = [None]*rank + for i in range(rank): + gi = g[i] + index_types[i] = lex_index_types[gi] + indices[i] = lex_indices[gi] + if gi < nfree: + ind = sorted_free[gi] + assert index_types[i] == sorted_free[gi].tensor_index_type + free.append((ind, i)) + else: + j = gi - nfree + idum, cov = divmod(j, 2) + if cov: + dum[idum][1] = i + else: + dum[idum][0] = i + dum = [tuple(x) for x in dum] + + return _IndexStructure(free, dum, index_types, indices) + + def indices_canon_args(self): + """ + Returns ``(g, dummies, msym, v)``, the entries of ``canonicalize`` + + See ``canonicalize`` in ``tensor_can.py`` in combinatorics module. + """ + # to be called after sorted_components + from sympy.combinatorics.permutations import _af_new + n = self._ext_rank + g = [None]*n + [n, n+1] + + # Converts the symmetry of the metric into msym from .canonicalize() + # method in the combinatorics module + def metric_symmetry_to_msym(metric): + if metric is None: + return None + sym = metric.symmetry + if sym == TensorSymmetry.fully_symmetric(2): + return 0 + if sym == TensorSymmetry.fully_symmetric(-2): + return 1 + return None + + # ordered indices: first the free indices, ordered by types + # then the dummy indices, ordered by types and contravariant before + # covariant + # g[position in tensor] = position in ordered indices + for i, (indx, ipos) in enumerate(self._get_sorted_free_indices_for_canon()): + g[ipos] = i + pos = len(self.free) + j = len(self.free) + dummies = [] + prev = None + a = [] + msym = [] + for ipos1, ipos2 in self._get_sorted_dum_indices_for_canon(): + g[ipos1] = j + g[ipos2] = j + 1 + j += 2 + typ = self.index_types[ipos1] + if typ != prev: + if a: + dummies.append(a) + a = [pos, pos + 1] + prev = typ + msym.append(metric_symmetry_to_msym(typ.metric)) + else: + a.extend([pos, pos + 1]) + pos += 2 + if a: + dummies.append(a) + + return _af_new(g), dummies, msym + + +def components_canon_args(components): + numtyp = [] + prev = None + for t in components: + if t == prev: + numtyp[-1][1] += 1 + else: + prev = t + numtyp.append([prev, 1]) + v = [] + for h, n in numtyp: + if h.comm in (0, 1): + comm = h.comm + else: + comm = TensorManager.get_comm(h.comm, h.comm) + v.append((h.symmetry.base, h.symmetry.generators, n, comm)) + return v + + +class _TensorDataLazyEvaluator(CantSympify): + """ + EXPERIMENTAL: do not rely on this class, it may change without deprecation + warnings in future versions of SymPy. + + Explanation + =========== + + This object contains the logic to associate components data to a tensor + expression. Components data are set via the ``.data`` property of tensor + expressions, is stored inside this class as a mapping between the tensor + expression and the ``ndarray``. + + Computations are executed lazily: whereas the tensor expressions can have + contractions, tensor products, and additions, components data are not + computed until they are accessed by reading the ``.data`` property + associated to the tensor expression. + """ + _substitutions_dict: dict[Any, Any] = {} + _substitutions_dict_tensmul: dict[Any, Any] = {} + + def __getitem__(self, key): + dat = self._get(key) + if dat is None: + return None + + from .array import NDimArray + if not isinstance(dat, NDimArray): + return dat + + if dat.rank() == 0: + return dat[()] + elif dat.rank() == 1 and len(dat) == 1: + return dat[0] + return dat + + def _get(self, key): + """ + Retrieve ``data`` associated with ``key``. + + Explanation + =========== + + This algorithm looks into ``self._substitutions_dict`` for all + ``TensorHead`` in the ``TensExpr`` (or just ``TensorHead`` if key is a + TensorHead instance). It reconstructs the components data that the + tensor expression should have by performing on components data the + operations that correspond to the abstract tensor operations applied. + + Metric tensor is handled in a different manner: it is pre-computed in + ``self._substitutions_dict_tensmul``. + """ + if key in self._substitutions_dict: + return self._substitutions_dict[key] + + if isinstance(key, TensorHead): + return None + + if isinstance(key, Tensor): + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in key.get_indices()]) + srch = (key.component,) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + array_list = [self.data_from_tensor(key)] + return self.data_contract_dum(array_list, key.dum, key.ext_rank) + + if isinstance(key, TensMul): + tensmul_args = key.args + if len(tensmul_args) == 1 and len(tensmul_args[0].components) == 1: + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in tensmul_args[0].get_indices()]) + srch = (tensmul_args[0].components[0],) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + #data_list = [self.data_from_tensor(i) for i in tensmul_args if isinstance(i, TensExpr)] + data_list = [self.data_from_tensor(i) if isinstance(i, Tensor) else i.data for i in tensmul_args if isinstance(i, TensExpr)] + coeff = prod([i for i in tensmul_args if not isinstance(i, TensExpr)]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + data_result = self.data_contract_dum(data_list, key.dum, key.ext_rank) + return coeff*data_result + + if isinstance(key, TensAdd): + data_list = [] + free_args_list = [] + for arg in key.args: + if isinstance(arg, TensExpr): + data_list.append(arg.data) + free_args_list.append([x[0] for x in arg.free]) + else: + data_list.append(arg) + free_args_list.append([]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + + sum_list = [] + from .array import permutedims + for data, free_args in zip(data_list, free_args_list): + if len(free_args) < 2: + sum_list.append(data) + else: + free_args_pos = {y: x for x, y in enumerate(free_args)} + axes = [free_args_pos[arg] for arg in key.free_args] + sum_list.append(permutedims(data, axes)) + return reduce(lambda x, y: x+y, sum_list) + + return None + + @staticmethod + def data_contract_dum(ndarray_list, dum, ext_rank): + from .array import tensorproduct, tensorcontraction, MutableDenseNDimArray + arrays = list(map(MutableDenseNDimArray, ndarray_list)) + prodarr = tensorproduct(*arrays) + return tensorcontraction(prodarr, *dum) + + def data_tensorhead_from_tensmul(self, data, tensmul, tensorhead): + """ + This method is used when assigning components data to a ``TensMul`` + object, it converts components data to a fully contravariant ndarray, + which is then stored according to the ``TensorHead`` key. + """ + if data is None: + return None + + return self._correct_signature_from_indices( + data, + tensmul.get_indices(), + tensmul.free, + tensmul.dum, + True) + + def data_from_tensor(self, tensor): + """ + This method corrects the components data to the right signature + (covariant/contravariant) using the metric associated with each + ``TensorIndexType``. + """ + tensorhead = tensor.component + + if tensorhead.data is None: + return None + + return self._correct_signature_from_indices( + tensorhead.data, + tensor.get_indices(), + tensor.free, + tensor.dum) + + def _assign_data_to_tensor_expr(self, key, data): + if isinstance(key, TensAdd): + raise ValueError('cannot assign data to TensAdd') + # here it is assumed that `key` is a `TensMul` instance. + if len(key.components) != 1: + raise ValueError('cannot assign data to TensMul with multiple components') + tensorhead = key.components[0] + newdata = self.data_tensorhead_from_tensmul(data, key, tensorhead) + return tensorhead, newdata + + def _check_permutations_on_data(self, tens, data): + from .array import permutedims + from .array.arrayop import Flatten + + if isinstance(tens, TensorHead): + rank = tens.rank + generators = tens.symmetry.generators + elif isinstance(tens, Tensor): + rank = tens.rank + generators = tens.components[0].symmetry.generators + elif isinstance(tens, TensorIndexType): + rank = tens.metric.rank + generators = tens.metric.symmetry.generators + + # Every generator is a permutation, check that by permuting the array + # by that permutation, the array will be the same, except for a + # possible sign change if the permutation admits it. + for gener in generators: + sign_change = +1 if (gener(rank) == rank) else -1 + data_swapped = data + last_data = data + permute_axes = list(map(gener, range(rank))) + # the order of a permutation is the number of times to get the + # identity by applying that permutation. + for i in range(gener.order()-1): + data_swapped = permutedims(data_swapped, permute_axes) + # if any value in the difference array is non-zero, raise an error: + if any(Flatten(last_data - sign_change*data_swapped)): + raise ValueError("Component data symmetry structure error") + last_data = data_swapped + + def __setitem__(self, key, value): + """ + Set the components data of a tensor object/expression. + + Explanation + =========== + + Components data are transformed to the all-contravariant form and stored + with the corresponding ``TensorHead`` object. If a ``TensorHead`` object + cannot be uniquely identified, it will raise an error. + """ + data = _TensorDataLazyEvaluator.parse_data(value) + self._check_permutations_on_data(key, data) + + # TensorHead and TensorIndexType can be assigned data directly, while + # TensMul must first convert data to a fully contravariant form, and + # assign it to its corresponding TensorHead single component. + if not isinstance(key, (TensorHead, TensorIndexType)): + key, data = self._assign_data_to_tensor_expr(key, data) + + if isinstance(key, TensorHead): + for dim, indextype in zip(data.shape, key.index_types): + if indextype.data is None: + raise ValueError("index type {} has no components data"\ + " associated (needed to raise/lower index)".format(indextype)) + if not indextype.dim.is_number: + continue + if dim != indextype.dim: + raise ValueError("wrong dimension of ndarray") + self._substitutions_dict[key] = data + + def __delitem__(self, key): + del self._substitutions_dict[key] + + def __contains__(self, key): + return key in self._substitutions_dict + + def add_metric_data(self, metric, data): + """ + Assign data to the ``metric`` tensor. The metric tensor behaves in an + anomalous way when raising and lowering indices. + + Explanation + =========== + + A fully covariant metric is the inverse transpose of the fully + contravariant metric (it is meant matrix inverse). If the metric is + symmetric, the transpose is not necessary and mixed + covariant/contravariant metrics are Kronecker deltas. + """ + # hard assignment, data should not be added to `TensorHead` for metric: + # the problem with `TensorHead` is that the metric is anomalous, i.e. + # raising and lowering the index means considering the metric or its + # inverse, this is not the case for other tensors. + self._substitutions_dict_tensmul[metric, True, True] = data + inverse_transpose = self.inverse_transpose_matrix(data) + # in symmetric spaces, the transpose is the same as the original matrix, + # the full covariant metric tensor is the inverse transpose, so this + # code will be able to handle non-symmetric metrics. + self._substitutions_dict_tensmul[metric, False, False] = inverse_transpose + # now mixed cases, these are identical to the unit matrix if the metric + # is symmetric. + m = data.tomatrix() + invt = inverse_transpose.tomatrix() + self._substitutions_dict_tensmul[metric, True, False] = m * invt + self._substitutions_dict_tensmul[metric, False, True] = invt * m + + @staticmethod + def _flip_index_by_metric(data, metric, pos): + from .array import tensorproduct, tensorcontraction + + mdim = metric.rank() + ddim = data.rank() + + if pos == 0: + data = tensorcontraction( + tensorproduct( + metric, + data + ), + (1, mdim+pos) + ) + else: + data = tensorcontraction( + tensorproduct( + data, + metric + ), + (pos, ddim) + ) + return data + + @staticmethod + def inverse_matrix(ndarray): + m = ndarray.tomatrix().inv() + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def inverse_transpose_matrix(ndarray): + m = ndarray.tomatrix().inv().T + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def _correct_signature_from_indices(data, indices, free, dum, inverse=False): + """ + Utility function to correct the values inside the components data + ndarray according to whether indices are covariant or contravariant. + + It uses the metric matrix to lower values of covariant indices. + """ + # change the ndarray values according covariantness/contravariantness of the indices + # use the metric + for i, indx in enumerate(indices): + if not indx.is_up and not inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric(data, indx.tensor_index_type.data, i) + elif not indx.is_up and inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric( + data, + _TensorDataLazyEvaluator.inverse_matrix(indx.tensor_index_type.data), + i + ) + return data + + @staticmethod + def _sort_data_axes(old, new): + from .array import permutedims + + new_data = old.data.copy() + + old_free = [i[0] for i in old.free] + new_free = [i[0] for i in new.free] + + for i in range(len(new_free)): + for j in range(i, len(old_free)): + if old_free[j] == new_free[i]: + old_free[i], old_free[j] = old_free[j], old_free[i] + new_data = permutedims(new_data, (i, j)) + break + return new_data + + @staticmethod + def add_rearrange_tensmul_parts(new_tensmul, old_tensmul): + def sorted_compo(): + return _TensorDataLazyEvaluator._sort_data_axes(old_tensmul, new_tensmul) + + _TensorDataLazyEvaluator._substitutions_dict[new_tensmul] = sorted_compo() + + @staticmethod + def parse_data(data): + """ + Transform ``data`` to array. The parameter ``data`` may + contain data in various formats, e.g. nested lists, SymPy ``Matrix``, + and so on. + + Examples + ======== + + >>> from sympy.tensor.tensor import _TensorDataLazyEvaluator + >>> _TensorDataLazyEvaluator.parse_data([1, 3, -6, 12]) + [1, 3, -6, 12] + + >>> _TensorDataLazyEvaluator.parse_data([[1, 2], [4, 7]]) + [[1, 2], [4, 7]] + """ + from .array import MutableDenseNDimArray + + if not isinstance(data, MutableDenseNDimArray): + if len(data) == 2 and hasattr(data[0], '__call__'): + data = MutableDenseNDimArray(data[0], data[1]) + else: + data = MutableDenseNDimArray(data) + return data + +_tensor_data_substitution_dict = _TensorDataLazyEvaluator() + + +class _TensorManager: + """ + Class to manage tensor properties. + + Notes + ===== + + Tensors belong to tensor commutation groups; each group has a label + ``comm``; there are predefined labels: + + ``0`` tensors commuting with any other tensor + + ``1`` tensors anticommuting among themselves + + ``2`` tensors not commuting, apart with those with ``comm=0`` + + Other groups can be defined using ``set_comm``; tensors in those + groups commute with those with ``comm=0``; by default they + do not commute with any other group. + """ + def __init__(self): + self._comm_init() + + def _comm_init(self): + self._comm = [{} for i in range(3)] + for i in range(3): + self._comm[0][i] = 0 + self._comm[i][0] = 0 + self._comm[1][1] = 1 + self._comm[2][1] = None + self._comm[1][2] = None + self._comm_symbols2i = {0:0, 1:1, 2:2} + self._comm_i2symbol = {0:0, 1:1, 2:2} + + @property + def comm(self): + return self._comm + + def comm_symbols2i(self, i): + """ + Get the commutation group number corresponding to ``i``. + + ``i`` can be a symbol or a number or a string. + + If ``i`` is not already defined its commutation group number + is set. + """ + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + return n + return self._comm_symbols2i[i] + + def comm_i2symbol(self, i): + """ + Returns the symbol corresponding to the commutation group number. + """ + return self._comm_i2symbol[i] + + def set_comm(self, i, j, c): + """ + Set the commutation parameter ``c`` for commutation groups ``i, j``. + + Parameters + ========== + + i, j : symbols representing commutation groups + + c : group commutation number + + Notes + ===== + + ``i, j`` can be symbols, strings or numbers, + apart from ``0, 1`` and ``2`` which are reserved respectively + for commuting, anticommuting tensors and tensors not commuting + with any other group apart with the commuting tensors. + For the remaining cases, use this method to set the commutation rules; + by default ``c=None``. + + The group commutation number ``c`` is assigned in correspondence + to the group commutation symbols; it can be + + 0 commuting + + 1 anticommuting + + None no commutation property + + Examples + ======== + + ``G`` and ``GH`` do not commute with themselves and commute with + each other; A is commuting. + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorManager, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz') + >>> i0,i1,i2,i3,i4 = tensor_indices('i0:5', Lorentz) + >>> A = TensorHead('A', [Lorentz]) + >>> G = TensorHead('G', [Lorentz], TensorSymmetry.no_symmetry(1), 'Gcomm') + >>> GH = TensorHead('GH', [Lorentz], TensorSymmetry.no_symmetry(1), 'GHcomm') + >>> TensorManager.set_comm('Gcomm', 'GHcomm', 0) + >>> (GH(i1)*G(i0)).canon_bp() + G(i0)*GH(i1) + >>> (G(i1)*G(i0)).canon_bp() + G(i1)*G(i0) + >>> (G(i1)*A(i0)).canon_bp() + A(i0)*G(i1) + """ + if c not in (0, 1, None): + raise ValueError('`c` can assume only the values 0, 1 or None') + + i = sympify(i) + j = sympify(j) + + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + if j not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[0][n] = 0 + self._comm[n][0] = 0 + self._comm_symbols2i[j] = n + self._comm_i2symbol[n] = j + ni = self._comm_symbols2i[i] + nj = self._comm_symbols2i[j] + self._comm[ni][nj] = c + self._comm[nj][ni] = c + + """ + Cached sympy functions (e.g. expand) may have cached the results of + expressions involving tensors, but those results may not be valid after + changing the commutation properties. To stay on the safe side, we clear + the cache of all functions. + """ + clear_cache() + + def set_comms(self, *args): + """ + Set the commutation group numbers ``c`` for symbols ``i, j``. + + Parameters + ========== + + args : sequence of ``(i, j, c)`` + """ + for i, j, c in args: + self.set_comm(i, j, c) + + def get_comm(self, i, j): + """ + Return the commutation parameter for commutation group numbers ``i, j`` + + see ``_TensorManager.set_comm`` + """ + return self._comm[i].get(j, 0 if i == 0 or j == 0 else None) + + def clear(self): + """ + Clear the TensorManager. + """ + self._comm_init() + + +TensorManager = _TensorManager() + + +class TensorIndexType(Basic): + """ + A TensorIndexType is characterized by its name and its metric. + + Parameters + ========== + + name : name of the tensor type + dummy_name : name of the head of dummy indices + dim : dimension, it can be a symbol or an integer or ``None`` + eps_dim : dimension of the epsilon tensor + metric_symmetry : integer that denotes metric symmetry or ``None`` for no metric + metric_name : string with the name of the metric tensor + + Attributes + ========== + + ``metric`` : the metric tensor + ``delta`` : ``Kronecker delta`` + ``epsilon`` : the ``Levi-Civita epsilon`` tensor + ``data`` : (deprecated) a property to add ``ndarray`` values, to work in a specified basis. + + Notes + ===== + + The possible values of the ``metric_symmetry`` parameter are: + + ``1`` : metric tensor is fully symmetric + ``0`` : metric tensor possesses no index symmetry + ``-1`` : metric tensor is fully antisymmetric + ``None``: there is no metric tensor (metric equals to ``None``) + + The metric is assumed to be symmetric by default. It can also be set + to a custom tensor by the ``.set_metric()`` method. + + If there is a metric the metric is used to raise and lower indices. + + In the case of non-symmetric metric, the following raising and + lowering conventions will be adopted: + + ``psi(a) = g(a, b)*psi(-b); chi(-a) = chi(b)*g(-b, -a)`` + + From these it is easy to find: + + ``g(-a, b) = delta(-a, b)`` + + where ``delta(-a, b) = delta(b, -a)`` is the ``Kronecker delta`` + (see ``TensorIndex`` for the conventions on indices). + For antisymmetric metrics there is also the following equality: + + ``g(a, -b) = -delta(a, -b)`` + + If there is no metric it is not possible to raise or lower indices; + e.g. the index of the defining representation of ``SU(N)`` + is 'covariant' and the conjugate representation is + 'contravariant'; for ``N > 2`` they are linearly independent. + + ``eps_dim`` is by default equal to ``dim``, if the latter is an integer; + else it can be assigned (for use in naive dimensional regularization); + if ``eps_dim`` is not an integer ``epsilon`` is ``None``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> Lorentz.metric + metric(Lorentz,Lorentz) + """ + + def __new__(cls, name, dummy_name=None, dim=None, eps_dim=None, + metric_symmetry=1, metric_name='metric', **kwargs): + if 'dummy_fmt' in kwargs: + dummy_fmt = kwargs['dummy_fmt'] + sympy_deprecation_warning( + f""" + The dummy_fmt keyword to TensorIndexType is deprecated. Use + dummy_name={dummy_fmt} instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-dummy-fmt", + ) + dummy_name = dummy_fmt + + if isinstance(name, str): + name = Symbol(name) + + if dummy_name is None: + dummy_name = str(name)[0] + if isinstance(dummy_name, str): + dummy_name = Symbol(dummy_name) + + if dim is None: + dim = Symbol("dim_" + dummy_name.name) + else: + dim = sympify(dim) + + if eps_dim is None: + eps_dim = dim + else: + eps_dim = sympify(eps_dim) + + metric_symmetry = sympify(metric_symmetry) + + if isinstance(metric_name, str): + metric_name = Symbol(metric_name) + + if 'metric' in kwargs: + SymPyDeprecationWarning( + """ + The 'metric' keyword argument to TensorIndexType is + deprecated. Use the 'metric_symmetry' keyword argument or the + TensorIndexType.set_metric() method instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-metric", + ) + metric = kwargs.get('metric') + if metric is not None: + if metric in (True, False, 0, 1): + metric_name = 'metric' + #metric_antisym = metric + else: + metric_name = metric.name + #metric_antisym = metric.antisym + + if metric: + metric_symmetry = -1 + else: + metric_symmetry = 1 + + obj = Basic.__new__(cls, name, dummy_name, dim, eps_dim, + metric_symmetry, metric_name) + + obj._autogenerated = [] + return obj + + @property + def name(self): + return self.args[0].name + + @property + def dummy_name(self): + return self.args[1].name + + @property + def dim(self): + return self.args[2] + + @property + def eps_dim(self): + return self.args[3] + + @memoize_property + def metric(self): + metric_symmetry = self.args[4] + metric_name = self.args[5] + if metric_symmetry is None: + return None + + if metric_symmetry == 0: + symmetry = TensorSymmetry.no_symmetry(2) + elif metric_symmetry == 1: + symmetry = TensorSymmetry.fully_symmetric(2) + elif metric_symmetry == -1: + symmetry = TensorSymmetry.fully_symmetric(-2) + + return TensorHead(metric_name, [self]*2, symmetry) + + @memoize_property + def delta(self): + return TensorHead('KD', [self]*2, TensorSymmetry.fully_symmetric(2)) + + @memoize_property + def epsilon(self): + if not isinstance(self.eps_dim, (SYMPY_INTS, Integer)): + return None + symmetry = TensorSymmetry.fully_symmetric(-self.eps_dim) + return TensorHead('Eps', [self]*self.eps_dim, symmetry) + + def set_metric(self, tensor): + self._metric = tensor + + def __lt__(self, other): + return self.name < other.name + + def __str__(self): + return self.name + + __repr__ = __str__ + + # Everything below this line is deprecated + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # This assignment is a bit controversial, should metric components be assigned + # to the metric only or also to the TensorIndexType object? The advantage here + # is the ability to assign a 1D array and transform it to a 2D diagonal array. + from .array import MutableDenseNDimArray + + data = _TensorDataLazyEvaluator.parse_data(data) + if data.rank() > 2: + raise ValueError("data have to be of rank 1 (diagonal metric) or 2.") + if data.rank() == 1: + if self.dim.is_number: + nda_dim = data.shape[0] + if nda_dim != self.dim: + raise ValueError("Dimension mismatch") + + dim = data.shape[0] + newndarray = MutableDenseNDimArray.zeros(dim, dim) + for i, val in enumerate(data): + newndarray[i, i] = val + data = newndarray + dim1, dim2 = data.shape + if dim1 != dim2: + raise ValueError("Non-square matrix tensor.") + if self.dim.is_number: + if self.dim != dim1: + raise ValueError("Dimension mismatch") + _tensor_data_substitution_dict[self] = data + _tensor_data_substitution_dict.add_metric_data(self.metric, data) + with ignore_warnings(SymPyDeprecationWarning): + delta = self.get_kronecker_delta() + i1 = TensorIndex('i1', self) + i2 = TensorIndex('i2', self) + with ignore_warnings(SymPyDeprecationWarning): + delta(i1, -i2).data = _TensorDataLazyEvaluator.parse_data(eye(dim1)) + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + @deprecated( + """ + The TensorIndexType.get_kronecker_delta() method is deprecated. Use + the TensorIndexType.delta attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_kronecker_delta(self): + sym2 = TensorSymmetry(get_symmetric_group_sgs(2)) + delta = TensorHead('KD', [self]*2, sym2) + return delta + + @deprecated( + """ + The TensorIndexType.get_epsilon() method is deprecated. Use + the TensorIndexType.epsilon attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_epsilon(self): + if not isinstance(self._eps_dim, (SYMPY_INTS, Integer)): + return None + sym = TensorSymmetry(get_symmetric_group_sgs(self._eps_dim, 1)) + epsilon = TensorHead('Eps', [self]*self._eps_dim, sym) + return epsilon + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + This destroys components data associated to the ``TensorIndexType``, if + any, specifically: + + * metric tensor data + * Kronecker tensor data + """ + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def delete_tensmul_data(key): + if key in _tensor_data_substitution_dict._substitutions_dict_tensmul: + del _tensor_data_substitution_dict._substitutions_dict_tensmul[key] + + # delete metric data: + delete_tensmul_data((self.metric, True, True)) + delete_tensmul_data((self.metric, True, False)) + delete_tensmul_data((self.metric, False, True)) + delete_tensmul_data((self.metric, False, False)) + + # delete delta tensor data: + delta = self.get_kronecker_delta() + if delta in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[delta] + + +class TensorIndex(Basic): + """ + Represents a tensor index + + Parameters + ========== + + name : name of the index, or ``True`` if you want it to be automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + + Attributes + ========== + + ``name`` + ``tensor_index_type`` + ``is_up`` + + Notes + ===== + + Tensor indices are contracted with the Einstein summation convention. + + An index can be in contravariant or in covariant form; in the latter + case it is represented prepending a ``-`` to the index name. Adding + ``-`` to a covariant (is_up=False) index makes it contravariant. + + Dummy indices have a name with head given by + ``tensor_inde_type.dummy_name`` with underscore and a number. + + Similar to ``symbols`` multiple contravariant indices can be created + at once using ``tensor_indices(s, typ)``, where ``s`` is a string + of names. + + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorIndex, TensorHead, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> mu = TensorIndex('mu', Lorentz, is_up=False) + >>> nu, rho = tensor_indices('nu, rho', Lorentz) + >>> A = TensorHead('A', [Lorentz, Lorentz]) + >>> A(mu, nu) + A(-mu, nu) + >>> A(-mu, -rho) + A(mu, -rho) + >>> A(mu, -mu) + A(-L_0, L_0) + """ + def __new__(cls, name, tensor_index_type, is_up=True): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up) + + @property + def name(self): + return self.args[0].name + + @property + def tensor_index_type(self): + return self.args[1] + + @property + def is_up(self): + return self.args[2] + + def _print(self): + s = self.name + if not self.is_up: + s = '-%s' % s + return s + + def __lt__(self, other): + return ((self.tensor_index_type, self.name) < + (other.tensor_index_type, other.name)) + + def __neg__(self): + t1 = TensorIndex(self.name, self.tensor_index_type, + (not self.is_up)) + return t1 + + +def tensor_indices(s, typ): + """ + Returns list of tensor indices given their names and their types. + + Parameters + ========== + + s : string of comma separated names of indices + + typ : ``TensorIndexType`` of the indices + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + """ + if isinstance(s, str): + a = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + tilist = [TensorIndex(i, typ) for i in a] + if len(tilist) == 1: + return tilist[0] + return tilist + + +class TensorSymmetry(Basic): + """ + Monoterm symmetry of a tensor (i.e. any symmetric or anti-symmetric + index permutation). For the relevant terminology see ``tensor_can.py`` + section of the combinatorics module. + + Parameters + ========== + + bsgs : tuple ``(base, sgs)`` BSGS of the symmetry of the tensor + + Attributes + ========== + + ``base`` : base of the BSGS + ``generators`` : generators of the BSGS + ``rank`` : rank of the tensor + + Notes + ===== + + A tensor can have an arbitrary monoterm symmetry provided by its BSGS. + Multiterm symmetries, like the cyclic symmetry of the Riemann tensor + (i.e., Bianchi identity), are not covered. See combinatorics module for + information on how to generate BSGS for a general index permutation group. + Simple symmetries can be generated using built-in methods. + + See Also + ======== + + sympy.combinatorics.tensor_can.get_symmetric_group_sgs + + Examples + ======== + + Define a symmetric tensor of rank 2 + + >>> from sympy.tensor.tensor import TensorIndexType, TensorSymmetry, get_symmetric_group_sgs, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> sym = TensorSymmetry(get_symmetric_group_sgs(2)) + >>> T = TensorHead('T', [Lorentz]*2, sym) + + Note, that the same can also be done using built-in TensorSymmetry methods + + >>> sym2 = TensorSymmetry.fully_symmetric(2) + >>> sym == sym2 + True + """ + def __new__(cls, *args, **kw_args): + if len(args) == 1: + base, generators = args[0] + elif len(args) == 2: + base, generators = args + else: + raise TypeError("bsgs required, either two separate parameters or one tuple") + + if not isinstance(base, Tuple): + base = Tuple(*base) + if not isinstance(generators, Tuple): + generators = Tuple(*generators) + + return Basic.__new__(cls, base, generators, **kw_args) + + @property + def base(self): + return self.args[0] + + @property + def generators(self): + return self.args[1] + + @property + def rank(self): + return self.generators[0].size - 2 + + @classmethod + def fully_symmetric(cls, rank): + """ + Returns a fully symmetric (antisymmetric if ``rank``<0) + TensorSymmetry object for ``abs(rank)`` indices. + """ + if rank > 0: + bsgs = get_symmetric_group_sgs(rank, False) + elif rank < 0: + bsgs = get_symmetric_group_sgs(-rank, True) + elif rank == 0: + bsgs = ([], [Permutation(1)]) + return TensorSymmetry(bsgs) + + @classmethod + def direct_product(cls, *args): + """ + Returns a TensorSymmetry object that is being a direct product of + fully (anti-)symmetric index permutation groups. + + Notes + ===== + + Some examples for different values of ``(*args)``: + ``(1)`` vector, equivalent to ``TensorSymmetry.fully_symmetric(1)`` + ``(2)`` tensor with 2 symmetric indices, equivalent to ``.fully_symmetric(2)`` + ``(-2)`` tensor with 2 antisymmetric indices, equivalent to ``.fully_symmetric(-2)`` + ``(2, -2)`` tensor with the first 2 indices commuting and the last 2 anticommuting + ``(1, 1, 1)`` tensor with 3 indices without any symmetry + """ + base, sgs = [], [Permutation(1)] + for arg in args: + if arg > 0: + bsgs2 = get_symmetric_group_sgs(arg, False) + elif arg < 0: + bsgs2 = get_symmetric_group_sgs(-arg, True) + else: + continue + base, sgs = bsgs_direct_product(base, sgs, *bsgs2) + + return TensorSymmetry(base, sgs) + + @classmethod + def riemann(cls): + """ + Returns a monotorem symmetry of the Riemann tensor + """ + return TensorSymmetry(riemann_bsgs) + + @classmethod + def no_symmetry(cls, rank): + """ + TensorSymmetry object for ``rank`` indices with no symmetry + """ + return TensorSymmetry([], [Permutation(rank+1)]) + + +@deprecated( + """ + The tensorsymmetry() function is deprecated. Use the TensorSymmetry + constructor instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorsymmetry", +) +def tensorsymmetry(*args): + """ + Returns a ``TensorSymmetry`` object. This method is deprecated, use + ``TensorSymmetry.direct_product()`` or ``.riemann()`` instead. + + Explanation + =========== + + One can represent a tensor with any monoterm slot symmetry group + using a BSGS. + + ``args`` can be a BSGS + ``args[0]`` base + ``args[1]`` sgs + + Usually tensors are in (direct products of) representations + of the symmetric group; + ``args`` can be a list of lists representing the shapes of Young tableaux + + Notes + ===== + + For instance: + ``[[1]]`` vector + ``[[1]*n]`` symmetric tensor of rank ``n`` + ``[[n]]`` antisymmetric tensor of rank ``n`` + ``[[2, 2]]`` monoterm slot symmetry of the Riemann tensor + ``[[1],[1]]`` vector*vector + ``[[2],[1],[1]`` (antisymmetric tensor)*vector*vector + + Notice that with the shape ``[2, 2]`` we associate only the monoterm + symmetries of the Riemann tensor; this is an abuse of notation, + since the shape ``[2, 2]`` corresponds usually to the irreducible + representation characterized by the monoterm symmetries and by the + cyclic symmetry. + """ + from sympy.combinatorics import Permutation + + def tableau2bsgs(a): + if len(a) == 1: + # antisymmetric vector + n = a[0] + bsgs = get_symmetric_group_sgs(n, 1) + else: + if all(x == 1 for x in a): + # symmetric vector + n = len(a) + bsgs = get_symmetric_group_sgs(n) + elif a == [2, 2]: + bsgs = riemann_bsgs + else: + raise NotImplementedError + return bsgs + + if not args: + return TensorSymmetry(Tuple(), Tuple(Permutation(1))) + + if len(args) == 2 and isinstance(args[1][0], Permutation): + return TensorSymmetry(args) + base, sgs = tableau2bsgs(args[0]) + for a in args[1:]: + basex, sgsx = tableau2bsgs(a) + base, sgs = bsgs_direct_product(base, sgs, basex, sgsx) + return TensorSymmetry(Tuple(base, sgs)) + +@deprecated( + "TensorType is deprecated. Use tensor_heads() instead.", + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensortype", +) +class TensorType(Basic): + """ + Class of tensor types. Deprecated, use tensor_heads() instead. + + Parameters + ========== + + index_types : list of ``TensorIndexType`` of the tensor indices + symmetry : ``TensorSymmetry`` of the tensor + + Attributes + ========== + + ``index_types`` + ``symmetry`` + ``types`` : list of ``TensorIndexType`` without repetitions + """ + is_commutative = False + + def __new__(cls, index_types, symmetry, **kw_args): + assert symmetry.rank == len(index_types) + obj = Basic.__new__(cls, Tuple(*index_types), symmetry, **kw_args) + return obj + + @property + def index_types(self): + return self.args[0] + + @property + def symmetry(self): + return self.args[1] + + @property + def types(self): + return sorted(set(self.index_types), key=lambda x: x.name) + + def __str__(self): + return 'TensorType(%s)' % ([str(x) for x in self.index_types]) + + def __call__(self, s, comm=0): + """ + Return a TensorHead object or a list of TensorHead objects. + + Parameters + ========== + + s : name or string of names. + + comm : Commutation group. + + see ``_TensorManager.set_comm`` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + if len(names) == 1: + return TensorHead(names[0], self.index_types, self.symmetry, comm) + else: + return [TensorHead(name, self.index_types, self.symmetry, comm) for name in names] + + +@deprecated( + """ + The tensorhead() function is deprecated. Use tensor_heads() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorhead", +) +def tensorhead(name, typ, sym=None, comm=0): + """ + Function generating tensorhead(s). This method is deprecated, + use TensorHead constructor or tensor_heads() instead. + + Parameters + ========== + + name : name or sequence of names (as in ``symbols``) + + typ : index types + + sym : same as ``*args`` in ``tensorsymmetry`` + + comm : commutation group number + see ``_TensorManager.set_comm`` + """ + if sym is None: + sym = [[1] for i in range(len(typ))] + with ignore_warnings(SymPyDeprecationWarning): + sym = tensorsymmetry(*sym) + return TensorHead(name, typ, sym, comm) + + +class TensorHead(Basic): + """ + Tensor head of the tensor. + + Parameters + ========== + + name : name of the tensor + index_types : list of TensorIndexType + symmetry : TensorSymmetry of the tensor + comm : commutation group number + + Attributes + ========== + + ``name`` + ``index_types`` + ``rank`` : total number of indices + ``symmetry`` + ``comm`` : commutation group + + Notes + ===== + + Similar to ``symbols`` multiple TensorHeads can be created using + ``tensorhead(s, typ, sym=None, comm=0)`` function, where ``s`` + is the string of names and ``sym`` is the monoterm tensor symmetry + (see ``tensorsymmetry``). + + A ``TensorHead`` belongs to a commutation group, defined by a + symbol on number ``comm`` (see ``_TensorManager.set_comm``); + tensors in a commutation group have the same commutation properties; + by default ``comm`` is ``0``, the group of the commuting tensors. + + Examples + ======== + + Define a fully antisymmetric tensor of rank 2: + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> asym2 = TensorSymmetry.fully_symmetric(-2) + >>> A = TensorHead('A', [Lorentz, Lorentz], asym2) + + Examples with ndarray values, the components data assigned to the + ``TensorHead`` object are assumed to be in a fully-contravariant + representation. In case it is necessary to assign components data which + represents the values of a non-fully covariant tensor, see the other + examples. + + >>> from sympy.tensor.tensor import tensor_indices + >>> from sympy import diag + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i0, i1 = tensor_indices('i0:2', Lorentz) + + Specify a replacement dictionary to keep track of the arrays to use for + replacements in the tensorial expression. The ``TensorIndexType`` is + associated to the metric used for contractions (in fully covariant form): + + >>> repl = {Lorentz: diag(1, -1, -1, -1)} + + Let's see some examples of working with components with the electromagnetic + tensor: + + >>> from sympy import symbols + >>> Ex, Ey, Ez, Bx, By, Bz = symbols('E_x E_y E_z B_x B_y B_z') + >>> c = symbols('c', positive=True) + + Let's define `F`, an antisymmetric tensor: + + >>> F = TensorHead('F', [Lorentz, Lorentz], asym2) + + Let's update the dictionary to contain the matrix to use in the + replacements: + + >>> repl.update({F(-i0, -i1): [ + ... [0, Ex/c, Ey/c, Ez/c], + ... [-Ex/c, 0, -Bz, By], + ... [-Ey/c, Bz, 0, -Bx], + ... [-Ez/c, -By, Bx, 0]]}) + + Now it is possible to retrieve the contravariant form of the Electromagnetic + tensor: + + >>> F(i0, i1).replace_with_arrays(repl, [i0, i1]) + [[0, -E_x/c, -E_y/c, -E_z/c], [E_x/c, 0, -B_z, B_y], [E_y/c, B_z, 0, -B_x], [E_z/c, -B_y, B_x, 0]] + + and the mixed contravariant-covariant form: + + >>> F(i0, -i1).replace_with_arrays(repl, [i0, -i1]) + [[0, E_x/c, E_y/c, E_z/c], [E_x/c, 0, B_z, -B_y], [E_y/c, -B_z, 0, B_x], [E_z/c, B_y, -B_x, 0]] + + Energy-momentum of a particle may be represented as: + + >>> from sympy import symbols + >>> P = TensorHead('P', [Lorentz], TensorSymmetry.no_symmetry(1)) + >>> E, px, py, pz = symbols('E p_x p_y p_z', positive=True) + >>> repl.update({P(i0): [E, px, py, pz]}) + + The contravariant and covariant components are, respectively: + + >>> P(i0).replace_with_arrays(repl, [i0]) + [E, p_x, p_y, p_z] + >>> P(-i0).replace_with_arrays(repl, [-i0]) + [E, -p_x, -p_y, -p_z] + + The contraction of a 1-index tensor by itself: + + >>> expr = P(i0)*P(-i0) + >>> expr.replace_with_arrays(repl, []) + E**2 - p_x**2 - p_y**2 - p_z**2 + """ + is_commutative = False + + def __new__(cls, name, index_types, symmetry=None, comm=0): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry, sympify(comm)) + return obj + + @property + def name(self): + return self.args[0].name + + @property + def index_types(self): + return list(self.args[1]) + + @property + def symmetry(self): + return self.args[2] + + @property + def comm(self): + return TensorManager.comm_symbols2i(self.args[3]) + + @property + def rank(self): + return len(self.index_types) + + def __lt__(self, other): + return (self.name, self.index_types) < (other.name, other.index_types) + + def commutes_with(self, other): + """ + Returns ``0`` if ``self`` and ``other`` commute, ``1`` if they anticommute. + + Returns ``None`` if ``self`` and ``other`` neither commute nor anticommute. + """ + r = TensorManager.get_comm(self.comm, other.comm) + return r + + def _print(self): + return '%s(%s)' %(self.name, ','.join([str(x) for x in self.index_types])) + + def __call__(self, *indices, **kw_args): + """ + Returns a tensor with indices. + + Explanation + =========== + + There is a special behavior in case of indices denoted by ``True``, + they are considered auto-matrix indices, their slots are automatically + filled, and confer to the tensor the behavior of a matrix or vector + upon multiplication with another tensor containing auto-matrix indices + of the same ``TensorIndexType``. This means indices get summed over the + same way as in matrix multiplication. For matrix behavior, define two + auto-matrix indices, for vector behavior define just one. + + Indices can also be strings, in which case the attribute + ``index_types`` is used to convert them to proper ``TensorIndex``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + >>> t = A(a, -b) + >>> t + A(a, -b) + + """ + + updated_indices = [] + for idx, typ in zip(indices, self.index_types): + if isinstance(idx, str): + idx = idx.strip().replace(" ", "") + if idx.startswith('-'): + updated_indices.append(TensorIndex(idx[1:], typ, + is_up=False)) + else: + updated_indices.append(TensorIndex(idx, typ)) + else: + updated_indices.append(idx) + + updated_indices += indices[len(updated_indices):] + + tensor = Tensor(self, updated_indices, **kw_args) + return tensor.doit() + + # Everything below this line is deprecated + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power on abstract tensors.") + from .array import tensorproduct, tensorcontraction + metrics = [_.data for _ in self.index_types] + + marray = self.data + marraydim = marray.rank() + for metric in metrics: + marray = tensorproduct(marray, metric, marray) + marray = tensorcontraction(marray, (0, marraydim), (marraydim+1, marraydim+2)) + + return marray ** (other * S.Half) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + Destroy components data associated to the ``TensorHead`` object, this + checks for attached components data, and destroys components data too. + """ + # do not garbage collect Kronecker tensor (it should be done by + # ``TensorIndexType`` garbage collection) + deprecate_data() + if self.name == "KD": + return + + # the data attached to a tensor must be deleted only by the TensorHead + # destructor. If the TensorHead is deleted, it means that there are no + # more instances of that tensor anywhere. + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + +def tensor_heads(s, index_types, symmetry=None, comm=0): + """ + Returns a sequence of TensorHeads from a string `s` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + thlist = [TensorHead(name, index_types, symmetry, comm) for name in names] + if len(thlist) == 1: + return thlist[0] + return thlist + + +class TensExpr(Expr, ABC): + """ + Abstract base class for tensor expressions + + Notes + ===== + + A tensor expression is an expression formed by tensors; + currently the sums of tensors are distributed. + + A ``TensExpr`` can be a ``TensAdd`` or a ``TensMul``. + + ``TensMul`` objects are formed by products of component tensors, + and include a coefficient, which is a SymPy expression. + + + In the internal representation contracted indices are represented + by ``(ipos1, ipos2, icomp1, icomp2)``, where ``icomp1`` is the position + of the component tensor with contravariant index, ``ipos1`` is the + slot which the index occupies in that component tensor. + + Contracted indices are therefore nameless in the internal representation. + """ + + _op_priority = 12.0 + is_commutative = False + + def __neg__(self): + return self*S.NegativeOne + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + return TensAdd(self, other).doit(deep=False) + + def __radd__(self, other): + return TensAdd(other, self).doit(deep=False) + + def __sub__(self, other): + return TensAdd(self, -other).doit(deep=False) + + def __rsub__(self, other): + return TensAdd(other, -self).doit(deep=False) + + def __mul__(self, other): + """ + Multiply two tensors using Einstein summation convention. + + Explanation + =========== + + If the two tensors have an index in common, one contravariant + and the other covariant, in their product the indices are summed + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t1 = p(m0) + >>> t2 = q(-m0) + >>> t1*t2 + p(L_0)*q(-L_0) + """ + return TensMul(self, other).doit(deep=False) + + def __rmul__(self, other): + return TensMul(other, self).doit(deep=False) + + def __truediv__(self, other): + other = _sympify(other) + if isinstance(other, TensExpr): + raise ValueError('cannot divide by a tensor') + return TensMul(self, S.One/other).doit(deep=False) + + def __rtruediv__(self, other): + raise ValueError('cannot divide by a tensor') + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power without ndarray data.") + from .array import tensorproduct, tensorcontraction + free = self.free + marray = self.data + mdim = marray.rank() + for metric in free: + marray = tensorcontraction( + tensorproduct( + marray, + metric[0].tensor_index_type.data, + marray), + (0, mdim), (mdim+1, mdim+2) + ) + return marray ** (other * S.Half) + + def __rpow__(self, other): + raise NotImplementedError + + @property + @abstractmethod + def nocoeff(self): + raise NotImplementedError("abstract method") + + @property + @abstractmethod + def coeff(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_indices(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_free_indices(self) -> list[TensorIndex]: + raise NotImplementedError("abstract method") + + @abstractmethod + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + raise NotImplementedError("abstract method") + + def fun_eval(self, *index_tuples): + deprecate_fun_eval() + return self.substitute_indices(*index_tuples) + + def get_matrix(self): + """ + DEPRECATED: do not use. + + Returns ndarray components data as a matrix, if components data are + available and ndarray dimension does not exceed 2. + """ + from sympy.matrices.dense import Matrix + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if 0 < self.rank <= 2: + rows = self.data.shape[0] + columns = self.data.shape[1] if self.rank == 2 else 1 + if self.rank == 2: + mat_list = [] * rows + for i in range(rows): + mat_list.append([]) + for j in range(columns): + mat_list[i].append(self[i, j]) + else: + mat_list = [None] * rows + for i in range(rows): + mat_list[i] = self[i] + return Matrix(mat_list) + else: + raise NotImplementedError( + "missing multidimensional reduction to matrix.") + + @staticmethod + def _get_indices_permutation(indices1, indices2): + return [indices1.index(i) for i in indices2] + + def _get_free_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_free_indices_set()) + return indset + + def _get_dummy_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_dummy_indices_set()) + return indset + + def _get_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_indices_set()) + return indset + + @property + def _iterate_dummy_indices(self): + dummy_set = self._get_dummy_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in dummy_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_free_indices(self): + free_set = self._get_free_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in free_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_indices(self): + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @staticmethod + def _contract_and_permute_with_metric(metric, array, pos, dim): + # TODO: add possibility of metric after (spinors) + from .array import tensorcontraction, tensorproduct, permutedims + + array = tensorcontraction(tensorproduct(metric, array), (1, 2+pos)) + permu = list(range(dim)) + permu[0], permu[pos] = permu[pos], permu[0] + return permutedims(array, permu) + + @staticmethod + def _match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict): + from .array import permutedims + + index_types1 = [i.tensor_index_type for i in free_ind1] + + # Check if variance of indices needs to be fixed: + pos2up = [] + pos2down = [] + free2remaining = free_ind2[:] + for pos1, index1 in enumerate(free_ind1): + if index1 in free2remaining: + pos2 = free2remaining.index(index1) + free2remaining[pos2] = None + continue + if -index1 in free2remaining: + pos2 = free2remaining.index(-index1) + free2remaining[pos2] = None + free_ind2[pos2] = index1 + if index1.is_up: + pos2up.append(pos2) + else: + pos2down.append(pos2) + else: + index2 = free2remaining[pos1] + if index2 is None: + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + free2remaining[pos1] = None + free_ind2[pos1] = index1 + if index1.is_up ^ index2.is_up: + if index1.is_up: + pos2up.append(pos1) + else: + pos2down.append(pos1) + + if len(set(free_ind1) & set(free_ind2)) < len(free_ind1): + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + + # Raise indices: + for pos in pos2up: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + metric_inverse = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = TensExpr._contract_and_permute_with_metric(metric_inverse, array, pos, len(free_ind1)) + # Lower indices: + for pos in pos2down: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + array = TensExpr._contract_and_permute_with_metric(metric, array, pos, len(free_ind1)) + + if free_ind1: + permutation = TensExpr._get_indices_permutation(free_ind2, free_ind1) + array = permutedims(array, permutation) + + if hasattr(array, "rank") and array.rank() == 0: + array = array[()] + + return free_ind2, array + + def replace_with_arrays(self, replacement_dict, indices=None): + """ + Replace the tensorial expressions with arrays. The final array will + correspond to the N-dimensional array with indices arranged according + to ``indices``. + + Parameters + ========== + + replacement_dict + dictionary containing the replacement rules for tensors. + indices + the index order with respect to which the array is read. The + original index order will be used if no value is passed. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> from sympy.tensor.tensor import TensorHead + >>> from sympy import symbols, diag + + >>> L = TensorIndexType("L") + >>> i, j = tensor_indices("i j", L) + >>> A = TensorHead("A", [L]) + >>> A(i).replace_with_arrays({A(i): [1, 2]}, [i]) + [1, 2] + + Since 'indices' is optional, we can also call replace_with_arrays by + this way if no specific index order is needed: + + >>> A(i).replace_with_arrays({A(i): [1, 2]}) + [1, 2] + + >>> expr = A(i)*A(j) + >>> expr.replace_with_arrays({A(i): [1, 2]}) + [[1, 2], [2, 4]] + + For contractions, specify the metric of the ``TensorIndexType``, which + in this case is ``L``, in its covariant form: + + >>> expr = A(i)*A(-i) + >>> expr.replace_with_arrays({A(i): [1, 2], L: diag(1, -1)}) + -3 + + Symmetrization of an array: + + >>> H = TensorHead("H", [L, L]) + >>> a, b, c, d = symbols("a b c d") + >>> expr = H(i, j)/2 + H(j, i)/2 + >>> expr.replace_with_arrays({H(i, j): [[a, b], [c, d]]}) + [[a, b/2 + c/2], [b/2 + c/2, d]] + + Anti-symmetrization of an array: + + >>> expr = H(i, j)/2 - H(j, i)/2 + >>> repl = {H(i, j): [[a, b], [c, d]]} + >>> expr.replace_with_arrays(repl) + [[0, b/2 - c/2], [-b/2 + c/2, 0]] + + The same expression can be read as the transpose by inverting ``i`` and + ``j``: + + >>> expr.replace_with_arrays(repl, [j, i]) + [[0, -b/2 + c/2], [b/2 - c/2, 0]] + """ + from .array import Array + + indices = indices or [] + remap = {k.args[0] if k.is_up else -k.args[0]: k for k in self.get_free_indices()} + for i, index in enumerate(indices): + if isinstance(index, (Symbol, Mul)): + if index in remap: + indices[i] = remap[index] + else: + indices[i] = -remap[-index] + + replacement_dict = {tensor: Array(array) for tensor, array in replacement_dict.items()} + + # Check dimensions of replaced arrays: + for tensor, array in replacement_dict.items(): + if isinstance(tensor, TensorIndexType): + expected_shape = [tensor.dim for i in range(2)] + else: + expected_shape = [index_type.dim for index_type in tensor.index_types] + if len(expected_shape) != array.rank() or (not all(dim1 == dim2 if + dim1.is_number else True for dim1, dim2 in zip(expected_shape, + array.shape))): + raise ValueError("shapes for tensor %s expected to be %s, "\ + "replacement array shape is %s" % (tensor, expected_shape, + array.shape)) + + ret_indices, array = self._extract_data(replacement_dict) + + last_indices, array = self._match_indices_with_other_tensor(array, indices, ret_indices, replacement_dict) + return array + + def _check_add_Sum(self, expr, index_symbols): + from sympy.concrete.summations import Sum + indices = self.get_indices() + dum = self.dum + sum_indices = [ (index_symbols[i], 0, + indices[i].tensor_index_type.dim-1) for i, j in dum] + if sum_indices: + expr = Sum(expr, *sum_indices) + return expr + + def _expand_partial_derivative(self): + # simply delegate the _expand_partial_derivative() to + # its arguments to expand a possibly found PartialDerivative + return self.func(*[ + a._expand_partial_derivative() + if isinstance(a, TensExpr) else a + for a in self.args]) + + def _matches_simple(self, expr, repl_dict=None, old=False): + """ + Matches assuming there are no wild objects in self. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensExpr): + if len(self.get_free_indices()) > 0: + #self has indices, but expr does not. + return None + elif set(self.get_free_indices()) != set(expr.get_free_indices()): + #If there are no wilds and the free indices are not the same, they cannot match. + return None + + if canon_bp(self - expr) == S.Zero: + return repl_dict + else: + return None + + +class TensAdd(TensExpr, AssocOp): + """ + Sum of tensors. + + Parameters + ========== + + free_args : list of the free indices + + Attributes + ========== + + ``args`` : tuple of addends + ``rank`` : rank of the tensor + ``free_args`` : list of the free indices in sorted order + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_heads, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(a) + q(a); t + p(a) + q(a) + + Examples with components data added to the tensor expression: + + >>> from sympy import symbols, diag + >>> x, y, z, t = symbols("x y z t") + >>> repl = {} + >>> repl[Lorentz] = diag(1, -1, -1, -1) + >>> repl[p(a)] = [1, 2, 3, 4] + >>> repl[q(a)] = [x, y, z, t] + + The following are: 2**2 - 3**2 - 2**2 - 7**2 ==> -58 + + >>> expr = p(a) + q(a) + >>> expr.replace_with_arrays(repl, [a]) + [x + 1, y + 2, z + 3, t + 4] + """ + + def __new__(cls, *args, **kw_args): + args = [_sympify(x) for x in args if x] + args = TensAdd._tensAdd_flatten(args) + args.sort(key=default_sort_key) + if not args: + return S.Zero + if len(args) == 1: + return args[0] + + return Basic.__new__(cls, *args, **kw_args) + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self) -> list[TensorIndex]: + return self.free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + newargs = [arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args] + return self.func(*newargs) + + @memoize_property + def rank(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].rank + else: + return 0 + + @memoize_property + def free_args(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].free_args + else: + return [] + + @memoize_property + def free_indices(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].get_free_indices() + else: + return set() + + def doit(self, **hints) -> Expr: + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args # type: ignore + + # if any of the args are zero (after doit), drop them. Otherwise, _tensAdd_check will complain about non-matching indices, even though the TensAdd is correctly formed. + args = [arg for arg in args if arg != S.Zero] + + if len(args) == 0: + return S.Zero + elif len(args) == 1: + return args[0] + + # now check that all addends have the same indices: + TensAdd._tensAdd_check(args) + + # Collect terms appearing more than once, differing by their coefficients: + args = TensAdd._tensAdd_collect_terms(args) + + # collect canonicalized terms + def sort_key(t): + if not isinstance(t, TensExpr): + return [], [], [] + if hasattr(t, "_index_structure") and hasattr(t, "components"): + x = get_index_structure(t) + return t.components, x.free, x.dum + return [], [], [] + args.sort(key=sort_key) + + if not args: + return S.Zero + # it there is only a component tensor return it + if len(args) == 1: + return args[0] + + obj = self.func(*args) + return obj + + @staticmethod + def _tensAdd_flatten(args): + # flatten TensAdd, coerce terms which are not tensors to tensors + a = [] + for x in args: + if isinstance(x, (Add, TensAdd)): + a.extend(list(x.args)) + else: + a.append(x) + args = [x for x in a if x.coeff] + return args + + @staticmethod + def _tensAdd_check(args): + # check that all addends have the same free indices + + def get_indices_set(x: Expr) -> set[TensorIndex]: + if isinstance(x, TensExpr): + return set(x.get_free_indices()) + return set() + + indices0 = get_indices_set(args[0]) + list_indices = [get_indices_set(arg) for arg in args[1:]] + if not all(x == indices0 for x in list_indices): + raise ValueError('all tensors must have the same indices') + + @staticmethod + def _tensAdd_collect_terms(args): + # collect TensMul terms differing at most by their coefficient + terms_dict = defaultdict(list) + scalars = S.Zero + if isinstance(args[0], TensExpr): + free_indices = set(args[0].get_free_indices()) + else: + free_indices = set() + + for arg in args: + if not isinstance(arg, TensExpr): + if free_indices != set(): + raise ValueError("wrong valence") + scalars += arg + continue + if free_indices != set(arg.get_free_indices()): + raise ValueError("wrong valence") + # TODO: what is the part which is not a coeff? + # needs an implementation similar to .as_coeff_Mul() + terms_dict[arg.nocoeff].append(arg.coeff) + + new_args = [TensMul(Add(*coeff), t).doit(deep=False) for t, coeff in terms_dict.items() if Add(*coeff) != 0] + if isinstance(scalars, Add): + new_args = list(scalars.args) + new_args + elif scalars != 0: + new_args = [scalars] + new_args + return new_args + + def get_indices(self): + indices = [] + for arg in self.args: + indices.extend([i for i in get_indices(arg) if i not in indices]) + return indices + + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + index_tuples = list(zip(free_args, indices)) + a = [x.func(*x.substitute_indices(*index_tuples).args) for x in self.args] + res = TensAdd(*a).doit(deep=False) + return res + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + """ + expr = self.expand() + if isinstance(expr, self.func): + args = [canon_bp(x) for x in expr.args] + res = TensAdd(*args).doit(deep=False) + return res + else: + return canon_bp(expr) + + def equals(self, other): + other = _sympify(other) + if isinstance(other, TensMul) and other.coeff == 0: + return all(x.coeff == 0 for x in self.args) + if isinstance(other, TensExpr): + if self.rank != other.rank: + return False + if isinstance(other, TensAdd): + if set(self.args) != set(other.args): + return False + else: + return True + t = self - other + if not isinstance(t, TensExpr): + return t == 0 + else: + if isinstance(t, TensMul): + return t.coeff == 0 + else: + return all(x.coeff == 0 for x in t.args) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def contract_delta(self, delta): + args = [x.contract_delta(delta) if isinstance(x, TensExpr) else x for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + contract_all : if True, eliminate all ``g`` which are contracted + + Notes + ===== + + see the ``TensorIndexType`` docstring for the contraction conventions + """ + + args = [contract_metric(x, g) for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensAdd(*new_args).doit(deep=False) + + def _print(self): + a = [] + args = self.args + for x in args: + a.append(str(x)) + s = ' + '.join(a) + s = s.replace('+ -', '- ') + return s + + def _extract_data(self, replacement_dict): + from sympy.tensor.array import Array, permutedims + args_indices, arrays = zip(*[ + arg._extract_data(replacement_dict) if + isinstance(arg, TensExpr) else ([], arg) for arg in self.args + ]) + arrays = [Array(i) for i in arrays] + ref_indices = args_indices[0] + for i in range(1, len(args_indices)): + indices = args_indices[i] + array = arrays[i] + permutation = TensMul._get_indices_permutation(indices, ref_indices) + arrays[i] = permutedims(array, permutation) + return ref_indices, sum(arrays, Array.zeros(*array.shape)) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self.expand()] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + if not self.data: + raise ValueError("No iteration on abstract tensors") + return self.data.flatten().__iter__() + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + return Add.fromiter(args) + + def _eval_partial_derivative(self, s): + # Evaluation like Add + list_addends = [] + for a in self.args: + if isinstance(a, TensExpr): + list_addends.append(a._eval_partial_derivative(s)) + # do not call diff if s is no symbol + elif s._diff_wrt: + list_addends.append(a._eval_derivative(s)) + + return self.func(*list_addends).doit(deep=False) + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensAdd): + return None + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + wildatoms = _get_wilds(arg) + wildatom_types = sift(wildatoms, type) + if len(wildatoms) == 0: + return "nonwild" + elif WildTensor in wildatom_types.keys(): + for w in wildatom_types["WildTensor"]: + if len(w.get_indices()) == 0: + return "indexless_wildtensor" + return "wildtensor" + else: + return "otherwild" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #First try to match the terms without WildTensors + matched_e_tensors = [] #Used to make sure that the same tensor in expr is not matched with more than one tensor in self. + for q_tensor in query_sifted["nonwild"]: + matched_this_q = False + for e_tensor in expr_sifted["nonwild"]: + if e_tensor in matched_e_tensors: + continue + + m = q_tensor.matches(e_tensor, repl_dict=repl_dict, old=old) + if m is None: + continue + else: + matched_this_q = True + repl_dict.update(m) + matched_e_tensors.append(e_tensor) + break + + if not matched_this_q: + return None + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["otherwild"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w in repl_dict.keys(): + repl_dict[w] += m.pop(w) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["indexless_wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + if len(remaining_e_tensors) > 0: + return None + else: + return repl_dict + + +class Tensor(TensExpr): + """ + Base tensor class, i.e. this represents a tensor, the single unit to be + put into an expression. + + Explanation + =========== + + This object is usually created from a ``TensorHead``, by attaching indices + to it. Indices preceded by a minus sign are considered contravariant, + otherwise covariant. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead + >>> Lorentz = TensorIndexType("Lorentz", dummy_name="L") + >>> mu, nu = tensor_indices('mu nu', Lorentz) + >>> A = TensorHead("A", [Lorentz, Lorentz]) + >>> A(mu, -nu) + A(mu, -nu) + >>> A(mu, -mu) + A(L_0, -L_0) + + It is also possible to use symbols instead of inidices (appropriate indices + are then generated automatically). + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> A(x, mu) + A(x, mu) + >>> A(x, -x) + A(L_0, -L_0) + + """ + + is_commutative = False + + _index_structure: _IndexStructure + args: tuple[TensorHead, Tuple] + + def __new__(cls, tensor_head, indices, *, is_canon_bp=False, **kw_args): + indices = cls._parse_indices(tensor_head, indices) + obj = Basic.__new__(cls, tensor_head, Tuple(*indices), **kw_args) + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = Tensor._build_index_map(indices, obj._index_structure) + return obj + + @property + def free(self): + return self._free + + @property + def dum(self): + return self._dum + + @property + def ext_rank(self): + return self._ext_rank + + @property + def coeff(self): + return self._coeff + + @property + def nocoeff(self): + return self._nocoeff + + @property + def component(self): + return self._component + + @property + def components(self): + return self._components + + @property + def head(self): + return self.args[0] + + @property + def indices(self): + return self.args[1] + + @property + def free_indices(self): + return set(self._index_structure.get_free_indices()) + + @property + def index_types(self): + return self.head.index_types + + @property + def rank(self): + return len(self.free_indices) + + @staticmethod + def _build_index_map(indices, index_structure): + index_map = {} + for idx in indices: + index_map[idx] = (indices.index(idx),) + return index_map + + def doit(self, **hints): + args, indices, free, dum = TensMul._tensMul_contract_indices([self]) + return args[0] + + @staticmethod + def _parse_indices(tensor_head, indices): + if not isinstance(indices, (tuple, list, Tuple)): + raise TypeError("indices should be an array, got %s" % type(indices)) + indices = list(indices) + for i, index in enumerate(indices): + if isinstance(index, Symbol): + indices[i] = TensorIndex(index, tensor_head.index_types[i], True) + elif isinstance(index, Mul): + c, e = index.as_coeff_Mul() + if c == -1 and isinstance(e, Symbol): + indices[i] = TensorIndex(e, tensor_head.index_types[i], False) + else: + raise ValueError("index not understood: %s" % index) + elif not isinstance(index, TensorIndex): + raise TypeError("wrong type for index: %s is %s" % (index, type(index))) + return indices + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + return self.func(self.args[0], indices, is_canon_bp=is_canon_bp).doit() + + def _get_free_indices_set(self): + return {i[0] for i in self._index_structure.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self._index_structure.dum)) + return {idx for i, idx in enumerate(self.args[1]) if i in dummy_pos} + + def _get_indices_set(self): + return set(self.args[1].args) + + @property + def free_in_args(self): + return [(ind, pos, 0) for ind, pos in self.free] + + @property + def dum_in_args(self): + return [(p1, p2, 0, 0) for p1, p2 in self.dum] + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + def commutes_with(self, other): + """ + :param other: + :return: + 0 commute + 1 anticommute + None neither commute nor anticommute + """ + if not isinstance(other, TensExpr): + return 0 + elif isinstance(other, Tensor): + return self.component.commutes_with(other.component) + return NotImplementedError + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g``. + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp) + + def canon_bp(self): + if self.is_canon_bp: + return self + expr = self.expand() + g, dummies, msym = expr._index_structure.indices_canon_args() + v = components_canon_args([expr.component]) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tensor = self.perm2tensor(can, True) + return tensor + + def split(self): + return [self] + + def sorted_components(self): + return self + + def get_indices(self) -> list[TensorIndex]: + """ + Get a list of indices, corresponding to those of the tensor. + """ + return list(self.args[1]) + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices, corresponding to those of the tensor. + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: this could be optimized by only swapping the indices + # instead of visiting the whole expression tree: + return self.xreplace(repl) + + def as_base_exp(self): + return self, S.One + + def substitute_indices(self, *index_tuples): + """ + Return a tensor with free indices substituted according to ``index_tuples``. + + ``index_types`` list of tuples ``(old_index, new_index)``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(i, k)*B(-k, -j); t + A(i, L_0)*B(-L_0, -j) + >>> t.substitute_indices((i, k),(-j, l)) + A(k, L_0)*B(-L_0, l) + """ + indices = [] + for index in self.indices: + for ind_old, ind_new in index_tuples: + if (index.name == ind_old.name and index.tensor_index_type == + ind_old.tensor_index_type): + if index.is_up == ind_old.is_up: + indices.append(ind_new) + else: + indices.append(-ind_new) + break + else: + indices.append(index) + return self.head(*indices) + + def _get_symmetrized_forms(self): + """ + Return a list giving all possible permutations of self that are allowed by its symmetries. + """ + comp = self.component + gens = comp.symmetry.generators + rank = comp.rank + + old_perms = None + new_perms = {self} + while new_perms != old_perms: + old_perms = new_perms.copy() + for tens in old_perms: + for gen in gens: + inds = tens.get_indices() + per = [gen.apply(i) for i in range(0,rank)] + sign = (-1)**(gen.apply(rank) - rank) + ind_map = dict(zip(inds, [inds[i] for i in per])) + new_perms.add( sign * tens._replace_indices(ind_map) ) + + return new_perms + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + #Now consider all index symmetries of expr, and see if any of them allow a match. + for new_expr in expr._get_symmetrized_forms(): + m = self._matches(new_expr, repl_dict, old=old) + if m is not None: + repl_dict.update(m) + return repl_dict + + return None + + def _matches(self, expr, repl_dict=None, old=False): + """ + This does not account for index symmetries of expr + """ + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + s_indices = self.get_indices() + e_indices = expr.get_indices() + + if len(s_indices) != len(e_indices): + return None + + for i in range(len(s_indices)): + s_ind = s_indices[i] + m = s_ind.matches(e_indices[i]) + if m is None: + return None + elif -s_ind in repl_dict.keys() and -repl_dict[-s_ind] != m[s_ind]: + return None + else: + repl_dict.update(m) + + return repl_dict + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + # TODO: put this into TensExpr? + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + # TODO: put this into TensExpr? + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _extract_data(self, replacement_dict): + from .array import Array + for k, v in replacement_dict.items(): + if isinstance(k, Tensor) and k.args[0] == self.args[0]: + other = k + array = v + break + else: + raise ValueError("%s not found in %s" % (self, replacement_dict)) + + # TODO: inefficient, this should be done at root level only: + replacement_dict = {k: Array(v) for k, v in replacement_dict.items()} + array = Array(array) + + dum1 = self.dum + dum2 = other.dum + + if len(dum2) > 0: + for pair in dum2: + # allow `dum2` if the contained values are also in `dum1`. + if pair not in dum1: + raise NotImplementedError("%s with contractions is not implemented" % other) + # Remove elements in `dum2` from `dum1`: + dum1 = [pair for pair in dum1 if pair not in dum2] + if len(dum1) > 0: + indices1 = self.get_indices() + indices2 = other.get_indices() + repl = {} + for p1, p2 in dum1: + repl[indices2[p2]] = -indices2[p1] + for pos in (p1, p2): + if indices1[pos].is_up ^ indices2[pos].is_up: + metric = replacement_dict[indices1[pos].tensor_index_type] + if indices1[pos].is_up: + metric = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = self._contract_and_permute_with_metric(metric, array, pos, len(indices2)) + other = other.xreplace(repl).doit() + array = _TensorDataLazyEvaluator.data_contract_dum([array], dum1, len(indices2)) + + free_ind1 = self.get_free_indices() + free_ind2 = other.get_free_indices() + + return self._match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # TODO: check data compatibility with properties of tensor. + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + def _print(self): + indices = [str(ind) for ind in self.indices] + component = self.component + if component.rank > 0: + return ('%s(%s)' % (component.name, ', '.join(indices))) + else: + return ('%s' % component.name) + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return S.One == other + + def _get_compar_comp(self): + t = self.canon_bp() + r = (t.coeff, tuple(t.components), \ + tuple(sorted(t.free)), tuple(sorted(t.dum))) + return r + + return _get_compar_comp(self) == _get_compar_comp(other) + + def contract_metric(self, g): + # if metric is not the same, ignore this step: + if self.component != g: + return self + # in case there are free components, do not perform anything: + if len(self.free) != 0: + return self + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + sign = S.One + typ = g.index_types[0] + + if not antisym: + # g(i, -i) + sign = sign*typ.dim + else: + # g(i, -i) + sign = sign*typ.dim + + dp0, dp1 = self.dum[0] + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + + return sign + + def contract_delta(self, metric): + return self.contract_metric(metric) + + def _eval_rewrite_as_Indexed(self, tens, indices, **kwargs): + from sympy.tensor.indexed import Indexed + # TODO: replace .args[0] with .name: + index_symbols = [i.args[0] for i in self.get_indices()] + expr = Indexed(tens.args[0], *index_symbols) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s: Tensor) -> Expr: + + if not isinstance(s, Tensor): + return S.Zero + else: + + # @a_i/@a_k = delta_i^k + # @a_i/@a^k = g_ij delta^j_k + # @a^i/@a^k = delta^i_k + # @a^i/@a_k = g^ij delta_j^k + # TODO: if there is no metric present, the derivative should be zero? + + if self.head != s.head: + return S.Zero + + # if heads are the same, provide delta and/or metric products + # for every free index pair in the appropriate tensor + # assumed that the free indices are in proper order + # A contravariante index in the derivative becomes covariant + # after performing the derivative and vice versa + + kronecker_delta_list = [1] + + # not guarantee a correct index order + + for (count, (iself, iother)) in enumerate(zip(self.get_free_indices(), s.get_free_indices())): + if iself.tensor_index_type != iother.tensor_index_type: + raise ValueError("index types not compatible") + else: + tensor_index_type = iself.tensor_index_type + tensor_metric = tensor_index_type.metric + dummy = TensorIndex("d_" + str(count), tensor_index_type, + is_up=iself.is_up) + if iself.is_up == iother.is_up: + kroneckerdelta = tensor_index_type.delta(iself, -iother) + else: + kroneckerdelta = ( + TensMul(tensor_metric(iself, dummy), + tensor_index_type.delta(-dummy, -iother)) + ) + kronecker_delta_list.append(kroneckerdelta) + return TensMul.fromiter(kronecker_delta_list).doit(deep=False) + # doit necessary to rename dummy indices accordingly + + +class TensMul(TensExpr, AssocOp): + """ + Product of tensors. + + Parameters + ========== + + coeff : SymPy coefficient of the tensor + args + + Attributes + ========== + + ``components`` : list of ``TensorHead`` of the component tensors + ``types`` : list of nonrepeated ``TensorIndexType`` + ``free`` : list of ``(ind, ipos, icomp)``, see Notes + ``dum`` : list of ``(ipos1, ipos2, icomp1, icomp2)``, see Notes + ``ext_rank`` : rank of the tensor counting the dummy indices + ``rank`` : rank of the tensor + ``coeff`` : SymPy coefficient of the tensor + ``free_args`` : list of the free indices in sorted order + ``is_canon_bp`` : ``True`` if the tensor in in canonical form + + Notes + ===== + + ``args[0]`` list of ``TensorHead`` of the component tensors. + + ``args[1]`` list of ``(ind, ipos, icomp)`` + where ``ind`` is a free index, ``ipos`` is the slot position + of ``ind`` in the ``icomp``-th component tensor. + + ``args[2]`` list of tuples representing dummy indices. + ``(ipos1, ipos2, icomp1, icomp2)`` indicates that the contravariant + dummy index is the ``ipos1``-th slot position in the ``icomp1``-th + component tensor; the corresponding covariant index is + in the ``ipos2`` slot position in the ``icomp2``-th component tensor. + + """ + identity = S.One + + _index_structure: _IndexStructure + + def __new__(cls, *args, **kw_args): + is_canon_bp = kw_args.get('is_canon_bp', False) + args = list(map(_sympify, args)) + + """ + If the internal dummy indices in one arg conflict with the free indices + of the remaining args, we need to rename those internal dummy indices. + """ + free = [get_free_indices(arg) for arg in args] + free = set(itertools.chain(*free)) #flatten free + newargs = [] + for arg in args: + dum_this = set(get_dummy_indices(arg)) + dum_other = [get_dummy_indices(a) for a in newargs] + dum_other = set(itertools.chain(*dum_other)) #flatten dum_other + free_this = set(get_free_indices(arg)) + if len(dum_this.intersection(free)) > 0: + exclude = free_this.union(free, dum_other) + newarg = TensMul._dedupe_indices(arg, exclude) + else: + newarg = arg + newargs.append(newarg) + + args = newargs + + # Flatten: + args = [i for arg in args for i in (arg.args if isinstance(arg, (TensMul, Mul)) else [arg])] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args, replace_indices=False) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = TensExpr.__new__(cls, *args) + obj._indices = indices + obj._index_types = index_types.copy() + obj._index_structure = index_structure + obj._free = index_structure.free[:] + obj._dum = index_structure.dum[:] + obj._free_indices = {x[0] for x in obj.free} + obj._rank = len(obj.free) + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = S.One + obj._is_canon_bp = is_canon_bp + return obj + + index_types = property(lambda self: self._index_types) + free = property(lambda self: self._free) + dum = property(lambda self: self._dum) + free_indices = property(lambda self: self._free_indices) + rank = property(lambda self: self._rank) + ext_rank = property(lambda self: self._ext_rank) + + @staticmethod + def _indices_to_free_dum(args_indices): + free2pos1 = {} + free2pos2 = {} + dummy_data = [] + indices = [] + + # Notation for positions (to better understand the code): + # `pos1`: position in the `args`. + # `pos2`: position in the indices. + + # Example: + # A(i, j)*B(k, m, n)*C(p) + # `pos1` of `n` is 1 because it's in `B` (second `args` of TensMul). + # `pos2` of `n` is 4 because it's the fifth overall index. + + # Counter for the index position wrt the whole expression: + pos2 = 0 + + for pos1, arg_indices in enumerate(args_indices): + + for index in arg_indices: + if not isinstance(index, TensorIndex): + raise TypeError("expected TensorIndex") + if -index in free2pos1: + # Dummy index detected: + other_pos1 = free2pos1.pop(-index) + other_pos2 = free2pos2.pop(-index) + if index.is_up: + dummy_data.append((index, pos1, other_pos1, pos2, other_pos2)) + else: + dummy_data.append((-index, other_pos1, pos1, other_pos2, pos2)) + indices.append(index) + elif index in free2pos1: + raise ValueError("Repeated index: %s" % index) + else: + free2pos1[index] = pos1 + free2pos2[index] = pos2 + indices.append(index) + pos2 += 1 + + free = list(free2pos2.items()) + free_names = [i.name for i in free2pos2.keys()] + + dummy_data.sort(key=lambda x: x[3]) + return indices, free, free_names, dummy_data + + @staticmethod + def _dummy_data_to_dum(dummy_data): + return [(p2a, p2b) for (i, p1a, p1b, p2a, p2b) in dummy_data] + + @staticmethod + def _tensMul_contract_indices(args, replace_indices=True): + replacements = [{} for _ in args] + + #_index_order = all(_has_index_order(arg) for arg in args) + + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + cdt = defaultdict(int) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + if replace_indices: + for old_index, pos1cov, pos1contra, pos2cov, pos2contra in dummy_data: + index_type = old_index.tensor_index_type + while True: + dummy_name = dummy_name_gen(index_type) + if dummy_name not in free_names: + break + dummy = old_index.func(dummy_name, index_type, *old_index.args[2:]) + replacements[pos1cov][old_index] = dummy + replacements[pos1contra][-old_index] = -dummy + indices[pos2cov] = dummy + indices[pos2contra] = -dummy + args = [ + arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg + for arg, repl in zip(args, replacements)] + + """ + The order of indices might've changed due to the replacements (e.g. if one of the args is a TensAdd, replacing an index can change the sort order of the terms, thus changing the order of indices returned by its get_indices() method). + To stay on the safe side, we calculate these quantities again. + """ + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + dum = TensMul._dummy_data_to_dum(dummy_data) + return args, indices, free, dum + + @staticmethod + def _get_components_from_args(args): + """ + Get a list of ``Tensor`` objects having the same ``TIDS`` if multiplied + by one another. + """ + components = [] + for arg in args: + if not isinstance(arg, TensExpr): + continue + if isinstance(arg, TensAdd): + continue + components.extend(arg.components) + return components + + @staticmethod + def _rebuild_tensors_list(args, index_structure): + indices = index_structure.get_indices() + #tensors = [None for i in components] # pre-allocate list + ind_pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + prev_pos = ind_pos + ind_pos += arg.ext_rank + args[i] = Tensor(arg.component, indices[prev_pos:ind_pos]) + + def doit(self, **hints): + is_canon_bp = self._is_canon_bp + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + + """ + There may now be conflicts between dummy indices of different args + (each arg's doit method does not have any information about which + dummy indices are already used in the other args), so we + deduplicate them. + """ + rule = dict(zip(self.args, args)) + rule = self._dedupe_indices_in_rule(rule) + args = [rule[a] for a in self.args] + + else: + args = self.args + + args = [arg for arg in args if arg != self.identity] + + # Extract non-tensor coefficients: + coeff = reduce(lambda a, b: a*b, [arg for arg in args if not isinstance(arg, TensExpr)], S.One) + args = [arg for arg in args if isinstance(arg, TensExpr)] + + if len(args) == 0: + return coeff + + if coeff != self.identity: + args = [coeff] + args + if coeff == 0: + return S.Zero + + if len(args) == 1: + return args[0] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = self.func(*args) + obj._index_types = index_types + obj._index_structure = index_structure + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = coeff + obj._is_canon_bp = is_canon_bp + return obj + + # TODO: this method should be private + # TODO: should this method be renamed _from_components_free_dum ? + @staticmethod + def from_data(coeff, components, free, dum, **kw_args): + return TensMul(coeff, *TensMul._get_tensors_from_components_free_dum(components, free, dum), **kw_args).doit(deep=False) + + @staticmethod + def _get_tensors_from_components_free_dum(components, free, dum): + """ + Get a list of ``Tensor`` objects by distributing ``free`` and ``dum`` indices on the ``components``. + """ + index_structure = _IndexStructure.from_components_free_dum(components, free, dum) + indices = index_structure.get_indices() + tensors = [None for i in components] # pre-allocate list + + # distribute indices on components to build a list of tensors: + ind_pos = 0 + for i, component in enumerate(components): + prev_pos = ind_pos + ind_pos += component.rank + tensors[i] = Tensor(component, indices[prev_pos:ind_pos]) + return tensors + + def _get_free_indices_set(self): + return {i[0] for i in self.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self.dum)) + return {idx for i, idx in enumerate(self._index_structure.get_indices()) if i in dummy_pos} + + def _get_position_offset_for_indices(self): + arg_offset = [None for i in range(self.ext_rank)] + counter = 0 + for arg in self.args: + if not isinstance(arg, TensExpr): + continue + for j in range(arg.ext_rank): + arg_offset[j + counter] = counter + counter += arg.ext_rank + return arg_offset + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + @property + def components(self): + return self._get_components_from_args(self.args) + + @property + def free_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(ind, pos-arg_offset[pos], argpos[pos]) for (ind, pos) in self.free] + + @property + def coeff(self): + # return Mul.fromiter([c for c in self.args if not isinstance(c, TensExpr)]) + return self._coeff + + @property + def nocoeff(self): + return self.func(*self.args, 1/self.coeff).doit(deep=False) + + @property + def dum_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(p1-arg_offset[p1], p2-arg_offset[p2], argpos[p1], argpos[p2]) for p1, p2 in self.dum] + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return self.coeff == other + + return self.canon_bp() == other.canon_bp() + + def get_indices(self): + """ + Returns the list of indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + The dummy indices are given a name which does not collide with + the names of the free indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_indices() + [L_0, -L_0, m2] + """ + return self._indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Returns the list of free indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_free_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_free_indices() + [m2] + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + return self.func(*[arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args]) + + def split(self): + """ + Returns a list of tensors, whose product is ``self``. + + Explanation + =========== + + Dummy indices contracted among different tensor components + become free indices with the same name as the one used to + represent the dummy indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(a,b)*B(-b,c) + >>> t + A(a, L_0)*B(-L_0, c) + >>> t.split() + [A(a, L_0), B(-L_0, c)] + """ + if self.args == (): + return [self] + splitp = [] + res = 1 + for arg in self.args: + if isinstance(arg, Tensor): + splitp.append(res*arg) + res = 1 + else: + res *= arg + return splitp + + def _eval_expand_mul(self, **hints): + args1 = [arg.args if isinstance(arg, (Add, TensAdd)) else (arg,) for arg in self.args] + return TensAdd(*[ + TensMul(*i).doit(deep=False) for i in itertools.product(*args1)] + ) + + def __neg__(self): + return TensMul(S.NegativeOne, self, is_canon_bp=self._is_canon_bp).doit(deep=False) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _get_args_for_traditional_printer(self): + args = list(self.args) + if self.coeff.could_extract_minus_sign(): + # expressions like "-A(a)" + sign = "-" + if args[0] == S.NegativeOne: + args = args[1:] + else: + args[0] = -args[0] + else: + sign = "" + return sign, args + + def _sort_args_for_sorted_components(self): + """ + Returns the ``args`` sorted according to the components commutation + properties. + + Explanation + =========== + + The sorting is done taking into account the commutation group + of the component tensors. + """ + cv = [arg for arg in self.args if isinstance(arg, TensExpr)] + sign = 1 + n = len(cv) - 1 + for i in range(n): + for j in range(n, i, -1): + c = cv[j-1].commutes_with(cv[j]) + # if `c` is `None`, it does neither commute nor anticommute, skip: + if c not in (0, 1): + continue + typ1 = sorted(set(cv[j-1].component.index_types), key=lambda x: x.name) + typ2 = sorted(set(cv[j].component.index_types), key=lambda x: x.name) + if (typ1, cv[j-1].component.name) > (typ2, cv[j].component.name): + cv[j-1], cv[j] = cv[j], cv[j-1] + # if `c` is 1, the anticommute, so change sign: + if c: + sign = -sign + + coeff = sign * self.coeff + if coeff != 1: + return [coeff] + cv + return cv + + def sorted_components(self): + """ + Returns a tensor product with sorted components. + """ + return TensMul(*self._sort_args_for_sorted_components()).doit(deep=False) + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp=is_canon_bp) + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + >>> t = A(m0,-m1)*A(m1,-m0) + >>> t.canon_bp() + -A(L_0, L_1)*A(-L_0, -L_1) + >>> t = A(m0,-m1)*A(m1,-m2)*A(m2,-m0) + >>> t.canon_bp() + 0 + """ + if self._is_canon_bp: + return self + expr = self.expand() + if isinstance(expr, TensAdd): + return expr.canon_bp() + if not expr.components: + return expr + expr = expr.doit(deep=False) #make sure self.coeff is populated correctly + t = expr.sorted_components() + g, dummies, msym = t._index_structure.indices_canon_args() + v = components_canon_args(t.components) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tmul = t.perm2tensor(can, True) + return tmul + + def contract_delta(self, delta): + t = self.contract_metric(delta) + return t + + def _get_indices_to_args_pos(self): + """ + Get a dict mapping the index position to TensMul's argument number. + """ + pos_map = {} + pos_counter = 0 + for arg_i, arg in enumerate(self.args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + for i in range(arg.ext_rank): + pos_map[pos_counter] = arg_i + pos_counter += 1 + return pos_map + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + Notes + ===== + + See the ``TensorIndexType`` docstring for the contraction conventions. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m0)*q(m1)*g(-m0, -m1) + >>> t.canon_bp() + metric(L_0, L_1)*p(-L_0)*q(-L_1) + >>> t.contract_metric(g).canon_bp() + p(L_0)*q(-L_0) + """ + expr = self.expand().doit(deep=False) + if self != expr: + expr = canon_bp(expr) + return contract_metric(expr, g) + pos_map = self._get_indices_to_args_pos() + args = list(self.args) + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + + # list of positions of the metric ``g`` inside ``args`` + gpos = [i for i, x in enumerate(self.args) if isinstance(x, Tensor) and x.component == g] + if not gpos: + return self + + # Sign is either 1 or -1, to correct the sign after metric contraction + # (for spinor indices). + sign = 1 + dum = self.dum[:] + free = self.free[:] + elim = set() + for gposx in gpos: + if gposx in elim: + continue + free1 = [x for x in free if pos_map[x[1]] == gposx] + dum1 = [x for x in dum if pos_map[x[0]] == gposx or pos_map[x[1]] == gposx] + if not dum1: + continue + elim.add(gposx) + # subs with the multiplication neutral element, that is, remove it: + args[gposx] = 1 + if len(dum1) == 2: + if not antisym: + dum10, dum11 = dum1 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 contravariant + p0 = dum10[0] + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + else: + dum10, dum11 = dum1 + # change the sign to bring the indices of the metric to contravariant + # form; change the sign if dum10 has the metric index in position 0 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 is contravariant + p0 = dum10[0] + if dum10[1] == 1: + sign = -sign + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if dum10[0] == 0: + sign = -sign + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + sign = -sign + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + + elif len(dum1) == 1: + if not antisym: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + else: + p1 = dp0 + + ind, p = free1[0] + free.append((ind, p1)) + else: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + if dp0 == 0: + sign = -sign + else: + p1 = dp0 + ind, p = free1[0] + free.append((ind, p1)) + dum = [x for x in dum if x not in dum1] + free = [x for x in free if x not in free1] + + # shift positions: + shift = 0 + shifts = [0]*len(args) + for i in range(len(args)): + if i in elim: + shift += 2 + continue + shifts[i] = shift + free = [(ind, p - shifts[pos_map[p]]) for (ind, p) in free if pos_map[p] not in elim] + dum = [(p0 - shifts[pos_map[p0]], p1 - shifts[pos_map[p1]]) for p0, p1 in dum if pos_map[p0] not in elim and pos_map[p1] not in elim] + + res = ( sign*TensMul(*args) ).doit(deep=False) + if not isinstance(res, TensExpr): + return res + im = _IndexStructure.from_components_free_dum(res.components, free, dum) + return res._set_new_index_structure(im) + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + args = list(self.args) + pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + ext_rank = arg.ext_rank + args[i] = arg._set_indices(*indices[pos:pos+ext_rank]) + pos += ext_rank + return TensMul(*args, is_canon_bp=is_canon_bp).doit(deep=False) + + @staticmethod + def _index_replacement_for_contract_metric(args, free, dum): + for arg in args: + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensMul(*new_args).doit(deep=False) + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + def _extract_data(self, replacement_dict): + args_indices, arrays = zip(*[arg._extract_data(replacement_dict) for arg in self.args if isinstance(arg, TensExpr)]) + coeff = reduce(operator.mul, [a for a in self.args if not isinstance(a, TensExpr)], S.One) + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + dum = TensMul._dummy_data_to_dum(dummy_data) + ext_rank = self.ext_rank + free.sort(key=lambda x: x[1]) + free_indices = [i[0] for i in free] + return free_indices, coeff*_TensorDataLazyEvaluator.data_contract_dum(arrays, dum, ext_rank) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + dat = _tensor_data_substitution_dict[self.expand()] + return dat + + @data.setter + def data(self, data): + deprecate_data() + raise ValueError("Not possible to set component data to a tensor expression") + + @data.deleter + def data(self): + deprecate_data() + raise ValueError("Not possible to delete component data to a tensor expression") + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No iteration on abstract tensors") + return self.data.__iter__() + + @staticmethod + def _dedupe_indices(new, exclude): + """ + exclude: set + new: TensExpr + + If ``new`` has any dummy indices that are in ``exclude``, return a version + of new with those indices replaced. If no replacements are needed, + return None + + """ + exclude = set(exclude) + dums_new = set(get_dummy_indices(new)) + free_new = set(get_free_indices(new)) + + conflicts = dums_new.intersection(exclude) + if len(conflicts) == 0: + return None + + """ + ``exclude_for_gen`` is to be passed to ``_IndexStructure._get_generator_for_dummy_indices()``. + Since the latter does not use the index position for anything, we just + set it as ``None`` here. + """ + exclude.update(dums_new) + exclude.update(free_new) + exclude_for_gen = [(i, None) for i in exclude] + gen = _IndexStructure._get_generator_for_dummy_indices(exclude_for_gen) + repl = {} + for d in conflicts: + if -d in repl.keys(): + continue + newname = gen(d.tensor_index_type) + new_d = d.func(newname, *d.args[1:]) + repl[d] = new_d + repl[-d] = -new_d + + if len(repl) == 0: + return None + + new_renamed = new._replace_indices(repl) + return new_renamed + + def _dedupe_indices_in_rule(self, rule): + """ + rule: dict + + This applies TensMul._dedupe_indices on all values of rule. + + """ + index_rules = {k:v for k,v in rule.items() if isinstance(k, TensorIndex)} + other_rules = {k:v for k,v in rule.items() if k not in index_rules.keys()} + exclude = set(self.get_indices()) + + newrule = {} + newrule.update(index_rules) + exclude.update(index_rules.keys()) + exclude.update(index_rules.values()) + for old, new in other_rules.items(): + new_renamed = TensMul._dedupe_indices(new, exclude) + if old == new or new_renamed is None: + newrule[old] = new + else: + newrule[old] = new_renamed + exclude.update(get_indices(new_renamed)) + return newrule + + def _eval_subs(self, old, new): + """ + If new is an index which is already present in self as a dummy, the dummies in self should be renamed. + """ + + if not isinstance(new, TensorIndex): + return None + + exclude = {new} + self_renamed = self._dedupe_indices(self, exclude) + if self_renamed is None: + return None + else: + return self_renamed._subs(old, new).doit(deep=False) + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + from sympy.concrete.summations import Sum + index_symbols = [i.args[0] for i in self.get_indices()] + args = [arg.args[0] if isinstance(arg, Sum) else arg for arg in args] + expr = Mul.fromiter(args) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s): + # Evaluation like Mul + terms = [] + for i, arg in enumerate(self.args): + # checking whether some tensor instance is differentiated + # or some other thing is necessary, but ugly + if isinstance(arg, TensExpr): + d = arg._eval_partial_derivative(s) + else: + # do not call diff is s is no symbol + if s._diff_wrt: + d = arg._eval_derivative(s) + else: + d = S.Zero + if d: + terms.append(TensMul.fromiter(self.args[:i] + (d,) + self.args[i + 1:]).doit(deep=False)) + return TensAdd.fromiter(terms).doit(deep=False) + + + def _matches_commutative(self, expr, repl_dict=None, old=False): + """ + Match assuming all tensors commute. But note that we are not assuming anything about their symmetry under index permutations. + """ + #Take care of the various possible types for expr. + if not isinstance(expr, TensMul): + if isinstance(expr, (TensExpr, Expr)): + expr = TensMul(expr) + else: + return None + + #The code that follows assumes expr is a TensMul + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #Make sure that none of the dummy indices in self, expr conflict with the values already present in repl_dict. This may happen due to automatic index relabelling when rem_query and rem_expr are formed later on in this function (it calls itself recursively). + indices = [k for k in repl_dict.values() if isinstance(k ,TensorIndex)] + + def dedupe(expr): + renamed = TensMul._dedupe_indices(expr, indices) + if renamed is not None: + return renamed + else: + return expr + + self = dedupe(self) + expr = dedupe(expr) + + #Find the non-tensor part of expr. This need not be the same as expr.coeff when expr.doit() has not been called. + expr_coeff = reduce(lambda a, b: a*b, [arg for arg in expr.args if not isinstance(arg, TensExpr)], S.One) + + # handle simple patterns + if self == expr: + return repl_dict + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + if isinstance(arg, WildTensor): + return "WildTensor" + elif isinstance(arg, (Tensor, TensExpr)): + return "Tensor" + else: + return "coeff" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #Sanity checks + if "coeff" in query_sifted.keys(): + if TensMul(*query_sifted["coeff"]).doit(deep=False) != self.coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {query_sifted['coeff']}") + if "coeff" in expr_sifted.keys(): + if TensMul(*expr_sifted["coeff"]).doit(deep=False) != expr_coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {expr_sifted['coeff']}") + + query_tens_heads = {tuple(getattr(x, "components", [])) for x in query_sifted["Tensor"]} #We use getattr because, e.g. TensAdd does not have the 'components' attribute. + expr_tens_heads = {tuple(getattr(x, "components", [])) for x in expr_sifted["Tensor"]} + if not query_tens_heads.issubset(expr_tens_heads): + #Some tensorheads in self are not present in the expr + return None + + #Try to match all non-wild tensors of self with tensors that compose expr + if len(query_sifted["Tensor"]) > 0: + q_tensor = query_sifted["Tensor"][0] + """ + We need to iterate over all possible symmetrized forms of q_tensor since the matches given by some of them may map dummy indices to free indices; the information about which indices are dummy/free will only be available later, when we are doing rem_q.matches(rem_e) + """ + for q_tens in q_tensor._get_symmetrized_forms(): + for e in expr_sifted["Tensor"]: + if isinstance(q_tens, TensMul): + #q_tensor got a minus sign due to this permutation. + sign = -1 + else: + sign = 1 + + """ + _matches is used here since we are already iterating over index permutations of q_tensor. Also note that the sign is removed from q_tensor, and will later be put into rem_q. + """ + m = (sign*q_tens)._matches(e) + if m is None: + continue + + rem_query = self.func(sign, *[a for a in self.args if a != q_tensor]).doit(deep=False) + rem_expr = expr.func(*[a for a in expr.args if a != e]).doit(deep=False) + tmp_repl = {} + tmp_repl.update(repl_dict) + tmp_repl.update(m) + rem_m = rem_query.matches(rem_expr, repl_dict=tmp_repl) + if rem_m is not None: + #Check that contracted indices are not mapped to different indices. + internally_consistent = True + for k in rem_m.keys(): + if isinstance(k,TensorIndex): + if -k in rem_m.keys() and rem_m[-k] != -rem_m[k]: + internally_consistent = False + break + if internally_consistent: + repl_dict.update(rem_m) + return repl_dict + + return None + + #Try to match WildTensor instances which have indices + matched_e_tensors = [] + remaining_e_tensors = expr_sifted["Tensor"] + indexless_wilds, wilds = sift(query_sifted["WildTensor"], lambda x: len(x.get_free_indices()) == 0, binary=True) + + for w in wilds: + free_this_wild = set(w.get_free_indices()) + tensors_to_try = [] + for t in remaining_e_tensors: + free = t.get_free_indices() + shares_indices_with_wild = True + for i in free: + if all(j.matches(i) is None for j in free_this_wild): + #The index i matches none of the indices in free_this_wild + shares_indices_with_wild = False + if shares_indices_with_wild: + tensors_to_try.append(t) + + m = w.matches(TensMul(*tensors_to_try).doit(deep=False) ) + if m is None: + return None + else: + for tens in tensors_to_try: + matched_e_tensors.append(tens) + repl_dict.update(m) + + #Try to match indexless WildTensor instances + remaining_e_tensors = [t for t in expr_sifted["Tensor"] if t not in matched_e_tensors] + if len(indexless_wilds) > 0: + #If there are any remaining tensors, match them with the indexless WildTensor + m = indexless_wilds[0].matches( TensMul(1,*remaining_e_tensors).doit(deep=False) ) + if m is None: + return None + else: + repl_dict.update(m) + elif len(remaining_e_tensors) > 0: + return None + + #Try to match the non-tensorial coefficient + m = self.coeff.matches(expr_coeff, old=old) + if m is None: + return None + else: + repl_dict.update(m) + + return repl_dict + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + commute = all(arg.component.comm == 0 for arg in expr.args if isinstance(arg, Tensor)) + if commute: + return self._matches_commutative(expr, repl_dict, old) + else: + raise NotImplementedError("Tensor matching not implemented for non-commuting tensors") + +class TensorElement(TensExpr): + """ + Tensor with evaluated components. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> i, j, k = symbols("i j k") + >>> A = TensorHead("A", [L, L], TensorSymmetry.fully_symmetric(2)) + >>> A(i, j).get_free_indices() + [i, j] + + If we want to set component ``i`` to a specific value, use the + ``TensorElement`` class: + + >>> from sympy.tensor.tensor import TensorElement + >>> te = TensorElement(A(i, j), {i: 2}) + + As index ``i`` has been accessed (``{i: 2}`` is the evaluation of its 3rd + element), the free indices will only contain ``j``: + + >>> te.get_free_indices() + [j] + """ + + def __new__(cls, expr, index_map): + if not isinstance(expr, Tensor): + # remap + if not isinstance(expr, TensExpr): + raise TypeError("%s is not a tensor expression" % expr) + return expr.func(*[TensorElement(arg, index_map) for arg in expr.args]) + expr_free_indices = expr.get_free_indices() + name_translation = {i.args[0]: i for i in expr_free_indices} + index_map = {name_translation.get(index, index): value for index, value in index_map.items()} + index_map = {index: value for index, value in index_map.items() if index in expr_free_indices} + if len(index_map) == 0: + return expr + free_indices = [i for i in expr_free_indices if i not in index_map.keys()] + index_map = Dict(index_map) + obj = TensExpr.__new__(cls, expr, index_map) + obj._free_indices = free_indices + return obj + + @property + def free(self): + return [(index, i) for i, index in enumerate(self.get_free_indices())] + + @property + def dum(self): + # TODO: inherit dummies from expr + return [] + + @property + def expr(self): + return self._args[0] + + @property + def index_map(self): + return self._args[1] + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self): + return self._free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: can be improved: + return self.xreplace(repl) + + def get_indices(self): + return self.get_free_indices() + + def _extract_data(self, replacement_dict): + ret_indices, array = self.expr._extract_data(replacement_dict) + index_map = self.index_map + slice_tuple = tuple(index_map.get(i, slice(None)) for i in ret_indices) + ret_indices = [i for i in ret_indices if i not in index_map] + array = array.__getitem__(slice_tuple) + return ret_indices, array + + +class WildTensorHead(TensorHead): + """ + A wild object that is used to create ``WildTensor`` instances + + Explanation + =========== + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + + A WildTensorHead can be created without specifying a ``TensorIndexType`` + + >>> W = WildTensorHead("W") + + Calling it with a ``TensorIndex`` creates a ``WildTensor`` instance. + + >>> type(W(p)) + + + The ``TensorIndexType`` is automatically detected from the index that is passed + + >>> W(p).component + W(R3) + + Calling it with no indices returns an object that can match tensors with any number of indices. + + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + If you want to ignore the order of indices while matching, pass ``unordered_indices=True``. + + >>> U = WildTensorHead("U", unordered_indices=True) + >>> W(p,q).matches(Q(q,p)) + >>> U(p,q).matches(Q(q,p)) + {U(R3,R3): _WildTensExpr(Q(q, p))} + + Parameters + ========== + name : name of the tensor + unordered_indices : whether the order of the indices matters for matching + (default: False) + + See also + ======== + ``WildTensor`` + ``TensorHead`` + + """ + def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indices=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if index_types is None: + index_types = [] + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + if symmetry != TensorSymmetry.no_symmetry(len(index_types)): + raise NotImplementedError("Wild matching based on symmetry is not implemented.") + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices)) + + return obj + + @property + def unordered_indices(self): + return self.args[4] + + def __call__(self, *indices, **kwargs): + tensor = WildTensor(self, indices, **kwargs) + return tensor.doit() + + +class WildTensor(Tensor): + """ + A wild object which matches ``Tensor`` instances + + Explanation + =========== + This is instantiated by attaching indices to a ``WildTensorHead`` instance. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + + Matching also takes the indices into account + >>> W(p).matches(K(p)) + {W(R3): _WildTensExpr(K(p))} + >>> W(p).matches(K(q)) + >>> W(p).matches(K(-p)) + + If you want to match objects with any number of indices, just use a ``WildTensor`` with no indices. + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + See Also + ======== + ``WildTensorHead`` + ``Tensor`` + + """ + def __new__(cls, tensor_head, indices, **kw_args): + is_canon_bp = kw_args.pop("is_canon_bp", False) + + if tensor_head.func == TensorHead: + """ + If someone tried to call WildTensor by supplying a TensorHead (not a WildTensorHead), return a normal tensor instead. This is helpful when using subs on an expression to replace occurrences of a WildTensorHead with a TensorHead. + """ + return Tensor(tensor_head, indices, is_canon_bp=is_canon_bp, **kw_args) + elif tensor_head.func == _WildTensExpr: + return tensor_head(*indices) + + indices = cls._parse_indices(tensor_head, indices) + index_types = [ind.tensor_index_type for ind in indices] + tensor_head = tensor_head.func( + tensor_head.name, + index_types, + symmetry=None, + comm=tensor_head.comm, + unordered_indices=tensor_head.unordered_indices, + ) + + obj = Basic.__new__(cls, tensor_head, Tuple(*indices)) + obj.name = tensor_head.name + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = obj._build_index_map(indices, obj._index_structure) + + return obj + + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensExpr) and expr != S(1): + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if len(self.indices) > 0: + if not hasattr(expr, "get_free_indices"): + return None + expr_indices = expr.get_free_indices() + if len(expr_indices) != len(self.indices): + return None + if self._component.unordered_indices: + m = self._match_indices_ignoring_order(expr) + if m is None: + return None + else: + repl_dict.update(m) + else: + for i in range(len(expr_indices)): + m = self.indices[i].matches(expr_indices[i]) + if m is None: + return None + else: + repl_dict.update(m) + + repl_dict[self.component] = _WildTensExpr(expr) + else: + #If no indices were passed to the WildTensor, it may match tensors with any number of indices. + repl_dict[self] = expr + + return repl_dict + + def _match_indices_ignoring_order(self, expr, repl_dict=None, old=False): + """ + Helper method for matches. Checks if the indices of self and expr + match disregarding index ordering. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + def siftkey(ind): + if isinstance(ind, WildTensorIndex): + if ind.ignore_updown: + return "wild, updown" + else: + return "wild" + else: + return "nonwild" + + indices_sifted = sift(self.indices, siftkey) + + matched_indices = [] + expr_indices_remaining = expr.get_indices() + for ind in indices_sifted["nonwild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + if e_ind in matched_indices: + continue + m = ind.matches(e_ind) + if m is not None: + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild, updown"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + if len(matched_indices) < len(self.indices): + return None + else: + return repl_dict + +class WildTensorIndex(TensorIndex): + """ + A wild object that matches TensorIndex instances. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorIndex, TensorIndexType, WildTensorIndex + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex("p", R3) + + By default, covariant indices only match with covariant indices (and + similarly for contravariant) + + >>> q = WildTensorIndex("q", R3) + >>> (q).matches(p) + {q: p} + >>> (q).matches(-p) + + If you want matching to ignore whether the index is co/contra-variant, set + ignore_updown=True + + >>> r = WildTensorIndex("r", R3, ignore_updown=True) + >>> (r).matches(-p) + {r: -p} + >>> (r).matches(p) + {r: p} + + Parameters + ========== + name : name of the index (string), or ``True`` if you want it to be + automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + ignore_updown : bool, Whether this should match both co- and contra-variant + indices (default:False) + """ + def __new__(cls, name, tensor_index_type, is_up=True, ignore_updown=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + ignore_updown = sympify(ignore_updown) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up, ignore_updown) + + @property + def ignore_updown(self): + return self.args[3] + + def __neg__(self): + t1 = WildTensorIndex(self.name, self.tensor_index_type, + (not self.is_up), self.ignore_updown) + return t1 + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensorIndex): + return None + if self.tensor_index_type != expr.tensor_index_type: + return None + if not self.ignore_updown: + if self.is_up != expr.is_up: + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + repl_dict[self] = expr + return repl_dict + + +class _WildTensExpr(Basic): + """ + INTERNAL USE ONLY + + This is an object that helps with replacement of WildTensors in expressions. + When this object is set as the tensor_head of a WildTensor, it replaces the + WildTensor by a TensExpr (passed when initializing this object). + + Examples + ======== + >>> from sympy.tensor.tensor import WildTensorHead, TensorIndex, TensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> print( ( K(p) ).replace( W(p), W(q)*W(-q)*W(p) ) ) + K(R_0)*K(-R_0)*K(p) + + """ + def __init__(self, expr): + if not isinstance(expr, TensExpr): + raise TypeError("_WildTensExpr expects a TensExpr as argument") + self.expr = expr + + def __call__(self, *indices): + return self.expr._replace_indices(dict(zip(self.expr.get_free_indices(), indices))) + + def __neg__(self): + return self.func(self.expr*S.NegativeOne) + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(self.expr+other.expr) + + def __radd__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(other.expr+self.expr) + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return other + (-self) + + def __mul__(self, other): + raise NotImplementedError + + def __rmul__(self, other): + raise NotImplementedError + + def __truediv__(self, other): + raise NotImplementedError + + def __rtruediv__(self, other): + raise NotImplementedError + + def __pow__(self, other): + raise NotImplementedError + + def __rpow__(self, other): + raise NotImplementedError + + +def canon_bp(p): + """ + Butler-Portugal canonicalization. See ``tensor_can.py`` from the + combinatorics module for the details. + """ + if isinstance(p, TensExpr): + return p.canon_bp() + return p + + +def tensor_mul(*a): + """ + product of tensors + """ + if not a: + return TensMul.from_data(S.One, [], [], []) + t = a[0] + for tx in a[1:]: + t = t*tx + return t + + +def riemann_cyclic_replace(t_r): + """ + replace Riemann tensor with an equivalent expression + + ``R(m,n,p,q) -> 2/3*R(m,n,p,q) - 1/3*R(m,q,n,p) + 1/3*R(m,p,n,q)`` + + """ + free = sorted(t_r.free, key=lambda x: x[1]) + m, n, p, q = [x[0] for x in free] + t0 = t_r*Rational(2, 3) + t1 = -t_r.substitute_indices((m,m),(n,q),(p,n),(q,p))*Rational(1, 3) + t2 = t_r.substitute_indices((m,m),(n,p),(p,n),(q,q))*Rational(1, 3) + t3 = t0 + t1 + t2 + return t3 + +def riemann_cyclic(t2): + """ + Replace each Riemann tensor with an equivalent expression + satisfying the cyclic identity. + + This trick is discussed in the reference guide to Cadabra. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, riemann_cyclic, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + >>> t = R(i,j,k,l)*(R(-i,-j,-k,-l) - 2*R(-i,-k,-j,-l)) + >>> riemann_cyclic(t) + 0 + """ + t2 = t2.expand() + if isinstance(t2, (TensMul, Tensor)): + args = [t2] + else: + args = t2.args + a1 = [x.split() for x in args] + a2 = [[riemann_cyclic_replace(tx) for tx in y] for y in a1] + a3 = [tensor_mul(*v) for v in a2] + t3 = TensAdd(*a3).doit(deep=False) + if not t3: + return t3 + else: + return canon_bp(t3) + + +def get_lines(ex, index_type): + """ + Returns ``(lines, traces, rest)`` for an index type, + where ``lines`` is the list of list of positions of a matrix line, + ``traces`` is the list of list of traced matrix lines, + ``rest`` is the rest of the elements of the tensor. + """ + def _join_lines(a): + i = 0 + while i < len(a): + x = a[i] + xend = x[-1] + xstart = x[0] + hit = True + while hit: + hit = False + for j in range(i + 1, len(a)): + if j >= len(a): + break + if a[j][0] == xend: + hit = True + x.extend(a[j][1:]) + xend = x[-1] + a.pop(j) + continue + if a[j][0] == xstart: + hit = True + a[i] = reversed(a[j][1:]) + x + x = a[i] + xstart = a[i][0] + a.pop(j) + continue + if a[j][-1] == xend: + hit = True + x.extend(reversed(a[j][:-1])) + xend = x[-1] + a.pop(j) + continue + if a[j][-1] == xstart: + hit = True + a[i] = a[j][:-1] + x + x = a[i] + xstart = x[0] + a.pop(j) + continue + i += 1 + return a + + arguments = ex.args + dt = {} + for c in ex.args: + if not isinstance(c, TensExpr): + continue + if c in dt: + continue + index_types = c.index_types + a = [] + for i in range(len(index_types)): + if index_types[i] is index_type: + a.append(i) + if len(a) > 2: + raise ValueError('at most two indices of type %s allowed' % index_type) + if len(a) == 2: + dt[c] = a + #dum = ex.dum + lines = [] + traces = [] + traces1 = [] + #indices_to_args_pos = ex._get_indices_to_args_pos() + # TODO: add a dum_to_components_map ? + for p0, p1, c0, c1 in ex.dum_in_args: + if arguments[c0] not in dt: + continue + if c0 == c1: + traces.append([c0]) + continue + ta0 = dt[arguments[c0]] + ta1 = dt[arguments[c1]] + if p0 not in ta0: + continue + if ta0.index(p0) == ta1.index(p1): + # case gamma(i,s0,-s1) in c0, gamma(j,-s0,s2) in c1; + # to deal with this case one could add to the position + # a flag for transposition; + # one could write [(c0, False), (c1, True)] + raise NotImplementedError + # if p0 == ta0[1] then G in pos c0 is mult on the right by G in c1 + # if p0 == ta0[0] then G in pos c1 is mult on the right by G in c0 + ta0 = dt[arguments[c0]] + b0, b1 = (c0, c1) if p0 == ta0[1] else (c1, c0) + lines1 = lines.copy() + for line in lines: + if line[-1] == b0: + if line[0] == b1: + n = line.index(min(line)) + traces1.append(line) + traces.append(line[n:] + line[:n]) + else: + line.append(b1) + break + elif line[0] == b1: + line.insert(0, b0) + break + else: + lines1.append([b0, b1]) + + lines = [x for x in lines1 if x not in traces1] + lines = _join_lines(lines) + rest = [] + for line in lines: + for y in line: + rest.append(y) + for line in traces: + for y in line: + rest.append(y) + rest = [x for x in range(len(arguments)) if x not in rest] + + return lines, traces, rest + + +def get_free_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_free_indices() + + +def get_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_indices() + +def get_dummy_indices(t): + if not isinstance(t, TensExpr): + return () + inds = t.get_indices() + free = t.get_free_indices() + return [i for i in inds if i not in free] + +def get_index_structure(t): + if isinstance(t, TensExpr): + return t._index_structure + return _IndexStructure([], [], [], []) + + +def get_coeff(t): + if isinstance(t, Tensor): + return S.One + if isinstance(t, TensMul): + return t.coeff + if isinstance(t, TensExpr): + raise ValueError("no coefficient associated to this tensor expression") + return t + +def contract_metric(t, g): + if isinstance(t, TensExpr): + return t.contract_metric(g) + return t + +def perm2tensor(t, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + if not isinstance(t, TensExpr): + return t + elif isinstance(t, (Tensor, TensMul)): + nim = get_index_structure(t).perm2tensor(g, is_canon_bp=is_canon_bp) + res = t._set_new_index_structure(nim, is_canon_bp=is_canon_bp) + if g[-1] != len(g) - 1: + return -res + + return res + raise NotImplementedError() + + +def substitute_indices(t, *index_tuples): + if not isinstance(t, TensExpr): + return t + return t.substitute_indices(*index_tuples) + + +def _get_wilds(expr): + return list(expr.atoms(Wild, WildFunction, WildTensor, WildTensorIndex, WildTensorHead)) + + +def get_postprocessor(cls): + def _postprocessor(expr): + tens_class = {Mul: TensMul, Add: TensAdd}[cls] + if any(isinstance(a, TensExpr) for a in expr.args): + return tens_class(*expr.args) + else: + return expr + + return _postprocessor + +Basic._constructor_postprocessor_mapping[TensExpr] = { + "Mul": [get_postprocessor(Mul)], +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..467d0dad6dd2072e16f41efcc59331342f29c8fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91db13cfa9191d16304df3f6a405bf1038dc2c79 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fcc6b1edb019e68d1f0dd926c0970e4339fcacb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e86f2dbf6187131eaadcfd0de7639ebf991110 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fdcbc27553676fb5210cf23425e384e4cdc988a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ee5120f757e7877bd9c22a2961812d4f1a7b683 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cef6b8bd2995e31ea55fd59a853967ddad5d5381 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ae40865d1bddffaa976dc3d94ae1ef1b6c97ca35 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_functions.py @@ -0,0 +1,57 @@ +from sympy.tensor.functions import TensorProduct +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array import Array +from sympy.abc import x, y, z +from sympy.abc import i, j, k, l + + +A = MatrixSymbol("A", 3, 3) +B = MatrixSymbol("B", 3, 3) +C = MatrixSymbol("C", 3, 3) + + +def test_TensorProduct_construction(): + assert TensorProduct(3, 4) == 12 + assert isinstance(TensorProduct(A, A), TensorProduct) + + expr = TensorProduct(TensorProduct(x, y), z) + assert expr == x*y*z + + expr = TensorProduct(TensorProduct(A, B), C) + assert expr == TensorProduct(A, B, C) + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]])) + assert expr == Array([ + [ + [[0, -1], [1, 0]], + [[0, 0], [0, 0]] + ], + [ + [[0, 0], [0, 0]], + [[0, -1], [1, 0]] + ] + ]) + + +def test_TensorProduct_shape(): + + expr = TensorProduct(3, 4, evaluate=False) + assert expr.shape == () + assert expr.rank() == 0 + + expr = TensorProduct(Array([1, 2]), Array([x, y]), evaluate=False) + assert expr.shape == (2, 2) + assert expr.rank() == 2 + expr = TensorProduct(expr, expr, evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]]), evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + +def test_TensorProduct_getitem(): + expr = TensorProduct(A, B) + assert expr[i, j, k, l] == A[i, j]*B[k, l] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_index_methods.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..df20f7e7c1ab392321e8350b95dd07c5639c1865 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_index_methods.py @@ -0,0 +1,227 @@ +from sympy.core import symbols, S, Pow, Function +from sympy.functions import exp +from sympy.testing.pytest import raises +from sympy.tensor.indexed import Idx, IndexedBase +from sympy.tensor.index_methods import IndexConformanceException + +from sympy.tensor.index_methods import (get_contraction_structure, get_indices) + + +def test_trivial_indices(): + x, y = symbols('x y') + assert get_indices(x) == (set(), {}) + assert get_indices(x*y) == (set(), {}) + assert get_indices(x + y) == (set(), {}) + assert get_indices(x**y) == (set(), {}) + + +def test_get_indices_Indexed(): + x = IndexedBase('x') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i, j]) == ({i, j}, {}) + assert get_indices(x[j, i]) == ({j, i}, {}) + + +def test_get_indices_Idx(): + f = Function('f') + i, j = Idx('i'), Idx('j') + assert get_indices(f(i)*j) == ({i, j}, {}) + assert get_indices(f(j, i)) == ({j, i}, {}) + assert get_indices(f(i)*i) == (set(), {}) + + +def test_get_indices_mul(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[j]*y[i]) == ({i, j}, {}) + assert get_indices(x[i]*y[j]) == ({i, j}, {}) + + +def test_get_indices_exceptions(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + raises(IndexConformanceException, lambda: get_indices(x[i] + y[j])) + + +def test_scalar_broadcast(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i] + y[i, i]) == ({i}, {}) + assert get_indices(x[i] + y[j, j]) == ({i}, {}) + + +def test_get_indices_add(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(x[i] + 2*y[i]) == ({i}, {}) + assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {}) + assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {}) + assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {}) + assert get_indices( + y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {}) + + +def test_get_indices_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(Pow(x[i], y[j])) == ({i, j}, {}) + assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {}) + assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {}) + assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i])) + + # test of a design decision, this may change: + assert get_indices(Pow(x[i], 2)) == ({i}, {}) + + +def test_get_contraction_structure_basic(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_contraction_structure(x[i]*y[j]) == {None: {x[i]*y[j]}} + assert get_contraction_structure(x[i] + y[j]) == {None: {x[i], y[j]}} + assert get_contraction_structure(x[i]*y[i]) == {(i,): {x[i]*y[i]}} + assert get_contraction_structure( + 1 + x[i]*y[i]) == {None: {S.One}, (i,): {x[i]*y[i]}} + assert get_contraction_structure(x[i]**y[i]) == {None: {x[i]**y[i]}} + + +def test_get_contraction_structure_complex(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + expr1 = y[i] + A[i, j]*x[j] + d1 = {None: {y[i]}, (j,): {A[i, j]*x[j]}} + assert get_contraction_structure(expr1) == d1 + expr2 = expr1*A[k, i] + x[k] + d2 = {None: {x[k]}, (i,): {expr1*A[k, i]}, expr1*A[k, i]: [d1]} + assert get_contraction_structure(expr2) == d2 + + +def test_contraction_structure_simple_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj = x[i, i]**y[j, j] + assert get_contraction_structure(ii_jj) == { + None: {ii_jj}, + ii_jj: [ + {(i,): {x[i, i]}}, + {(j,): {y[j, j]}} + ] + } + + ii_jk = x[i, i]**y[j, k] + assert get_contraction_structure(ii_jk) == { + None: {x[i, i]**y[j, k]}, + x[i, i]**y[j, k]: [ + {(i,): {x[i, i]}} + ] + } + + +def test_contraction_structure_Mul_and_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + + i_ji = x[i]**(y[j]*x[i]) + assert get_contraction_structure(i_ji) == {None: {i_ji}} + ij_i = (x[i]*y[j])**(y[i]) + assert get_contraction_structure(ij_i) == {None: {ij_i}} + j_ij_i = x[j]*(x[i]*y[j])**(y[i]) + assert get_contraction_structure(j_ij_i) == {(j,): {j_ij_i}} + j_i_ji = x[j]*x[i]**(y[j]*x[i]) + assert get_contraction_structure(j_i_ji) == {(j,): {j_i_ji}} + ij_exp_kki = x[i]*y[j]*exp(y[i]*y[k, k]) + result = get_contraction_structure(ij_exp_kki) + expected = { + (i,): {ij_exp_kki}, + ij_exp_kki: [{ + None: {exp(y[i]*y[k, k])}, + exp(y[i]*y[k, k]): [{ + None: {y[i]*y[k, k]}, + y[i]*y[k, k]: [{(k,): {y[k, k]}}] + }]} + ] + } + assert result == expected + + +def test_contraction_structure_Add_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + s_ii_jj_s = (1 + x[i, i])**(1 + y[j, j]) + expected = { + None: {s_ii_jj_s}, + s_ii_jj_s: [ + {None: {S.One}, (i,): {x[i, i]}}, + {None: {S.One}, (j,): {y[j, j]}} + ] + } + result = get_contraction_structure(s_ii_jj_s) + assert result == expected + + s_ii_jk_s = (1 + x[i, i]) ** (1 + y[j, k]) + expected_2 = { + None: {(x[i, i] + 1)**(y[j, k] + 1)}, + s_ii_jk_s: [ + {None: {S.One}, (i,): {x[i, i]}} + ] + } + result_2 = get_contraction_structure(s_ii_jk_s) + assert result_2 == expected_2 + + +def test_contraction_structure_Pow_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj_kk = x[i, i]**y[j, j]**z[k, k] + expected = { + None: {ii_jj_kk}, + ii_jj_kk: [ + {(i,): {x[i, i]}}, + { + None: {y[j, j]**z[k, k]}, + y[j, j]**z[k, k]: [ + {(j,): {y[j, j]}}, + {(k,): {z[k, k]}} + ] + } + ] + } + assert get_contraction_structure(ii_jj_kk) == expected + + +def test_ufunc_support(): + f = Function('f') + g = Function('g') + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + a = symbols('a') + + assert get_indices(f(x[i])) == ({i}, {}) + assert get_indices(f(x[i], y[j])) == ({i, j}, {}) + assert get_indices(f(y[i])*g(x[i])) == (set(), {}) + assert get_indices(f(a, x[i])) == ({i}, {}) + assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {}) + assert get_indices(g(f(x[i]))) == ({i}, {}) + + assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}} + assert get_contraction_structure( + f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}} + assert get_contraction_structure( + f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}} + assert get_contraction_structure( + f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_indexed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..689ec932c8fcefe0a24de289dd2ffd6820c63f19 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_indexed.py @@ -0,0 +1,511 @@ +from sympy.core import symbols, Symbol, Tuple, oo, Dummy +from sympy.tensor.indexed import IndexException +from sympy.testing.pytest import raises +from sympy.utilities.iterables import iterable + +# import test: +from sympy.concrete.summations import Sum +from sympy.core.function import Function, Subs, Derivative +from sympy.core.relational import (StrictLessThan, GreaterThan, + StrictGreaterThan, LessThan) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.series.order import Order +from sympy.sets.fancysets import Range +from sympy.tensor.indexed import IndexedBase, Idx, Indexed + + +def test_Idx_construction(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i) != Idx(i, 1) + assert Idx(i, a) == Idx(i, (0, a - 1)) + assert Idx(i, oo) == Idx(i, (0, oo)) + + x = symbols('x', integer=False) + raises(TypeError, lambda: Idx(x)) + raises(TypeError, lambda: Idx(0.5)) + raises(TypeError, lambda: Idx(i, x)) + raises(TypeError, lambda: Idx(i, 0.5)) + raises(TypeError, lambda: Idx(i, (x, 5))) + raises(TypeError, lambda: Idx(i, (2, x))) + raises(TypeError, lambda: Idx(i, (2, 3.5))) + + +def test_Idx_properties(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).is_integer + assert Idx(i).name == 'i' + assert Idx(i + 2).name == 'i + 2' + assert Idx('foo').name == 'foo' + + +def test_Idx_bounds(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).lower is None + assert Idx(i).upper is None + assert Idx(i, a).lower == 0 + assert Idx(i, a).upper == a - 1 + assert Idx(i, 5).lower == 0 + assert Idx(i, 5).upper == 4 + assert Idx(i, oo).lower == 0 + assert Idx(i, oo).upper is oo + assert Idx(i, (a, b)).lower == a + assert Idx(i, (a, b)).upper == b + assert Idx(i, (1, 5)).lower == 1 + assert Idx(i, (1, 5)).upper == 5 + assert Idx(i, (-oo, oo)).lower is -oo + assert Idx(i, (-oo, oo)).upper is oo + + +def test_Idx_fixed_bounds(): + i, a, b, x = symbols('i a b x', integer=True) + assert Idx(x).lower is None + assert Idx(x).upper is None + assert Idx(x, a).lower == 0 + assert Idx(x, a).upper == a - 1 + assert Idx(x, 5).lower == 0 + assert Idx(x, 5).upper == 4 + assert Idx(x, oo).lower == 0 + assert Idx(x, oo).upper is oo + assert Idx(x, (a, b)).lower == a + assert Idx(x, (a, b)).upper == b + assert Idx(x, (1, 5)).lower == 1 + assert Idx(x, (1, 5)).upper == 5 + assert Idx(x, (-oo, oo)).lower is -oo + assert Idx(x, (-oo, oo)).upper is oo + + +def test_Idx_inequalities(): + i14 = Idx("i14", (1, 4)) + i79 = Idx("i79", (7, 9)) + i46 = Idx("i46", (4, 6)) + i35 = Idx("i35", (3, 5)) + + assert i14 <= 5 + assert i14 < 5 + assert not (i14 >= 5) + assert not (i14 > 5) + + assert 5 >= i14 + assert 5 > i14 + assert not (5 <= i14) + assert not (5 < i14) + + assert LessThan(i14, 5) + assert StrictLessThan(i14, 5) + assert not GreaterThan(i14, 5) + assert not StrictGreaterThan(i14, 5) + + assert i14 <= 4 + assert isinstance(i14 < 4, StrictLessThan) + assert isinstance(i14 >= 4, GreaterThan) + assert not (i14 > 4) + + assert isinstance(i14 <= 1, LessThan) + assert not (i14 < 1) + assert i14 >= 1 + assert isinstance(i14 > 1, StrictGreaterThan) + + assert not (i14 <= 0) + assert not (i14 < 0) + assert i14 >= 0 + assert i14 > 0 + + from sympy.abc import x + + assert isinstance(i14 < x, StrictLessThan) + assert isinstance(i14 > x, StrictGreaterThan) + assert isinstance(i14 <= x, LessThan) + assert isinstance(i14 >= x, GreaterThan) + + assert i14 < i79 + assert i14 <= i79 + assert not (i14 > i79) + assert not (i14 >= i79) + + assert i14 <= i46 + assert isinstance(i14 < i46, StrictLessThan) + assert isinstance(i14 >= i46, GreaterThan) + assert not (i14 > i46) + + assert isinstance(i14 < i35, StrictLessThan) + assert isinstance(i14 > i35, StrictGreaterThan) + assert isinstance(i14 <= i35, LessThan) + assert isinstance(i14 >= i35, GreaterThan) + + iNone1 = Idx("iNone1") + iNone2 = Idx("iNone2") + + assert isinstance(iNone1 < iNone2, StrictLessThan) + assert isinstance(iNone1 > iNone2, StrictGreaterThan) + assert isinstance(iNone1 <= iNone2, LessThan) + assert isinstance(iNone1 >= iNone2, GreaterThan) + + +def test_Idx_inequalities_current_fails(): + i14 = Idx("i14", (1, 4)) + + assert S(5) >= i14 + assert S(5) > i14 + assert not (S(5) <= i14) + assert not (S(5) < i14) + + +def test_Idx_func_args(): + i, a, b = symbols('i a b', integer=True) + ii = Idx(i) + assert ii.func(*ii.args) == ii + ii = Idx(i, a) + assert ii.func(*ii.args) == ii + ii = Idx(i, (a, b)) + assert ii.func(*ii.args) == ii + + +def test_Idx_subs(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i, a).subs(a, b) == Idx(i, b) + assert Idx(i, a).subs(i, b) == Idx(b, a) + + assert Idx(i).subs(i, 2) == Idx(2) + assert Idx(i, a).subs(a, 2) == Idx(i, 2) + assert Idx(i, (a, b)).subs(i, 2) == Idx(2, (a, b)) + + +def test_IndexedBase_sugar(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A1 = Indexed(a, i, j) + A2 = IndexedBase(a) + assert A1 == A2[i, j] + assert A1 == A2[(i, j)] + assert A1 == A2[[i, j]] + assert A1 == A2[Tuple(i, j)] + assert all(a.is_Integer for a in A2[1, 0].args[1:]) + + +def test_IndexedBase_subs(): + i = symbols('i', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i] == B[i].subs(b, a) + C = {1: 2} + assert C[1] == A[1].subs(A, C) + + +def test_IndexedBase_shape(): + i, j, m, n = symbols('i j m n', integer=True) + a = IndexedBase('a', shape=(m, m)) + b = IndexedBase('a', shape=(m, n)) + assert b.shape == Tuple(m, n) + assert a[i, j] != b[i, j] + assert a[i, j] == b[i, j].subs(n, m) + assert b.func(*b.args) == b + assert b[i, j].func(*b[i, j].args) == b[i, j] + raises(IndexException, lambda: b[i]) + raises(IndexException, lambda: b[i, i, j]) + F = IndexedBase("F", shape=m) + assert F.shape == Tuple(m) + assert F[i].subs(i, j) == F[j] + raises(IndexException, lambda: F[i, j]) + + +def test_IndexedBase_assumptions(): + i = Symbol('i', integer=True) + a = Symbol('a') + A = IndexedBase(a, positive=True) + for c in (A, A[i]): + assert c.is_real + assert c.is_complex + assert not c.is_imaginary + assert c.is_nonnegative + assert c.is_nonzero + assert c.is_commutative + assert log(exp(c)) == c + + assert A != IndexedBase(a) + assert A == IndexedBase(a, positive=True, real=True) + assert A[i] != Indexed(a, i) + + +def test_IndexedBase_assumptions_inheritance(): + I = Symbol('I', integer=True) + I_inherit = IndexedBase(I) + I_explicit = IndexedBase('I', integer=True) + + assert I_inherit.is_integer + assert I_explicit.is_integer + assert I_inherit.label.is_integer + assert I_explicit.label.is_integer + assert I_inherit == I_explicit + + +def test_issue_17652(): + """Regression test issue #17652. + + IndexedBase.label should not upcast subclasses of Symbol + """ + class SubClass(Symbol): + pass + + x = SubClass('X') + assert type(x) == SubClass + base = IndexedBase(x) + assert type(x) == SubClass + assert type(base.label) == SubClass + + +def test_Indexed_constructor(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A == Indexed(Symbol('A'), i, j) + assert A == Indexed(IndexedBase('A'), i, j) + raises(TypeError, lambda: Indexed(A, i, j)) + raises(IndexException, lambda: Indexed("A")) + assert A.free_symbols == {A, A.base.label, i, j} + + +def test_Indexed_func_args(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A = Indexed(a, i, j) + assert A == A.func(*A.args) + + +def test_Indexed_subs(): + i, j, k = symbols('i j k', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i, j] == B[i, j].subs(b, a) + assert A[i, j] == A[i, k].subs(k, j) + + +def test_Indexed_properties(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A.name == 'A[i, j]' + assert A.rank == 2 + assert A.indices == (i, j) + assert A.base == IndexedBase('A') + assert A.ranges == [None, None] + raises(IndexException, lambda: A.shape) + + n, m = symbols('n m', integer=True) + assert Indexed('A', Idx( + i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n) + raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape) + + +def test_Indexed_shape_precedence(): + i, j = symbols('i j', integer=True) + o, p = symbols('o p', integer=True) + n, m = symbols('n m', integer=True) + a = IndexedBase('a', shape=(o, p)) + assert a.shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), (None, None)] + assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p) + + +def test_complex_indices(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert A.rank == 2 + assert A.indices == (i, i + j) + + +def test_not_interable(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert not iterable(A) + + +def test_Indexed_coeff(): + N = Symbol('N', integer=True) + len_y = N + i = Idx('i', len_y-1) + y = IndexedBase('y', shape=(len_y,)) + a = (1/y[i+1]*y[i]).coeff(y[i]) + b = (y[i]/y[i+1]).coeff(y[i]) + assert a == b + + +def test_differentiation(): + from sympy.functions.special.tensor_functions import KroneckerDelta + i, j, k, l = symbols('i j k l', cls=Idx) + a = symbols('a') + m, n = symbols("m, n", integer=True, finite=True) + assert m.is_real + h, L = symbols('h L', cls=IndexedBase) + hi, hj = h[i], h[j] + + expr = hi + assert expr.diff(hj) == KroneckerDelta(i, j) + assert expr.diff(hi) == KroneckerDelta(i, i) + + expr = S(2) * hi + assert expr.diff(hj) == S(2) * KroneckerDelta(i, j) + assert expr.diff(hi) == S(2) * KroneckerDelta(i, i) + assert expr.diff(a) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hj) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr.diff(hj), (i, -oo, oo)) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hj).doit() == 2 + + assert Sum(expr.diff(hi), (i, -oo, oo)).doit() == Sum(2, (i, -oo, oo)).doit() + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() is oo + + expr = a * hj * hj / S(2) + assert expr.diff(hi) == a * h[j] * KroneckerDelta(i, j) + assert expr.diff(a) == hj * hj / S(2) + assert expr.diff(a, 2) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr.diff(hi), (i, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() == a*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr.diff(hi), (j, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(hi).doit() == a*h[i] + + expr = a * sin(hj * hj) + assert expr.diff(hi) == 2*a*cos(hj * hj) * hj * KroneckerDelta(i, j) + assert expr.diff(hj) == 2*a*cos(hj * hj) * hj + + expr = a * L[i, j] * h[j] + assert expr.diff(hi) == a*L[i, j]*KroneckerDelta(i, j) + assert expr.diff(hj) == a*L[i, j] + assert expr.diff(L[i, j]) == a*h[j] + assert expr.diff(L[k, l]) == a*KroneckerDelta(i, k)*KroneckerDelta(j, l)*h[j] + assert expr.diff(L[i, l]) == a*KroneckerDelta(j, l)*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]) == Sum(a * KroneckerDelta(i, k) * KroneckerDelta(j, l) * h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]).doit() == a * KroneckerDelta(i, k) * h[l] + + assert h[m].diff(h[m]) == 1 + assert h[m].diff(h[n]) == KroneckerDelta(m, n) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (m, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]).doit() == a + assert Sum(a*h[m], (n, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (n, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[m]).doit() == oo*a + + +def test_indexed_series(): + A = IndexedBase("A") + i = symbols("i", integer=True) + assert sin(A[i]).series(A[i]) == A[i] - A[i]**3/6 + A[i]**5/120 + Order(A[i]**6, A[i]) + + +def test_indexed_is_constant(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + assert not A[i].is_constant() + assert A[i].is_constant(j) + assert not A[1+2*i, k].is_constant() + assert not A[1+2*i, k].is_constant(i) + assert A[1+2*i, k].is_constant(j) + assert not A[1+2*i, k].is_constant(k) + + +def test_issue_12533(): + d = IndexedBase('d') + assert IndexedBase(range(5)) == Range(0, 5, 1) + assert d[0].subs(Symbol("d"), range(5)) == 0 + assert d[0].subs(d, range(5)) == 0 + assert d[1].subs(d, range(5)) == 1 + assert Indexed(Range(5), 2) == 2 + + +def test_issue_12780(): + n = symbols("n") + i = Idx("i", (0, n)) + raises(TypeError, lambda: i.subs(n, 1.5)) + + +def test_issue_18604(): + m = symbols("m") + assert Idx("i", m).name == 'i' + assert Idx("i", m).lower == 0 + assert Idx("i", m).upper == m - 1 + m = symbols("m", real=False) + raises(TypeError, lambda: Idx("i", m)) + +def test_Subs_with_Indexed(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + x, y, z = symbols("x,y,z") + f = Function("f") + + assert Subs(A[i], A[i], A[j]).diff(A[j]) == 1 + assert Subs(A[i], A[i], x).diff(A[i]) == 0 + assert Subs(A[i], A[i], x).diff(A[j]) == 0 + assert Subs(A[i], A[i], x).diff(x) == 1 + assert Subs(A[i], A[i], x).diff(y) == 0 + assert Subs(A[i], A[i], A[j]).diff(A[k]) == KroneckerDelta(j, k) + assert Subs(x, x, A[i]).diff(A[j]) == KroneckerDelta(i, j) + assert Subs(f(A[i]), A[i], x).diff(A[j]) == 0 + assert Subs(f(A[i]), A[i], A[k]).diff(A[j]) == Derivative(f(A[k]), A[k])*KroneckerDelta(j, k) + assert Subs(x, x, A[i]**2).diff(A[j]) == 2*KroneckerDelta(i, j)*A[i] + assert Subs(A[i], A[i], A[j]**2).diff(A[k]) == 2*KroneckerDelta(j, k)*A[j] + + assert Subs(A[i]*x, x, A[i]).diff(A[i]) == 2*A[i] + assert Subs(A[i]*x, x, A[i]).diff(A[j]) == 2*A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[i]) == A[j] + A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[j]) == A[i] + A[j]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[i]).diff(A[k]) == 2*A[i]*KroneckerDelta(i, k) + assert Subs(A[i]*x, x, A[j]).diff(A[k]) == KroneckerDelta(i, k)*A[j] + KroneckerDelta(j, k)*A[i] + + assert Subs(A[i]*x, A[i], x).diff(A[i]) == 0 + assert Subs(A[i]*x, A[i], x).diff(A[j]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[i]) == x + assert Subs(A[i]*x, A[j], x).diff(A[j]) == x*KroneckerDelta(i, j) + assert Subs(A[i]*x, A[i], x).diff(A[k]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[k]) == x*KroneckerDelta(i, k) + + +def test_complicated_derivative_with_Indexed(): + x, y = symbols("x,y", cls=IndexedBase) + sigma = symbols("sigma") + i, j, k = symbols("i,j,k") + m0,m1,m2,m3,m4,m5 = symbols("m0:6") + f = Function("f") + + expr = f((x[i] - y[i])**2/sigma) + _xi_1 = symbols("xi_1", cls=Dummy) + assert expr.diff(x[m0]).dummy_eq( + (x[i] - y[i])*KroneckerDelta(i, m0)*\ + 2*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + ) + assert expr.diff(x[m0]).diff(x[m1]).dummy_eq( + 2*KroneckerDelta(i, m0)*\ + KroneckerDelta(i, m1)*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + \ + 4*(x[i] - y[i])**2*KroneckerDelta(i, m0)*KroneckerDelta(i, m1)*\ + Subs( + Derivative(f(_xi_1), _xi_1, _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma**2 + ) + + +def test_IndexedBase_commutative(): + t = IndexedBase('t', commutative=False) + u = IndexedBase('u', commutative=False) + v = IndexedBase('v') + assert t[0]*v[0] == v[0]*t[0] + assert t[0]*u[0] != u[0]*t[0] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_printing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3cf7f0591a7012c93354ab7b8d7e010def38bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_printing.py @@ -0,0 +1,13 @@ +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead +from sympy import I + +def test_printing_TensMul(): + R3 = TensorIndexType('R3', dim=3) + p, q = tensor_indices("p q", R3) + K = TensorHead("K", [R3]) + + assert repr(2*K(p)) == "2*K(p)" + assert repr(-K(p)) == "-K(p)" + assert repr(-2*K(p)*K(q)) == "-2*K(p)*K(q)" + assert repr(-I*K(p)) == "-I*K(p)" + assert repr(I*K(p)) == "I*K(p)" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3113f5be9bcd32224f3525b5d831b6d7476c39e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/tensor/tests/test_tensor.py @@ -0,0 +1,2218 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand +from sympy.core.numbers import Integer +from sympy.matrices.dense import (Matrix, eye) +from sympy.tensor.indexed import Indexed +from sympy.combinatorics import Permutation +from sympy.core import S, Rational, Symbol, Basic, Add, Wild, Function +from sympy.core.containers import Tuple +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals import integrate +from sympy.tensor.array import Array +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, \ + get_symmetric_group_sgs, TensorIndex, tensor_mul, TensAdd, \ + riemann_cyclic_replace, riemann_cyclic, TensMul, tensor_heads, \ + TensorManager, TensExpr, TensorHead, canon_bp, \ + tensorhead, tensorsymmetry, TensorType, substitute_indices, \ + WildTensorIndex, WildTensorHead, _WildTensExpr +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.matrices import diag + +def _is_equal(arg1, arg2): + if isinstance(arg1, TensExpr): + return arg1.equals(arg2) + elif isinstance(arg2, TensExpr): + return arg2.equals(arg1) + return arg1 == arg2 + + +#################### Tests from tensor_can.py ####################### +def test_canonicalize_no_slot_sym(): + # A_d0 * B^d0; T_c = A^d0*B_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, d0, d1 = tensor_indices('a,b,d0,d1', Lorentz) + A, B = tensor_heads('A,B', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*B(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*B(-L_0)' + + # A^a * B^b; T_c = T + t = A(a)*B(b) + tc = t.canon_bp() + assert tc == t + # B^b * A^a + t1 = B(b)*A(a) + tc = t1.canon_bp() + assert str(tc) == 'A(a)*B(b)' + + # A symmetric + # A^{b}_{d0}*A^{d0, a}; T_c = A^{a d0}*A{b}_{d0} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, -d0)*A(d0, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, -L_0)' + + # A^{d1}_{d0}*B^d0*C_d1 + # T_c = A^{d0 d1}*B_d0*C_d1 + B, C = tensor_heads('B,C', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0)*C(-L_1)' + + # A without symmetry + # A^{d1}_{d0}*B^d0*C_d1 ord=[d0,-d0,d1,-d1]; g = [2,1,0,3,4,5] + # T_c = A^{d0 d1}*B_d1*C_d0; can = [0,2,3,1,4,5] + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1)*C(-L_0)' + + # A, B without symmetry + # A^{d1}_{d0}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d0 d1} + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0, -L_1)' + # A_{d0}^{d1}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d1 d0} + t = A(-d0, d1)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1, -L_0)' + + # A, B, C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c=A^{d0 d1}*B_{a d1}*C_{d0 b} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_1)*C(-L_0, -b)' + + # A symmetric, B and C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c = A^{d0 d1}*B_{a d0}*C_{d1 b} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-L_1, -b)' + + # A and C symmetric, B without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # T_c = A^{d0 d1}*B_{a d0}*C_{b d1} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-b, -L_1)' + +def test_canonicalize_no_dummies(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, c, d = tensor_indices('a, b, c, d', Lorentz) + + # A commuting + # A^c A^b A^a + # T_c = A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == 'A(a)*A(b)*A(c)' + + # A anticommuting + # A^c A^b A^a + # T_c = -A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == '-A(a)*A(b)*A(c)' + + # A commuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + + # A anticommuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = -A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2), 1) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == '-A(a, c)*A(b, d)' + + # A^{c,a}*A^{b,d} + # T_c = A^{a c}*A^{b d} + t = A(c, a)*A(b, d) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + +def test_tensorhead_construction_without_symmetry(): + L = TensorIndexType('Lorentz') + A1 = TensorHead('A', [L, L]) + A2 = TensorHead('A', [L, L], TensorSymmetry.no_symmetry(2)) + assert A1 == A2 + A3 = TensorHead('A', [L, L], TensorSymmetry.fully_symmetric(2)) # Symmetric + assert A1 != A3 + +def test_no_metric_symmetry(): + # no metric symmetry; A no symmetry + # A^d1_d0 * A^d0_d1 + # T_c = A^d0_d1 * A^d1_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L', metric_symmetry=0) + d0, d1, d2, d3 = tensor_indices('d:4', Lorentz) + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*A(d0, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)' + + # A^d1_d2 * A^d0_d3 * A^d2_d1 * A^d3_d0 + # T_c = A^d0_d1 * A^d1_d0 * A^d2_d3 * A^d3_d2 + t = A(d1, -d2)*A(d0, -d3)*A(d2, -d1)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)*A(L_2, -L_3)*A(L_3, -L_2)' + + # A^d0_d2 * A^d1_d3 * A^d3_d0 * A^d2_d1 + # T_c = A^d0_d1 * A^d1_d2 * A^d2_d3 * A^d3_d0 + t = A(d0, -d1)*A(d1, -d2)*A(d2, -d3)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_2)*A(L_2, -L_3)*A(L_3, -L_0)' + +def test_canonicalize1(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Lorentz) + + # A_d0*A^d0; ord = [d0,-d0] + # T_c = A^d0*A_d0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)' + + # A commuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c = A^d0*A_d0*A^d1*A_d1*A^d2*A_d2 + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)*A(L_1)*A(-L_1)*A(L_2)*A(-L_2)' + + # A anticommuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c 0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert tc == 0 + + # A commuting symmetric + # A^{d0 b}*A^a_d1*A^d1_d0 + # T_c = A^{a d0}*A^{b d1}*A_{d0 d1} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(a, -d1)*A(d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, L_1)*A(-L_0, -L_1)' + + # A, B commuting symmetric + # A^{d0 b}*A^d1_d0*B^a_d1 + # T_c = A^{b d0}*A_d0^d1*B^a_d1 + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(d1, -d0)*B(a, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(b, L_0)*A(-L_0, L_1)*B(a, -L_1)' + + # A commuting symmetric + # A^{d1 d0 b}*A^{a}_{d1 d0}; ord=[a,b, d0,-d0,d1,-d1] + # T_c = A^{a d0 d1}*A^{b}_{d0 d1} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + t = A(d1, d0, b)*A(a, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0, L_1)*A(b, -L_0, -L_1)' + + # A^{d3 d0 d2}*A^a0_{d1 d2}*A^d1_d3^a1*A^{a2 a3}_d0 + # T_c = A^{a0 d0 d1}*A^a1_d0^d2*A^{a2 a3 d3}*A_{d1 d2 d3} + t = A(d3, d0, d2)*A(a0, -d1, -d2)*A(d1, -d3, a1)*A(a2, a3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a0, L_0, L_1)*A(a1, -L_0, L_2)*A(a2, a3, L_3)*A(-L_1, -L_2, -L_3)' + + # A commuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # in this esxample and in the next three, + # renaming dummy indices and using symmetry of A, + # T = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + # can = 0 + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert tc == 0 + + # A anticommuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1, L_2)*A(-L_0, -L_1, L_3)*B(-L_2, -L_3)' + + # A anticommuting symmetric, B antisymmetric commuting, antisymmetric metric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = -A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + Spinor = TensorIndexType('Spinor', dummy_name='S', metric_symmetry=-1) + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Spinor) + A = TensorHead('A', [Spinor]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Spinor]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == '-A(S_0, S_1, S_2)*A(-S_0, -S_1, S_3)*B(-S_2, -S_3)' + + # A anticommuting symmetric, B antisymmetric anticommuting, + # no metric symmetry + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + Mat = TensorIndexType('Mat', metric_symmetry=0, dummy_name='M') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Mat) + A = TensorHead('A', [Mat]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Mat]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(M_0, M_1, M_2)*A(-M_0, -M_1, -M_3)*B(-M_2, M_3)' + + # Gamma anticommuting + # Gamma_{mu nu} * gamma^rho * Gamma^{nu mu alpha} + # T_c = -Gamma^{mu nu} * gamma^rho * Gamma_{alpha mu nu} + alpha, beta, gamma, mu, nu, rho = \ + tensor_indices('alpha,beta,gamma,mu,nu,rho', Lorentz) + Gamma = TensorHead('Gamma', [Lorentz], + TensorSymmetry.fully_symmetric(1), 2) + Gamma2 = TensorHead('Gamma', [Lorentz]*2, + TensorSymmetry.fully_symmetric(-2), 2) + Gamma3 = TensorHead('Gamma', [Lorentz]*3, + TensorSymmetry.fully_symmetric(-3), 2) + t = Gamma2(-mu, -nu)*Gamma(rho)*Gamma3(nu, mu, alpha) + tc = t.canon_bp() + assert str(tc) == '-Gamma(L_0, L_1)*Gamma(rho)*Gamma(alpha, -L_0, -L_1)' + + # Gamma_{mu nu} * Gamma^{gamma beta} * gamma_rho * Gamma^{nu mu alpha} + # T_c = Gamma^{mu nu} * Gamma^{beta gamma} * gamma_rho * Gamma^alpha_{mu nu} + t = Gamma2(mu, nu)*Gamma2(beta, gamma)*Gamma(-rho)*Gamma3(alpha, -mu, -nu) + tc = t.canon_bp() + assert str(tc) == 'Gamma(L_0, L_1)*Gamma(beta, gamma)*Gamma(-rho)*Gamma(alpha, -L_0, -L_1)' + + # f^a_{b,c} antisymmetric in b,c; A_mu^a no symmetry + # f^c_{d a} * f_{c e b} * A_mu^d * A_nu^a * A^{nu e} * A^{mu b} + # g = [8,11,5, 9,13,7, 1,10, 3,4, 2,12, 0,6, 14,15] + # T_c = -f^{a b c} * f_a^{d e} * A^mu_b * A_{mu d} * A^nu_c * A_{nu e} + Flavor = TensorIndexType('Flavor', dummy_name='F') + a, b, c, d, e, ff = tensor_indices('a,b,c,d,e,f', Flavor) + mu, nu = tensor_indices('mu,nu', Lorentz) + f = TensorHead('f', [Flavor]*3, TensorSymmetry.direct_product(1, -2)) + A = TensorHead('A', [Lorentz, Flavor], TensorSymmetry.no_symmetry(2)) + t = f(c, -d, -a)*f(-c, -e, -b)*A(-mu, d)*A(-nu, a)*A(nu, e)*A(mu, b) + tc = t.canon_bp() + assert str(tc) == '-f(F_0, F_1, F_2)*f(-F_0, F_3, F_4)*A(L_0, -F_1)*A(-L_0, -F_3)*A(L_1, -F_2)*A(-L_1, -F_4)' + + +def test_bug_correction_tensor_indices(): + # to make sure that tensor_indices does not return a list if creating + # only one index: + A = TensorIndexType("A") + i = tensor_indices('i', A) + assert not isinstance(i, (tuple, list)) + assert isinstance(i, TensorIndex) + + +def test_riemann_invariants(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11 = \ + tensor_indices('d0:12', Lorentz) + # R^{d0 d1}_{d1 d0}; ord = [d0,-d0,d1,-d1] + # T_c = -R^{d0 d1}_{d0 d1} + R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + t = R(d0, d1, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == '-R(L_0, L_1, -L_0, -L_1)' + + # R_d11^d1_d0^d5 * R^{d6 d4 d0}_d5 * R_{d7 d2 d8 d9} * + # R_{d10 d3 d6 d4} * R^{d2 d7 d11}_d1 * R^{d8 d9 d3 d10} + # can = [0,2,4,6, 1,3,8,10, 5,7,12,14, 9,11,16,18, 13,15,20,22, + # 17,19,21>> from sympy.tensor.tensor import TensorIndexType, TensorHead + >>> from sympy.tensor.toperators import PartialDerivative + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> A = TensorHead("A", [L]) + >>> B = TensorHead("B", [L]) + >>> i, j, k = symbols("i j k") + + >>> expr = PartialDerivative(A(i), A(j)) + >>> expr + PartialDerivative(A(i), A(j)) + + The ``PartialDerivative`` object behaves like a tensorial expression: + + >>> expr.get_indices() + [i, -j] + + Notice that the deriving variables have opposite valence than the + printed one: ``A(j)`` is printed as covariant, but the index of the + derivative is actually contravariant, i.e. ``-j``. + + Indices can be contracted: + + >>> expr = PartialDerivative(A(i), A(i)) + >>> expr + PartialDerivative(A(L_0), A(L_0)) + >>> expr.get_indices() + [L_0, -L_0] + + The method ``.get_indices()`` always returns all indices (even the + contracted ones). If only uncontracted indices are needed, call + ``.get_free_indices()``: + + >>> expr.get_free_indices() + [] + + Nested partial derivatives are flattened: + + >>> expr = PartialDerivative(PartialDerivative(A(i), A(j)), A(k)) + >>> expr + PartialDerivative(A(i), A(j), A(k)) + >>> expr.get_indices() + [i, -j, -k] + + Replace a derivative with array values: + + >>> from sympy.abc import x, y + >>> from sympy import sin, log + >>> compA = [sin(x), log(x)*y**3] + >>> compB = [x, y] + >>> expr = PartialDerivative(A(i), B(j)) + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}) + [[cos(x), 0], [y**3/x, 3*y**2*log(x)]] + + The returned array is indexed by `(i, -j)`. + + Be careful that other SymPy modules put the indices of the deriving + variables before the indices of the derivand in the derivative result. + For example: + + >>> expr.get_free_indices() + [i, -j] + + >>> from sympy import Matrix, Array + >>> Matrix(compA).diff(Matrix(compB)).reshape(2, 2) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + >>> Array(compA).diff(Array(compB)) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + + These are the transpose of the result of ``PartialDerivative``, + as the matrix and the array modules put the index `-j` before `i` in the + derivative result. An array read with index order `(-j, i)` is indeed the + transpose of the same array read with index order `(i, -j)`. By specifying + the index order to ``.replace_with_arrays`` one can get a compatible + expression: + + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}, [-j, i]) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + """ + + def __new__(cls, expr, *variables): + + # Flatten: + if isinstance(expr, PartialDerivative): + variables = expr.variables + variables + expr = expr.expr + + args, indices, free, dum = cls._contract_indices_for_derivative( + S(expr), variables) + + obj = TensExpr.__new__(cls, *args) + + obj._indices = indices + obj._free = free + obj._dum = dum + return obj + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + @classmethod + def _contract_indices_for_derivative(cls, expr, variables): + variables_opposite_valence = [] + + for i in variables: + if isinstance(i, Tensor): + i_free_indices = i.get_free_indices() + variables_opposite_valence.append( + i.xreplace({k: -k for k in i_free_indices})) + elif isinstance(i, Symbol): + variables_opposite_valence.append(i) + + args, indices, free, dum = TensMul._tensMul_contract_indices( + [expr] + variables_opposite_valence, replace_indices=True) + + for i in range(1, len(args)): + args_i = args[i] + if isinstance(args_i, Tensor): + i_indices = args[i].get_free_indices() + args[i] = args[i].xreplace({k: -k for k in i_indices}) + + return args, indices, free, dum + + def doit(self, **hints): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + return obj + + def _expand_partial_derivative(self): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + result = obj + + if not args[0].free_symbols: + return S.Zero + elif isinstance(obj.expr, TensAdd): + # take care of sums of multi PDs + result = obj.expr.func(*[ + self.func(a, *obj.variables)._expand_partial_derivative() + for a in result.expr.args]) + elif isinstance(obj.expr, TensMul): + # take care of products of multi PDs + if len(obj.variables) == 1: + # derivative with respect to single variable + terms = [] + mulargs = list(obj.expr.args) + for ind in range(len(mulargs)): + if not isinstance(sympify(mulargs[ind]), Number): + # a number coefficient is not considered for + # expansion of PartialDerivative + d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative() + terms.append(TensMul(*(mulargs[:ind] + + [d] + + mulargs[(ind + 1):]))) + result = TensAdd.fromiter(terms) + else: + # derivative with respect to multiple variables + # decompose: + # partial(expr, (u, v)) + # = partial(partial(expr, u).doit(), v).doit() + result = obj.expr # init with expr + for v in obj.variables: + result = self.func(result, v)._expand_partial_derivative() + # then throw PD on it + + return result + + def _perform_derivative(self): + result = self.expr + for v in self.variables: + if isinstance(result, TensExpr): + result = result._eval_partial_derivative(v) + else: + if v._diff_wrt: + result = result._eval_derivative(v) + else: + result = S.Zero + return result + + def get_indices(self): + return self._indices + + def get_free_indices(self): + free = sorted(self._free, key=lambda x: x[1]) + return [i[0] for i in free] + + def _replace_indices(self, repl): + expr = self.expr.xreplace(repl) + mirrored = {-k: -v for k, v in repl.items()} + variables = [i.xreplace(mirrored) for i in self.variables] + return self.func(expr, *variables) + + @property + def expr(self): + return self.args[0] + + @property + def variables(self): + return self.args[1:] + + def _extract_data(self, replacement_dict): + from .array import derive_by_array, tensorcontraction + indices, array = self.expr._extract_data(replacement_dict) + for variable in self.variables: + var_indices, var_array = variable._extract_data(replacement_dict) + var_indices = [-i for i in var_indices] + coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array]) + dim_before = len(array.shape) + array = derive_by_array(array, var_array) + dim_after = len(array.shape) + dim_increase = dim_after - dim_before + array = permutedims(array, [i + dim_increase for i in range(dim_before)] + list(range(dim_increase))) + array = array.as_mutable() + varindex = var_indices[0] + # Remove coefficients of base vector: + coeff_index = [0] + [slice(None) for i in range(len(indices))] + for i, coeff in enumerate(coeff_array): + coeff_index[0] = i + array[tuple(coeff_index)] /= coeff + if -varindex in indices: + pos = indices.index(-varindex) + array = tensorcontraction(array, (0, pos+1)) + indices.pop(pos) + else: + indices.append(varindex) + return indices, array diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_dyadic.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..2e396fcf2a81af897b59c0065f6b15f5c6933222 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_dyadic.py @@ -0,0 +1,134 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.vector import (CoordSys3D, Vector, Dyadic, + DyadicAdd, DyadicMul, DyadicZero, + BaseDyadic, express) + + +A = CoordSys3D('A') + + +def test_dyadic(): + a, b = symbols('a, b') + assert Dyadic.zero != 0 + assert isinstance(Dyadic.zero, DyadicZero) + assert BaseDyadic(A.i, A.j) != BaseDyadic(A.j, A.i) + assert (BaseDyadic(Vector.zero, A.i) == + BaseDyadic(A.i, Vector.zero) == Dyadic.zero) + + d1 = A.i | A.i + d2 = A.j | A.j + d3 = A.i | A.j + + assert isinstance(d1, BaseDyadic) + d_mul = a*d1 + assert isinstance(d_mul, DyadicMul) + assert d_mul.base_dyadic == d1 + assert d_mul.measure_number == a + assert isinstance(a*d1 + b*d3, DyadicAdd) + assert d1 == A.i.outer(A.i) + assert d3 == A.i.outer(A.j) + v1 = a*A.i - A.k + v2 = A.i + b*A.j + assert v1 | v2 == v1.outer(v2) == a * (A.i|A.i) + (a*b) * (A.i|A.j) +\ + - (A.k|A.i) - b * (A.k|A.j) + assert d1 * 0 == Dyadic.zero + assert d1 != Dyadic.zero + assert d1 * 2 == 2 * (A.i | A.i) + assert d1 / 2. == 0.5 * d1 + + assert d1.dot(0 * d1) == Vector.zero + assert d1 & d2 == Dyadic.zero + assert d1.dot(A.i) == A.i == d1 & A.i + + assert d1.cross(Vector.zero) == Dyadic.zero + assert d1.cross(A.i) == Dyadic.zero + assert d1 ^ A.j == d1.cross(A.j) + assert d1.cross(A.k) == - A.i | A.j + assert d2.cross(A.i) == - A.j | A.k == d2 ^ A.i + + assert A.i ^ d1 == Dyadic.zero + assert A.j.cross(d1) == - A.k | A.i == A.j ^ d1 + assert Vector.zero.cross(d1) == Dyadic.zero + assert A.k ^ d1 == A.j | A.i + assert A.i.dot(d1) == A.i & d1 == A.i + assert A.j.dot(d1) == Vector.zero + assert Vector.zero.dot(d1) == Vector.zero + assert A.j & d2 == A.j + + assert d1.dot(d3) == d1 & d3 == A.i | A.j == d3 + assert d3 & d1 == Dyadic.zero + + q = symbols('q') + B = A.orient_new_axis('B', q, A.k) + assert express(d1, B) == express(d1, B, B) + + expr1 = ((cos(q)**2) * (B.i | B.i) + (-sin(q) * cos(q)) * + (B.i | B.j) + (-sin(q) * cos(q)) * (B.j | B.i) + (sin(q)**2) * + (B.j | B.j)) + assert (express(d1, B) - expr1).simplify() == Dyadic.zero + + expr2 = (cos(q)) * (B.i | A.i) + (-sin(q)) * (B.j | A.i) + assert (express(d1, B, A) - expr2).simplify() == Dyadic.zero + + expr3 = (cos(q)) * (A.i | B.i) + (-sin(q)) * (A.i | B.j) + assert (express(d1, A, B) - expr3).simplify() == Dyadic.zero + + assert d1.to_matrix(A) == Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert d1.to_matrix(A, B) == Matrix([[cos(q), -sin(q), 0], + [0, 0, 0], + [0, 0, 0]]) + assert d3.to_matrix(A) == Matrix([[0, 1, 0], [0, 0, 0], [0, 0, 0]]) + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + v1 = a * A.i + b * A.j + c * A.k + v2 = d * A.i + e * A.j + f * A.k + d4 = v1.outer(v2) + assert d4.to_matrix(A) == Matrix([[a * d, a * e, a * f], + [b * d, b * e, b * f], + [c * d, c * e, c * f]]) + d5 = v1.outer(v1) + C = A.orient_new_axis('C', q, A.i) + for expected, actual in zip(C.rotation_matrix(A) * d5.to_matrix(A) * \ + C.rotation_matrix(A).T, d5.to_matrix(C)): + assert (expected - actual).simplify() == 0 + + +def test_dyadic_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = CoordSys3D('N') + + dy = N.i | N.i + test1 = (1 / x + 1 / y) * dy + assert (N.i & test1 & N.i) != (x + y) / (x * y) + test1 = test1.simplify() + assert test1.simplify() == simplify(test1) + assert (N.i & test1 & N.i) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * dy + test2 = test2.simplify() + assert (N.i & test2 & N.i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * dy + test3 = test3.simplify() + assert (N.i & test3 & N.i) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * dy + test4 = test4.simplify() + assert (N.i & test4 & N.i) == -2 * y + + +def test_dyadic_srepr(): + from sympy.printing.repr import srepr + N = CoordSys3D('N') + + dy = N.i | N.j + res = "BaseDyadic(CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix([["\ + "Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), "\ + "VectorZero())).i, CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix("\ + "[[Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), VectorZero())).j)" + assert srepr(dy) == res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_field_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_field_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..035c2ce0234b81069c5ad8dcb1c74f4de0164a8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_field_functions.py @@ -0,0 +1,321 @@ +from sympy.core.function import Derivative +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.simplify import simplify +from sympy.core.symbol import symbols +from sympy.core import S +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.vector import Dot +from sympy.vector.operators import curl, divergence, gradient, Gradient, Divergence, Cross +from sympy.vector.deloperator import Del +from sympy.vector.functions import (is_conservative, is_solenoidal, + scalar_potential, directional_derivative, + laplacian, scalar_potential_difference) +from sympy.testing.pytest import raises + +C = CoordSys3D('C') +i, j, k = C.base_vectors() +x, y, z = C.base_scalars() +delop = Del() +a, b, c, q = symbols('a b c q') + + +def test_del_operator(): + # Tests for curl + + assert delop ^ Vector.zero == Vector.zero + assert ((delop ^ Vector.zero).doit() == Vector.zero == + curl(Vector.zero)) + assert delop.cross(Vector.zero) == delop ^ Vector.zero + assert (delop ^ i).doit() == Vector.zero + assert delop.cross(2*y**2*j, doit=True) == Vector.zero + assert delop.cross(2*y**2*j) == delop ^ 2*y**2*j + v = x*y*z * (i + j + k) + assert ((delop ^ v).doit() == + (-x*y + x*z)*i + (x*y - y*z)*j + (-x*z + y*z)*k == + curl(v)) + assert delop ^ v == delop.cross(v) + assert (delop.cross(2*x**2*j) == + (Derivative(0, C.y) - Derivative(2*C.x**2, C.z))*C.i + + (-Derivative(0, C.x) + Derivative(0, C.z))*C.j + + (-Derivative(0, C.y) + Derivative(2*C.x**2, C.x))*C.k) + assert (delop.cross(2*x**2*j, doit=True) == 4*x*k == + curl(2*x**2*j)) + + #Tests for divergence + assert delop & Vector.zero is S.Zero == divergence(Vector.zero) + assert (delop & Vector.zero).doit() is S.Zero + assert delop.dot(Vector.zero) == delop & Vector.zero + assert (delop & i).doit() is S.Zero + assert (delop & x**2*i).doit() == 2*x == divergence(x**2*i) + assert (delop.dot(v, doit=True) == x*y + y*z + z*x == + divergence(v)) + assert delop & v == delop.dot(v) + assert delop.dot(1/(x*y*z) * (i + j + k), doit=True) == \ + - 1 / (x*y*z**2) - 1 / (x*y**2*z) - 1 / (x**2*y*z) + v = x*i + y*j + z*k + assert (delop & v == Derivative(C.x, C.x) + + Derivative(C.y, C.y) + Derivative(C.z, C.z)) + assert delop.dot(v, doit=True) == 3 == divergence(v) + assert delop & v == delop.dot(v) + assert simplify((delop & v).doit()) == 3 + + #Tests for gradient + assert (delop.gradient(0, doit=True) == Vector.zero == + gradient(0)) + assert delop.gradient(0) == delop(0) + assert (delop(S.Zero)).doit() == Vector.zero + assert (delop(x) == (Derivative(C.x, C.x))*C.i + + (Derivative(C.x, C.y))*C.j + (Derivative(C.x, C.z))*C.k) + assert (delop(x)).doit() == i == gradient(x) + assert (delop(x*y*z) == + (Derivative(C.x*C.y*C.z, C.x))*C.i + + (Derivative(C.x*C.y*C.z, C.y))*C.j + + (Derivative(C.x*C.y*C.z, C.z))*C.k) + assert (delop.gradient(x*y*z, doit=True) == + y*z*i + z*x*j + x*y*k == + gradient(x*y*z)) + assert delop(x*y*z) == delop.gradient(x*y*z) + assert (delop(2*x**2)).doit() == 4*x*i + assert ((delop(a*sin(y) / x)).doit() == + -a*sin(y)/x**2 * i + a*cos(y)/x * j) + + #Tests for directional derivative + assert (Vector.zero & delop)(a) is S.Zero + assert ((Vector.zero & delop)(a)).doit() is S.Zero + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + assert ((v & delop)(S.Zero)).doit() is S.Zero + assert ((i & delop)(x)).doit() == 1 + assert ((j & delop)(y)).doit() == 1 + assert ((k & delop)(z)).doit() == 1 + assert ((i & delop)(x*y*z)).doit() == y*z + assert ((v & delop)(x)).doit() == x + assert ((v & delop)(x*y*z)).doit() == 3*x*y*z + assert (v & delop)(x + y + z) == C.x + C.y + C.z + assert ((v & delop)(x + y + z)).doit() == x + y + z + assert ((v & delop)(v)).doit() == v + assert ((i & delop)(v)).doit() == i + assert ((j & delop)(v)).doit() == j + assert ((k & delop)(v)).doit() == k + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + + # Tests for laplacian on scalar fields + assert laplacian(x*y*z) is S.Zero + assert laplacian(x**2) == S(2) + assert laplacian(x**2*y**2*z**2) == \ + 2*y**2*z**2 + 2*x**2*z**2 + 2*x**2*y**2 + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + assert laplacian(A.r + A.theta + A.phi) == 2/A.r + cos(A.theta)/(A.r**2*sin(A.theta)) + assert laplacian(B.r + B.theta + B.z) == 1/B.r + + # Tests for laplacian on vector fields + assert laplacian(x*y*z*(i + j + k)) == Vector.zero + assert laplacian(x*y**2*z*(i + j + k)) == \ + 2*x*z*i + 2*x*z*j + 2*x*z*k + + +def test_product_rules(): + """ + Tests the six product rules defined with respect to the Del + operator + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Del + + """ + + #Define the scalar and vector functions + f = 2*x*y*z + g = x*y + y*z + z*x + u = x**2*i + 4*j - y**2*z*k + v = 4*i + x*y*z*k + + # First product rule + lhs = delop(f * g, doit=True) + rhs = (f * delop(g) + g * delop(f)).doit() + assert simplify(lhs) == simplify(rhs) + + # Second product rule + lhs = delop(u & v).doit() + rhs = ((u ^ (delop ^ v)) + (v ^ (delop ^ u)) + \ + ((u & delop)(v)) + ((v & delop)(u))).doit() + assert simplify(lhs) == simplify(rhs) + + # Third product rule + lhs = (delop & (f*v)).doit() + rhs = ((f * (delop & v)) + (v & (delop(f)))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fourth product rule + lhs = (delop & (u ^ v)).doit() + rhs = ((v & (delop ^ u)) - (u & (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fifth product rule + lhs = (delop ^ (f * v)).doit() + rhs = (((delop(f)) ^ v) + (f * (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Sixth product rule + lhs = (delop ^ (u ^ v)).doit() + rhs = (u * (delop & v) - v * (delop & u) + + (v & delop)(u) - (u & delop)(v)).doit() + assert simplify(lhs) == simplify(rhs) + + +P = C.orient_new_axis('P', q, C.k) # type: ignore +scalar_field = 2*x**2*y*z +grad_field = gradient(scalar_field) +vector_field = y**2*i + 3*x*j + 5*y*z*k +curl_field = curl(vector_field) + + +def test_conservative(): + assert is_conservative(Vector.zero) is True + assert is_conservative(i) is True + assert is_conservative(2 * i + 3 * j + 4 * k) is True + assert (is_conservative(y*z*i + x*z*j + x*y*k) is + True) + assert is_conservative(x * j) is False + assert is_conservative(grad_field) is True + assert is_conservative(curl_field) is False + assert (is_conservative(4*x*y*z*i + 2*x**2*z*j) is + False) + assert is_conservative(z*P.i + P.x*k) is True + + +def test_solenoidal(): + assert is_solenoidal(Vector.zero) is True + assert is_solenoidal(i) is True + assert is_solenoidal(2 * i + 3 * j + 4 * k) is True + assert (is_solenoidal(y*z*i + x*z*j + x*y*k) is + True) + assert is_solenoidal(y * j) is False + assert is_solenoidal(grad_field) is False + assert is_solenoidal(curl_field) is True + assert is_solenoidal((-2*y + 3)*k) is True + assert is_solenoidal(cos(q)*i + sin(q)*j + cos(q)*P.k) is True + assert is_solenoidal(z*P.i + P.x*k) is True + + +def test_directional_derivative(): + assert directional_derivative(C.x*C.y*C.z, 3*C.i + 4*C.j + C.k) == C.x*C.y + 4*C.x*C.z + 3*C.y*C.z + assert directional_derivative(5*C.x**2*C.z, 3*C.i + 4*C.j + C.k) == 5*C.x**2 + 30*C.x*C.z + assert directional_derivative(5*C.x**2*C.z, 4*C.j) is S.Zero + + D = CoordSys3D("D", "spherical", variable_names=["r", "theta", "phi"], + vector_names=["e_r", "e_theta", "e_phi"]) + r, theta, phi = D.base_scalars() + e_r, e_theta, e_phi = D.base_vectors() + assert directional_derivative(r**2*e_r, e_r) == 2*r*e_r + assert directional_derivative(5*r**2*phi, 3*e_r + 4*e_theta + e_phi) == 5*r**2 + 30*r*phi + + +def test_scalar_potential(): + assert scalar_potential(Vector.zero, C) == 0 + assert scalar_potential(i, C) == x + assert scalar_potential(j, C) == y + assert scalar_potential(k, C) == z + assert scalar_potential(y*z*i + x*z*j + x*y*k, C) == x*y*z + assert scalar_potential(grad_field, C) == scalar_field + assert scalar_potential(z*P.i + P.x*k, C) == x*z*cos(q) + y*z*sin(q) + assert scalar_potential(z*P.i + P.x*k, P) == P.x*P.z + raises(ValueError, lambda: scalar_potential(x*j, C)) + + +def test_scalar_potential_difference(): + point1 = C.origin.locate_new('P1', 1*i + 2*j + 3*k) + point2 = C.origin.locate_new('P2', 4*i + 5*j + 6*k) + genericpointC = C.origin.locate_new('RP', x*i + y*j + z*k) + genericpointP = P.origin.locate_new('PP', P.x*P.i + P.y*P.j + P.z*P.k) + assert scalar_potential_difference(S.Zero, C, point1, point2) == 0 + assert (scalar_potential_difference(scalar_field, C, C.origin, + genericpointC) == + scalar_field) + assert (scalar_potential_difference(grad_field, C, C.origin, + genericpointC) == + scalar_field) + assert scalar_potential_difference(grad_field, C, point1, point2) == 948 + assert (scalar_potential_difference(y*z*i + x*z*j + + x*y*k, C, point1, + genericpointC) == + x*y*z - 6) + potential_diff_P = (2*P.z*(P.x*sin(q) + P.y*cos(q))* + (P.x*cos(q) - P.y*sin(q))**2) + assert (scalar_potential_difference(grad_field, P, P.origin, + genericpointP).simplify() == + potential_diff_P.simplify()) + + +def test_differential_operators_curvilinear_system(): + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + # Test for spherical coordinate system and gradient + assert gradient(3*A.r + 4*A.theta) == 3*A.i + 4/A.r*A.j + assert gradient(3*A.r*A.phi + 4*A.theta) == 3*A.phi*A.i + 4/A.r*A.j + (3/sin(A.theta))*A.k + assert gradient(0*A.r + 0*A.theta+0*A.phi) == Vector.zero + assert gradient(A.r*A.theta*A.phi) == A.theta*A.phi*A.i + A.phi*A.j + (A.theta/sin(A.theta))*A.k + # Test for spherical coordinate system and divergence + assert divergence(A.r * A.i + A.theta * A.j + A.phi * A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 3 + 1/(sin(A.theta)*A.r) + assert divergence(3*A.r*A.phi*A.i + A.theta*A.j + A.r*A.theta*A.phi*A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 9*A.phi + A.theta/sin(A.theta) + assert divergence(Vector.zero) == 0 + assert divergence(0*A.i + 0*A.j + 0*A.k) == 0 + # Test for spherical coordinate system and curl + assert curl(A.r*A.i + A.theta*A.j + A.phi*A.k) == \ + (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + A.theta/A.r*A.k + assert curl(A.r*A.j + A.phi*A.k) == (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + 2*A.k + + # Test for cylindrical coordinate system and gradient + assert gradient(0*B.r + 0*B.theta+0*B.z) == Vector.zero + assert gradient(B.r*B.theta*B.z) == B.theta*B.z*B.i + B.z*B.j + B.r*B.theta*B.k + assert gradient(3*B.r) == 3*B.i + assert gradient(2*B.theta) == 2/B.r * B.j + assert gradient(4*B.z) == 4*B.k + # Test for cylindrical coordinate system and divergence + assert divergence(B.r*B.i + B.theta*B.j + B.z*B.k) == 3 + 1/B.r + assert divergence(B.r*B.j + B.z*B.k) == 1 + # Test for cylindrical coordinate system and curl + assert curl(B.r*B.j + B.z*B.k) == 2*B.k + assert curl(3*B.i + 2/B.r*B.j + 4*B.k) == Vector.zero + +def test_mixed_coordinates(): + # gradient + a = CoordSys3D('a') + b = CoordSys3D('b') + c = CoordSys3D('c') + assert gradient(a.x*b.y) == b.y*a.i + a.x*b.j + assert gradient(3*cos(q)*a.x*b.x+a.y*(a.x+(cos(q)+b.x))) ==\ + (a.y + 3*b.x*cos(q))*a.i + (a.x + b.x + cos(q))*a.j + (3*a.x*cos(q) + a.y)*b.i + # Some tests need further work: + # assert gradient(a.x*(cos(a.x+b.x))) == (cos(a.x + b.x))*a.i + a.x*Gradient(cos(a.x + b.x)) + # assert gradient(cos(a.x + b.x)*cos(a.x + b.z)) == Gradient(cos(a.x + b.x)*cos(a.x + b.z)) + assert gradient(a.x**b.y) == Gradient(a.x**b.y) + # assert gradient(cos(a.x+b.y)*a.z) == None + assert gradient(cos(a.x*b.y)) == Gradient(cos(a.x*b.y)) + assert gradient(3*cos(q)*a.x*b.x*a.z*a.y+ b.y*b.z + cos(a.x+a.y)*b.z) == \ + (3*a.y*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.i + \ + (3*a.x*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.j + (3*a.x*a.y*b.x*cos(q))*a.k + \ + (3*a.x*a.y*a.z*cos(q))*b.i + b.z*b.j + (b.y + cos(a.x + a.y))*b.k + # divergence + assert divergence(a.i*a.x+a.j*a.y+a.z*a.k + b.i*b.x+b.j*b.y+b.z*b.k + c.i*c.x+c.j*c.y+c.z*c.k) == S(9) + # assert divergence(3*a.i*a.x*cos(a.x+b.z) + a.j*b.x*c.z) == None + assert divergence(3*a.i*a.x*a.z + b.j*b.x*c.z + 3*a.j*a.z*a.y) == \ + 6*a.z + b.x*Dot(b.j, c.k) + assert divergence(3*cos(q)*a.x*b.x*b.i*c.x) == \ + 3*a.x*b.x*cos(q)*Dot(b.i, c.i) + 3*a.x*c.x*cos(q) + 3*b.x*c.x*cos(q)*Dot(b.i, a.i) + assert divergence(a.x*b.x*c.x*Cross(a.x*a.i, a.y*b.j)) ==\ + a.x*b.x*c.x*Divergence(Cross(a.x*a.i, a.y*b.j)) + \ + b.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), a.i) + \ + a.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), b.i) + \ + a.x*b.x*Dot(Cross(a.x*a.i, a.y*b.j), c.i) + assert divergence(a.x*b.x*c.x*(a.x*a.i + b.x*b.i)) == \ + 4*a.x*b.x*c.x +\ + a.x**2*c.x*Dot(a.i, b.i) +\ + a.x**2*b.x*Dot(a.i, c.i) +\ + b.x**2*c.x*Dot(b.i, a.i) +\ + a.x*b.x**2*Dot(b.i, c.i) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdf9821b6c853755ce12d0cbdfa599bd4f312e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_functions.py @@ -0,0 +1,184 @@ +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import express, matrix_to_vector, orthogonalize +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.testing.pytest import raises + +N = CoordSys3D('N') +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +A = N.orient_new_axis('A', q1, N.k) # type: ignore +B = A.orient_new_axis('B', q2, A.i) +C = B.orient_new_axis('C', q3, B.j) + + +def test_express(): + assert express(Vector.zero, N) == Vector.zero + assert express(S.Zero, N) is S.Zero + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + assert express(A.i, N) == cos(q1)*N.i + sin(q1)*N.j + assert express(A.j, N) == -sin(q1)*N.i + cos(q1)*N.j + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == cos(q2)*B.j - sin(q2)*B.k + assert express(A.k, B) == sin(q2)*B.j + cos(q2)*B.k + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + # Check to make sure UnitVectors get converted properly + assert express(N.i, N) == N.i + assert express(N.j, N) == N.j + assert express(N.k, N) == N.k + assert express(N.i, A) == (cos(q1)*A.i - sin(q1)*A.j) + assert express(N.j, A) == (sin(q1)*A.i + cos(q1)*A.j) + assert express(N.k, A) == A.k + assert express(N.i, B) == (cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k) + assert express(N.j, B) == (sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k) + assert express(N.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(N.i, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.i - + sin(q1)*cos(q2)*C.j + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.k) + assert express(N.j, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.i + + cos(q1)*cos(q2)*C.j + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.k) + assert express(N.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(A.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(A.j, N) == (-sin(q1)*N.i + cos(q1)*N.j) + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == (cos(q2)*B.j - sin(q2)*B.k) + assert express(A.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(A.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(A.j, C) == (sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k) + assert express(A.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(B.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(B.j, N) == (-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(B.k, N) == (sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k) + assert express(B.i, A) == A.i + assert express(B.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(B.k, A) == (-sin(q2)*A.j + cos(q2)*A.k) + assert express(B.i, B) == B.i + assert express(B.j, B) == B.j + assert express(B.k, B) == B.k + assert express(B.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(B.j, C) == C.j + assert express(B.k, C) == (-sin(q3)*C.i + cos(q3)*C.k) + + assert express(C.i, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.i + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.j - + sin(q3)*cos(q2)*N.k) + assert express(C.j, N) == ( + -sin(q1)*cos(q2)*N.i + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(C.k, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.i + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.j + + cos(q2)*cos(q3)*N.k) + assert express(C.i, A) == (cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k) + assert express(C.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(C.k, A) == (sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k) + assert express(C.i, B) == (cos(q3)*B.i - sin(q3)*B.k) + assert express(C.j, B) == B.j + assert express(C.k, B) == (sin(q3)*B.i + cos(q3)*B.k) + assert express(C.i, C) == C.i + assert express(C.j, C) == C.j + assert express(C.k, C) == C.k == (C.k) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.i == express((cos(q1)*A.i - sin(q1)*A.j), N).simplify() + assert N.j == express((sin(q1)*A.i + cos(q1)*A.j), N).simplify() + assert N.i == express((cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k), N).simplify() + assert N.j == express((sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k), N).simplify() + assert N.k == express((sin(q2)*B.j + cos(q2)*B.k), N).simplify() + + + assert A.i == express((cos(q1)*N.i + sin(q1)*N.j), A).simplify() + assert A.j == express((-sin(q1)*N.i + cos(q1)*N.j), A).simplify() + + assert A.j == express((cos(q2)*B.j - sin(q2)*B.k), A).simplify() + assert A.k == express((sin(q2)*B.j + cos(q2)*B.k), A).simplify() + + assert A.i == express((cos(q3)*C.i + sin(q3)*C.k), A).simplify() + assert A.j == express((sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k), A).simplify() + + assert A.k == express((-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k), A).simplify() + assert B.i == express((cos(q1)*N.i + sin(q1)*N.j), B).simplify() + assert B.j == express((-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k), B).simplify() + + assert B.k == express((sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k), B).simplify() + + assert B.j == express((cos(q2)*A.j + sin(q2)*A.k), B).simplify() + assert B.k == express((-sin(q2)*A.j + cos(q2)*A.k), B).simplify() + assert B.i == express((cos(q3)*C.i + sin(q3)*C.k), B).simplify() + assert B.k == express((-sin(q3)*C.i + cos(q3)*C.k), B).simplify() + assert C.i == express((cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k), C).simplify() + assert C.j == express((cos(q2)*A.j + sin(q2)*A.k), C).simplify() + assert C.k == express((sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k), C).simplify() + assert C.i == express((cos(q3)*B.i - sin(q3)*B.k), C).simplify() + assert C.k == express((sin(q3)*B.i + cos(q3)*B.k), C).simplify() + + +def test_matrix_to_vector(): + m = Matrix([[1], [2], [3]]) + assert matrix_to_vector(m, C) == C.i + 2*C.j + 3*C.k + m = Matrix([[0], [0], [0]]) + assert matrix_to_vector(m, N) == matrix_to_vector(m, C) == \ + Vector.zero + m = Matrix([[q1], [q2], [q3]]) + assert matrix_to_vector(m, N) == q1*N.i + q2*N.j + q3*N.k + + +def test_orthogonalize(): + C = CoordSys3D('C') + a, b = symbols('a b', integer=True) + i, j, k = C.base_vectors() + v1 = i + 2*j + v2 = 2*i + 3*j + v3 = 3*i + 5*j + v4 = 3*i + j + v5 = 2*i + 2*j + v6 = a*i + b*j + v7 = 4*a*i + 4*b*j + assert orthogonalize(v1, v2) == [C.i + 2*C.j, C.i*Rational(2, 5) + -C.j/5] + # from wikipedia + assert orthogonalize(v4, v5, orthonormal=True) == \ + [(3*sqrt(10))*C.i/10 + (sqrt(10))*C.j/10, (-sqrt(10))*C.i/10 + (3*sqrt(10))*C.j/10] + raises(ValueError, lambda: orthogonalize(v1, v2, v3)) + raises(ValueError, lambda: orthogonalize(v6, v7)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_implicitregion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..3686d847a7f165cb5ba9aeb813e5922aaa17e1e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_implicitregion.py @@ -0,0 +1,90 @@ +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z, s, t +from sympy.sets import FiniteSet, EmptySet +from sympy.geometry import Point +from sympy.vector import ImplicitRegion +from sympy.testing.pytest import raises + + +def test_ImplicitRegion(): + ellipse = ImplicitRegion((x, y), (x**2/4 + y**2/16 - 1)) + assert ellipse.equation == x**2/4 + y**2/16 - 1 + assert ellipse.variables == (x, y) + assert ellipse.degree == 2 + r = ImplicitRegion((x, y, z), Eq(x**4 + y**2 - x*y, 6)) + assert r.equation == x**4 + y**2 - x*y - 6 + assert r.variables == (x, y, z) + assert r.degree == 4 + + +def test_regular_point(): + r1 = ImplicitRegion((x,), x**2 - 16) + assert r1.regular_point() == (-4,) + c1 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert c1.regular_point() == (0, -2) + c2 = ImplicitRegion((x, y), (x - S(5)/2)**2 + y**2 - (S(1)/4)**2) + assert c2.regular_point() == (S(5)/2, -S(1)/4) + c3 = ImplicitRegion((x, y), (y - 5)**2 - 16*(x - 5)) + assert c3.regular_point() == (5, 5) + r2 = ImplicitRegion((x, y), x**2 - 4*x*y - 3*y**2 + 4*x + 8*y - 5) + assert r2.regular_point() == (S(4)/7, S(9)/7) + r3 = ImplicitRegion((x, y), x**2 - 2*x*y + 3*y**2 - 2*x - 5*y + 3/2) + raises(ValueError, lambda: r3.regular_point()) + + +def test_singular_points_and_multiplicty(): + r1 = ImplicitRegion((x, y, z), Eq(x + y + z, 0)) + assert r1.singular_points() == EmptySet + r2 = ImplicitRegion((x, y, z), x*y*z + y**4 -x**2*z**2) + assert r2.singular_points() == FiniteSet((0, 0, z), (x, 0, 0)) + assert r2.multiplicity((0, 0, 0)) == 3 + assert r2.multiplicity((0, 0, 6)) == 2 + r3 = ImplicitRegion((x, y, z), z**2 - x**2 - y**2) + assert r3.singular_points() == FiniteSet((0, 0, 0)) + assert r3.multiplicity((0, 0, 0)) == 2 + r4 = ImplicitRegion((x, y), x**2 + y**2 - 2*x) + assert r4.singular_points() == EmptySet + assert r4.multiplicity(Point(1, 3)) == 0 + + +def test_rational_parametrization(): + p = ImplicitRegion((x,), x - 2) + assert p.rational_parametrization() == (x - 2,) + + line = ImplicitRegion((x, y), Eq(y, 3*x + 2)) + assert line.rational_parametrization() == (x, 3*x + 2) + + circle1 = ImplicitRegion((x, y), (x-2)**2 + (y+3)**2 - 4) + assert circle1.rational_parametrization(parameters=t) == (4*t/(t**2 + 1) + 2, 4*t**2/(t**2 + 1) - 5) + circle2 = ImplicitRegion((x, y), (x - S.Half)**2 + y**2 - (S(1)/2)**2) + + assert circle2.rational_parametrization(parameters=t) == (t/(t**2 + 1) + S(1)/2, t**2/(t**2 + 1) - S(1)/2) + circle3 = ImplicitRegion((x, y), Eq(x**2 + y**2, 2*x)) + assert circle3.rational_parametrization(parameters=(t,)) == (2*t/(t**2 + 1) + 1, 2*t**2/(t**2 + 1) - 1) + + parabola = ImplicitRegion((x, y), (y - 3)**2 - 4*(x + 6)) + assert parabola.rational_parametrization(t) == (-6 + 4/t**2, 3 + 4/t) + + rect_hyperbola = ImplicitRegion((x, y), x*y - 1) + assert rect_hyperbola.rational_parametrization(t) == (-1 + (t + 1)/t, t) + + cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert cubic_curve.rational_parametrization(parameters=(t)) == (t**2 - 1, t*(t**2 - 1)) + cuspidal = ImplicitRegion((x, y), (x**3 - y**2)) + assert cuspidal.rational_parametrization(t) == (t**2, t**3) + + I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert I.rational_parametrization(t) == (t**2 - 1, t*(t**2 - 1)) + + sphere = ImplicitRegion((x, y, z), Eq(x**2 + y**2 + z**2, 2*x)) + assert sphere.rational_parametrization(parameters=(s, t)) == (2/(s**2 + t**2 + 1), 2*t/(s**2 + t**2 + 1), 2*s/(s**2 + t**2 + 1)) + + conic = ImplicitRegion((x, y), Eq(x**2 + 4*x*y + 3*y**2 + x - y + 10, 0)) + assert conic.rational_parametrization(t) == ( + S(17)/2 + 4/(3*t**2 + 4*t + 1), 4*t/(3*t**2 + 4*t + 1) - S(11)/2) + + r1 = ImplicitRegion((x, y), y**2 - x**3 + x) + raises(NotImplementedError, lambda: r1.rational_parametrization()) + r2 = ImplicitRegion((x, y), y**2 - x**3 - x**2 + 1) + raises(NotImplementedError, lambda: r2.rational_parametrization()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_integrals.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..84c900d038e214df1ea59a8cd8fb2929005c3674 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_integrals.py @@ -0,0 +1,106 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.integrals import ParametricIntegral, vector_integrate +from sympy.vector.parametricregion import ParametricRegion +from sympy.vector.implicitregion import ImplicitRegion +from sympy.abc import x, y, z, u, v, r, t, theta, phi +from sympy.geometry import Point, Segment, Curve, Circle, Polygon, Plane + +C = CoordSys3D('C') + +def test_parametric_lineintegrals(): + halfcircle = ParametricRegion((4*cos(theta), 4*sin(theta)), (theta, -pi/2, pi/2)) + assert ParametricIntegral(C.x*C.y**4, halfcircle) == S(8192)/5 + + curve = ParametricRegion((t, t**2, t**3), (t, 0, 1)) + field1 = 8*C.x**2*C.y*C.z*C.i + 5*C.z*C.j - 4*C.x*C.y*C.k + assert ParametricIntegral(field1, curve) == 1 + line = ParametricRegion((4*t - 1, 2 - 2*t, t), (t, 0, 1)) + assert ParametricIntegral(C.x*C.z*C.i - C.y*C.z*C.k, line) == 3 + + assert ParametricIntegral(4*C.x**3, ParametricRegion((1, t), (t, 0, 2))) == 8 + + helix = ParametricRegion((cos(t), sin(t), 3*t), (t, 0, 4*pi)) + assert ParametricIntegral(C.x*C.y*C.z, helix) == -3*sqrt(10)*pi + + field2 = C.y*C.i + C.z*C.j + C.z*C.k + assert ParametricIntegral(field2, ParametricRegion((cos(t), sin(t), t**2), (t, 0, pi))) == -5*pi/2 + pi**4/2 + +def test_parametric_surfaceintegrals(): + + semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + assert ParametricIntegral(C.z, semisphere) == 8*pi + + cylinder = ParametricRegion((sqrt(3)*cos(theta), sqrt(3)*sin(theta), z), (z, 0, 6), (theta, 0, 2*pi)) + assert ParametricIntegral(C.y, cylinder) == 0 + + cone = ParametricRegion((v*cos(u), v*sin(u), v), (u, 0, 2*pi), (v, 0, 1)) + assert ParametricIntegral(C.x*C.i + C.y*C.j + C.z**4*C.k, cone) == pi/3 + + triangle1 = ParametricRegion((x, y), (x, 0, 2), (y, 0, 10 - 5*x)) + triangle2 = ParametricRegion((x, y), (y, 0, 10 - 5*x), (x, 0, 2)) + assert ParametricIntegral(-15.6*C.y*C.k, triangle1) == ParametricIntegral(-15.6*C.y*C.k, triangle2) + assert ParametricIntegral(C.z, triangle1) == 10*C.z + +def test_parametric_volumeintegrals(): + + cube = ParametricRegion((x, y, z), (x, 0, 1), (y, 0, 1), (z, 0, 1)) + assert ParametricIntegral(1, cube) == 1 + + solidsphere1 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (theta, 0, 2*pi), (phi, 0, pi)) + solidsphere2 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (phi, 0, pi), (theta, 0, 2*pi)) + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere1) == -256*pi/15 + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere2) == 256*pi/15 + + region_under_plane1 = ParametricRegion((x, y, z), (x, 0, 3), (y, 0, -2*x/3 + 2),\ + (z, 0, 6 - 2*x - 3*y)) + region_under_plane2 = ParametricRegion((x, y, z), (x, 0, 3), (z, 0, 6 - 2*x - 3*y),\ + (y, 0, -2*x/3 + 2)) + + assert ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane1) == \ + ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane2) + assert ParametricIntegral(2*C.x, region_under_plane2) == -9 + +def test_vector_integrate(): + halfdisc = ParametricRegion((r*cos(theta), r* sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert vector_integrate(C.x**2, halfdisc) == 4*pi + assert vector_integrate(C.x, ParametricRegion((t, t**2), (t, 2, 3))) == -17*sqrt(17)/12 + 37*sqrt(37)/12 + + assert vector_integrate(C.y**3*C.z, (C.x, 0, 3), (C.y, -1, 4)) == 765*C.z/4 + + s1 = Segment(Point(0, 0), Point(0, 1)) + assert vector_integrate(-15*C.y, s1) == S(-15)/2 + s2 = Segment(Point(4, 3, 9), Point(1, 1, 7)) + assert vector_integrate(C.y*C.i, s2) == -6 + + curve = Curve((sin(t), cos(t)), (t, 0, 2)) + assert vector_integrate(5*C.z, curve) == 10*C.z + + c1 = Circle(Point(2, 3), 6) + assert vector_integrate(C.x*C.y, c1) == 72*pi + c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0)) + assert vector_integrate(1, c2) == c2.circumference + + triangle = Polygon((0, 0), (1, 0), (1, 1)) + assert vector_integrate(C.x*C.i - 14*C.y*C.j, triangle) == 0 + p1, p2, p3, p4 = [(0, 0), (1, 0), (5, 1), (0, 1)] + poly = Polygon(p1, p2, p3, p4) + assert vector_integrate(-23*C.z, poly) == -161*C.z - 23*sqrt(17)*C.z + + point = Point(2, 3) + assert vector_integrate(C.i*C.y, point) == ParametricIntegral(C.y*C.i, ParametricRegion((2, 3))) + + c3 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert vector_integrate(45, c3) == 180*pi + c4 = ImplicitRegion((x, y), (x - 3)**2 + (y - 4)**2 - 9) + assert vector_integrate(1, c4) == 6*pi + + pl = Plane(Point(1, 1, 1), Point(2, 3, 4), Point(2, 2, 2)) + raises(ValueError, lambda: vector_integrate(C.x*C.z*C.i + C.k, pl)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_operators.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..5734edadd00547c67d6f864b50afd966ad8392a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_operators.py @@ -0,0 +1,43 @@ +from sympy.vector import CoordSys3D, Gradient, Divergence, Curl, VectorZero, Laplacian +from sympy.printing.repr import srepr + +R = CoordSys3D('R') +s1 = R.x*R.y*R.z # type: ignore +s2 = R.x + 3*R.y**2 # type: ignore +s3 = R.x**2 + R.y**2 + R.z**2 # type: ignore +v1 = R.x*R.i + R.z*R.z*R.j # type: ignore +v2 = R.x*R.i + R.y*R.j + R.z*R.k # type: ignore +v3 = R.x**2*R.i + R.y**2*R.j + R.z**2*R.k # type: ignore + + +def test_Gradient(): + assert Gradient(s1) == Gradient(R.x*R.y*R.z) + assert Gradient(s2) == Gradient(R.x + 3*R.y**2) + assert Gradient(s1).doit() == R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + assert Gradient(s2).doit() == R.i + 6*R.y*R.j + + +def test_Divergence(): + assert Divergence(v1) == Divergence(R.x*R.i + R.z*R.z*R.j) + assert Divergence(v2) == Divergence(R.x*R.i + R.y*R.j + R.z*R.k) + assert Divergence(v1).doit() == 1 + assert Divergence(v2).doit() == 3 + # issue 22384 + Rc = CoordSys3D('R', transformation='cylindrical') + assert Divergence(Rc.i).doit() == 1/Rc.r + + +def test_Curl(): + assert Curl(v1) == Curl(R.x*R.i + R.z*R.z*R.j) + assert Curl(v2) == Curl(R.x*R.i + R.y*R.j + R.z*R.k) + assert Curl(v1).doit() == (-2*R.z)*R.i + assert Curl(v2).doit() == VectorZero() + + +def test_Laplacian(): + assert Laplacian(s3) == Laplacian(R.x**2 + R.y**2 + R.z**2) + assert Laplacian(v3) == Laplacian(R.x**2*R.i + R.y**2*R.j + R.z**2*R.k) + assert Laplacian(s3).doit() == 6 + assert Laplacian(v3).doit() == 2*R.i + 2*R.j + 2*R.k + assert srepr(Laplacian(s3)) == \ + 'Laplacian(Add(Pow(R.x, Integer(2)), Pow(R.y, Integer(2)), Pow(R.z, Integer(2))))' diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_printing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ae76905e967bdf93485f135c6a69f968e1208986 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_printing.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +from sympy.core.function import Function +from sympy.integrals.integrals import Integral +from sympy.printing.latex import latex +from sympy.printing.pretty import pretty as xpretty +from sympy.vector import CoordSys3D, Del, Vector, express +from sympy.abc import a, b, c +from sympy.testing.pytest import XFAIL + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +# Initialize the basic and tedious vector/dyadic expressions +# needed for testing. +# Some of the pretty forms shown denote how the expressions just +# above them should look with pretty printing. +N = CoordSys3D('N') +C = N.orient_new_axis('C', a, N.k) # type: ignore +v = [] +d = [] +v.append(Vector.zero) +v.append(N.i) # type: ignore +v.append(-N.i) # type: ignore +v.append(N.i + N.j) # type: ignore +v.append(a*N.i) # type: ignore +v.append(a*N.i - b*N.j) # type: ignore +v.append((a**2 + N.x)*N.i + N.k) # type: ignore +v.append((a**2 + b)*N.i + 3*(C.y - c)*N.k) # type: ignore +f = Function('f') +v.append(N.j - (Integral(f(b)) - C.x**2)*N.k) # type: ignore +upretty_v_8 = """\ + ⎛ 2 ⌠ ⎞ \n\ +j_N + ⎜x_C - ⎮ f(b) db⎟ k_N\n\ + ⎝ ⌡ ⎠ \ +""" +pretty_v_8 = """\ +j_N + / / \\\n\ + | 2 | |\n\ + |x_C - | f(b) db|\n\ + | | |\n\ + \\ / / \ +""" + +v.append(N.i + C.k) # type: ignore +v.append(express(N.i, C)) # type: ignore +v.append((a**2 + b)*N.i + (Integral(f(b)))*N.k) # type: ignore +upretty_v_11 = """\ +⎛ 2 ⎞ ⎛⌠ ⎞ \n\ +⎝a + b⎠ i_N + ⎜⎮ f(b) db⎟ k_N\n\ + ⎝⌡ ⎠ \ +""" +pretty_v_11 = """\ +/ 2 \\ + / / \\\n\ +\\a + b/ i_N| | |\n\ + | | f(b) db|\n\ + | | |\n\ + \\/ / \ +""" + +for x in v: + d.append(x | N.k) # type: ignore +s = 3*N.x**2*C.y # type: ignore +upretty_s = """\ + 2\n\ +3⋅y_C⋅x_N \ +""" +pretty_s = """\ + 2\n\ +3*y_C*x_N \ +""" + +# This is the pretty form for ((a**2 + b)*N.i + 3*(C.y - c)*N.k) | N.k +upretty_d_7 = """\ +⎛ 2 ⎞ \n\ +⎝a + b⎠ (i_N|k_N) + (3⋅y_C - 3⋅c) (k_N|k_N)\ +""" +pretty_d_7 = """\ +/ 2 \\ (i_N|k_N) + (3*y_C - 3*c) (k_N|k_N)\n\ +\\a + b/ \ +""" + + +def test_str_printing(): + assert str(v[0]) == '0' + assert str(v[1]) == 'N.i' + assert str(v[2]) == '(-1)*N.i' + assert str(v[3]) == 'N.i + N.j' + assert str(v[8]) == 'N.j + (C.x**2 - Integral(f(b), b))*N.k' + assert str(v[9]) == 'C.k + N.i' + assert str(s) == '3*C.y*N.x**2' + assert str(d[0]) == '0' + assert str(d[1]) == '(N.i|N.k)' + assert str(d[4]) == 'a*(N.i|N.k)' + assert str(d[5]) == 'a*(N.i|N.k) + (-b)*(N.j|N.k)' + assert str(d[8]) == ('(N.j|N.k) + (C.x**2 - ' + + 'Integral(f(b), b))*(N.k|N.k)') + + +@XFAIL +def test_pretty_printing_ascii(): + assert pretty(v[0]) == '0' + assert pretty(v[1]) == 'i_N' + assert pretty(v[5]) == '(a) i_N + (-b) j_N' + assert pretty(v[8]) == pretty_v_8 + assert pretty(v[2]) == '(-1) i_N' + assert pretty(v[11]) == pretty_v_11 + assert pretty(s) == pretty_s + assert pretty(d[0]) == '(0|0)' + assert pretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert pretty(d[7]) == pretty_d_7 + assert pretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_pretty_print_unicode_v(): + assert upretty(v[0]) == '0' + assert upretty(v[1]) == 'i_N' + assert upretty(v[5]) == '(a) i_N + (-b) j_N' + # Make sure the printing works in other objects + assert upretty(v[5].args) == '((a) i_N, (-b) j_N)' + assert upretty(v[8]) == upretty_v_8 + assert upretty(v[2]) == '(-1) i_N' + assert upretty(v[11]) == upretty_v_11 + assert upretty(s) == upretty_s + assert upretty(d[0]) == '(0|0)' + assert upretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert upretty(d[7]) == upretty_d_7 + assert upretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_latex_printing(): + assert latex(v[0]) == '\\mathbf{\\hat{0}}' + assert latex(v[1]) == '\\mathbf{\\hat{i}_{N}}' + assert latex(v[2]) == '- \\mathbf{\\hat{i}_{N}}' + assert latex(v[5]) == ('\\left(a\\right)\\mathbf{\\hat{i}_{N}} + ' + + '\\left(- b\\right)\\mathbf{\\hat{j}_{N}}') + assert latex(v[6]) == ('\\left(\\mathbf{{x}_{N}} + a^{2}\\right)\\mathbf{\\hat{i}_' + + '{N}} + \\mathbf{\\hat{k}_{N}}') + assert latex(v[8]) == ('\\mathbf{\\hat{j}_{N}} + \\left(\\mathbf{{x}_' + + '{C}}^{2} - \\int f{\\left(b \\right)}\\,' + + ' db\\right)\\mathbf{\\hat{k}_{N}}') + assert latex(s) == '3 \\mathbf{{y}_{C}} \\mathbf{{x}_{N}}^{2}' + assert latex(d[0]) == '(\\mathbf{\\hat{0}}|\\mathbf{\\hat{0}})' + assert latex(d[4]) == ('\\left(a\\right)\\left(\\mathbf{\\hat{i}_{N}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right)') + assert latex(d[9]) == ('\\left(\\mathbf{\\hat{k}_{C}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right) + \\left(' + + '\\mathbf{\\hat{i}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + assert latex(d[11]) == ('\\left(a^{2} + b\\right)\\left(\\mathbf{\\hat{i}_{N}}' + + '{\\middle|}\\mathbf{\\hat{k}_{N}}\\right) + ' + + '\\left(\\int f{\\left(b \\right)}\\, db\\right)\\left(' + + '\\mathbf{\\hat{k}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + +def test_issue_23058(): + from sympy import symbols, sin, cos, pi, UnevaluatedExpr + + delop = Del() + CC_ = CoordSys3D("C") + y = CC_.y + xhat = CC_.i + + t = symbols("t") + ten = symbols("10", positive=True) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + vecE = (1/eps) * Integral(delop.cross(vecB/mu).doit(), t) + vecE = vecE.doit() + + vecB_str = """\ +⎛ ⎛y_C⎞ ⎛ 5 ⎞⎞ \n\ +⎜2⋅sin⎜───⎟⋅cos⎝10 ⋅t⎠⎟ i_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────⎟ \n\ +⎜ 4 ⎟ \n\ +⎝ 10 ⎠ \ +""" + vecE_str = """\ +⎛ 4 ⎛ 5 ⎞ ⎛y_C⎞ ⎞ \n\ +⎜-10 ⋅sin⎝10 ⋅t⎠⋅cos⎜───⎟ ⎟ k_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────────⎟ \n\ +⎝ 2⋅π ⎠ \ +""" + + assert upretty(vecB) == vecB_str + assert upretty(vecE) == vecE_str + + ten = UnevaluatedExpr(10) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + + vecB_str = """\ +⎛ -4 ⎛ 5⎞ ⎛ -3⎞⎞ \n\ +⎝2⋅10 ⋅cos⎝t⋅10 ⎠⋅sin⎝y_C⋅10 ⎠⎠ i_C \ +""" + assert upretty(vecB) == vecB_str + +def test_custom_names(): + A = CoordSys3D('A', vector_names=['x', 'y', 'z'], + variable_names=['i', 'j', 'k']) + assert A.i.__str__() == 'A.i' + assert A.x.__str__() == 'A.x' + assert A.i._pretty_form == 'i_A' + assert A.x._pretty_form == 'x_A' + assert A.i._latex_form == r'\mathbf{{i}_{A}}' + assert A.x._latex_form == r"\mathbf{\hat{x}_{A}}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_vector.py b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..daba6d6a02c87b41a8bf801eee9b9045897d0003 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/sympy/vector/tests/test_vector.py @@ -0,0 +1,342 @@ +from sympy.core import Rational, S, Add, Mul, I +from sympy.simplify import simplify, trigsimp +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.vector import Vector, BaseVector, VectorAdd, \ + VectorMul, VectorZero +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Cross, Dot, cross +from sympy.testing.pytest import raises +from sympy.vector.kind import VectorKind +from sympy.core.kind import NumberKind +from sympy.testing.pytest import XFAIL + + +C = CoordSys3D('C') + +i, j, k = C.base_vectors() +a, b, c = symbols('a b c') + + +def test_cross(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) == Cross(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Cross(v1, v2).doit() == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert cross(v1, v2) == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert Cross(v1, v2) == -Cross(v2, v1) + # XXX: Cannot use Cross here. See XFAIL test below: + assert cross(v1, v2) + cross(v2, v1) == Vector.zero + + +@XFAIL +def test_cross_xfail(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) + Cross(v2, v1) == Vector.zero + + +def test_dot(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Dot(v1, v2) == Dot(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Dot(v1, v2).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v2, v1).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v1, v2) == Dot(v2, v1) + + +def test_vector_sympy(): + """ + Test whether the Vector framework confirms to the hashing + and equality testing properties of SymPy. + """ + v1 = 3*j + assert v1 == j*3 + assert v1.components == {j: 3} + v2 = 3*i + 4*j + 5*k + v3 = 2*i + 4*j + i + 4*k + k + assert v3 == v2 + assert v3.__hash__() == v2.__hash__() + + +def test_kind(): + assert C.i.kind is VectorKind(NumberKind) + assert C.j.kind is VectorKind(NumberKind) + assert C.k.kind is VectorKind(NumberKind) + + assert C.x.kind is NumberKind + assert C.y.kind is NumberKind + assert C.z.kind is NumberKind + + assert Mul._kind_dispatcher(NumberKind, VectorKind(NumberKind)) is VectorKind(NumberKind) + assert Mul(2, C.i).kind is VectorKind(NumberKind) + + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert v1.kind is VectorKind(NumberKind) + assert v2.kind is VectorKind(NumberKind) + + assert (v1 + v2).kind is VectorKind(NumberKind) + assert Add(v1, v2).kind is VectorKind(NumberKind) + assert Cross(v1, v2).doit().kind is VectorKind(NumberKind) + assert VectorAdd(v1, v2).kind is VectorKind(NumberKind) + assert VectorMul(2, v1).kind is VectorKind(NumberKind) + assert VectorZero().kind is VectorKind(NumberKind) + + assert v1.projection(v2).kind is VectorKind(NumberKind) + assert v2.projection(v1).kind is VectorKind(NumberKind) + + +def test_vectoradd(): + assert isinstance(Add(C.i, C.j), VectorAdd) + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert isinstance(Add(v1, v2), VectorAdd) + + # https://github.com/sympy/sympy/issues/26121 + + E = Matrix([C.i, C.j, C.k]).T + a = Matrix([1, 2, 3]) + av = E*a + + assert av[0].kind == VectorKind() + assert isinstance(av[0], VectorAdd) + + +def test_vector(): + assert isinstance(i, BaseVector) + assert i != j + assert j != k + assert k != i + assert i - i == Vector.zero + assert i + Vector.zero == i + assert i - Vector.zero == i + assert Vector.zero != 0 + assert -Vector.zero == Vector.zero + + v1 = a*i + b*j + c*k + v2 = a**2*i + b**2*j + c**2*k + v3 = v1 + v2 + v4 = 2 * v1 + v5 = a * i + + assert isinstance(v1, VectorAdd) + assert v1 - v1 == Vector.zero + assert v1 + Vector.zero == v1 + assert v1.dot(i) == a + assert v1.dot(j) == b + assert v1.dot(k) == c + assert i.dot(v2) == a**2 + assert j.dot(v2) == b**2 + assert k.dot(v2) == c**2 + assert v3.dot(i) == a**2 + a + assert v3.dot(j) == b**2 + b + assert v3.dot(k) == c**2 + c + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -1 * (v2 - v1) + assert a * v1 == v1 * a + + assert isinstance(v5, VectorMul) + assert v5.base_vector == i + assert v5.measure_number == a + assert isinstance(v4, Vector) + assert isinstance(v4, VectorAdd) + assert isinstance(v4, Vector) + assert isinstance(Vector.zero, VectorZero) + assert isinstance(Vector.zero, Vector) + assert isinstance(v1 * 0, VectorZero) + + assert v1.to_matrix(C) == Matrix([[a], [b], [c]]) + + assert i.components == {i: 1} + assert v5.components == {i: a} + assert v1.components == {i: a, j: b, k: c} + + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(a, v1) == v1*a + assert VectorMul(1, i) == i + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(0, Vector.zero) == Vector.zero + raises(TypeError, lambda: v1.outer(1)) + raises(TypeError, lambda: v1.dot(1)) + + +def test_vector_magnitude_normalize(): + assert Vector.zero.magnitude() == 0 + assert Vector.zero.normalize() == Vector.zero + + assert i.magnitude() == 1 + assert j.magnitude() == 1 + assert k.magnitude() == 1 + assert i.normalize() == i + assert j.normalize() == j + assert k.normalize() == k + + v1 = a * i + assert v1.normalize() == (a/sqrt(a**2))*i + assert v1.magnitude() == sqrt(a**2) + + v2 = a*i + b*j + c*k + assert v2.magnitude() == sqrt(a**2 + b**2 + c**2) + assert v2.normalize() == v2 / v2.magnitude() + + v3 = i + j + assert v3.normalize() == (sqrt(2)/2)*C.i + (sqrt(2)/2)*C.j + + +def test_vector_simplify(): + A, s, k, m = symbols('A, s, k, m') + + test1 = (1 / a + 1 / b) * i + assert (test1 & i) != (a + b) / (a * b) + test1 = simplify(test1) + assert (test1 & i) == (a + b) / (a * b) + assert test1.simplify() == simplify(test1) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * i + test2 = simplify(test2) + assert (test2 & i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * a - 2 * (2 + 2 * a)) / (2 + 2 * a)) * i + test3 = simplify(test3) + assert (test3 & i) == 0 + + test4 = ((-4 * a * b**2 - 2 * b**3 - 2 * a**2 * b) / (a + b)**2) * i + test4 = simplify(test4) + assert (test4 & i) == -2 * b + + v = (sin(a)+cos(a))**2*i - j + assert trigsimp(v) == (2*sin(a + pi/4)**2)*i + (-1)*j + assert trigsimp(v) == v.trigsimp() + + assert simplify(Vector.zero) == Vector.zero + + +def test_vector_equals(): + assert (2*i).equals(j) is False + assert i.equals(i) is True + + # https://github.com/sympy/sympy/issues/25915 + A = (sqrt(2) + sqrt(6)) / sqrt(sqrt(3) + 2) + assert (A*i).equals(2*i) is True + assert (A*i).equals(3*i) is False + + # Test comparing vectors in different coordinate systems + D = C.orient_new_axis('D', pi/2, C.k) + assert (D.i).equals(C.j) is True + assert (D.i).equals(C.i) is False + + +def test_vector_conjugate(): + # https://github.com/sympy/sympy/issues/27094 + assert (I*i + (1 + I)*j + 2*k).conjugate() == -I*i + (1 - I)*j + 2*k + + +def test_vector_dot(): + assert i.dot(Vector.zero) == 0 + assert Vector.zero.dot(i) == 0 + assert i & Vector.zero == 0 + + assert i.dot(i) == 1 + assert i.dot(j) == 0 + assert i.dot(k) == 0 + assert i & i == 1 + assert i & j == 0 + assert i & k == 0 + + assert j.dot(i) == 0 + assert j.dot(j) == 1 + assert j.dot(k) == 0 + assert j & i == 0 + assert j & j == 1 + assert j & k == 0 + + assert k.dot(i) == 0 + assert k.dot(j) == 0 + assert k.dot(k) == 1 + assert k & i == 0 + assert k & j == 0 + assert k & k == 1 + + raises(TypeError, lambda: k.dot(1)) + + +def test_vector_cross(): + assert i.cross(Vector.zero) == Vector.zero + assert Vector.zero.cross(i) == Vector.zero + + assert i.cross(i) == Vector.zero + assert i.cross(j) == k + assert i.cross(k) == -j + assert i ^ i == Vector.zero + assert i ^ j == k + assert i ^ k == -j + + assert j.cross(i) == -k + assert j.cross(j) == Vector.zero + assert j.cross(k) == i + assert j ^ i == -k + assert j ^ j == Vector.zero + assert j ^ k == i + + assert k.cross(i) == j + assert k.cross(j) == -i + assert k.cross(k) == Vector.zero + assert k ^ i == j + assert k ^ j == -i + assert k ^ k == Vector.zero + + assert k.cross(1) == Cross(k, 1) + + +def test_projection(): + v1 = i + j + k + v2 = 3*i + 4*j + v3 = 0*i + 0*j + assert v1.projection(v1) == i + j + k + assert v1.projection(v2) == Rational(7, 3)*C.i + Rational(7, 3)*C.j + Rational(7, 3)*C.k + assert v1.projection(v1, scalar=True) == S.One + assert v1.projection(v2, scalar=True) == Rational(7, 3) + assert v3.projection(v1) == Vector.zero + assert v3.projection(v1, scalar=True) == S.Zero + + +def test_vector_diff_integrate(): + f = Function('f') + v = f(a)*C.i + a**2*C.j - C.k + assert Derivative(v, a) == Derivative((f(a))*C.i + + a**2*C.j + (-1)*C.k, a) + assert (diff(v, a) == v.diff(a) == Derivative(v, a).doit() == + (Derivative(f(a), a))*C.i + 2*a*C.j) + assert (Integral(v, a) == (Integral(f(a), a))*C.i + + (Integral(a**2, a))*C.j + (Integral(-1, a))*C.k) + + +def test_vector_args(): + raises(ValueError, lambda: BaseVector(3, C)) + raises(TypeError, lambda: BaseVector(0, Vector.zero)) + + +def test_srepr(): + from sympy.printing.repr import srepr + res = "CoordSys3D(Str('C'), Tuple(ImmutableDenseMatrix([[Integer(1), "\ + "Integer(0), Integer(0)], [Integer(0), Integer(1), Integer(0)], "\ + "[Integer(0), Integer(0), Integer(1)]]), VectorZero())).i" + assert srepr(C.i) == res + + +def test_scalar(): + from sympy.vector import CoordSys3D + C = CoordSys3D('C') + v1 = 3*C.i + 4*C.j + 5*C.k + v2 = 3*C.i - 4*C.j + 5*C.k + assert v1.is_Vector is True + assert v1.is_scalar is False + assert (v1.dot(v2)).is_scalar is True + assert (v1.cross(v2)).is_scalar is False