diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_mapper.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..64cc908801d057c75e5eded5bb9ca592180f5e1b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_mapper.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import abc +import hashlib + +from fsspec.implementations.local import make_path_posix + + +class AbstractCacheMapper(abc.ABC): + """Abstract super-class for mappers from remote URLs to local cached + basenames. + """ + + @abc.abstractmethod + def __call__(self, path: str) -> str: + ... + + def __eq__(self, other: object) -> bool: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return isinstance(other, type(self)) + + def __hash__(self) -> int: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return hash(type(self)) + + +class BasenameCacheMapper(AbstractCacheMapper): + """Cache mapper that uses the basename of the remote URL and a fixed number + of directory levels above this. + + The default is zero directory levels, meaning different paths with the same + basename will have the same cached basename. + """ + + def __init__(self, directory_levels: int = 0): + if directory_levels < 0: + raise ValueError( + "BasenameCacheMapper requires zero or positive directory_levels" + ) + self.directory_levels = directory_levels + + # Separator for directories when encoded as strings. + self._separator = "_@_" + + def __call__(self, path: str) -> str: + path = make_path_posix(path) + prefix, *bits = path.rsplit("/", self.directory_levels + 1) + if bits: + return self._separator.join(bits) + else: + return prefix # No separator found, simple filename + + def __eq__(self, other: object) -> bool: + return super().__eq__(other) and self.directory_levels == other.directory_levels + + def __hash__(self) -> int: + return super().__hash__() ^ hash(self.directory_levels) + + +class HashCacheMapper(AbstractCacheMapper): + """Cache mapper that uses a hash of the remote URL.""" + + def __call__(self, path: str) -> str: + return hashlib.sha256(path.encode()).hexdigest() + + +def create_cache_mapper(same_names: bool) -> AbstractCacheMapper: + """Factory method to create cache mapper for backward compatibility with + ``CachingFileSystem`` constructor using ``same_names`` kwarg. + """ + if same_names: + return BasenameCacheMapper() + else: + return HashCacheMapper() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_metadata.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..16964c2a7153d40b480dd47513d1129ed27e307b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/cache_metadata.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import os +import pickle +import time +from typing import TYPE_CHECKING + +from fsspec.utils import atomic_write + +try: + import ujson as json +except ImportError: + if not TYPE_CHECKING: + import json + +if TYPE_CHECKING: + from typing import Any, Dict, Iterator, Literal + + from typing_extensions import TypeAlias + + from .cached import CachingFileSystem + + Detail: TypeAlias = Dict[str, Any] + + +class CacheMetadata: + """Cache metadata. + + All reading and writing of cache metadata is performed by this class, + accessing the cached files and blocks is not. + + Metadata is stored in a single file per storage directory in JSON format. + For backward compatibility, also reads metadata stored in pickle format + which is converted to JSON when next saved. + """ + + def __init__(self, storage: list[str]): + """ + + Parameters + ---------- + storage: list[str] + Directories containing cached files, must be at least one. Metadata + is stored in the last of these directories by convention. + """ + if not storage: + raise ValueError("CacheMetadata expects at least one storage location") + + self._storage = storage + self.cached_files: list[Detail] = [{}] + + # Private attribute to force saving of metadata in pickle format rather than + # JSON for use in tests to confirm can read both pickle and JSON formats. + self._force_save_pickle = False + + def _load(self, fn: str) -> Detail: + """Low-level function to load metadata from specific file""" + try: + with open(fn, "r") as f: + return json.load(f) + except ValueError: + with open(fn, "rb") as f: + return pickle.load(f) + + def _save(self, metadata_to_save: Detail, fn: str) -> None: + """Low-level function to save metadata to specific file""" + if self._force_save_pickle: + with atomic_write(fn) as f: + pickle.dump(metadata_to_save, f) + else: + with atomic_write(fn, mode="w") as f: + json.dump(metadata_to_save, f) + + def _scan_locations( + self, writable_only: bool = False + ) -> Iterator[tuple[str, str, bool]]: + """Yield locations (filenames) where metadata is stored, and whether + writable or not. + + Parameters + ---------- + writable: bool + Set to True to only yield writable locations. + + Returns + ------- + Yields (str, str, bool) + """ + n = len(self._storage) + for i, storage in enumerate(self._storage): + writable = i == n - 1 + if writable_only and not writable: + continue + yield os.path.join(storage, "cache"), storage, writable + + def check_file( + self, path: str, cfs: CachingFileSystem | None + ) -> Literal[False] | tuple[Detail, str]: + """If path is in cache return its details, otherwise return ``False``. + + If the optional CachingFileSystem is specified then it is used to + perform extra checks to reject possible matches, such as if they are + too old. + """ + for (fn, base, _), cache in zip(self._scan_locations(), self.cached_files): + if path not in cache: + continue + detail = cache[path].copy() + + if cfs is not None: + if cfs.check_files and detail["uid"] != cfs.fs.ukey(path): + # Wrong file as determined by hash of file properties + continue + if cfs.expiry and time.time() - detail["time"] > cfs.expiry: + # Cached file has expired + continue + + fn = os.path.join(base, detail["fn"]) + if os.path.exists(fn): + return detail, fn + return False + + def clear_expired(self, expiry_time: int) -> tuple[list[str], bool]: + """Remove expired metadata from the cache. + + Returns names of files corresponding to expired metadata and a boolean + flag indicating whether the writable cache is empty. Caller is + responsible for deleting the expired files. + """ + expired_files = [] + for path, detail in self.cached_files[-1].copy().items(): + if time.time() - detail["time"] > expiry_time: + fn = detail.get("fn", "") + if not fn: + raise RuntimeError( + f"Cache metadata does not contain 'fn' for {path}" + ) + fn = os.path.join(self._storage[-1], fn) + expired_files.append(fn) + self.cached_files[-1].pop(path) + + if self.cached_files[-1]: + cache_path = os.path.join(self._storage[-1], "cache") + self._save(self.cached_files[-1], cache_path) + + writable_cache_empty = not self.cached_files[-1] + return expired_files, writable_cache_empty + + def load(self) -> None: + """Load all metadata from disk and store in ``self.cached_files``""" + cached_files = [] + for fn, _, _ in self._scan_locations(): + if os.path.exists(fn): + # TODO: consolidate blocks here + loaded_cached_files = self._load(fn) + for c in loaded_cached_files.values(): + if isinstance(c["blocks"], list): + c["blocks"] = set(c["blocks"]) + cached_files.append(loaded_cached_files) + else: + cached_files.append({}) + self.cached_files = cached_files or [{}] + + def on_close_cached_file(self, f: Any, path: str) -> None: + """Perform side-effect actions on closing a cached file. + + The actual closing of the file is the responsibility of the caller. + """ + # File must be writeble, so in self.cached_files[-1] + c = self.cached_files[-1][path] + if c["blocks"] is not True and len(c["blocks"]) * f.blocksize >= f.size: + c["blocks"] = True + + def pop_file(self, path: str) -> str | None: + """Remove metadata of cached file. + + If path is in the cache, return the filename of the cached file, + otherwise return ``None``. Caller is responsible for deleting the + cached file. + """ + details = self.check_file(path, None) + if not details: + return None + _, fn = details + if fn.startswith(self._storage[-1]): + self.cached_files[-1].pop(path) + self.save() + else: + raise PermissionError( + "Can only delete cached file in last, writable cache location" + ) + return fn + + def save(self) -> None: + """Save metadata to disk""" + for (fn, _, writable), cache in zip(self._scan_locations(), self.cached_files): + if not writable: + continue + + if os.path.exists(fn): + cached_files = self._load(fn) + for k, c in cached_files.items(): + if k in cache: + if c["blocks"] is True or cache[k]["blocks"] is True: + c["blocks"] = True + else: + # self.cached_files[*][*]["blocks"] must continue to + # point to the same set object so that updates + # performed by MMapCache are propagated back to + # self.cached_files. + blocks = cache[k]["blocks"] + blocks.update(c["blocks"]) + c["blocks"] = blocks + c["time"] = max(c["time"], cache[k]["time"]) + c["uid"] = cache[k]["uid"] + + # Files can be added to cache after it was written once + for k, c in cache.items(): + if k not in cached_files: + cached_files[k] = c + else: + cached_files = cache + cache = {k: v.copy() for k, v in cached_files.items()} + for c in cache.values(): + if isinstance(c["blocks"], set): + c["blocks"] = list(c["blocks"]) + self._save(cache, fn) + self.cached_files[-1] = cached_files + + def update_file(self, path: str, detail: Detail) -> None: + """Update metadata for specific file in memory, do not save""" + self.cached_files[-1][path] = detail diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/data.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/data.py new file mode 100644 index 0000000000000000000000000000000000000000..67d16ec387e24c63ecbd3464c6f4f8e223821651 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/data.py @@ -0,0 +1,48 @@ +import base64 +import io +from urllib.parse import unquote + +from fsspec import AbstractFileSystem + + +class DataFileSystem(AbstractFileSystem): + """A handy decoder for data-URLs + + Example + ------- + >>> with fsspec.open("data:,Hello%2C%20World%21") as f: + ... print(f.read()) + b"Hello, World!" + + """ + + protocol = "data" + + def __init__(self, **kwargs): + """No parameters for this filesystem""" + super().__init__(**kwargs) + + def cat_file(self, path, start=None, end=None, **kwargs): + pref, data = path.split(",", 1) + if pref.endswith("base64"): + return base64.b64decode(data)[start:end] + return unquote(data).encode()[start:end] + + def info(self, path, **kwargs): + pref, name = path.split(",", 1) + data = self.cat_file(path) + mime = pref.split(":", 1)[1].split(";", 1)[0] + return {"name": name, "size": len(data), "type": "file", "mimetype": mime} + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + if "r" not in mode: + raise ValueError("Read only filesystem") + return io.BytesIO(self.cat_file(path)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/dbfs.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/dbfs.py new file mode 100644 index 0000000000000000000000000000000000000000..ce9f9eadb798577970ee95530743b4521813ca7c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/dbfs.py @@ -0,0 +1,467 @@ +import base64 +import urllib + +import requests +import requests.exceptions +from requests.adapters import HTTPAdapter, Retry + +from fsspec import AbstractFileSystem +from fsspec.spec import AbstractBufferedFile + + +class DatabricksException(Exception): + """ + Helper class for exceptions raised in this module. + """ + + def __init__(self, error_code, message): + """Create a new DatabricksException""" + super().__init__(message) + + self.error_code = error_code + self.message = message + + +class DatabricksFileSystem(AbstractFileSystem): + """ + Get access to the Databricks filesystem implementation over HTTP. + Can be used inside and outside of a databricks cluster. + """ + + def __init__(self, instance, token, **kwargs): + """ + Create a new DatabricksFileSystem. + + Parameters + ---------- + instance: str + The instance URL of the databricks cluster. + For example for an Azure databricks cluster, this + has the form adb-..azuredatabricks.net. + token: str + Your personal token. Find out more + here: https://docs.databricks.com/dev-tools/api/latest/authentication.html + """ + self.instance = instance + self.token = token + self.session = requests.Session() + self.retries = Retry( + total=10, + backoff_factor=0.05, + status_forcelist=[408, 429, 500, 502, 503, 504], + ) + + self.session.mount("https://", HTTPAdapter(max_retries=self.retries)) + self.session.headers.update({"Authorization": f"Bearer {self.token}"}) + + super().__init__(**kwargs) + + def ls(self, path, detail=True, **kwargs): + """ + List the contents of the given path. + + Parameters + ---------- + path: str + Absolute path + detail: bool + Return not only the list of filenames, + but also additional information on file sizes + and types. + """ + out = self._ls_from_cache(path) + if not out: + try: + r = self._send_to_api( + method="get", endpoint="list", json={"path": path} + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) + + raise e + files = r["files"] + out = [ + { + "name": o["path"], + "type": "directory" if o["is_dir"] else "file", + "size": o["file_size"], + } + for o in files + ] + self.dircache[path] = out + + if detail: + return out + return [o["name"] for o in out] + + def makedirs(self, path, exist_ok=True): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + exist_ok: bool + If false, checks if the folder + exists before creating it (and raises an + Exception if this is the case) + """ + if not exist_ok: + try: + # If the following succeeds, the path is already present + self._send_to_api( + method="get", endpoint="get-status", json={"path": path} + ) + raise FileExistsError(f"Path {path} already exists") + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + pass + + try: + self._send_to_api(method="post", endpoint="mkdirs", json={"path": path}) + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) + + raise e + self.invalidate_cache(self._parent(path)) + + def mkdir(self, path, create_parents=True, **kwargs): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + create_parents: bool + Whether to create all parents or not. + "False" is not implemented so far. + """ + if not create_parents: + raise NotImplementedError + + self.mkdirs(path, **kwargs) + + def rm(self, path, recursive=False, **kwargs): + """ + Remove the file or folder at the given absolute path. + + Parameters + ---------- + path: str + Absolute path what to remove + recursive: bool + Recursively delete all files in a folder. + """ + try: + self._send_to_api( + method="post", + endpoint="delete", + json={"path": path, "recursive": recursive}, + ) + except DatabricksException as e: + # This is not really an exception, it just means + # not everything was deleted so far + if e.error_code == "PARTIAL_DELETE": + self.rm(path=path, recursive=recursive) + elif e.error_code == "IO_ERROR": + # Using the same exception as the os module would use here + raise OSError(e.message) + + raise e + self.invalidate_cache(self._parent(path)) + + def mv( + self, source_path, destination_path, recursive=False, maxdepth=None, **kwargs + ): + """ + Move a source to a destination path. + + A note from the original [databricks API manual] + (https://docs.databricks.com/dev-tools/api/latest/dbfs.html#move). + + When moving a large number of files the API call will time out after + approximately 60s, potentially resulting in partially moved data. + Therefore, for operations that move more than 10k files, we strongly + discourage using the DBFS REST API. + + Parameters + ---------- + source_path: str + From where to move (absolute path) + destination_path: str + To where to move (absolute path) + recursive: bool + Not implemented to far. + maxdepth: + Not implemented to far. + """ + if recursive: + raise NotImplementedError + if maxdepth: + raise NotImplementedError + + try: + self._send_to_api( + method="post", + endpoint="move", + json={"source_path": source_path, "destination_path": destination_path}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) + elif e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) + + raise e + self.invalidate_cache(self._parent(source_path)) + self.invalidate_cache(self._parent(destination_path)) + + def _open(self, path, mode="rb", block_size="default", **kwargs): + """ + Overwrite the base class method to make sure to create a DBFile. + All arguments are copied from the base method. + + Only the default blocksize is allowed. + """ + return DatabricksFile(self, path, mode=mode, block_size=block_size, **kwargs) + + def _send_to_api(self, method, endpoint, json): + """ + Send the given json to the DBFS API + using a get or post request (specified by the argument `method`). + + Parameters + ---------- + method: str + Which http method to use for communication; "get" or "post". + endpoint: str + Where to send the request to (last part of the API URL) + json: dict + Dictionary of information to send + """ + if method == "post": + session_call = self.session.post + elif method == "get": + session_call = self.session.get + else: + raise ValueError(f"Do not understand method {method}") + + url = urllib.parse.urljoin(f"https://{self.instance}/api/2.0/dbfs/", endpoint) + + r = session_call(url, json=json) + + # The DBFS API will return a json, also in case of an exception. + # We want to preserve this information as good as possible. + try: + r.raise_for_status() + except requests.HTTPError as e: + # try to extract json error message + # if that fails, fall back to the original exception + try: + exception_json = e.response.json() + except Exception: + raise e + + raise DatabricksException(**exception_json) + + return r.json() + + def _create_handle(self, path, overwrite=True): + """ + Internal function to create a handle, which can be used to + write blocks of a file to DBFS. + A handle has a unique identifier which needs to be passed + whenever written during this transaction. + The handle is active for 10 minutes - after that a new + write transaction needs to be created. + Make sure to close the handle after you are finished. + + Parameters + ---------- + path: str + Absolute path for this file. + overwrite: bool + If a file already exist at this location, either overwrite + it or raise an exception. + """ + try: + r = self._send_to_api( + method="post", + endpoint="create", + json={"path": path, "overwrite": overwrite}, + ) + return r["handle"] + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) + + raise e + + def _close_handle(self, handle): + """ + Close a handle, which was opened by :func:`_create_handle`. + + Parameters + ---------- + handle: str + Which handle to close. + """ + try: + self._send_to_api(method="post", endpoint="close", json={"handle": handle}) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) + + raise e + + def _add_data(self, handle, data): + """ + Upload data to an already opened file handle + (opened by :func:`_create_handle`). + The maximal allowed data size is 1MB after + conversion to base64. + Remember to close the handle when you are finished. + + Parameters + ---------- + handle: str + Which handle to upload data to. + data: bytes + Block of data to add to the handle. + """ + data = base64.b64encode(data).decode() + try: + self._send_to_api( + method="post", + endpoint="add-block", + json={"handle": handle, "data": data}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) + elif e.error_code == "MAX_BLOCK_SIZE_EXCEEDED": + raise ValueError(e.message) + + raise e + + def _get_data(self, path, start, end): + """ + Download data in bytes from a given absolute path in a block + from [start, start+length]. + The maximum number of allowed bytes to read is 1MB. + + Parameters + ---------- + path: str + Absolute path to download data from + start: int + Start position of the block + end: int + End position of the block + """ + try: + r = self._send_to_api( + method="get", + endpoint="read", + json={"path": path, "offset": start, "length": end - start}, + ) + return base64.b64decode(r["data"]) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) + elif e.error_code in ["INVALID_PARAMETER_VALUE", "MAX_READ_SIZE_EXCEEDED"]: + raise ValueError(e.message) + + raise e + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class DatabricksFile(AbstractBufferedFile): + """ + Helper class for files referenced in the DatabricksFileSystem. + """ + + DEFAULT_BLOCK_SIZE = 1 * 2**20 # only allowed block size + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + """ + Create a new instance of the DatabricksFile. + + The blocksize needs to be the default one. + """ + if block_size is None or block_size == "default": + block_size = self.DEFAULT_BLOCK_SIZE + + assert ( + block_size == self.DEFAULT_BLOCK_SIZE + ), f"Only the default block size is allowed, not {block_size}" + + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options or {}, + **kwargs, + ) + + def _initiate_upload(self): + """Internal function to start a file upload""" + self.handle = self.fs._create_handle(self.path) + + def _upload_chunk(self, final=False): + """Internal function to add a chunk of data to a started upload""" + self.buffer.seek(0) + data = self.buffer.getvalue() + + data_chunks = [ + data[start:end] for start, end in self._to_sized_blocks(len(data)) + ] + + for data_chunk in data_chunks: + self.fs._add_data(handle=self.handle, data=data_chunk) + + if final: + self.fs._close_handle(handle=self.handle) + return True + + def _fetch_range(self, start, end): + """Internal function to download a block of data""" + return_buffer = b"" + length = end - start + for chunk_start, chunk_end in self._to_sized_blocks(length, start): + return_buffer += self.fs._get_data( + path=self.path, start=chunk_start, end=chunk_end + ) + + return return_buffer + + def _to_sized_blocks(self, length, start=0): + """Helper function to split a range from 0 to total_length into bloksizes""" + end = start + length + for data_chunk in range(start, end, self.blocksize): + data_start = data_chunk + data_end = min(end, data_chunk + self.blocksize) + yield data_start, data_end diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/ftp.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/ftp.py new file mode 100644 index 0000000000000000000000000000000000000000..415f4844952f362188561a3e41425d364a115400 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/ftp.py @@ -0,0 +1,385 @@ +import os +import sys +import uuid +import warnings +from ftplib import FTP, Error, error_perm +from typing import Any + +from ..spec import AbstractBufferedFile, AbstractFileSystem +from ..utils import infer_storage_options, isfilelike + + +class FTPFileSystem(AbstractFileSystem): + """A filesystem over classic FTP""" + + root_marker = "/" + cachable = False + protocol = "ftp" + + def __init__( + self, + host, + port=21, + username=None, + password=None, + acct=None, + block_size=None, + tempdir=None, + timeout=30, + encoding="utf-8", + **kwargs, + ): + """ + You can use _get_kwargs_from_urls to get some kwargs from + a reasonable FTP url. + + Authentication will be anonymous if username/password are not + given. + + Parameters + ---------- + host: str + The remote server name/ip to connect to + port: int + Port to connect with + username: str or None + If authenticating, the user's identifier + password: str of None + User's password on the server, if using + acct: str or None + Some servers also need an "account" string for auth + block_size: int or None + If given, the read-ahead or write buffer size. + tempdir: str + Directory on remote to put temporary files when in a transaction + timeout: int + Timeout of the ftp connection in seconds + encoding: str + Encoding to use for directories and filenames in FTP connection + """ + super().__init__(**kwargs) + self.host = host + self.port = port + self.tempdir = tempdir or "/tmp" + self.cred = username, password, acct + self.timeout = timeout + self.encoding = encoding + if block_size is not None: + self.blocksize = block_size + else: + self.blocksize = 2**16 + self._connect() + + def _connect(self): + if sys.version_info >= (3, 9): + self.ftp = FTP(timeout=self.timeout, encoding=self.encoding) + elif self.encoding: + warnings.warn("`encoding` not supported for python<3.9, ignoring") + self.ftp = FTP(timeout=self.timeout) + else: + self.ftp = FTP(timeout=self.timeout) + self.ftp.connect(self.host, self.port) + self.ftp.login(*self.cred) + + @classmethod + def _strip_protocol(cls, path): + return "/" + infer_storage_options(path)["path"].lstrip("/").rstrip("/") + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + return out + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + out = [] + if path not in self.dircache: + try: + try: + out = [ + (fn, details) + for (fn, details) in self.ftp.mlsd(path) + if fn not in [".", ".."] + and details["type"] not in ["pdir", "cdir"] + ] + except error_perm: + out = _mlsd2(self.ftp, path) # Not platform independent + for fn, details in out: + if path == "/": + path = "" # just for forming the names, below + details["name"] = "/".join([path, fn.lstrip("/")]) + if details["type"] == "file": + details["size"] = int(details["size"]) + else: + details["size"] = 0 + if details["type"] == "dir": + details["type"] = "directory" + self.dircache[path] = out + except Error: + try: + info = self.info(path) + if info["type"] == "file": + out = [(path, info)] + except (Error, IndexError): + raise FileNotFoundError(path) + files = self.dircache.get(path, out) + if not detail: + return sorted([fn for fn, details in files]) + return [details for fn, details in files] + + def info(self, path, **kwargs): + # implement with direct method + path = self._strip_protocol(path) + if path == "/": + # special case, since this dir has no real entry + return {"name": "/", "size": 0, "type": "directory"} + files = self.ls(self._parent(path).lstrip("/"), True) + try: + out = [f for f in files if f["name"] == path][0] + except IndexError: + raise FileNotFoundError(path) + return out + + def get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + if not os.path.exists(lpath): + os.mkdir(lpath) + return + if isfilelike(lpath): + outfile = lpath + else: + outfile = open(lpath, "wb") + + def cb(x): + outfile.write(x) + + self.ftp.retrbinary( + f"RETR {rpath}", + blocksize=self.blocksize, + callback=cb, + ) + if not isfilelike(lpath): + outfile.close() + + def cat_file(self, path, start=None, end=None, **kwargs): + if end is not None: + return super().cat_file(path, start, end, **kwargs) + out = [] + + def cb(x): + out.append(x) + + try: + self.ftp.retrbinary( + f"RETR {path}", + blocksize=self.blocksize, + rest=start, + callback=cb, + ) + except (Error, error_perm) as orig_exc: + raise FileNotFoundError(path) from orig_exc + return b"".join(out) + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + autocommit=True, + **kwargs, + ): + path = self._strip_protocol(path) + block_size = block_size or self.blocksize + return FTPFile( + self, + path, + mode=mode, + block_size=block_size, + tempdir=self.tempdir, + autocommit=autocommit, + cache_options=cache_options, + ) + + def _rm(self, path): + path = self._strip_protocol(path) + self.ftp.delete(path) + self.invalidate_cache(self._parent(path)) + + def rm(self, path, recursive=False, maxdepth=None): + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(paths): + if self.isfile(p): + self.rm_file(p) + else: + self.rmdir(p) + + def mkdir(self, path: str, create_parents: bool = True, **kwargs: Any) -> None: + path = self._strip_protocol(path) + parent = self._parent(path) + if parent != self.root_marker and not self.exists(parent) and create_parents: + self.mkdir(parent, create_parents=create_parents) + + self.ftp.mkd(path) + self.invalidate_cache(self._parent(path)) + + def makedirs(self, path: str, exist_ok: bool = False) -> None: + path = self._strip_protocol(path) + if self.exists(path): + # NB: "/" does not "exist" as it has no directory entry + if not exist_ok: + raise FileExistsError(f"{path} exists without `exist_ok`") + # exists_ok=True -> no-op + else: + self.mkdir(path, create_parents=True) + + def rmdir(self, path): + path = self._strip_protocol(path) + self.ftp.rmd(path) + self.invalidate_cache(self._parent(path)) + + def mv(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + self.ftp.rename(path1, path2) + self.invalidate_cache(self._parent(path1)) + self.invalidate_cache(self._parent(path2)) + + def __del__(self): + self.ftp.close() + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class TransferDone(Exception): + """Internal exception to break out of transfer""" + + pass + + +class FTPFile(AbstractBufferedFile): + """Interact with a remote FTP file with read/write buffering""" + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + if not autocommit: + self.target = self.path + self.path = "/".join([kwargs["tempdir"], str(uuid.uuid4())]) + + def commit(self): + self.fs.mv(self.path, self.target) + + def discard(self): + self.fs.rm(self.path) + + def _fetch_range(self, start, end): + """Get bytes between given byte limits + + Implemented by raising an exception in the fetch callback when the + number of bytes received reaches the requested amount. + + Will fail if the server does not respect the REST command on + retrieve requests. + """ + out = [] + total = [0] + + def callback(x): + total[0] += len(x) + if total[0] > end - start: + out.append(x[: (end - start) - total[0]]) + if end < self.size: + raise TransferDone + else: + out.append(x) + + if total[0] == end - start and end < self.size: + raise TransferDone + + try: + self.fs.ftp.retrbinary( + f"RETR {self.path}", + blocksize=self.blocksize, + rest=start, + callback=callback, + ) + except TransferDone: + try: + # stop transfer, we got enough bytes for this block + self.fs.ftp.abort() + self.fs.ftp.getmultiline() + except Error: + self.fs._connect() + + return b"".join(out) + + def _upload_chunk(self, final=False): + self.buffer.seek(0) + self.fs.ftp.storbinary( + f"STOR {self.path}", self.buffer, blocksize=self.blocksize, rest=self.offset + ) + return True + + +def _mlsd2(ftp, path="."): + """ + Fall back to using `dir` instead of `mlsd` if not supported. + + This parses a Linux style `ls -l` response to `dir`, but the response may + be platform dependent. + + Parameters + ---------- + ftp: ftplib.FTP + path: str + Expects to be given path, but defaults to ".". + """ + lines = [] + minfo = [] + ftp.dir(path, lines.append) + for line in lines: + split_line = line.split() + if len(split_line) < 9: + continue + this = ( + split_line[-1], + { + "modify": " ".join(split_line[5:8]), + "unix.owner": split_line[2], + "unix.group": split_line[3], + "unix.mode": split_line[0], + "size": split_line[4], + }, + ) + if "d" == this[1]["unix.mode"][0]: + this[1]["type"] = "dir" + else: + this[1]["type"] = "file" + minfo.append(this) + return minfo diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/libarchive.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/libarchive.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6f145352e1989e0477e259be02d8d7f4d729e2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/libarchive.py @@ -0,0 +1,213 @@ +from contextlib import contextmanager +from ctypes import ( + CFUNCTYPE, + POINTER, + c_int, + c_longlong, + c_void_p, + cast, + create_string_buffer, +) + +import libarchive +import libarchive.ffi as ffi + +from fsspec import open_files +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.implementations.memory import MemoryFile +from fsspec.utils import DEFAULT_BLOCK_SIZE + +# Libarchive requires seekable files or memory only for certain archive +# types. However, since we read the directory first to cache the contents +# and also allow random access to any file, the file-like object needs +# to be seekable no matter what. + +# Seek call-backs (not provided in the libarchive python wrapper) +SEEK_CALLBACK = CFUNCTYPE(c_longlong, c_int, c_void_p, c_longlong, c_int) +read_set_seek_callback = ffi.ffi( + "read_set_seek_callback", [ffi.c_archive_p, SEEK_CALLBACK], c_int, ffi.check_int +) +new_api = hasattr(ffi, "NO_OPEN_CB") + + +@contextmanager +def custom_reader(file, format_name="all", filter_name="all", block_size=ffi.page_size): + """Read an archive from a seekable file-like object. + + The `file` object must support the standard `readinto` and 'seek' methods. + """ + buf = create_string_buffer(block_size) + buf_p = cast(buf, c_void_p) + + def read_func(archive_p, context, ptrptr): + # readinto the buffer, returns number of bytes read + length = file.readinto(buf) + # write the address of the buffer into the pointer + ptrptr = cast(ptrptr, POINTER(c_void_p)) + ptrptr[0] = buf_p + # tell libarchive how much data was written into the buffer + return length + + def seek_func(archive_p, context, offset, whence): + file.seek(offset, whence) + # tell libarchvie the current position + return file.tell() + + read_cb = ffi.READ_CALLBACK(read_func) + seek_cb = SEEK_CALLBACK(seek_func) + + if new_api: + open_cb = ffi.NO_OPEN_CB + close_cb = ffi.NO_CLOSE_CB + else: + open_cb = libarchive.read.OPEN_CALLBACK(ffi.VOID_CB) + close_cb = libarchive.read.CLOSE_CALLBACK(ffi.VOID_CB) + + with libarchive.read.new_archive_read(format_name, filter_name) as archive_p: + read_set_seek_callback(archive_p, seek_cb) + ffi.read_open(archive_p, None, open_cb, read_cb, close_cb) + yield libarchive.read.ArchiveRead(archive_p) + + +class LibArchiveFileSystem(AbstractArchiveFileSystem): + """Compressed archives as a file-system (read-only) + + Supports the following formats: + tar, pax , cpio, ISO9660, zip, mtree, shar, ar, raw, xar, lha/lzh, rar + Microsoft CAB, 7-Zip, WARC + + See the libarchive documentation for further restrictions. + https://www.libarchive.org/ + + Keeps file object open while instance lives. It only works in seekable + file-like objects. In case the filesystem does not support this kind of + file object, it is recommended to cache locally. + + This class is pickleable, but not necessarily thread-safe (depends on the + platform). See libarchive documentation for details. + """ + + root_marker = "" + protocol = "libarchive" + cachable = False + + def __init__( + self, + fo="", + mode="r", + target_protocol=None, + target_options=None, + block_size=DEFAULT_BLOCK_SIZE, + **kwargs, + ): + """ + Parameters + ---------- + fo: str or file-like + Contains ZIP, and must exist. If a str, will fetch file using + :meth:`~fsspec.open_files`, which must return one file exactly. + mode: str + Currently, only 'r' accepted + target_protocol: str (optional) + If ``fo`` is a string, this value can be used to override the + FS protocol inferred from a URL + target_options: dict (optional) + Kwargs passed when instantiating the target FS, if ``fo`` is + a string. + """ + super().__init__(self, **kwargs) + if mode != "r": + raise ValueError("Only read from archive files accepted") + if isinstance(fo, str): + files = open_files(fo, protocol=target_protocol, **(target_options or {})) + if len(files) != 1: + raise ValueError( + f'Path "{fo}" did not resolve to exactly one file: "{files}"' + ) + fo = files[0] + self.of = fo + self.fo = fo.__enter__() # the whole instance is a context + self.block_size = block_size + self.dir_cache = None + + @contextmanager + def _open_archive(self): + self.fo.seek(0) + with custom_reader(self.fo, block_size=self.block_size) as arc: + yield arc + + @classmethod + def _strip_protocol(cls, path): + # file paths are always relative to the archive root + return super()._strip_protocol(path).lstrip("/") + + def _get_dirs(self): + fields = { + "name": "pathname", + "size": "size", + "created": "ctime", + "mode": "mode", + "uid": "uid", + "gid": "gid", + "mtime": "mtime", + } + + if self.dir_cache is not None: + return + + self.dir_cache = {} + list_names = [] + with self._open_archive() as arc: + for entry in arc: + if not entry.isdir and not entry.isfile: + # Skip symbolic links, fifo entries, etc. + continue + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(set(entry.name)) + } + ) + f = {key: getattr(entry, fields[key]) for key in fields} + f["type"] = "directory" if entry.isdir else "file" + list_names.append(entry.name) + + self.dir_cache[f["name"]] = f + # libarchive does not seem to return an entry for the directories (at least + # not in all formats), so get the directories names from the files names + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(list_names) + } + ) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if mode != "rb": + raise NotImplementedError + + data = bytes() + with self._open_archive() as arc: + for entry in arc: + if entry.pathname != path: + continue + + if entry.size == 0: + # empty file, so there are no blocks + break + + for block in entry.get_blocks(entry.size): + data = block + break + else: + raise ValueError + return MemoryFile(fs=self, path=path, data=data) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/local.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/local.py new file mode 100644 index 0000000000000000000000000000000000000000..17f96c1b8c0ab0d1ced7a279e7f05a3df22d2861 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/local.py @@ -0,0 +1,418 @@ +import datetime +import io +import logging +import os +import os.path as osp +import re +import shutil +import stat +import tempfile + +from fsspec import AbstractFileSystem +from fsspec.compression import compr +from fsspec.core import get_compression +from fsspec.utils import isfilelike, stringify_path + +logger = logging.getLogger("fsspec.local") + + +class LocalFileSystem(AbstractFileSystem): + """Interface to files on local storage + + Parameters + ---------- + auto_mkdir: bool + Whether, when opening a file, the directory containing it should + be created (if it doesn't already exist). This is assumed by pyarrow + code. + """ + + root_marker = "/" + protocol = "file", "local" + local_file = True + + def __init__(self, auto_mkdir=False, **kwargs): + super().__init__(**kwargs) + self.auto_mkdir = auto_mkdir + + @property + def fsid(self): + return "local" + + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if self.exists(path): + raise FileExistsError(path) + if create_parents: + self.makedirs(path, exist_ok=True) + else: + os.mkdir(path, **kwargs) + + def makedirs(self, path, exist_ok=False): + path = self._strip_protocol(path) + os.makedirs(path, exist_ok=exist_ok) + + def rmdir(self, path): + path = self._strip_protocol(path) + os.rmdir(path) + + def ls(self, path, detail=False, **kwargs): + path = self._strip_protocol(path) + info = self.info(path) + if info["type"] == "directory": + with os.scandir(path) as it: + infos = [self.info(f) for f in it] + else: + infos = [info] + + if not detail: + return [i["name"] for i in infos] + return infos + + def info(self, path, **kwargs): + if isinstance(path, os.DirEntry): + # scandir DirEntry + out = path.stat(follow_symlinks=False) + link = path.is_symlink() + if path.is_dir(follow_symlinks=False): + t = "directory" + elif path.is_file(follow_symlinks=False): + t = "file" + else: + t = "other" + path = self._strip_protocol(path.path) + else: + # str or path-like + path = self._strip_protocol(path) + out = os.stat(path, follow_symlinks=False) + link = stat.S_ISLNK(out.st_mode) + if link: + out = os.stat(path, follow_symlinks=True) + if stat.S_ISDIR(out.st_mode): + t = "directory" + elif stat.S_ISREG(out.st_mode): + t = "file" + else: + t = "other" + result = { + "name": path, + "size": out.st_size, + "type": t, + "created": out.st_ctime, + "islink": link, + } + for field in ["mode", "uid", "gid", "mtime", "ino", "nlink"]: + result[field] = getattr(out, f"st_{field}") + if result["islink"]: + result["destination"] = os.readlink(path) + try: + out2 = os.stat(path, follow_symlinks=True) + result["size"] = out2.st_size + except OSError: + result["size"] = 0 + return result + + def lexists(self, path, **kwargs): + return osp.lexists(path) + + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + if self.auto_mkdir: + self.makedirs(self._parent(path2), exist_ok=True) + if self.isfile(path1): + shutil.copyfile(path1, path2) + elif self.isdir(path1): + self.mkdirs(path2, exist_ok=True) + else: + raise FileNotFoundError(path1) + + def get_file(self, path1, path2, callback=None, **kwargs): + if isfilelike(path2): + with open(path1, "rb") as f: + shutil.copyfileobj(f, path2) + else: + return self.cp_file(path1, path2, **kwargs) + + def put_file(self, path1, path2, callback=None, **kwargs): + return self.cp_file(path1, path2, **kwargs) + + def mv_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + shutil.move(path1, path2) + + def link(self, src, dst, **kwargs): + src = self._strip_protocol(src) + dst = self._strip_protocol(dst) + os.link(src, dst, **kwargs) + + def symlink(self, src, dst, **kwargs): + src = self._strip_protocol(src) + dst = self._strip_protocol(dst) + os.symlink(src, dst, **kwargs) + + def islink(self, path) -> bool: + return os.path.islink(self._strip_protocol(path)) + + def rm_file(self, path): + os.remove(self._strip_protocol(path)) + + def rm(self, path, recursive=False, maxdepth=None): + if not isinstance(path, list): + path = [path] + + for p in path: + p = self._strip_protocol(p).rstrip("/") + if self.isdir(p): + if not recursive: + raise ValueError("Cannot delete directory, set recursive=True") + if osp.abspath(p) == os.getcwd(): + raise ValueError("Cannot delete current working directory") + shutil.rmtree(p) + else: + os.remove(p) + + def unstrip_protocol(self, name): + name = self._strip_protocol(name) # normalise for local/win/... + return f"file://{name}" + + def _open(self, path, mode="rb", block_size=None, **kwargs): + path = self._strip_protocol(path) + if self.auto_mkdir and "w" in mode: + self.makedirs(self._parent(path), exist_ok=True) + return LocalFileOpener(path, mode, fs=self, **kwargs) + + def touch(self, path, truncate=True, **kwargs): + path = self._strip_protocol(path) + if self.auto_mkdir: + self.makedirs(self._parent(path), exist_ok=True) + if self.exists(path): + os.utime(path, None) + else: + open(path, "a").close() + if truncate: + os.truncate(path, 0) + + def created(self, path): + info = self.info(path=path) + return datetime.datetime.fromtimestamp( + info["created"], tz=datetime.timezone.utc + ) + + def modified(self, path): + info = self.info(path=path) + return datetime.datetime.fromtimestamp(info["mtime"], tz=datetime.timezone.utc) + + @classmethod + def _parent(cls, path): + path = cls._strip_protocol(path).rstrip("/") + if "/" in path: + return path.rsplit("/", 1)[0] + else: + return cls.root_marker + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("file://"): + path = path[7:] + elif path.startswith("file:"): + path = path[5:] + elif path.startswith("local://"): + path = path[8:] + elif path.startswith("local:"): + path = path[6:] + return make_path_posix(path).rstrip("/") or cls.root_marker + + def _isfilestore(self): + # Inheriting from DaskFileSystem makes this False (S3, etc. were) + # the original motivation. But we are a posix-like file system. + # See https://github.com/dask/dask/issues/5526 + return True + + def chmod(self, path, mode): + path = stringify_path(path) + return os.chmod(path, mode) + + +def make_path_posix(path, sep=os.sep): + """Make path generic""" + if isinstance(path, (list, set, tuple)): + return type(path)(make_path_posix(p) for p in path) + if "~" in path: + path = osp.expanduser(path) + if sep == "/": + # most common fast case for posix + if path.startswith("/"): + return path + if path.startswith("./"): + path = path[2:] + return f"{os.getcwd()}/{path}" + if ( + (sep not in path and "/" not in path) + or (sep == "/" and not path.startswith("/")) + or (sep == "\\" and ":" not in path and not path.startswith("\\\\")) + ): + # relative path like "path" or "rel\\path" (win) or rel/path" + if os.sep == "\\": + # abspath made some more '\\' separators + return make_path_posix(osp.abspath(path)) + else: + return f"{os.getcwd()}/{path}" + if path.startswith("file://"): + path = path[7:] + if re.match("/[A-Za-z]:", path): + # for windows file URI like "file:///C:/folder/file" + # or "file:///C:\\dir\\file" + path = path[1:].replace("\\", "/").replace("//", "/") + if path.startswith("\\\\"): + # special case for windows UNC/DFS-style paths, do nothing, + # just flip the slashes around (case below does not work!) + return path.replace("\\", "/") + if re.match("[A-Za-z]:", path): + # windows full path like "C:\\local\\path" + return path.lstrip("\\").replace("\\", "/").replace("//", "/") + if path.startswith("\\"): + # windows network path like "\\server\\path" + return "/" + path.lstrip("\\").replace("\\", "/").replace("//", "/") + return path + + +def trailing_sep(path): + """Return True if the path ends with a path separator. + + A forward slash is always considered a path separator, even on Operating + Systems that normally use a backslash. + """ + # TODO: if all incoming paths were posix-compliant then separator would + # always be a forward slash, simplifying this function. + # See https://github.com/fsspec/filesystem_spec/pull/1250 + return path.endswith(os.sep) or (os.altsep is not None and path.endswith(os.altsep)) + + +class LocalFileOpener(io.IOBase): + def __init__( + self, path, mode, autocommit=True, fs=None, compression=None, **kwargs + ): + logger.debug("open file: %s", path) + self.path = path + self.mode = mode + self.fs = fs + self.f = None + self.autocommit = autocommit + self.compression = get_compression(path, compression) + self.blocksize = io.DEFAULT_BUFFER_SIZE + self._open() + + def _open(self): + if self.f is None or self.f.closed: + if self.autocommit or "w" not in self.mode: + self.f = open(self.path, mode=self.mode) + if self.compression: + compress = compr[self.compression] + self.f = compress(self.f, mode=self.mode) + else: + # TODO: check if path is writable? + i, name = tempfile.mkstemp() + os.close(i) # we want normal open and normal buffered file + self.temp = name + self.f = open(name, mode=self.mode) + if "w" not in self.mode: + self.size = self.f.seek(0, 2) + self.f.seek(0) + self.f.size = self.size + + def _fetch_range(self, start, end): + # probably only used by cached FS + if "r" not in self.mode: + raise ValueError + self._open() + self.f.seek(start) + return self.f.read(end - start) + + def __setstate__(self, state): + self.f = None + loc = state.pop("loc", None) + self.__dict__.update(state) + if "r" in state["mode"]: + self.f = None + self._open() + self.f.seek(loc) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("f") + if "r" in self.mode: + d["loc"] = self.f.tell() + else: + if not self.f.closed: + raise ValueError("Cannot serialise open write-mode local file") + return d + + def commit(self): + if self.autocommit: + raise RuntimeError("Can only commit if not already set to autocommit") + shutil.move(self.temp, self.path) + + def discard(self): + if self.autocommit: + raise RuntimeError("Cannot discard if set to autocommit") + os.remove(self.temp) + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return "r" not in self.mode + + def read(self, *args, **kwargs): + return self.f.read(*args, **kwargs) + + def write(self, *args, **kwargs): + return self.f.write(*args, **kwargs) + + def tell(self, *args, **kwargs): + return self.f.tell(*args, **kwargs) + + def seek(self, *args, **kwargs): + return self.f.seek(*args, **kwargs) + + def seekable(self, *args, **kwargs): + return self.f.seekable(*args, **kwargs) + + def readline(self, *args, **kwargs): + return self.f.readline(*args, **kwargs) + + def readlines(self, *args, **kwargs): + return self.f.readlines(*args, **kwargs) + + def close(self): + return self.f.close() + + def truncate(self, size=None) -> int: + return self.f.truncate(size) + + @property + def closed(self): + return self.f.closed + + def fileno(self): + return self.raw.fileno() + + def flush(self) -> None: + self.f.flush() + + def __iter__(self): + return self.f.__iter__() + + def __getattr__(self, item): + return getattr(self.f, item) + + def __enter__(self): + self._incontext = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._incontext = False + self.f.__exit__(exc_type, exc_value, traceback) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/memory.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..7320df2c359548458be3b34e20e06d251a16b2e1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/memory.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from errno import ENOTEMPTY +from io import BytesIO +from typing import Any, ClassVar + +from fsspec import AbstractFileSystem + +logger = logging.getLogger("fsspec.memoryfs") + + +class MemoryFileSystem(AbstractFileSystem): + """A filesystem based on a dict of BytesIO objects + + This is a global filesystem so instances of this class all point to the same + in memory filesystem. + """ + + store: ClassVar[dict[str, Any]] = {} # global, do not overwrite! + pseudo_dirs = [""] # global, do not overwrite! + protocol = "memory" + root_marker = "/" + + @classmethod + def _strip_protocol(cls, path): + if path.startswith("memory://"): + path = path[len("memory://") :] + if "::" in path or "://" in path: + return path.rstrip("/") + path = path.lstrip("/").rstrip("/") + return "/" + path if path else "" + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + if path in self.store: + # there is a key with this exact name + if not detail: + return [path] + return [ + { + "name": path, + "size": self.store[path].size, + "type": "file", + "created": self.store[path].created.timestamp(), + } + ] + paths = set() + starter = path + "/" + out = [] + for p2 in tuple(self.store): + if p2.startswith(starter): + if "/" not in p2[len(starter) :]: + # exact child + out.append( + { + "name": p2, + "size": self.store[p2].size, + "type": "file", + "created": self.store[p2].created.timestamp(), + } + ) + elif len(p2) > len(starter): + # implied child directory + ppath = starter + p2[len(starter) :].split("/", 1)[0] + if ppath not in paths: + out = out or [] + out.append( + { + "name": ppath, + "size": 0, + "type": "directory", + } + ) + paths.add(ppath) + for p2 in self.pseudo_dirs: + if p2.startswith(starter): + if "/" not in p2[len(starter) :]: + # exact child pdir + if p2 not in paths: + out.append({"name": p2, "size": 0, "type": "directory"}) + paths.add(p2) + else: + # directory implied by deeper pdir + ppath = starter + p2[len(starter) :].split("/", 1)[0] + if ppath not in paths: + out.append({"name": ppath, "size": 0, "type": "directory"}) + paths.add(ppath) + if not out: + if path in self.pseudo_dirs: + # empty dir + return [] + raise FileNotFoundError(path) + if detail: + return out + return sorted([f["name"] for f in out]) + + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if path in self.store or path in self.pseudo_dirs: + raise FileExistsError(path) + if self._parent(path).strip("/") and self.isfile(self._parent(path)): + raise NotADirectoryError(self._parent(path)) + if create_parents and self._parent(path).strip("/"): + try: + self.mkdir(self._parent(path), create_parents, **kwargs) + except FileExistsError: + pass + if path and path not in self.pseudo_dirs: + self.pseudo_dirs.append(path) + + def makedirs(self, path, exist_ok=False): + try: + self.mkdir(path, create_parents=True) + except FileExistsError: + if not exist_ok: + raise + + def pipe_file(self, path, value, **kwargs): + """Set the bytes of given file + + Avoids copies of the data if possible + """ + self.open(path, "wb", data=value) + + def rmdir(self, path): + path = self._strip_protocol(path) + if path == "": + # silently avoid deleting FS root + return + if path in self.pseudo_dirs: + if not self.ls(path): + self.pseudo_dirs.remove(path) + else: + raise OSError(ENOTEMPTY, "Directory not empty", path) + else: + raise FileNotFoundError(path) + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + if path in self.pseudo_dirs or any( + p.startswith(path + "/") for p in list(self.store) + self.pseudo_dirs + ): + return { + "name": path, + "size": 0, + "type": "directory", + } + elif path in self.store: + filelike = self.store[path] + return { + "name": path, + "size": filelike.size, + "type": "file", + "created": getattr(filelike, "created", None), + } + else: + raise FileNotFoundError(path) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if path in self.pseudo_dirs: + raise IsADirectoryError(path) + parent = path + while len(parent) > 1: + parent = self._parent(parent) + if self.isfile(parent): + raise FileExistsError(parent) + if mode in ["rb", "ab", "r+b"]: + if path in self.store: + f = self.store[path] + if mode == "ab": + # position at the end of file + f.seek(0, 2) + else: + # position at the beginning of file + f.seek(0) + return f + else: + raise FileNotFoundError(path) + elif mode == "wb": + m = MemoryFile(self, path, kwargs.get("data")) + if not self._intrans: + m.commit() + return m + else: + name = self.__class__.__name__ + raise ValueError(f"unsupported file mode for {name}: {mode!r}") + + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + if self.isfile(path1): + self.store[path2] = MemoryFile( + self, path2, self.store[path1].getvalue() + ) # implicit copy + elif self.isdir(path1): + if path2 not in self.pseudo_dirs: + self.pseudo_dirs.append(path2) + else: + raise FileNotFoundError(path1) + + def cat_file(self, path, start=None, end=None, **kwargs): + path = self._strip_protocol(path) + try: + return bytes(self.store[path].getbuffer()[start:end]) + except KeyError: + raise FileNotFoundError(path) + + def _rm(self, path): + path = self._strip_protocol(path) + try: + del self.store[path] + except KeyError as e: + raise FileNotFoundError(path) from e + + def modified(self, path): + path = self._strip_protocol(path) + try: + return self.store[path].modified + except KeyError: + raise FileNotFoundError(path) + + def created(self, path): + path = self._strip_protocol(path) + try: + return self.store[path].created + except KeyError: + raise FileNotFoundError(path) + + def rm(self, path, recursive=False, maxdepth=None): + if isinstance(path, str): + path = self._strip_protocol(path) + else: + path = [self._strip_protocol(p) for p in path] + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(paths): + # If the expanded path doesn't exist, it is only because the expanded + # path was a directory that does not exist in self.pseudo_dirs. This + # is possible if you directly create files without making the + # directories first. + if not self.exists(p): + continue + if self.isfile(p): + self.rm_file(p) + else: + self.rmdir(p) + + +class MemoryFile(BytesIO): + """A BytesIO which can't close and works as a context manager + + Can initialise with data. Each path should only be active once at any moment. + + No need to provide fs, path if auto-committing (default) + """ + + def __init__(self, fs=None, path=None, data=None): + logger.debug("open file %s", path) + self.fs = fs + self.path = path + self.created = datetime.now(tz=timezone.utc) + self.modified = datetime.now(tz=timezone.utc) + if data: + super().__init__(data) + self.seek(0) + + @property + def size(self): + return self.getbuffer().nbytes + + def __enter__(self): + return self + + def close(self): + pass + + def discard(self): + pass + + def commit(self): + self.fs.store[self.path] = self + self.modified = datetime.now(tz=timezone.utc) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/sftp.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/sftp.py new file mode 100644 index 0000000000000000000000000000000000000000..77f7b370cd246f9a9bfd34141afc3edd728d13c3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/sftp.py @@ -0,0 +1,180 @@ +import datetime +import logging +import os +import types +import uuid +from stat import S_ISDIR, S_ISLNK + +import paramiko + +from .. import AbstractFileSystem +from ..utils import infer_storage_options + +logger = logging.getLogger("fsspec.sftp") + + +class SFTPFileSystem(AbstractFileSystem): + """Files over SFTP/SSH + + Peer-to-peer filesystem over SSH using paramiko. + + Note: if using this with the ``open`` or ``open_files``, with full URLs, + there is no way to tell if a path is relative, so all paths are assumed + to be absolute. + """ + + protocol = "sftp", "ssh" + + def __init__(self, host, **ssh_kwargs): + """ + + Parameters + ---------- + host: str + Hostname or IP as a string + temppath: str + Location on the server to put files, when within a transaction + ssh_kwargs: dict + Parameters passed on to connection. See details in + https://docs.paramiko.org/en/3.3/api/client.html#paramiko.client.SSHClient.connect + May include port, username, password... + """ + if self._cached: + return + super().__init__(**ssh_kwargs) + self.temppath = ssh_kwargs.pop("temppath", "/tmp") # remote temp directory + self.host = host + self.ssh_kwargs = ssh_kwargs + self._connect() + + def _connect(self): + logger.debug("Connecting to SFTP server %s", self.host) + self.client = paramiko.SSHClient() + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.client.connect(self.host, **self.ssh_kwargs) + self.ftp = self.client.open_sftp() + + @classmethod + def _strip_protocol(cls, path): + return infer_storage_options(path)["path"] + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + return out + + def mkdir(self, path, create_parents=True, mode=511): + logger.debug("Creating folder %s", path) + if self.exists(path): + raise FileExistsError(f"File exists: {path}") + + if create_parents: + self.makedirs(path) + else: + self.ftp.mkdir(path, mode) + + def makedirs(self, path, exist_ok=False, mode=511): + if self.exists(path) and not exist_ok: + raise FileExistsError(f"File exists: {path}") + + parts = path.split("/") + new_path = "/" if path[:1] == "/" else "" + + for part in parts: + if part: + new_path = f"{new_path}/{part}" if new_path else part + if not self.exists(new_path): + self.ftp.mkdir(new_path, mode) + + def rmdir(self, path): + logger.debug("Removing folder %s", path) + self.ftp.rmdir(path) + + def info(self, path): + stat = self._decode_stat(self.ftp.stat(path)) + stat["name"] = path + return stat + + @staticmethod + def _decode_stat(stat, parent_path=None): + if S_ISDIR(stat.st_mode): + t = "directory" + elif S_ISLNK(stat.st_mode): + t = "link" + else: + t = "file" + out = { + "name": "", + "size": stat.st_size, + "type": t, + "uid": stat.st_uid, + "gid": stat.st_gid, + "time": datetime.datetime.fromtimestamp( + stat.st_atime, tz=datetime.timezone.utc + ), + "mtime": datetime.datetime.fromtimestamp( + stat.st_mtime, tz=datetime.timezone.utc + ), + } + if parent_path: + out["name"] = "/".join([parent_path.rstrip("/"), stat.filename]) + return out + + def ls(self, path, detail=False): + logger.debug("Listing folder %s", path) + stats = [self._decode_stat(stat, path) for stat in self.ftp.listdir_iter(path)] + if detail: + return stats + else: + paths = [stat["name"] for stat in stats] + return sorted(paths) + + def put(self, lpath, rpath, callback=None, **kwargs): + logger.debug("Put file %s into %s", lpath, rpath) + self.ftp.put(lpath, rpath) + + def get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + else: + self.ftp.get(self._strip_protocol(rpath), lpath) + + def _open(self, path, mode="rb", block_size=None, **kwargs): + """ + block_size: int or None + If 0, no buffering, if 1, line buffering, if >1, buffer that many + bytes, if None use default from paramiko. + """ + logger.debug("Opening file %s", path) + if kwargs.get("autocommit", True) is False: + # writes to temporary file, move on commit + path2 = "/".join([self.temppath, str(uuid.uuid4())]) + f = self.ftp.open(path2, mode, bufsize=block_size if block_size else -1) + f.temppath = path2 + f.targetpath = path + f.fs = self + f.commit = types.MethodType(commit_a_file, f) + f.discard = types.MethodType(discard_a_file, f) + else: + f = self.ftp.open(path, mode, bufsize=block_size if block_size else -1) + return f + + def _rm(self, path): + if self.isdir(path): + self.ftp.rmdir(path) + else: + self.ftp.remove(path) + + def mv(self, old, new): + logger.debug("Renaming %s into %s", old, new) + self.ftp.posix_rename(old, new) + + +def commit_a_file(self): + self.fs.mv(self.temppath, self.targetpath) + + +def discard_a_file(self): + self.fs._rm(self.temppath) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/tar.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/tar.py new file mode 100644 index 0000000000000000000000000000000000000000..412e5ba4d2cdea7db090dc96412e697909a38d78 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/tar.py @@ -0,0 +1,124 @@ +import logging +import tarfile + +import fsspec +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.compression import compr +from fsspec.utils import infer_compression + +typemap = {b"0": "file", b"5": "directory"} + +logger = logging.getLogger("tar") + + +class TarFileSystem(AbstractArchiveFileSystem): + """Compressed Tar archives as a file-system (read-only) + + Supports the following formats: + tar.gz, tar.bz2, tar.xz + """ + + root_marker = "" + protocol = "tar" + cachable = False + + def __init__( + self, + fo="", + index_store=None, + target_options=None, + target_protocol=None, + compression=None, + **kwargs, + ): + super().__init__(**kwargs) + target_options = target_options or {} + + if isinstance(fo, str): + self.of = fsspec.open(fo, protocol=target_protocol, **target_options) + fo = self.of.open() # keep the reference + + # Try to infer compression. + if compression is None: + name = None + + # Try different ways to get hold of the filename. `fo` might either + # be a `fsspec.LocalFileOpener`, an `io.BufferedReader` or an + # `fsspec.AbstractFileSystem` instance. + try: + # Amended io.BufferedReader or similar. + # This uses a "protocol extension" where original filenames are + # propagated to archive-like filesystems in order to let them + # infer the right compression appropriately. + if hasattr(fo, "original"): + name = fo.original + + # fsspec.LocalFileOpener + elif hasattr(fo, "path"): + name = fo.path + + # io.BufferedReader + elif hasattr(fo, "name"): + name = fo.name + + # fsspec.AbstractFileSystem + elif hasattr(fo, "info"): + name = fo.info()["name"] + + except Exception as ex: + logger.warning( + f"Unable to determine file name, not inferring compression: {ex}" + ) + + if name is not None: + compression = infer_compression(name) + logger.info(f"Inferred compression {compression} from file name {name}") + + if compression is not None: + # TODO: tarfile already implements compression with modes like "'r:gz'", + # but then would seek to offset in the file work? + fo = compr[compression](fo) + + self._fo_ref = fo + self.fo = fo # the whole instance is a context + self.tar = tarfile.TarFile(fileobj=self.fo) + self.dir_cache = None + + self.index_store = index_store + self.index = None + self._index() + + def _index(self): + # TODO: load and set saved index, if exists + out = {} + for ti in self.tar: + info = ti.get_info() + info["type"] = typemap.get(info["type"], "file") + name = ti.get_info()["name"].rstrip("/") + out[name] = (info, ti.offset_data) + + self.index = out + # TODO: save index to self.index_store here, if set + + def _get_dirs(self): + if self.dir_cache is not None: + return + + # This enables ls to get directories as children as well as files + self.dir_cache = { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(self.tar.getnames()) + } + for member in self.tar.getmembers(): + info = member.get_info() + info["name"] = info["name"].rstrip("/") + info["type"] = typemap.get(info["type"], "file") + self.dir_cache[info["name"]] = info + + def _open(self, path, mode="rb", **kwargs): + if mode != "rb": + raise ValueError("Read-only filesystem implementation") + details, offset = self.index[path] + if details["type"] != "file": + raise ValueError("Can only handle regular files") + return self.tar.extractfile(path) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/webhdfs.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/webhdfs.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6b901d19e5b9e268b36545fc1332f0f098de76 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec/implementations/webhdfs.py @@ -0,0 +1,486 @@ +# https://hadoop.apache.org/docs/r1.0.4/webhdfs.html + +import logging +import os +import secrets +import shutil +import tempfile +import uuid +from contextlib import suppress +from urllib.parse import quote + +import requests + +from ..spec import AbstractBufferedFile, AbstractFileSystem +from ..utils import infer_storage_options, tokenize + +logger = logging.getLogger("webhdfs") + + +class WebHDFS(AbstractFileSystem): + """ + Interface to HDFS over HTTP using the WebHDFS API. Supports also HttpFS gateways. + + Four auth mechanisms are supported: + + insecure: no auth is done, and the user is assumed to be whoever they + say they are (parameter ``user``), or a predefined value such as + "dr.who" if not given + spnego: when kerberos authentication is enabled, auth is negotiated by + requests_kerberos https://github.com/requests/requests-kerberos . + This establishes a session based on existing kinit login and/or + specified principal/password; parameters are passed with ``kerb_kwargs`` + token: uses an existing Hadoop delegation token from another secured + service. Indeed, this client can also generate such tokens when + not insecure. Note that tokens expire, but can be renewed (by a + previously specified user) and may allow for proxying. + basic-auth: used when both parameter ``user`` and parameter ``password`` + are provided. + + """ + + tempdir = str(tempfile.gettempdir()) + protocol = "webhdfs", "webHDFS" + + def __init__( + self, + host, + port=50070, + kerberos=False, + token=None, + user=None, + password=None, + proxy_to=None, + kerb_kwargs=None, + data_proxy=None, + use_https=False, + session_cert=None, + session_verify=True, + **kwargs, + ): + """ + Parameters + ---------- + host: str + Name-node address + port: int + Port for webHDFS + kerberos: bool + Whether to authenticate with kerberos for this connection + token: str or None + If given, use this token on every call to authenticate. A user + and user-proxy may be encoded in the token and should not be also + given + user: str or None + If given, assert the user name to connect with + password: str or None + If given, assert the password to use for basic auth. If password + is provided, user must be provided also + proxy_to: str or None + If given, the user has the authority to proxy, and this value is + the user in who's name actions are taken + kerb_kwargs: dict + Any extra arguments for HTTPKerberosAuth, see + ``_ + data_proxy: dict, callable or None + If given, map data-node addresses. This can be necessary if the + HDFS cluster is behind a proxy, running on Docker or otherwise has + a mismatch between the host-names given by the name-node and the + address by which to refer to them from the client. If a dict, + maps host names ``host->data_proxy[host]``; if a callable, full + URLs are passed, and function must conform to + ``url->data_proxy(url)``. + use_https: bool + Whether to connect to the Name-node using HTTPS instead of HTTP + session_cert: str or Tuple[str, str] or None + Path to a certificate file, or tuple of (cert, key) files to use + for the requests.Session + session_verify: str, bool or None + Path to a certificate file to use for verifying the requests.Session. + kwargs + """ + if self._cached: + return + super().__init__(**kwargs) + self.url = ( + f"{'https' if use_https else 'http'}://{host}:{port}/webhdfs/v1" # noqa + ) + self.kerb = kerberos + self.kerb_kwargs = kerb_kwargs or {} + self.pars = {} + self.proxy = data_proxy or {} + if token is not None: + if user is not None or proxy_to is not None: + raise ValueError( + "If passing a delegation token, must not set " + "user or proxy_to, as these are encoded in the" + " token" + ) + self.pars["delegation"] = token + self.user = user + self.password = password + + if password is not None: + if user is None: + raise ValueError( + "If passing a password, the user must also be" + "set in order to set up the basic-auth" + ) + else: + if user is not None: + self.pars["user.name"] = user + + if proxy_to is not None: + self.pars["doas"] = proxy_to + if kerberos and user is not None: + raise ValueError( + "If using Kerberos auth, do not specify the " + "user, this is handled by kinit." + ) + + self.session_cert = session_cert + self.session_verify = session_verify + + self._connect() + + self._fsid = f"webhdfs_{tokenize(host, port)}" + + @property + def fsid(self): + return self._fsid + + def _connect(self): + self.session = requests.Session() + + if self.session_cert: + self.session.cert = self.session_cert + + self.session.verify = self.session_verify + + if self.kerb: + from requests_kerberos import HTTPKerberosAuth + + self.session.auth = HTTPKerberosAuth(**self.kerb_kwargs) + + if self.user is not None and self.password is not None: + from requests.auth import HTTPBasicAuth + + self.session.auth = HTTPBasicAuth(self.user, self.password) + + def _call(self, op, method="get", path=None, data=None, redirect=True, **kwargs): + url = self._apply_proxy(self.url + quote(path or "", safe="/=")) + args = kwargs.copy() + args.update(self.pars) + args["op"] = op.upper() + logger.debug("sending %s with %s", url, method) + out = self.session.request( + method=method.upper(), + url=url, + params=args, + data=data, + allow_redirects=redirect, + ) + if out.status_code in [400, 401, 403, 404, 500]: + try: + err = out.json() + msg = err["RemoteException"]["message"] + exp = err["RemoteException"]["exception"] + except (ValueError, KeyError): + pass + else: + if exp in ["IllegalArgumentException", "UnsupportedOperationException"]: + raise ValueError(msg) + elif exp in ["SecurityException", "AccessControlException"]: + raise PermissionError(msg) + elif exp in ["FileNotFoundException"]: + raise FileNotFoundError(msg) + else: + raise RuntimeError(msg) + out.raise_for_status() + return out + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + replication=None, + permissions=None, + **kwargs, + ): + """ + + Parameters + ---------- + path: str + File location + mode: str + 'rb', 'wb', etc. + block_size: int + Client buffer size for read-ahead or write buffer + autocommit: bool + If False, writes to temporary file that only gets put in final + location upon commit + replication: int + Number of copies of file on the cluster, write mode only + permissions: str or int + posix permissions, write mode only + kwargs + + Returns + ------- + WebHDFile instance + """ + block_size = block_size or self.blocksize + return WebHDFile( + self, + path, + mode=mode, + block_size=block_size, + tempdir=self.tempdir, + autocommit=autocommit, + replication=replication, + permissions=permissions, + ) + + @staticmethod + def _process_info(info): + info["type"] = info["type"].lower() + info["size"] = info["length"] + return info + + @classmethod + def _strip_protocol(cls, path): + return infer_storage_options(path)["path"] + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + if "username" in out: + out["user"] = out.pop("username") + return out + + def info(self, path): + out = self._call("GETFILESTATUS", path=path) + info = out.json()["FileStatus"] + info["name"] = path + return self._process_info(info) + + def ls(self, path, detail=False): + out = self._call("LISTSTATUS", path=path) + infos = out.json()["FileStatuses"]["FileStatus"] + for info in infos: + self._process_info(info) + info["name"] = path.rstrip("/") + "/" + info["pathSuffix"] + if detail: + return sorted(infos, key=lambda i: i["name"]) + else: + return sorted(info["name"] for info in infos) + + def content_summary(self, path): + """Total numbers of files, directories and bytes under path""" + out = self._call("GETCONTENTSUMMARY", path=path) + return out.json()["ContentSummary"] + + def ukey(self, path): + """Checksum info of file, giving method and result""" + out = self._call("GETFILECHECKSUM", path=path, redirect=False) + if "Location" in out.headers: + location = self._apply_proxy(out.headers["Location"]) + out2 = self.session.get(location) + out2.raise_for_status() + return out2.json()["FileChecksum"] + else: + out.raise_for_status() + return out.json()["FileChecksum"] + + def home_directory(self): + """Get user's home directory""" + out = self._call("GETHOMEDIRECTORY") + return out.json()["Path"] + + def get_delegation_token(self, renewer=None): + """Retrieve token which can give the same authority to other uses + + Parameters + ---------- + renewer: str or None + User who may use this token; if None, will be current user + """ + if renewer: + out = self._call("GETDELEGATIONTOKEN", renewer=renewer) + else: + out = self._call("GETDELEGATIONTOKEN") + t = out.json()["Token"] + if t is None: + raise ValueError("No token available for this user/security context") + return t["urlString"] + + def renew_delegation_token(self, token): + """Make token live longer. Returns new expiry time""" + out = self._call("RENEWDELEGATIONTOKEN", method="put", token=token) + return out.json()["long"] + + def cancel_delegation_token(self, token): + """Stop the token from being useful""" + self._call("CANCELDELEGATIONTOKEN", method="put", token=token) + + def chmod(self, path, mod): + """Set the permission at path + + Parameters + ---------- + path: str + location to set (file or directory) + mod: str or int + posix epresentation or permission, give as oct string, e.g, '777' + or 0o777 + """ + self._call("SETPERMISSION", method="put", path=path, permission=mod) + + def chown(self, path, owner=None, group=None): + """Change owning user and/or group""" + kwargs = {} + if owner is not None: + kwargs["owner"] = owner + if group is not None: + kwargs["group"] = group + self._call("SETOWNER", method="put", path=path, **kwargs) + + def set_replication(self, path, replication): + """ + Set file replication factor + + Parameters + ---------- + path: str + File location (not for directories) + replication: int + Number of copies of file on the cluster. Should be smaller than + number of data nodes; normally 3 on most systems. + """ + self._call("SETREPLICATION", path=path, method="put", replication=replication) + + def mkdir(self, path, **kwargs): + self._call("MKDIRS", method="put", path=path) + + def makedirs(self, path, exist_ok=False): + if exist_ok is False and self.exists(path): + raise FileExistsError(path) + self.mkdir(path) + + def mv(self, path1, path2, **kwargs): + self._call("RENAME", method="put", path=path1, destination=path2) + + def rm(self, path, recursive=False, **kwargs): + self._call( + "DELETE", + method="delete", + path=path, + recursive="true" if recursive else "false", + ) + + def rm_file(self, path, **kwargs): + self.rm(path) + + def cp_file(self, lpath, rpath, **kwargs): + with self.open(lpath) as lstream: + tmp_fname = "/".join([self._parent(rpath), f".tmp.{secrets.token_hex(16)}"]) + # Perform an atomic copy (stream to a temporary file and + # move it to the actual destination). + try: + with self.open(tmp_fname, "wb") as rstream: + shutil.copyfileobj(lstream, rstream) + self.mv(tmp_fname, rpath) + except BaseException: # noqa + with suppress(FileNotFoundError): + self.rm(tmp_fname) + raise + + def _apply_proxy(self, location): + if self.proxy and callable(self.proxy): + location = self.proxy(location) + elif self.proxy: + # as a dict + for k, v in self.proxy.items(): + location = location.replace(k, v, 1) + return location + + +class WebHDFile(AbstractBufferedFile): + """A file living in HDFS over webHDFS""" + + def __init__(self, fs, path, **kwargs): + super().__init__(fs, path, **kwargs) + kwargs = kwargs.copy() + if kwargs.get("permissions", None) is None: + kwargs.pop("permissions", None) + if kwargs.get("replication", None) is None: + kwargs.pop("replication", None) + self.permissions = kwargs.pop("permissions", 511) + tempdir = kwargs.pop("tempdir") + if kwargs.pop("autocommit", False) is False: + self.target = self.path + self.path = os.path.join(tempdir, str(uuid.uuid4())) + + def _upload_chunk(self, final=False): + """Write one part of a multi-block file upload + + Parameters + ========== + final: bool + This is the last block, so should complete file, if + self.autocommit is True. + """ + out = self.fs.session.post( + self.location, + data=self.buffer.getvalue(), + headers={"content-type": "application/octet-stream"}, + ) + out.raise_for_status() + return True + + def _initiate_upload(self): + """Create remote file/upload""" + kwargs = self.kwargs.copy() + if "a" in self.mode: + op, method = "APPEND", "POST" + else: + op, method = "CREATE", "PUT" + kwargs["overwrite"] = "true" + out = self.fs._call(op, method, self.path, redirect=False, **kwargs) + location = self.fs._apply_proxy(out.headers["Location"]) + if "w" in self.mode: + # create empty file to append to + out2 = self.fs.session.put( + location, headers={"content-type": "application/octet-stream"} + ) + out2.raise_for_status() + # after creating empty file, change location to append to + out2 = self.fs._call("APPEND", "POST", self.path, redirect=False, **kwargs) + self.location = self.fs._apply_proxy(out2.headers["Location"]) + + def _fetch_range(self, start, end): + start = max(start, 0) + end = min(self.size, end) + if start >= end or start >= self.size: + return b"" + out = self.fs._call( + "OPEN", path=self.path, offset=start, length=end - start, redirect=False + ) + out.raise_for_status() + if "Location" in out.headers: + location = out.headers["Location"] + out2 = self.fs.session.get(self.fs._apply_proxy(location)) + return out2.content + else: + return out.content + + def commit(self): + self.fs.mv(self.path, self.target) + + def discard(self): + self.fs.rm(self.path) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/expintegrals.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/expintegrals.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a7f64b414f3a5c2b6120507defa62d403c24bd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/expintegrals.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/qfunctions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/qfunctions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3dd60d6637488de170785c7634feff51fe691a0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/qfunctions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/signals.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/signals.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a5d9c9e7508dc4b2ce46aa929c518f2d4ab549b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/__pycache__/signals.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/qfunctions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/qfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..5a20e53a8b6fa0d8fbc9ad098614d2694998f49a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/functions/qfunctions.py @@ -0,0 +1,280 @@ +from .functions import defun, defun_wrapped + +@defun +def qp(ctx, a, q=None, n=None, **kwargs): + r""" + Evaluates the q-Pochhammer symbol (or q-rising factorial) + + .. math :: + + (a; q)_n = \prod_{k=0}^{n-1} (1-a q^k) + + where `n = \infty` is permitted if `|q| < 1`. Called with two arguments, + ``qp(a,q)`` computes `(a;q)_{\infty}`; with a single argument, ``qp(q)`` + computes `(q;q)_{\infty}`. The special case + + .. math :: + + \phi(q) = (q; q)_{\infty} = \prod_{k=1}^{\infty} (1-q^k) = + \sum_{k=-\infty}^{\infty} (-1)^k q^{(3k^2-k)/2} + + is also known as the Euler function, or (up to a factor `q^{-1/24}`) + the Dedekind eta function. + + **Examples** + + If `n` is a positive integer, the function amounts to a finite product:: + + >>> from mpmath import * + >>> mp.dps = 25; mp.pretty = True + >>> qp(2,3,5) + -725305.0 + >>> fprod(1-2*3**k for k in range(5)) + -725305.0 + >>> qp(2,3,0) + 1.0 + + Complex arguments are allowed:: + + >>> qp(2-1j, 0.75j) + (0.4628842231660149089976379 + 4.481821753552703090628793j) + + The regular Pochhammer symbol `(a)_n` is obtained in the + following limit as `q \to 1`:: + + >>> a, n = 4, 7 + >>> limit(lambda q: qp(q**a,q,n) / (1-q)**n, 1) + 604800.0 + >>> rf(a,n) + 604800.0 + + The Taylor series of the reciprocal Euler function gives + the partition function `P(n)`, i.e. the number of ways of writing + `n` as a sum of positive integers:: + + >>> taylor(lambda q: 1/qp(q), 0, 10) + [1.0, 1.0, 2.0, 3.0, 5.0, 7.0, 11.0, 15.0, 22.0, 30.0, 42.0] + + Special values include:: + + >>> qp(0) + 1.0 + >>> findroot(diffun(qp), -0.4) # location of maximum + -0.4112484791779547734440257 + >>> qp(_) + 1.228348867038575112586878 + + The q-Pochhammer symbol is related to the Jacobi theta functions. + For example, the following identity holds:: + + >>> q = mpf(0.5) # arbitrary + >>> qp(q) + 0.2887880950866024212788997 + >>> root(3,-2)*root(q,-24)*jtheta(2,pi/6,root(q,6)) + 0.2887880950866024212788997 + + """ + a = ctx.convert(a) + if n is None: + n = ctx.inf + else: + n = ctx.convert(n) + if n < 0: + raise ValueError("n cannot be negative") + if q is None: + q = a + else: + q = ctx.convert(q) + if n == 0: + return ctx.one + 0*(a+q) + infinite = (n == ctx.inf) + same = (a == q) + if infinite: + if abs(q) >= 1: + if same and (q == -1 or q == 1): + return ctx.zero * q + raise ValueError("q-function only defined for |q| < 1") + elif q == 0: + return ctx.one - a + maxterms = kwargs.get('maxterms', 50*ctx.prec) + if infinite and same: + # Euler's pentagonal theorem + def terms(): + t = 1 + yield t + k = 1 + x1 = q + x2 = q**2 + while 1: + yield (-1)**k * x1 + yield (-1)**k * x2 + x1 *= q**(3*k+1) + x2 *= q**(3*k+2) + k += 1 + if k > maxterms: + raise ctx.NoConvergence + return ctx.sum_accurately(terms) + # return ctx.nprod(lambda k: 1-a*q**k, [0,n-1]) + def factors(): + k = 0 + r = ctx.one + while 1: + yield 1 - a*r + r *= q + k += 1 + if k >= n: + return + if k > maxterms: + raise ctx.NoConvergence + return ctx.mul_accurately(factors) + +@defun_wrapped +def qgamma(ctx, z, q, **kwargs): + r""" + Evaluates the q-gamma function + + .. math :: + + \Gamma_q(z) = \frac{(q; q)_{\infty}}{(q^z; q)_{\infty}} (1-q)^{1-z}. + + + **Examples** + + Evaluation for real and complex arguments:: + + >>> from mpmath import * + >>> mp.dps = 25; mp.pretty = True + >>> qgamma(4,0.75) + 4.046875 + >>> qgamma(6,6) + 121226245.0 + >>> qgamma(3+4j, 0.5j) + (0.1663082382255199834630088 + 0.01952474576025952984418217j) + + The q-gamma function satisfies a functional equation similar + to that of the ordinary gamma function:: + + >>> q = mpf(0.25) + >>> z = mpf(2.5) + >>> qgamma(z+1,q) + 1.428277424823760954685912 + >>> (1-q**z)/(1-q)*qgamma(z,q) + 1.428277424823760954685912 + + """ + if abs(q) > 1: + return ctx.qgamma(z,1/q)*q**((z-2)*(z-1)*0.5) + return ctx.qp(q, q, None, **kwargs) / \ + ctx.qp(q**z, q, None, **kwargs) * (1-q)**(1-z) + +@defun_wrapped +def qfac(ctx, z, q, **kwargs): + r""" + Evaluates the q-factorial, + + .. math :: + + [n]_q! = (1+q)(1+q+q^2)\cdots(1+q+\cdots+q^{n-1}) + + or more generally + + .. math :: + + [z]_q! = \frac{(q;q)_z}{(1-q)^z}. + + **Examples** + + >>> from mpmath import * + >>> mp.dps = 25; mp.pretty = True + >>> qfac(0,0) + 1.0 + >>> qfac(4,3) + 2080.0 + >>> qfac(5,6) + 121226245.0 + >>> qfac(1+1j, 2+1j) + (0.4370556551322672478613695 + 0.2609739839216039203708921j) + + """ + if ctx.isint(z) and ctx._re(z) > 0: + n = int(ctx._re(z)) + return ctx.qp(q, q, n, **kwargs) / (1-q)**n + return ctx.qgamma(z+1, q, **kwargs) + +@defun +def qhyper(ctx, a_s, b_s, q, z, **kwargs): + r""" + Evaluates the basic hypergeometric series or hypergeometric q-series + + .. math :: + + \,_r\phi_s \left[\begin{matrix} + a_1 & a_2 & \ldots & a_r \\ + b_1 & b_2 & \ldots & b_s + \end{matrix} ; q,z \right] = + \sum_{n=0}^\infty + \frac{(a_1;q)_n, \ldots, (a_r;q)_n} + {(b_1;q)_n, \ldots, (b_s;q)_n} + \left((-1)^n q^{n\choose 2}\right)^{1+s-r} + \frac{z^n}{(q;q)_n} + + where `(a;q)_n` denotes the q-Pochhammer symbol (see :func:`~mpmath.qp`). + + **Examples** + + Evaluation works for real and complex arguments:: + + >>> from mpmath import * + >>> mp.dps = 25; mp.pretty = True + >>> qhyper([0.5], [2.25], 0.25, 4) + -0.1975849091263356009534385 + >>> qhyper([0.5], [2.25], 0.25-0.25j, 4) + (2.806330244925716649839237 + 3.568997623337943121769938j) + >>> qhyper([1+j], [2,3+0.5j], 0.25, 3+4j) + (9.112885171773400017270226 - 1.272756997166375050700388j) + + Comparing with a summation of the defining series, using + :func:`~mpmath.nsum`:: + + >>> b, q, z = 3, 0.25, 0.5 + >>> qhyper([], [b], q, z) + 0.6221136748254495583228324 + >>> nsum(lambda n: z**n / qp(q,q,n)/qp(b,q,n) * q**(n*(n-1)), [0,inf]) + 0.6221136748254495583228324 + + """ + #a_s = [ctx._convert_param(a)[0] for a in a_s] + #b_s = [ctx._convert_param(b)[0] for b in b_s] + #q = ctx._convert_param(q)[0] + a_s = [ctx.convert(a) for a in a_s] + b_s = [ctx.convert(b) for b in b_s] + q = ctx.convert(q) + z = ctx.convert(z) + r = len(a_s) + s = len(b_s) + d = 1+s-r + maxterms = kwargs.get('maxterms', 50*ctx.prec) + def terms(): + t = ctx.one + yield t + qk = 1 + k = 0 + x = 1 + while 1: + for a in a_s: + p = 1 - a*qk + t *= p + for b in b_s: + p = 1 - b*qk + if not p: + raise ValueError + t /= p + t *= z + x *= (-1)**d * qk ** d + qk *= q + t /= (1 - qk) + k += 1 + yield t * x + if k > maxterms: + raise ctx.NoConvergence + return ctx.sum_accurately(terms) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_gamma.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_gamma.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf06e8645e06f5437ee64c40c0140414e23d9a1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/extratest_gamma.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_calculus.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_calculus.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53ebef8320cffbb4a6892d0ba2b95c90658bb0e6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_calculus.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_eigen.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_eigen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b59e5ac90177b82620705556a035cfbc74c2220e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_eigen.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_power.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_power.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a5c773efd072bb87d94812ad03ac66c77cce618 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_power.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_str.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_str.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f98f391f70f2423b00bc3143c29a8339fac08702 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_str.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/extratest_zeta.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/extratest_zeta.py new file mode 100644 index 0000000000000000000000000000000000000000..582b3d9cbd956b9cdf94309e0e718371fe716101 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/extratest_zeta.py @@ -0,0 +1,30 @@ +from mpmath import zetazero +from timeit import default_timer as clock + +def test_zetazero(): + cases = [\ + (399999999, 156762524.6750591511), + (241389216, 97490234.2276711795), + (526196239, 202950727.691229534), + (542964976, 209039046.578535272), + (1048449112, 388858885.231056486), + (1048449113, 388858885.384337406), + (1048449114, 388858886.002285122), + (1048449115, 388858886.00239369), + (1048449116, 388858886.690745053) + ] + for n, v in cases: + print(n, v) + t1 = clock() + ok = zetazero(n).ae(complex(0.5,v)) + t2 = clock() + print("ok =", ok, ("(time = %s)" % round(t2-t1,3))) + print("Now computing two huge zeros (this may take hours)") + print("Computing zetazero(8637740722917)") + ok = zetazero(8637740722917).ae(complex(0.5,2124447368584.39296466152)) + print("ok =", ok) + ok = zetazero(8637740722918).ae(complex(0.5,2124447368584.39298170604)) + print("ok =", ok) + +if __name__ == "__main__": + test_zetazero() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/runtests.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/runtests.py new file mode 100644 index 0000000000000000000000000000000000000000..70fde272fdc0e05e3d8951edddca380bd36139ab --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/runtests.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +""" +python runtests.py -py + Use py.test to run tests (more useful for debugging) + +python runtests.py -coverage + Generate test coverage report. Statistics are written to /tmp + +python runtests.py -profile + Generate profile stats (this is much slower) + +python runtests.py -nogmpy + Run tests without using GMPY even if it exists + +python runtests.py -strict + Enforce extra tests in normalize() + +python runtests.py -local + Insert '../..' at the beginning of sys.path to use local mpmath + +python runtests.py -skip ... + Skip tests from the listed modules + +Additional arguments are used to filter the tests to run. Only files that have +one of the arguments in their name are executed. + +""" + +import sys, os, traceback + +profile = False +if "-profile" in sys.argv: + sys.argv.remove('-profile') + profile = True + +coverage = False +if "-coverage" in sys.argv: + sys.argv.remove('-coverage') + coverage = True + +if "-nogmpy" in sys.argv: + sys.argv.remove('-nogmpy') + os.environ['MPMATH_NOGMPY'] = 'Y' + +if "-strict" in sys.argv: + sys.argv.remove('-strict') + os.environ['MPMATH_STRICT'] = 'Y' + +if "-local" in sys.argv: + sys.argv.remove('-local') + importdir = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), + '../..')) +else: + importdir = '' + +# TODO: add a flag for this +testdir = '' + +def testit(importdir='', testdir='', exit_on_fail=False): + """Run all tests in testdir while importing from importdir.""" + if importdir: + sys.path.insert(1, importdir) + if testdir: + sys.path.insert(1, testdir) + import os.path + import mpmath + print("mpmath imported from %s" % os.path.dirname(mpmath.__file__)) + print("mpmath backend: %s" % mpmath.libmp.backend.BACKEND) + print("mpmath mp class: %s" % repr(mpmath.mp)) + print("mpmath version: %s" % mpmath.__version__) + print("Python version: %s" % sys.version) + print("") + if "-py" in sys.argv: + sys.argv.remove('-py') + import py + py.test.cmdline.main() + else: + import glob + from timeit import default_timer as clock + modules = [] + args = sys.argv[1:] + excluded = [] + if '-skip' in args: + excluded = args[args.index('-skip')+1:] + args = args[:args.index('-skip')] + # search for tests in directory of this file if not otherwise specified + if not testdir: + pattern = os.path.dirname(sys.argv[0]) + else: + pattern = testdir + if pattern: + pattern += '/' + pattern += 'test*.py' + # look for tests (respecting specified filter) + for f in glob.glob(pattern): + name = os.path.splitext(os.path.basename(f))[0] + # If run as a script, only run tests given as args, if any are given + if args and __name__ == "__main__": + ok = False + for arg in args: + if arg in name: + ok = True + break + if not ok: + continue + elif name in excluded: + continue + module = __import__(name) + priority = module.__dict__.get('priority', 100) + if priority == 666: + modules = [[priority, name, module]] + break + modules.append([priority, name, module]) + # execute tests + modules.sort() + tstart = clock() + for priority, name, module in modules: + print(name) + for f in sorted(module.__dict__.keys()): + if f.startswith('test_'): + if coverage and ('numpy' in f): + continue + sys.stdout.write(" " + f[5:].ljust(25) + " ") + t1 = clock() + try: + module.__dict__[f]() + except: + etype, evalue, trb = sys.exc_info() + if etype in (KeyboardInterrupt, SystemExit): + raise + print("") + print("TEST FAILED!") + print("") + traceback.print_exc() + if exit_on_fail: + return + t2 = clock() + print("ok " + " " + ("%.7f" % (t2-t1)) + " s") + tend = clock() + print("") + print("finished tests in " + ("%.2f" % (tend-tstart)) + " seconds") + # clean sys.path + if importdir: + sys.path.remove(importdir) + if testdir: + sys.path.remove(testdir) + +if __name__ == '__main__': + if profile: + import cProfile + cProfile.run("testit('%s', '%s')" % (importdir, testdir), sort=1) + elif coverage: + import trace + tracer = trace.Trace(ignoredirs=[sys.prefix, sys.exec_prefix], + trace=0, count=1) + tracer.run('testit(importdir, testdir)') + r = tracer.results() + r.write_results(show_missing=True, summary=True, coverdir="/tmp") + else: + testit(importdir, testdir) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_functions2.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_functions2.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2d57fcec9be0db4d921b013f24fd6a5e0e9930 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_functions2.py @@ -0,0 +1,2384 @@ +import math +import pytest +from mpmath import * + +def test_bessel(): + mp.dps = 15 + assert j0(1).ae(0.765197686557966551) + assert j0(pi).ae(-0.304242177644093864) + assert j0(1000).ae(0.0247866861524201746) + assert j0(-25).ae(0.0962667832759581162) + assert j1(1).ae(0.440050585744933516) + assert j1(pi).ae(0.284615343179752757) + assert j1(1000).ae(0.00472831190708952392) + assert j1(-25).ae(0.125350249580289905) + assert besselj(5,1).ae(0.000249757730211234431) + assert besselj(5+0j,1).ae(0.000249757730211234431) + assert besselj(5,pi).ae(0.0521411843671184747) + assert besselj(5,1000).ae(0.00502540694523318607) + assert besselj(5,-25).ae(0.0660079953984229934) + assert besselj(-3,2).ae(-0.128943249474402051) + assert besselj(-4,2).ae(0.0339957198075684341) + assert besselj(3,3+2j).ae(0.424718794929639595942 + 0.625665327745785804812j) + assert besselj(0.25,4).ae(-0.374760630804249715) + assert besselj(1+2j,3+4j).ae(0.319247428741872131 - 0.669557748880365678j) + assert (besselj(3, 10**10) * 10**5).ae(0.76765081748139204023) + assert bessely(-0.5, 0) == 0 + assert bessely(0.5, 0) == -inf + assert bessely(1.5, 0) == -inf + assert bessely(0,0) == -inf + assert bessely(-0.4, 0) == -inf + assert bessely(-0.6, 0) == inf + assert bessely(-1, 0) == inf + assert bessely(-1.4, 0) == inf + assert bessely(-1.6, 0) == -inf + assert bessely(-1, 0) == inf + assert bessely(-2, 0) == -inf + assert bessely(-3, 0) == inf + assert bessely(0.5, 0) == -inf + assert bessely(1, 0) == -inf + assert bessely(1.5, 0) == -inf + assert bessely(2, 0) == -inf + assert bessely(2.5, 0) == -inf + assert bessely(3, 0) == -inf + assert bessely(0,0.5).ae(-0.44451873350670655715) + assert bessely(1,0.5).ae(-1.4714723926702430692) + assert bessely(-1,0.5).ae(1.4714723926702430692) + assert bessely(3.5,0.5).ae(-138.86400867242488443) + assert bessely(0,3+4j).ae(4.6047596915010138655-8.8110771408232264208j) + assert bessely(0,j).ae(-0.26803248203398854876+1.26606587775200833560j) + assert (bessely(3, 10**10) * 10**5).ae(0.21755917537013204058) + assert besseli(0,0) == 1 + assert besseli(1,0) == 0 + assert besseli(2,0) == 0 + assert besseli(-1,0) == 0 + assert besseli(-2,0) == 0 + assert besseli(0,0.5).ae(1.0634833707413235193) + assert besseli(1,0.5).ae(0.25789430539089631636) + assert besseli(-1,0.5).ae(0.25789430539089631636) + assert besseli(3.5,0.5).ae(0.00068103597085793815863) + assert besseli(0,3+4j).ae(-3.3924877882755196097-1.3239458916287264815j) + assert besseli(0,j).ae(besselj(0,1)) + assert (besseli(3, 10**10) * mpf(10)**(-4342944813)).ae(4.2996028505491271875) + assert besselk(0,0) == inf + assert besselk(1,0) == inf + assert besselk(2,0) == inf + assert besselk(-1,0) == inf + assert besselk(-2,0) == inf + assert besselk(0,0.5).ae(0.92441907122766586178) + assert besselk(1,0.5).ae(1.6564411200033008937) + assert besselk(-1,0.5).ae(1.6564411200033008937) + assert besselk(3.5,0.5).ae(207.48418747548460607) + assert besselk(0,3+4j).ae(-0.007239051213570155013+0.026510418350267677215j) + assert besselk(0,j).ae(-0.13863371520405399968-1.20196971531720649914j) + assert (besselk(3, 10**10) * mpf(10)**4342944824).ae(1.1628981033356187851) + # test for issue 331, bug reported by Michael Hartmann + for n in range(10,100,10): + mp.dps = n + assert besseli(91.5,24.7708).ae("4.00830632138673963619656140653537080438462342928377020695738635559218797348548092636896796324190271316137982810144874264e-41") + +def test_bessel_zeros(): + mp.dps = 15 + assert besseljzero(0,1).ae(2.40482555769577276869) + assert besseljzero(2,1).ae(5.1356223018406825563) + assert besseljzero(1,50).ae(157.86265540193029781) + assert besseljzero(10,1).ae(14.475500686554541220) + assert besseljzero(0.5,3).ae(9.4247779607693797153) + assert besseljzero(2,1,1).ae(3.0542369282271403228) + assert besselyzero(0,1).ae(0.89357696627916752158) + assert besselyzero(2,1).ae(3.3842417671495934727) + assert besselyzero(1,50).ae(156.29183520147840108) + assert besselyzero(10,1).ae(12.128927704415439387) + assert besselyzero(0.5,3).ae(7.8539816339744830962) + assert besselyzero(2,1,1).ae(5.0025829314460639452) + +def test_hankel(): + mp.dps = 15 + assert hankel1(0,0.5).ae(0.93846980724081290423-0.44451873350670655715j) + assert hankel1(1,0.5).ae(0.2422684576748738864-1.4714723926702430692j) + assert hankel1(-1,0.5).ae(-0.2422684576748738864+1.4714723926702430692j) + assert hankel1(1.5,0.5).ae(0.0917016996256513026-2.5214655504213378514j) + assert hankel1(1.5,3+4j).ae(0.0066806866476728165382-0.0036684231610839127106j) + assert hankel2(0,0.5).ae(0.93846980724081290423+0.44451873350670655715j) + assert hankel2(1,0.5).ae(0.2422684576748738864+1.4714723926702430692j) + assert hankel2(-1,0.5).ae(-0.2422684576748738864-1.4714723926702430692j) + assert hankel2(1.5,0.5).ae(0.0917016996256513026+2.5214655504213378514j) + assert hankel2(1.5,3+4j).ae(14.783528526098567526-7.397390270853446512j) + +def test_struve(): + mp.dps = 15 + assert struveh(2,3).ae(0.74238666967748318564) + assert struveh(-2.5,3).ae(0.41271003220971599344) + assert struvel(2,3).ae(1.7476573277362782744) + assert struvel(-2.5,3).ae(1.5153394466819651377) + +def test_whittaker(): + mp.dps = 15 + assert whitm(2,3,4).ae(49.753745589025246591) + assert whitw(2,3,4).ae(14.111656223052932215) + +def test_kelvin(): + mp.dps = 15 + assert ber(2,3).ae(0.80836846563726819091) + assert ber(3,4).ae(-0.28262680167242600233) + assert ber(-3,2).ae(-0.085611448496796363669) + assert bei(2,3).ae(-0.89102236377977331571) + assert bei(-3,2).ae(-0.14420994155731828415) + assert ker(2,3).ae(0.12839126695733458928) + assert ker(-3,2).ae(-0.29802153400559142783) + assert ker(0.5,3).ae(-0.085662378535217097524) + assert kei(2,3).ae(0.036804426134164634000) + assert kei(-3,2).ae(0.88682069845786731114) + assert kei(0.5,3).ae(0.013633041571314302948) + +def test_hyper_misc(): + mp.dps = 15 + assert hyp0f1(1,0) == 1 + assert hyp1f1(1,2,0) == 1 + assert hyp1f2(1,2,3,0) == 1 + assert hyp2f1(1,2,3,0) == 1 + assert hyp2f2(1,2,3,4,0) == 1 + assert hyp2f3(1,2,3,4,5,0) == 1 + # Degenerate case: 0F0 + assert hyper([],[],0) == 1 + assert hyper([],[],-2).ae(exp(-2)) + # Degenerate case: 1F0 + assert hyper([2],[],1.5) == 4 + # + assert hyp2f1((1,3),(2,3),(5,6),mpf(27)/32).ae(1.6) + assert hyp2f1((1,4),(1,2),(3,4),mpf(80)/81).ae(1.8) + assert hyp2f1((2,3),(1,1),(3,2),(2+j)/3).ae(1.327531603558679093+0.439585080092769253j) + mp.dps = 25 + v = mpc('1.2282306665029814734863026', '-0.1225033830118305184672133') + assert hyper([(3,4),2+j,1],[1,5,j/3],mpf(1)/5+j/8).ae(v) + mp.dps = 15 + +def test_elliptic_integrals(): + mp.dps = 15 + assert ellipk(0).ae(pi/2) + assert ellipk(0.5).ae(gamma(0.25)**2/(4*sqrt(pi))) + assert ellipk(1) == inf + assert ellipk(1+0j) == inf + assert ellipk(-1).ae('1.3110287771460599052') + assert ellipk(-2).ae('1.1714200841467698589') + assert isinstance(ellipk(-2), mpf) + assert isinstance(ellipe(-2), mpf) + assert ellipk(-50).ae('0.47103424540873331679') + mp.dps = 30 + n1 = +fraction(99999,100000) + n2 = +fraction(100001,100000) + mp.dps = 15 + assert ellipk(n1).ae('7.1427724505817781901') + assert ellipk(n2).ae(mpc('7.1427417367963090109', '-1.5707923998261688019')) + assert ellipe(n1).ae('1.0000332138990829170') + v = ellipe(n2) + assert v.real.ae('0.999966786328145474069137') + assert (v.imag*10**6).ae('7.853952181727432') + assert ellipk(2).ae(mpc('1.3110287771460599052', '-1.3110287771460599052')) + assert ellipk(50).ae(mpc('0.22326753950210985451', '-0.47434723226254522087')) + assert ellipk(3+4j).ae(mpc('0.91119556380496500866', '0.63133428324134524388')) + assert ellipk(3-4j).ae(mpc('0.91119556380496500866', '-0.63133428324134524388')) + assert ellipk(-3+4j).ae(mpc('0.95357894880405122483', '0.23093044503746114444')) + assert ellipk(-3-4j).ae(mpc('0.95357894880405122483', '-0.23093044503746114444')) + assert isnan(ellipk(nan)) + assert isnan(ellipe(nan)) + assert ellipk(inf) == 0 + assert isinstance(ellipk(inf), mpc) + assert ellipk(-inf) == 0 + assert ellipk(1+0j) == inf + assert ellipe(0).ae(pi/2) + assert ellipe(0.5).ae(pi**(mpf(3)/2)/gamma(0.25)**2 +gamma(0.25)**2/(8*sqrt(pi))) + assert ellipe(1) == 1 + assert ellipe(1+0j) == 1 + assert ellipe(inf) == mpc(0,inf) + assert ellipe(-inf) == inf + assert ellipe(3+4j).ae(1.4995535209333469543-1.5778790079127582745j) + assert ellipe(3-4j).ae(1.4995535209333469543+1.5778790079127582745j) + assert ellipe(-3+4j).ae(2.5804237855343377803-0.8306096791000413778j) + assert ellipe(-3-4j).ae(2.5804237855343377803+0.8306096791000413778j) + assert ellipe(2).ae(0.59907011736779610372+0.59907011736779610372j) + assert ellipe('1e-1000000000').ae(pi/2) + assert ellipk('1e-1000000000').ae(pi/2) + assert ellipe(-pi).ae(2.4535865983838923) + mp.dps = 50 + assert ellipk(1/pi).ae('1.724756270009501831744438120951614673874904182624739673') + assert ellipe(1/pi).ae('1.437129808135123030101542922290970050337425479058225712') + assert ellipk(-10*pi).ae('0.5519067523886233967683646782286965823151896970015484512') + assert ellipe(-10*pi).ae('5.926192483740483797854383268707108012328213431657645509') + v = ellipk(pi) + assert v.real.ae('0.973089521698042334840454592642137667227167622330325225') + assert v.imag.ae('-1.156151296372835303836814390793087600271609993858798016') + v = ellipe(pi) + assert v.real.ae('0.4632848917264710404078033487934663562998345622611263332') + assert v.imag.ae('1.0637961621753130852473300451583414489944099504180510966') + mp.dps = 15 + +def test_exp_integrals(): + mp.dps = 15 + x = +e + z = e + sqrt(3)*j + assert ei(x).ae(8.21168165538361560) + assert li(x).ae(1.89511781635593676) + assert si(x).ae(1.82104026914756705) + assert ci(x).ae(0.213958001340379779) + assert shi(x).ae(4.11520706247846193) + assert chi(x).ae(4.09647459290515367) + assert fresnels(x).ae(0.437189718149787643) + assert fresnelc(x).ae(0.401777759590243012) + assert airyai(x).ae(0.0108502401568586681) + assert airybi(x).ae(8.98245748585468627) + assert ei(z).ae(3.72597969491314951 + 7.34213212314224421j) + assert li(z).ae(2.28662658112562502 + 1.50427225297269364j) + assert si(z).ae(2.48122029237669054 + 0.12684703275254834j) + assert ci(z).ae(0.169255590269456633 - 0.892020751420780353j) + assert shi(z).ae(1.85810366559344468 + 3.66435842914920263j) + assert chi(z).ae(1.86787602931970484 + 3.67777369399304159j) + assert fresnels(z/3).ae(0.034534397197008182 + 0.754859844188218737j) + assert fresnelc(z/3).ae(1.261581645990027372 + 0.417949198775061893j) + assert airyai(z).ae(-0.0162552579839056062 - 0.0018045715700210556j) + assert airybi(z).ae(-4.98856113282883371 + 2.08558537872180623j) + assert li(0) == 0.0 + assert li(1) == -inf + assert li(inf) == inf + assert isinstance(li(0.7), mpf) + assert si(inf).ae(pi/2) + assert si(-inf).ae(-pi/2) + assert ci(inf) == 0 + assert ci(0) == -inf + assert isinstance(ei(-0.7), mpf) + assert airyai(inf) == 0 + assert airybi(inf) == inf + assert airyai(-inf) == 0 + assert airybi(-inf) == 0 + assert fresnels(inf) == 0.5 + assert fresnelc(inf) == 0.5 + assert fresnels(-inf) == -0.5 + assert fresnelc(-inf) == -0.5 + assert shi(0) == 0 + assert shi(inf) == inf + assert shi(-inf) == -inf + assert chi(0) == -inf + assert chi(inf) == inf + +def test_ei(): + mp.dps = 15 + assert ei(0) == -inf + assert ei(inf) == inf + assert ei(-inf) == -0.0 + assert ei(20+70j).ae(6.1041351911152984397e6 - 2.7324109310519928872e6j) + # tests for the asymptotic expansion + # values checked with Mathematica ExpIntegralEi + mp.dps = 50 + r = ei(20000) + s = '3.8781962825045010930273870085501819470698476975019e+8681' + assert str(r) == s + r = ei(-200) + s = '-6.8852261063076355977108174824557929738368086933303e-90' + assert str(r) == s + r =ei(20000 + 10*j) + sre = '-3.255138234032069402493850638874410725961401274106e+8681' + sim = '-2.1081929993474403520785942429469187647767369645423e+8681' + assert str(r.real) == sre and str(r.imag) == sim + mp.dps = 15 + # More asymptotic expansions + assert chi(-10**6+100j).ae('1.3077239389562548386e+434288 + 7.6808956999707408158e+434287j') + assert shi(-10**6+100j).ae('-1.3077239389562548386e+434288 - 7.6808956999707408158e+434287j') + mp.dps = 15 + assert ei(10j).ae(-0.0454564330044553726+3.2291439210137706686j) + assert ei(100j).ae(-0.0051488251426104921+3.1330217936839529126j) + u = ei(fmul(10**20, j, exact=True)) + assert u.real.ae(-6.4525128526578084421345e-21, abs_eps=0, rel_eps=8*eps) + assert u.imag.ae(pi) + assert ei(-10j).ae(-0.0454564330044553726-3.2291439210137706686j) + assert ei(-100j).ae(-0.0051488251426104921-3.1330217936839529126j) + u = ei(fmul(-10**20, j, exact=True)) + assert u.real.ae(-6.4525128526578084421345e-21, abs_eps=0, rel_eps=8*eps) + assert u.imag.ae(-pi) + assert ei(10+10j).ae(-1576.1504265768517448+436.9192317011328140j) + u = ei(-10+10j) + assert u.real.ae(7.6698978415553488362543e-7, abs_eps=0, rel_eps=8*eps) + assert u.imag.ae(3.141595611735621062025) + +def test_e1(): + mp.dps = 15 + assert e1(0) == inf + assert e1(inf) == 0 + assert e1(-inf) == mpc(-inf, -pi) + assert e1(10j).ae(0.045456433004455372635 + 0.087551267423977430100j) + assert e1(100j).ae(0.0051488251426104921444 - 0.0085708599058403258790j) + assert e1(fmul(10**20, j, exact=True)).ae(6.4525128526578084421e-21 - 7.6397040444172830039e-21j, abs_eps=0, rel_eps=8*eps) + assert e1(-10j).ae(0.045456433004455372635 - 0.087551267423977430100j) + assert e1(-100j).ae(0.0051488251426104921444 + 0.0085708599058403258790j) + assert e1(fmul(-10**20, j, exact=True)).ae(6.4525128526578084421e-21 + 7.6397040444172830039e-21j, abs_eps=0, rel_eps=8*eps) + +def test_expint(): + mp.dps = 15 + assert expint(0,0) == inf + assert expint(0,1).ae(1/e) + assert expint(0,1.5).ae(2/exp(1.5)/3) + assert expint(1,1).ae(-ei(-1)) + assert expint(2,0).ae(1) + assert expint(3,0).ae(1/2.) + assert expint(4,0).ae(1/3.) + assert expint(-2, 0.5).ae(26/sqrt(e)) + assert expint(-1,-1) == 0 + assert expint(-2,-1).ae(-e) + assert expint(5.5, 0).ae(2/9.) + assert expint(2.00000001,0).ae(100000000./100000001) + assert expint(2+3j,4-j).ae(0.0023461179581675065414+0.0020395540604713669262j) + assert expint('1.01', '1e-1000').ae(99.9999999899412802) + assert expint('1.000000000001', 3.5).ae(0.00697013985754701819446) + assert expint(2,3).ae(3*ei(-3)+exp(-3)) + assert (expint(10,20)*10**10).ae(0.694439055541231353) + assert expint(3,inf) == 0 + assert expint(3.2,inf) == 0 + assert expint(3.2+2j,inf) == 0 + assert expint(1,3j).ae(-0.11962978600800032763 + 0.27785620120457163717j) + assert expint(1,3).ae(0.013048381094197037413) + assert expint(1,-3).ae(-ei(3)-pi*j) + #assert expint(3) == expint(1,3) + assert expint(1,-20).ae(-25615652.66405658882 - 3.1415926535897932385j) + assert expint(1000000,0).ae(1./999999) + assert expint(0,2+3j).ae(-0.025019798357114678171 + 0.027980439405104419040j) + assert expint(-1,2+3j).ae(-0.022411973626262070419 + 0.038058922011377716932j) + assert expint(-1.5,0) == inf + +def test_trig_integrals(): + mp.dps = 30 + assert si(mpf(1)/1000000).ae('0.000000999999999999944444444444446111') + assert ci(mpf(1)/1000000).ae('-13.2382948930629912435014366276') + assert si(10**10).ae('1.5707963267075846569685111517747537') + assert ci(10**10).ae('-4.87506025174822653785729773959e-11') + assert si(10**100).ae(pi/2) + assert (ci(10**100)*10**100).ae('-0.372376123661276688262086695553') + assert si(-3) == -si(3) + assert ci(-3).ae(ci(3) + pi*j) + # Test complex structure + mp.dps = 15 + assert mp.ci(50).ae(-0.0056283863241163054402) + assert mp.ci(50+2j).ae(-0.018378282946133067149+0.070352808023688336193j) + assert mp.ci(20j).ae(1.28078263320282943611e7+1.5707963267949j) + assert mp.ci(-2+20j).ae(-4.050116856873293505e6+1.207476188206989909e7j) + assert mp.ci(-50+2j).ae(-0.0183782829461330671+3.0712398455661049023j) + assert mp.ci(-50).ae(-0.0056283863241163054+3.1415926535897932385j) + assert mp.ci(-50-2j).ae(-0.0183782829461330671-3.0712398455661049023j) + assert mp.ci(-2-20j).ae(-4.050116856873293505e6-1.207476188206989909e7j) + assert mp.ci(-20j).ae(1.28078263320282943611e7-1.5707963267949j) + assert mp.ci(50-2j).ae(-0.018378282946133067149-0.070352808023688336193j) + assert mp.si(50).ae(1.5516170724859358947) + assert mp.si(50+2j).ae(1.497884414277228461-0.017515007378437448j) + assert mp.si(20j).ae(1.2807826332028294459e7j) + assert mp.si(-2+20j).ae(-1.20747603112735722103e7-4.050116856873293554e6j) + assert mp.si(-50+2j).ae(-1.497884414277228461-0.017515007378437448j) + assert mp.si(-50).ae(-1.5516170724859358947) + assert mp.si(-50-2j).ae(-1.497884414277228461+0.017515007378437448j) + assert mp.si(-2-20j).ae(-1.20747603112735722103e7+4.050116856873293554e6j) + assert mp.si(-20j).ae(-1.2807826332028294459e7j) + assert mp.si(50-2j).ae(1.497884414277228461+0.017515007378437448j) + assert mp.chi(50j).ae(-0.0056283863241163054+1.5707963267948966192j) + assert mp.chi(-2+50j).ae(-0.0183782829461330671+1.6411491348185849554j) + assert mp.chi(-20).ae(1.28078263320282943611e7+3.1415926535898j) + assert mp.chi(-20-2j).ae(-4.050116856873293505e6+1.20747571696809187053e7j) + assert mp.chi(-2-50j).ae(-0.0183782829461330671-1.6411491348185849554j) + assert mp.chi(-50j).ae(-0.0056283863241163054-1.5707963267948966192j) + assert mp.chi(2-50j).ae(-0.0183782829461330671-1.500443518771208283j) + assert mp.chi(20-2j).ae(-4.050116856873293505e6-1.20747603112735722951e7j) + assert mp.chi(20).ae(1.2807826332028294361e7) + assert mp.chi(2+50j).ae(-0.0183782829461330671+1.500443518771208283j) + assert mp.shi(50j).ae(1.5516170724859358947j) + assert mp.shi(-2+50j).ae(0.017515007378437448+1.497884414277228461j) + assert mp.shi(-20).ae(-1.2807826332028294459e7) + assert mp.shi(-20-2j).ae(4.050116856873293554e6-1.20747603112735722103e7j) + assert mp.shi(-2-50j).ae(0.017515007378437448-1.497884414277228461j) + assert mp.shi(-50j).ae(-1.5516170724859358947j) + assert mp.shi(2-50j).ae(-0.017515007378437448-1.497884414277228461j) + assert mp.shi(20-2j).ae(-4.050116856873293554e6-1.20747603112735722103e7j) + assert mp.shi(20).ae(1.2807826332028294459e7) + assert mp.shi(2+50j).ae(-0.017515007378437448+1.497884414277228461j) + def ae(x,y,tol=1e-12): + return abs(x-y) <= abs(y)*tol + assert fp.ci(fp.inf) == 0 + assert ae(fp.ci(fp.ninf), fp.pi*1j) + assert ae(fp.si(fp.inf), fp.pi/2) + assert ae(fp.si(fp.ninf), -fp.pi/2) + assert fp.si(0) == 0 + assert ae(fp.ci(50), -0.0056283863241163054402) + assert ae(fp.ci(50+2j), -0.018378282946133067149+0.070352808023688336193j) + assert ae(fp.ci(20j), 1.28078263320282943611e7+1.5707963267949j) + assert ae(fp.ci(-2+20j), -4.050116856873293505e6+1.207476188206989909e7j) + assert ae(fp.ci(-50+2j), -0.0183782829461330671+3.0712398455661049023j) + assert ae(fp.ci(-50), -0.0056283863241163054+3.1415926535897932385j) + assert ae(fp.ci(-50-2j), -0.0183782829461330671-3.0712398455661049023j) + assert ae(fp.ci(-2-20j), -4.050116856873293505e6-1.207476188206989909e7j) + assert ae(fp.ci(-20j), 1.28078263320282943611e7-1.5707963267949j) + assert ae(fp.ci(50-2j), -0.018378282946133067149-0.070352808023688336193j) + assert ae(fp.si(50), 1.5516170724859358947) + assert ae(fp.si(50+2j), 1.497884414277228461-0.017515007378437448j) + assert ae(fp.si(20j), 1.2807826332028294459e7j) + assert ae(fp.si(-2+20j), -1.20747603112735722103e7-4.050116856873293554e6j) + assert ae(fp.si(-50+2j), -1.497884414277228461-0.017515007378437448j) + assert ae(fp.si(-50), -1.5516170724859358947) + assert ae(fp.si(-50-2j), -1.497884414277228461+0.017515007378437448j) + assert ae(fp.si(-2-20j), -1.20747603112735722103e7+4.050116856873293554e6j) + assert ae(fp.si(-20j), -1.2807826332028294459e7j) + assert ae(fp.si(50-2j), 1.497884414277228461+0.017515007378437448j) + assert ae(fp.chi(50j), -0.0056283863241163054+1.5707963267948966192j) + assert ae(fp.chi(-2+50j), -0.0183782829461330671+1.6411491348185849554j) + assert ae(fp.chi(-20), 1.28078263320282943611e7+3.1415926535898j) + assert ae(fp.chi(-20-2j), -4.050116856873293505e6+1.20747571696809187053e7j) + assert ae(fp.chi(-2-50j), -0.0183782829461330671-1.6411491348185849554j) + assert ae(fp.chi(-50j), -0.0056283863241163054-1.5707963267948966192j) + assert ae(fp.chi(2-50j), -0.0183782829461330671-1.500443518771208283j) + assert ae(fp.chi(20-2j), -4.050116856873293505e6-1.20747603112735722951e7j) + assert ae(fp.chi(20), 1.2807826332028294361e7) + assert ae(fp.chi(2+50j), -0.0183782829461330671+1.500443518771208283j) + assert ae(fp.shi(50j), 1.5516170724859358947j) + assert ae(fp.shi(-2+50j), 0.017515007378437448+1.497884414277228461j) + assert ae(fp.shi(-20), -1.2807826332028294459e7) + assert ae(fp.shi(-20-2j), 4.050116856873293554e6-1.20747603112735722103e7j) + assert ae(fp.shi(-2-50j), 0.017515007378437448-1.497884414277228461j) + assert ae(fp.shi(-50j), -1.5516170724859358947j) + assert ae(fp.shi(2-50j), -0.017515007378437448-1.497884414277228461j) + assert ae(fp.shi(20-2j), -4.050116856873293554e6-1.20747603112735722103e7j) + assert ae(fp.shi(20), 1.2807826332028294459e7) + assert ae(fp.shi(2+50j), -0.017515007378437448+1.497884414277228461j) + +def test_airy(): + mp.dps = 15 + assert (airyai(10)*10**10).ae(1.1047532552898687) + assert (airybi(10)/10**9).ae(0.45564115354822515) + assert (airyai(1000)*10**9158).ae(9.306933063179556004) + assert (airybi(1000)/10**9154).ae(5.4077118391949465477) + assert airyai(-1000).ae(0.055971895773019918842) + assert airybi(-1000).ae(-0.083264574117080633012) + assert (airyai(100+100j)*10**188).ae(2.9099582462207032076 + 2.353013591706178756j) + assert (airybi(100+100j)/10**185).ae(1.7086751714463652039 - 3.1416590020830804578j) + +def test_hyper_0f1(): + mp.dps = 15 + v = 8.63911136507950465 + assert hyper([],[(1,3)],1.5).ae(v) + assert hyper([],[1/3.],1.5).ae(v) + assert hyp0f1(1/3.,1.5).ae(v) + assert hyp0f1((1,3),1.5).ae(v) + # Asymptotic expansion + assert hyp0f1(3,1e9).ae('4.9679055380347771271e+27455') + assert hyp0f1(3,1e9j).ae('-2.1222788784457702157e+19410 + 5.0840597555401854116e+19410j') + +def test_hyper_1f1(): + mp.dps = 15 + v = 1.2917526488617656673 + assert hyper([(1,2)],[(3,2)],0.7).ae(v) + assert hyper([(1,2)],[(3,2)],0.7+0j).ae(v) + assert hyper([0.5],[(3,2)],0.7).ae(v) + assert hyper([0.5],[1.5],0.7).ae(v) + assert hyper([0.5],[(3,2)],0.7+0j).ae(v) + assert hyper([0.5],[1.5],0.7+0j).ae(v) + assert hyper([(1,2)],[1.5+0j],0.7).ae(v) + assert hyper([0.5+0j],[1.5],0.7).ae(v) + assert hyper([0.5+0j],[1.5+0j],0.7+0j).ae(v) + assert hyp1f1(0.5,1.5,0.7).ae(v) + assert hyp1f1((1,2),1.5,0.7).ae(v) + # Asymptotic expansion + assert hyp1f1(2,3,1e10).ae('2.1555012157015796988e+4342944809') + assert (hyp1f1(2,3,1e10j)*10**10).ae(-0.97501205020039745852 - 1.7462392454512132074j) + # Shouldn't use asymptotic expansion + assert hyp1f1(-2, 1, 10000).ae(49980001) + # Bug + assert hyp1f1(1j,fraction(1,3),0.415-69.739j).ae(25.857588206024346592 + 15.738060264515292063j) + +def test_hyper_2f1(): + mp.dps = 15 + v = 1.0652207633823291032 + assert hyper([(1,2), (3,4)], [2], 0.3).ae(v) + assert hyper([(1,2), 0.75], [2], 0.3).ae(v) + assert hyper([0.5, 0.75], [2.0], 0.3).ae(v) + assert hyper([0.5, 0.75], [2.0], 0.3+0j).ae(v) + assert hyper([0.5+0j, (3,4)], [2.0], 0.3+0j).ae(v) + assert hyper([0.5+0j, (3,4)], [2.0], 0.3).ae(v) + assert hyper([0.5, (3,4)], [2.0+0j], 0.3).ae(v) + assert hyper([0.5+0j, 0.75+0j], [2.0+0j], 0.3+0j).ae(v) + v = 1.09234681096223231717 + 0.18104859169479360380j + assert hyper([(1,2),0.75+j], [2], 0.5).ae(v) + assert hyper([0.5,0.75+j], [2.0], 0.5).ae(v) + assert hyper([0.5,0.75+j], [2.0], 0.5+0j).ae(v) + assert hyper([0.5,0.75+j], [2.0+0j], 0.5+0j).ae(v) + v = 0.9625 - 0.125j + assert hyper([(3,2),-1],[4], 0.1+j/3).ae(v) + assert hyper([1.5,-1.0],[4], 0.1+j/3).ae(v) + assert hyper([1.5,-1.0],[4+0j], 0.1+j/3).ae(v) + assert hyper([1.5+0j,-1.0+0j],[4+0j], 0.1+j/3).ae(v) + v = 1.02111069501693445001 - 0.50402252613466859521j + assert hyper([(2,10),(3,10)],[(4,10)],1.5).ae(v) + assert hyper([0.2,(3,10)],[0.4+0j],1.5).ae(v) + assert hyper([0.2,(3,10)],[0.4+0j],1.5+0j).ae(v) + v = 0.76922501362865848528 + 0.32640579593235886194j + assert hyper([(2,10),(3,10)],[(4,10)],4+2j).ae(v) + assert hyper([0.2,(3,10)],[0.4+0j],4+2j).ae(v) + assert hyper([0.2,(3,10)],[(4,10)],4+2j).ae(v) + +def test_hyper_2f1_hard(): + mp.dps = 15 + # Singular cases + assert hyp2f1(2,-1,-1,3).ae(7) + assert hyp2f1(2,-1,-1,3,eliminate_all=True).ae(0.25) + assert hyp2f1(2,-2,-2,3).ae(34) + assert hyp2f1(2,-2,-2,3,eliminate_all=True).ae(0.25) + assert hyp2f1(2,-2,-3,3) == 14 + assert hyp2f1(2,-3,-2,3) == inf + assert hyp2f1(2,-1.5,-1.5,3) == 0.25 + assert hyp2f1(1,2,3,0) == 1 + assert hyp2f1(0,1,0,0) == 1 + assert hyp2f1(0,0,0,0) == 1 + assert isnan(hyp2f1(1,1,0,0)) + assert hyp2f1(2,-1,-5, 0.25+0.25j).ae(1.1+0.1j) + assert hyp2f1(2,-5,-5, 0.25+0.25j, eliminate=False).ae(163./128 + 125./128*j) + assert hyp2f1(0.7235, -1, -5, 0.3).ae(1.04341) + assert hyp2f1(0.7235, -5, -5, 0.3, eliminate=False).ae(1.2939225017815903812) + assert hyp2f1(-1,-2,4,1) == 1.5 + assert hyp2f1(1,2,-3,1) == inf + assert hyp2f1(-2,-2,1,1) == 6 + assert hyp2f1(1,-2,-4,1).ae(5./3) + assert hyp2f1(0,-6,-4,1) == 1 + assert hyp2f1(0,-3,-4,1) == 1 + assert hyp2f1(0,0,0,1) == 1 + assert hyp2f1(1,0,0,1,eliminate=False) == 1 + assert hyp2f1(1,1,0,1) == inf + assert hyp2f1(1,-6,-4,1) == inf + assert hyp2f1(-7.2,-0.5,-4.5,1) == 0 + assert hyp2f1(-7.2,-1,-2,1).ae(-2.6) + assert hyp2f1(1,-0.5,-4.5, 1) == inf + assert hyp2f1(1,0.5,-4.5, 1) == -inf + # Check evaluation on / close to unit circle + z = exp(j*pi/3) + w = (nthroot(2,3)+1)*exp(j*pi/12)/nthroot(3,4)**3 + assert hyp2f1('1/2','1/6','1/3', z).ae(w) + assert hyp2f1('1/2','1/6','1/3', z.conjugate()).ae(w.conjugate()) + assert hyp2f1(0.25, (1,3), 2, '0.999').ae(1.06826449496030635) + assert hyp2f1(0.25, (1,3), 2, '1.001').ae(1.06867299254830309446-0.00001446586793975874j) + assert hyp2f1(0.25, (1,3), 2, -1).ae(0.96656584492524351673) + assert hyp2f1(0.25, (1,3), 2, j).ae(0.99041766248982072266+0.03777135604180735522j) + assert hyp2f1(2,3,5,'0.99').ae(27.699347904322690602) + assert hyp2f1((3,2),-0.5,3,'0.99').ae(0.68403036843911661388) + assert hyp2f1(2,3,5,1j).ae(0.37290667145974386127+0.59210004902748285917j) + assert fsum([hyp2f1((7,10),(2,3),(-1,2), 0.95*exp(j*k)) for k in range(1,15)]).ae(52.851400204289452922+6.244285013912953225j) + assert fsum([hyp2f1((7,10),(2,3),(-1,2), 1.05*exp(j*k)) for k in range(1,15)]).ae(54.506013786220655330-3.000118813413217097j) + assert fsum([hyp2f1((7,10),(2,3),(-1,2), exp(j*k)) for k in range(1,15)]).ae(55.792077935955314887+1.731986485778500241j) + assert hyp2f1(2,2.5,-3.25,0.999).ae(218373932801217082543180041.33) + # Branches + assert hyp2f1(1,1,2,1.01).ae(4.5595744415723676911-3.1104877758314784539j) + assert hyp2f1(1,1,2,1.01+0.1j).ae(2.4149427480552782484+1.4148224796836938829j) + assert hyp2f1(1,1,2,3+4j).ae(0.14576709331407297807+0.48379185417980360773j) + assert hyp2f1(1,1,2,4).ae(-0.27465307216702742285 - 0.78539816339744830962j) + assert hyp2f1(1,1,2,-4).ae(0.40235947810852509365) + # Other: + # Cancellation with a large parameter involved (bug reported on sage-devel) + assert hyp2f1(112, (51,10), (-9,10), -0.99999).ae(-1.6241361047970862961e-24, abs_eps=0, rel_eps=eps*16) + +def test_hyper_3f2_etc(): + assert hyper([1,2,3],[1.5,8],-1).ae(0.67108992351533333030) + assert hyper([1,2,3,4],[5,6,7], -1).ae(0.90232988035425506008) + assert hyper([1,2,3],[1.25,5], 1).ae(28.924181329701905701) + assert hyper([1,2,3,4],[5,6,7],5).ae(1.5192307344006649499-1.1529845225075537461j) + assert hyper([1,2,3,4,5],[6,7,8,9],-1).ae(0.96288759462882357253) + assert hyper([1,2,3,4,5],[6,7,8,9],1).ae(1.0428697385885855841) + assert hyper([1,2,3,4,5],[6,7,8,9],5).ae(1.33980653631074769423-0.07143405251029226699j) + assert hyper([1,2.79,3.08,4.37],[5.2,6.1,7.3],5).ae(1.0996321464692607231-1.7748052293979985001j) + assert hyper([1,1,1],[1,2],1) == inf + assert hyper([1,1,1],[2,(101,100)],1).ae(100.01621213528313220) + # slow -- covered by doctests + #assert hyper([1,1,1],[2,3],0.9999).ae(1.2897972005319693905) + +def test_hyper_u(): + mp.dps = 15 + assert hyperu(2,-3,0).ae(0.05) + assert hyperu(2,-3.5,0).ae(4./99) + assert hyperu(2,0,0) == 0.5 + assert hyperu(-5,1,0) == -120 + assert hyperu(-5,2,0) == inf + assert hyperu(-5,-2,0) == 0 + assert hyperu(7,7,3).ae(0.00014681269365593503986) #exp(3)*gammainc(-6,3) + assert hyperu(2,-3,4).ae(0.011836478100271995559) + assert hyperu(3,4,5).ae(1./125) + assert hyperu(2,3,0.0625) == 256 + assert hyperu(-1,2,0.25+0.5j) == -1.75+0.5j + assert hyperu(0.5,1.5,7.25).ae(2/sqrt(29)) + assert hyperu(2,6,pi).ae(0.55804439825913399130) + assert (hyperu((3,2),8,100+201j)*10**4).ae(-0.3797318333856738798 - 2.9974928453561707782j) + assert (hyperu((5,2),(-1,2),-5000)*10**10).ae(-5.6681877926881664678j) + # XXX: fails because of undetected cancellation in low level series code + # Alternatively: could use asymptotic series here, if convergence test + # tweaked back to recognize this one + #assert (hyperu((5,2),(-1,2),-500)*10**7).ae(-1.82526906001593252847j) + +def test_hyper_2f0(): + mp.dps = 15 + assert hyper([1,2],[],3) == hyp2f0(1,2,3) + assert hyp2f0(2,3,7).ae(0.0116108068639728714668 - 0.0073727413865865802130j) + assert hyp2f0(2,3,0) == 1 + assert hyp2f0(0,0,0) == 1 + assert hyp2f0(-1,-1,1).ae(2) + assert hyp2f0(-4,1,1.5).ae(62.5) + assert hyp2f0(-4,1,50).ae(147029801) + assert hyp2f0(-4,1,0.0001).ae(0.99960011997600240000) + assert hyp2f0(0.5,0.25,0.001).ae(1.0001251174078538115) + assert hyp2f0(0.5,0.25,3+4j).ae(0.85548875824755163518 + 0.21636041283392292973j) + # Important: cancellation check + assert hyp2f0((1,6),(5,6),-0.02371708245126284498).ae(0.996785723120804309) + # Should be exact; polynomial case + assert hyp2f0(-2,1,0.5+0.5j,zeroprec=200) == 0 + assert hyp2f0(1,-2,0.5+0.5j,zeroprec=200) == 0 + # There used to be a bug in thresholds that made one of the following hang + for d in [15, 50, 80]: + mp.dps = d + assert hyp2f0(1.5, 0.5, 0.009).ae('1.006867007239309717945323585695344927904000945829843527398772456281301440034218290443367270629519483 + 1.238277162240704919639384945859073461954721356062919829456053965502443570466701567100438048602352623e-46j') + +def test_hyper_1f2(): + mp.dps = 15 + assert hyper([1],[2,3],4) == hyp1f2(1,2,3,4) + a1,b1,b2 = (1,10),(2,3),1./16 + assert hyp1f2(a1,b1,b2,10).ae(298.7482725554557568) + assert hyp1f2(a1,b1,b2,100).ae(224128961.48602947604) + assert hyp1f2(a1,b1,b2,1000).ae(1.1669528298622675109e+27) + assert hyp1f2(a1,b1,b2,10000).ae(2.4780514622487212192e+86) + assert hyp1f2(a1,b1,b2,100000).ae(1.3885391458871523997e+274) + assert hyp1f2(a1,b1,b2,1000000).ae('9.8851796978960318255e+867') + assert hyp1f2(a1,b1,b2,10**7).ae('1.1505659189516303646e+2746') + assert hyp1f2(a1,b1,b2,10**8).ae('1.4672005404314334081e+8685') + assert hyp1f2(a1,b1,b2,10**20).ae('3.6888217332150976493e+8685889636') + assert hyp1f2(a1,b1,b2,10*j).ae(-16.163252524618572878 - 44.321567896480184312j) + assert hyp1f2(a1,b1,b2,100*j).ae(61938.155294517848171 + 637349.45215942348739j) + assert hyp1f2(a1,b1,b2,1000*j).ae(8455057657257695958.7 + 6261969266997571510.6j) + assert hyp1f2(a1,b1,b2,10000*j).ae(-8.9771211184008593089e+60 + 4.6550528111731631456e+59j) + assert hyp1f2(a1,b1,b2,100000*j).ae(2.6398091437239324225e+193 + 4.1658080666870618332e+193j) + assert hyp1f2(a1,b1,b2,1000000*j).ae('3.5999042951925965458e+613 + 1.5026014707128947992e+613j') + assert hyp1f2(a1,b1,b2,10**7*j).ae('-8.3208715051623234801e+1939 - 3.6752883490851869429e+1941j') + assert hyp1f2(a1,b1,b2,10**8*j).ae('2.0724195707891484454e+6140 - 1.3276619482724266387e+6141j') + assert hyp1f2(a1,b1,b2,10**20*j).ae('-1.1734497974795488504e+6141851462 + 1.1498106965385471542e+6141851462j') + +def test_hyper_2f3(): + mp.dps = 15 + assert hyper([1,2],[3,4,5],6) == hyp2f3(1,2,3,4,5,6) + a1,a2,b1,b2,b3 = (1,10),(2,3),(3,10), 2, 1./16 + # Check asymptotic expansion + assert hyp2f3(a1,a2,b1,b2,b3,10).ae(128.98207160698659976) + assert hyp2f3(a1,a2,b1,b2,b3,1000).ae(6.6309632883131273141e25) + assert hyp2f3(a1,a2,b1,b2,b3,10000).ae(4.6863639362713340539e84) + assert hyp2f3(a1,a2,b1,b2,b3,100000).ae(8.6632451236103084119e271) + assert hyp2f3(a1,a2,b1,b2,b3,10**6).ae('2.0291718386574980641e865') + assert hyp2f3(a1,a2,b1,b2,b3,10**7).ae('7.7639836665710030977e2742') + assert hyp2f3(a1,a2,b1,b2,b3,10**8).ae('3.2537462584071268759e8681') + assert hyp2f3(a1,a2,b1,b2,b3,10**20).ae('1.2966030542911614163e+8685889627') + assert hyp2f3(a1,a2,b1,b2,b3,10*j).ae(-18.551602185587547854 - 13.348031097874113552j) + assert hyp2f3(a1,a2,b1,b2,b3,100*j).ae(78634.359124504488695 + 74459.535945281973996j) + assert hyp2f3(a1,a2,b1,b2,b3,1000*j).ae(597682550276527901.59 - 65136194809352613.078j) + assert hyp2f3(a1,a2,b1,b2,b3,10000*j).ae(-1.1779696326238582496e+59 + 1.2297607505213133872e+59j) + assert hyp2f3(a1,a2,b1,b2,b3,100000*j).ae(2.9844228969804380301e+191 + 7.5587163231490273296e+190j) + assert hyp2f3(a1,a2,b1,b2,b3,1000000*j).ae('7.4859161049322370311e+610 - 2.8467477015940090189e+610j') + assert hyp2f3(a1,a2,b1,b2,b3,10**7*j).ae('-1.7477645579418800826e+1938 - 1.7606522995808116405e+1938j') + assert hyp2f3(a1,a2,b1,b2,b3,10**8*j).ae('-1.6932731942958401784e+6137 - 2.4521909113114629368e+6137j') + assert hyp2f3(a1,a2,b1,b2,b3,10**20*j).ae('-2.0988815677627225449e+6141851451 + 5.7708223542739208681e+6141851452j') + +def test_hyper_2f2(): + mp.dps = 15 + assert hyper([1,2],[3,4],5) == hyp2f2(1,2,3,4,5) + a1,a2,b1,b2 = (3,10),4,(1,2),1./16 + assert hyp2f2(a1,a2,b1,b2,10).ae(448225936.3377556696) + assert hyp2f2(a1,a2,b1,b2,10000).ae('1.2012553712966636711e+4358') + assert hyp2f2(a1,a2,b1,b2,-20000).ae(-0.04182343755661214626) + assert hyp2f2(a1,a2,b1,b2,10**20).ae('1.1148680024303263661e+43429448190325182840') + +def test_orthpoly(): + mp.dps = 15 + assert jacobi(-4,2,3,0.7).ae(22800./4913) + assert jacobi(3,2,4,5.5) == 4133.125 + assert jacobi(1.5,5/6.,4,0).ae(-1.0851951434075508417) + assert jacobi(-2, 1, 2, 4).ae(-0.16) + assert jacobi(2, -1, 2.5, 4).ae(34.59375) + #assert jacobi(2, -1, 2, 4) == 28.5 + assert legendre(5, 7) == 129367 + assert legendre(0.5,0).ae(0.53935260118837935667) + assert legendre(-1,-1) == 1 + assert legendre(0,-1) == 1 + assert legendre(0, 1) == 1 + assert legendre(1, -1) == -1 + assert legendre(7, 1) == 1 + assert legendre(7, -1) == -1 + assert legendre(8,1.5).ae(15457523./32768) + assert legendre(j,-j).ae(2.4448182735671431011 + 0.6928881737669934843j) + assert chebyu(5,1) == 6 + assert chebyt(3,2) == 26 + assert legendre(3.5,-1) == inf + assert legendre(4.5,-1) == -inf + assert legendre(3.5+1j,-1) == mpc(inf,inf) + assert legendre(4.5+1j,-1) == mpc(-inf,-inf) + assert laguerre(4, -2, 3).ae(-1.125) + assert laguerre(3, 1+j, 0.5).ae(0.2291666666666666667 + 2.5416666666666666667j) + +def test_hermite(): + mp.dps = 15 + assert hermite(-2, 0).ae(0.5) + assert hermite(-1, 0).ae(0.88622692545275801365) + assert hermite(0, 0).ae(1) + assert hermite(1, 0) == 0 + assert hermite(2, 0).ae(-2) + assert hermite(0, 2).ae(1) + assert hermite(1, 2).ae(4) + assert hermite(1, -2).ae(-4) + assert hermite(2, -2).ae(14) + assert hermite(0.5, 0).ae(0.69136733903629335053) + assert hermite(9, 0) == 0 + assert hermite(4,4).ae(3340) + assert hermite(3,4).ae(464) + assert hermite(-4,4).ae(0.00018623860287512396181) + assert hermite(-3,4).ae(0.0016540169879668766270) + assert hermite(9, 2.5j).ae(13638725j) + assert hermite(9, -2.5j).ae(-13638725j) + assert hermite(9, 100).ae(511078883759363024000) + assert hermite(9, -100).ae(-511078883759363024000) + assert hermite(9, 100j).ae(512922083920643024000j) + assert hermite(9, -100j).ae(-512922083920643024000j) + assert hermite(-9.5, 2.5j).ae(-2.9004951258126778174e-6 + 1.7601372934039951100e-6j) + assert hermite(-9.5, -2.5j).ae(-2.9004951258126778174e-6 - 1.7601372934039951100e-6j) + assert hermite(-9.5, 100).ae(1.3776300722767084162e-22, abs_eps=0, rel_eps=eps) + assert hermite(-9.5, -100).ae('1.3106082028470671626e4355') + assert hermite(-9.5, 100j).ae(-9.7900218581864768430e-23 - 9.7900218581864768430e-23j, abs_eps=0, rel_eps=eps) + assert hermite(-9.5, -100j).ae(-9.7900218581864768430e-23 + 9.7900218581864768430e-23j, abs_eps=0, rel_eps=eps) + assert hermite(2+3j, -1-j).ae(851.3677063883687676 - 1496.4373467871007997j) + +def test_gegenbauer(): + mp.dps = 15 + assert gegenbauer(1,2,3).ae(12) + assert gegenbauer(2,3,4).ae(381) + assert gegenbauer(0,0,0) == 0 + assert gegenbauer(2,-1,3) == 0 + assert gegenbauer(-7, 0.5, 3).ae(8989) + assert gegenbauer(1, -0.5, 3).ae(-3) + assert gegenbauer(1, -1.5, 3).ae(-9) + assert gegenbauer(1, -0.5, 3).ae(-3) + assert gegenbauer(-0.5, -0.5, 3).ae(-2.6383553159023906245) + assert gegenbauer(2+3j, 1-j, 3+4j).ae(14.880536623203696780 + 20.022029711598032898j) + #assert gegenbauer(-2, -0.5, 3).ae(-12) + +def test_legenp(): + mp.dps = 15 + assert legenp(2,0,4) == legendre(2,4) + assert legenp(-2, -1, 0.5).ae(0.43301270189221932338) + assert legenp(-2, -1, 0.5, type=3).ae(0.43301270189221932338j) + assert legenp(-2, 1, 0.5).ae(-0.86602540378443864676) + assert legenp(2+j, 3+4j, -j).ae(134742.98773236786148 + 429782.72924463851745j) + assert legenp(2+j, 3+4j, -j, type=3).ae(802.59463394152268507 - 251.62481308942906447j) + assert legenp(2,4,3).ae(0) + assert legenp(2,4,3,type=3).ae(0) + assert legenp(2,1,0.5).ae(-1.2990381056766579701) + assert legenp(2,1,0.5,type=3).ae(1.2990381056766579701j) + assert legenp(3,2,3).ae(-360) + assert legenp(3,3,3).ae(240j*2**0.5) + assert legenp(3,4,3).ae(0) + assert legenp(0,0.5,2).ae(0.52503756790433198939 - 0.52503756790433198939j) + assert legenp(-1,-0.5,2).ae(0.60626116232846498110 + 0.60626116232846498110j) + assert legenp(-2,0.5,2).ae(1.5751127037129959682 - 1.5751127037129959682j) + assert legenp(-2,0.5,-0.5).ae(-0.85738275810499171286) + +def test_legenq(): + mp.dps = 15 + f = legenq + # Evaluation at poles + assert isnan(f(3,2,1)) + assert isnan(f(3,2,-1)) + assert isnan(f(3,2,1,type=3)) + assert isnan(f(3,2,-1,type=3)) + # Evaluation at 0 + assert f(0,1,0,type=2).ae(-1) + assert f(-2,2,0,type=2,zeroprec=200).ae(0) + assert f(1.5,3,0,type=2).ae(-2.2239343475841951023) + assert f(0,1,0,type=3).ae(j) + assert f(-2,2,0,type=3,zeroprec=200).ae(0) + assert f(1.5,3,0,type=3).ae(2.2239343475841951022*(1-1j)) + # Standard case, degree 0 + assert f(0,0,-1.5).ae(-0.8047189562170501873 + 1.5707963267948966192j) + assert f(0,0,-0.5).ae(-0.54930614433405484570) + assert f(0,0,0,zeroprec=200).ae(0) + assert f(0,0,0.5).ae(0.54930614433405484570) + assert f(0,0,1.5).ae(0.8047189562170501873 - 1.5707963267948966192j) + assert f(0,0,-1.5,type=3).ae(-0.80471895621705018730) + assert f(0,0,-0.5,type=3).ae(-0.5493061443340548457 - 1.5707963267948966192j) + assert f(0,0,0,type=3).ae(-1.5707963267948966192j) + assert f(0,0,0.5,type=3).ae(0.5493061443340548457 - 1.5707963267948966192j) + assert f(0,0,1.5,type=3).ae(0.80471895621705018730) + # Standard case, degree 1 + assert f(1,0,-1.5).ae(0.2070784343255752810 - 2.3561944901923449288j) + assert f(1,0,-0.5).ae(-0.72534692783297257715) + assert f(1,0,0).ae(-1) + assert f(1,0,0.5).ae(-0.72534692783297257715) + assert f(1,0,1.5).ae(0.2070784343255752810 - 2.3561944901923449288j) + # Standard case, degree 2 + assert f(2,0,-1.5).ae(-0.0635669991240192885 + 4.5160394395353277803j) + assert f(2,0,-0.5).ae(0.81866326804175685571) + assert f(2,0,0,zeroprec=200).ae(0) + assert f(2,0,0.5).ae(-0.81866326804175685571) + assert f(2,0,1.5).ae(0.0635669991240192885 - 4.5160394395353277803j) + # Misc orders and degrees + assert f(2,3,1.5,type=2).ae(-5.7243340223994616228j) + assert f(2,3,1.5,type=3).ae(-5.7243340223994616228) + assert f(2,3,0.5,type=2).ae(-12.316805742712016310) + assert f(2,3,0.5,type=3).ae(-12.316805742712016310j) + assert f(2,3,-1.5,type=2).ae(-5.7243340223994616228j) + assert f(2,3,-1.5,type=3).ae(5.7243340223994616228) + assert f(2,3,-0.5,type=2).ae(-12.316805742712016310) + assert f(2,3,-0.5,type=3).ae(-12.316805742712016310j) + assert f(2+3j, 3+4j, 0.5, type=3).ae(0.0016119404873235186807 - 0.0005885900510718119836j) + assert f(2+3j, 3+4j, -1.5, type=3).ae(0.008451400254138808670 + 0.020645193304593235298j) + assert f(-2.5,1,-1.5).ae(3.9553395527435335749j) + assert f(-2.5,1,-0.5).ae(1.9290561746445456908) + assert f(-2.5,1,0).ae(1.2708196271909686299) + assert f(-2.5,1,0.5).ae(-0.31584812990742202869) + assert f(-2.5,1,1.5).ae(-3.9553395527435335742 + 0.2993235655044701706j) + assert f(-2.5,1,-1.5,type=3).ae(0.29932356550447017254j) + assert f(-2.5,1,-0.5,type=3).ae(-0.3158481299074220287 - 1.9290561746445456908j) + assert f(-2.5,1,0,type=3).ae(1.2708196271909686292 - 1.2708196271909686299j) + assert f(-2.5,1,0.5,type=3).ae(1.9290561746445456907 + 0.3158481299074220287j) + assert f(-2.5,1,1.5,type=3).ae(-0.29932356550447017254) + +def test_agm(): + mp.dps = 15 + assert agm(0,0) == 0 + assert agm(0,1) == 0 + assert agm(1,1) == 1 + assert agm(7,7) == 7 + assert agm(j,j) == j + assert (1/agm(1,sqrt(2))).ae(0.834626841674073186) + assert agm(1,2).ae(1.4567910310469068692) + assert agm(1,3).ae(1.8636167832448965424) + assert agm(1,j).ae(0.599070117367796104+0.599070117367796104j) + assert agm(2) == agm(1,2) + assert agm(-3,4).ae(0.63468509766550907+1.3443087080896272j) + +def test_gammainc(): + mp.dps = 15 + assert gammainc(2,5).ae(6*exp(-5)) + assert gammainc(2,0,5).ae(1-6*exp(-5)) + assert gammainc(2,3,5).ae(-6*exp(-5)+4*exp(-3)) + assert gammainc(-2.5,-0.5).ae(-0.9453087204829418812-5.3164237738936178621j) + assert gammainc(0,2,4).ae(0.045121158298212213088) + assert gammainc(0,3).ae(0.013048381094197037413) + assert gammainc(0,2+j,1-j).ae(0.00910653685850304839-0.22378752918074432574j) + assert gammainc(0,1-j).ae(0.00028162445198141833+0.17932453503935894015j) + assert gammainc(3,4,5,True).ae(0.11345128607046320253) + assert gammainc(3.5,0,inf).ae(gamma(3.5)) + assert gammainc(-150.5,500).ae('6.9825435345798951153e-627') + assert gammainc(-150.5,800).ae('4.6885137549474089431e-788') + assert gammainc(-3.5, -20.5).ae(0.27008820585226911 - 1310.31447140574997636j) + assert gammainc(-3.5, -200.5).ae(0.27008820585226911 - 5.3264597096208368435e76j) # XXX real part + assert gammainc(0,0,2) == inf + assert gammainc(1,b=1).ae(0.6321205588285576784) + assert gammainc(3,2,2) == 0 + assert gammainc(2,3+j,3-j).ae(-0.28135485191849314194j) + assert gammainc(4+0j,1).ae(5.8860710587430771455) + # GH issue #301 + assert gammainc(-1,-1).ae(-0.8231640121031084799 + 3.1415926535897932385j) + assert gammainc(-2,-1).ae(1.7707229202810768576 - 1.5707963267948966192j) + assert gammainc(-3,-1).ae(-1.4963349162467073643 + 0.5235987755982988731j) + assert gammainc(-4,-1).ae(1.05365418617643814992 - 0.13089969389957471827j) + # Regularized upper gamma + assert isnan(gammainc(0, 0, regularized=True)) + assert gammainc(-1, 0, regularized=True) == inf + assert gammainc(1, 0, regularized=True) == 1 + assert gammainc(0, 5, regularized=True) == 0 + assert gammainc(0, 2+3j, regularized=True) == 0 + assert gammainc(0, 5000, regularized=True) == 0 + assert gammainc(0, 10**30, regularized=True) == 0 + assert gammainc(-1, 5, regularized=True) == 0 + assert gammainc(-1, 5000, regularized=True) == 0 + assert gammainc(-1, 10**30, regularized=True) == 0 + assert gammainc(-1, -5, regularized=True) == 0 + assert gammainc(-1, -5000, regularized=True) == 0 + assert gammainc(-1, -10**30, regularized=True) == 0 + assert gammainc(-1, 3+4j, regularized=True) == 0 + assert gammainc(1, 5, regularized=True).ae(exp(-5)) + assert gammainc(1, 5000, regularized=True).ae(exp(-5000)) + assert gammainc(1, 10**30, regularized=True).ae(exp(-10**30)) + assert gammainc(1, 3+4j, regularized=True).ae(exp(-3-4j)) + assert gammainc(-1000000,2).ae('1.3669297209397347754e-301037', abs_eps=0, rel_eps=8*eps) + assert gammainc(-1000000,2,regularized=True) == 0 + assert gammainc(-1000000,3+4j).ae('-1.322575609404222361e-698979 - 4.9274570591854533273e-698978j', abs_eps=0, rel_eps=8*eps) + assert gammainc(-1000000,3+4j,regularized=True) == 0 + assert gammainc(2+3j, 4+5j, regularized=True).ae(0.085422013530993285774-0.052595379150390078503j) + assert gammainc(1000j, 1000j, regularized=True).ae(0.49702647628921131761 + 0.00297355675013575341j) + # Generalized + assert gammainc(3,4,2) == -gammainc(3,2,4) + assert gammainc(4, 2, 3).ae(1.2593494302978947396) + assert gammainc(4, 2, 3, regularized=True).ae(0.20989157171631578993) + assert gammainc(0, 2, 3).ae(0.035852129613864082155) + assert gammainc(0, 2, 3, regularized=True) == 0 + assert gammainc(-1, 2, 3).ae(0.015219822548487616132) + assert gammainc(-1, 2, 3, regularized=True) == 0 + assert gammainc(0, 2, 3).ae(0.035852129613864082155) + assert gammainc(0, 2, 3, regularized=True) == 0 + # Should use upper gammas + assert gammainc(5, 10000, 12000).ae('1.1359381951461801687e-4327', abs_eps=0, rel_eps=8*eps) + # Should use lower gammas + assert gammainc(10000, 2, 3).ae('8.1244514125995785934e4765') + # GH issue 306 + assert gammainc(3,-1-1j) == 0 + assert gammainc(3,-1+1j) == 0 + assert gammainc(2,-1) == 0 + assert gammainc(2,-1+0j) == 0 + assert gammainc(2+0j,-1) == 0 + +def test_gammainc_expint_n(): + # These tests are intended to check all cases of the low-level code + # for upper gamma and expint with small integer index. + # Need to cover positive/negative arguments; small/large/huge arguments + # for both positive and negative indices, as well as indices 0 and 1 + # which may be special-cased + mp.dps = 15 + assert expint(-3,3.5).ae(0.021456366563296693987) + assert expint(-2,3.5).ae(0.014966633183073309405) + assert expint(-1,3.5).ae(0.011092916359219041088) + assert expint(0,3.5).ae(0.0086278238349481430685) + assert expint(1,3.5).ae(0.0069701398575483929193) + assert expint(2,3.5).ae(0.0058018939208991255223) + assert expint(3,3.5).ae(0.0049453773495857807058) + assert expint(-3,-3.5).ae(-4.6618170604073311319) + assert expint(-2,-3.5).ae(-5.5996974157555515963) + assert expint(-1,-3.5).ae(-6.7582555017739415818) + assert expint(0,-3.5).ae(-9.4615577024835182145) + assert expint(1,-3.5).ae(-13.925353995152335292 - 3.1415926535897932385j) + assert expint(2,-3.5).ae(-15.62328702434085977 - 10.995574287564276335j) + assert expint(3,-3.5).ae(-10.783026313250347722 - 19.242255003237483586j) + assert expint(-3,350).ae(2.8614825451252838069e-155, abs_eps=0, rel_eps=8*eps) + assert expint(-2,350).ae(2.8532837224504675901e-155, abs_eps=0, rel_eps=8*eps) + assert expint(-1,350).ae(2.8451316155828634555e-155, abs_eps=0, rel_eps=8*eps) + assert expint(0,350).ae(2.8370258275042797989e-155, abs_eps=0, rel_eps=8*eps) + assert expint(1,350).ae(2.8289659656701459404e-155, abs_eps=0, rel_eps=8*eps) + assert expint(2,350).ae(2.8209516419468505006e-155, abs_eps=0, rel_eps=8*eps) + assert expint(3,350).ae(2.8129824725501272171e-155, abs_eps=0, rel_eps=8*eps) + assert expint(-3,-350).ae(-2.8528796154044839443e+149) + assert expint(-2,-350).ae(-2.8610072121701264351e+149) + assert expint(-1,-350).ae(-2.8691813842677537647e+149) + assert expint(0,-350).ae(-2.8774025343659421709e+149) + u = expint(1,-350) + assert u.ae(-2.8856710698020863568e+149) + assert u.imag.ae(-3.1415926535897932385) + u = expint(2,-350) + assert u.ae(-2.8939874026504650534e+149) + assert u.imag.ae(-1099.5574287564276335) + u = expint(3,-350) + assert u.ae(-2.9023519497915044349e+149) + assert u.imag.ae(-192422.55003237483586) + assert expint(-3,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(-2,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(-1,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(0,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(1,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(2,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(3,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert expint(-3,-350000000000000000000000).ae('-3.7805306852415755699e+152003068666138139677871') + assert expint(-2,-350000000000000000000000).ae('-3.7805306852415755699e+152003068666138139677871') + assert expint(-1,-350000000000000000000000).ae('-3.7805306852415755699e+152003068666138139677871') + assert expint(0,-350000000000000000000000).ae('-3.7805306852415755699e+152003068666138139677871') + u = expint(1,-350000000000000000000000) + assert u.ae('-3.7805306852415755699e+152003068666138139677871') + assert u.imag.ae(-3.1415926535897932385) + u = expint(2,-350000000000000000000000) + assert u.imag.ae(-1.0995574287564276335e+24) + assert u.ae('-3.7805306852415755699e+152003068666138139677871') + u = expint(3,-350000000000000000000000) + assert u.imag.ae(-1.9242255003237483586e+47) + assert u.ae('-3.7805306852415755699e+152003068666138139677871') + # Small case; no branch cut + assert gammainc(-3,3.5).ae(0.00010020262545203707109) + assert gammainc(-2,3.5).ae(0.00040370427343557393517) + assert gammainc(-1,3.5).ae(0.0016576839773997501492) + assert gammainc(0,3.5).ae(0.0069701398575483929193) + assert gammainc(1,3.5).ae(0.03019738342231850074) + assert gammainc(2,3.5).ae(0.13588822540043325333) + assert gammainc(3,3.5).ae(0.64169439772426814072) + # Small case; with branch cut + assert gammainc(-3,-3.5).ae(0.03595832954467563286 + 0.52359877559829887308j) + assert gammainc(-2,-3.5).ae(-0.88024704597962022221 - 1.5707963267948966192j) + assert gammainc(-1,-3.5).ae(4.4637962926688170771 + 3.1415926535897932385j) + assert gammainc(0,-3.5).ae(-13.925353995152335292 - 3.1415926535897932385j) + assert gammainc(1,-3.5).ae(33.115451958692313751) + assert gammainc(2,-3.5).ae(-82.788629896730784377) + assert gammainc(3,-3.5).ae(240.08702670051927469) + # Asymptotic case; no branch cut + assert gammainc(-3,350).ae(6.5424095113340358813e-163, abs_eps=0, rel_eps=8*eps) + assert gammainc(-2,350).ae(2.296312222489899769e-160, abs_eps=0, rel_eps=8*eps) + assert gammainc(-1,350).ae(8.059861834133858573e-158, abs_eps=0, rel_eps=8*eps) + assert gammainc(0,350).ae(2.8289659656701459404e-155, abs_eps=0, rel_eps=8*eps) + assert gammainc(1,350).ae(9.9295903962649792963e-153, abs_eps=0, rel_eps=8*eps) + assert gammainc(2,350).ae(3.485286229089007733e-150, abs_eps=0, rel_eps=8*eps) + assert gammainc(3,350).ae(1.2233453960006379793e-147, abs_eps=0, rel_eps=8*eps) + # Asymptotic case; branch cut + u = gammainc(-3,-350) + assert u.ae(6.7889565783842895085e+141) + assert u.imag.ae(0.52359877559829887308) + u = gammainc(-2,-350) + assert u.ae(-2.3692668977889832121e+144) + assert u.imag.ae(-1.5707963267948966192) + u = gammainc(-1,-350) + assert u.ae(8.2685354361441858669e+146) + assert u.imag.ae(3.1415926535897932385) + u = gammainc(0,-350) + assert u.ae(-2.8856710698020863568e+149) + assert u.imag.ae(-3.1415926535897932385) + u = gammainc(1,-350) + assert u.ae(1.0070908870280797598e+152) + assert u.imag == 0 + u = gammainc(2,-350) + assert u.ae(-3.5147471957279983618e+154) + assert u.imag == 0 + u = gammainc(3,-350) + assert u.ae(1.2266568422179417091e+157) + assert u.imag == 0 + # Extreme asymptotic case + assert gammainc(-3,350000000000000000000000).ae('5.0362468738874738859e-152003068666138139677990', abs_eps=0, rel_eps=8*eps) + assert gammainc(-2,350000000000000000000000).ae('1.7626864058606158601e-152003068666138139677966', abs_eps=0, rel_eps=8*eps) + assert gammainc(-1,350000000000000000000000).ae('6.1694024205121555102e-152003068666138139677943', abs_eps=0, rel_eps=8*eps) + assert gammainc(0,350000000000000000000000).ae('2.1592908471792544286e-152003068666138139677919', abs_eps=0, rel_eps=8*eps) + assert gammainc(1,350000000000000000000000).ae('7.5575179651273905e-152003068666138139677896', abs_eps=0, rel_eps=8*eps) + assert gammainc(2,350000000000000000000000).ae('2.645131287794586675e-152003068666138139677872', abs_eps=0, rel_eps=8*eps) + assert gammainc(3,350000000000000000000000).ae('9.2579595072810533625e-152003068666138139677849', abs_eps=0, rel_eps=8*eps) + u = gammainc(-3,-350000000000000000000000) + assert u.ae('8.8175642804468234866e+152003068666138139677800') + assert u.imag.ae(0.52359877559829887308) + u = gammainc(-2,-350000000000000000000000) + assert u.ae('-3.0861474981563882203e+152003068666138139677824') + assert u.imag.ae(-1.5707963267948966192) + u = gammainc(-1,-350000000000000000000000) + assert u.ae('1.0801516243547358771e+152003068666138139677848') + assert u.imag.ae(3.1415926535897932385) + u = gammainc(0,-350000000000000000000000) + assert u.ae('-3.7805306852415755699e+152003068666138139677871') + assert u.imag.ae(-3.1415926535897932385) + assert gammainc(1,-350000000000000000000000).ae('1.3231857398345514495e+152003068666138139677895') + assert gammainc(2,-350000000000000000000000).ae('-4.6311500894209300731e+152003068666138139677918') + assert gammainc(3,-350000000000000000000000).ae('1.6209025312973255256e+152003068666138139677942') + +def test_incomplete_beta(): + mp.dps = 15 + assert betainc(-2,-3,0.5,0.75).ae(63.4305673311255413583969) + assert betainc(4.5,0.5+2j,2.5,6).ae(0.2628801146130621387903065 + 0.5162565234467020592855378j) + assert betainc(4,5,0,6).ae(90747.77142857142857142857) + +def test_erf(): + mp.dps = 15 + assert erf(0) == 0 + assert erf(1).ae(0.84270079294971486934) + assert erf(3+4j).ae(-120.186991395079444098 - 27.750337293623902498j) + assert erf(-4-3j).ae(-0.99991066178539168236 + 0.00004972026054496604j) + assert erf(pi).ae(0.99999112385363235839) + assert erf(1j).ae(1.6504257587975428760j) + assert erf(-1j).ae(-1.6504257587975428760j) + assert isinstance(erf(1), mpf) + assert isinstance(erf(-1), mpf) + assert isinstance(erf(0), mpf) + assert isinstance(erf(0j), mpc) + assert erf(inf) == 1 + assert erf(-inf) == -1 + assert erfi(0) == 0 + assert erfi(1/pi).ae(0.371682698493894314) + assert erfi(inf) == inf + assert erfi(-inf) == -inf + assert erf(1+0j) == erf(1) + assert erfc(1+0j) == erfc(1) + assert erf(0.2+0.5j).ae(1 - erfc(0.2+0.5j)) + assert erfc(0) == 1 + assert erfc(1).ae(1-erf(1)) + assert erfc(-1).ae(1-erf(-1)) + assert erfc(1/pi).ae(1-erf(1/pi)) + assert erfc(-10) == 2 + assert erfc(-1000000) == 2 + assert erfc(-inf) == 2 + assert erfc(inf) == 0 + assert isnan(erfc(nan)) + assert (erfc(10**4)*mpf(10)**43429453).ae('3.63998738656420') + assert erf(8+9j).ae(-1072004.2525062051158 + 364149.91954310255423j) + assert erfc(8+9j).ae(1072005.2525062051158 - 364149.91954310255423j) + assert erfc(-8-9j).ae(-1072003.2525062051158 + 364149.91954310255423j) + mp.dps = 50 + # This one does not use the asymptotic series + assert (erfc(10)*10**45).ae('2.0884875837625447570007862949577886115608181193212') + # This one does + assert (erfc(50)*10**1088).ae('2.0709207788416560484484478751657887929322509209954') + mp.dps = 15 + assert str(erfc(10**50)) == '3.66744826532555e-4342944819032518276511289189166050822943970058036665661144537831658646492088707747292249493384317534' + assert erfinv(0) == 0 + assert erfinv(0.5).ae(0.47693627620446987338) + assert erfinv(-0.5).ae(-0.47693627620446987338) + assert erfinv(1) == inf + assert erfinv(-1) == -inf + assert erf(erfinv(0.95)).ae(0.95) + assert erf(erfinv(0.999999999995)).ae(0.999999999995) + assert erf(erfinv(-0.999999999995)).ae(-0.999999999995) + mp.dps = 50 + assert erf(erfinv('0.99999999999999999999999999999995')).ae('0.99999999999999999999999999999995') + assert erf(erfinv('0.999999999999999999999999999999995')).ae('0.999999999999999999999999999999995') + assert erf(erfinv('-0.999999999999999999999999999999995')).ae('-0.999999999999999999999999999999995') + mp.dps = 15 + # Complex asymptotic expansions + v = erfc(50j) + assert v.real == 1 + assert v.imag.ae('-6.1481820666053078736e+1083') + assert erfc(-100+5j).ae(2) + assert (erfc(100+5j)*10**4335).ae(2.3973567853824133572 - 3.9339259530609420597j) + assert erfc(100+100j).ae(0.00065234366376857698698 - 0.0039357263629214118437j) + +def test_pdf(): + mp.dps = 15 + assert npdf(-inf) == 0 + assert npdf(inf) == 0 + assert npdf(5,0,2).ae(npdf(5+4,4,2)) + assert quadts(lambda x: npdf(x,-0.5,0.8), [-inf, inf]) == 1 + assert ncdf(0) == 0.5 + assert ncdf(3,3) == 0.5 + assert ncdf(-inf) == 0 + assert ncdf(inf) == 1 + assert ncdf(10) == 1 + # Verify that this is computed accurately + assert (ncdf(-10)*10**24).ae(7.619853024160526) + +def test_lambertw(): + mp.dps = 15 + assert lambertw(0) == 0 + assert lambertw(0+0j) == 0 + assert lambertw(inf) == inf + assert isnan(lambertw(nan)) + assert lambertw(inf,1).real == inf + assert lambertw(inf,1).imag.ae(2*pi) + assert lambertw(-inf,1).real == inf + assert lambertw(-inf,1).imag.ae(3*pi) + assert lambertw(0,-1) == -inf + assert lambertw(0,1) == -inf + assert lambertw(0,3) == -inf + assert lambertw(e).ae(1) + assert lambertw(1).ae(0.567143290409783873) + assert lambertw(-pi/2).ae(j*pi/2) + assert lambertw(-log(2)/2).ae(-log(2)) + assert lambertw(0.25).ae(0.203888354702240164) + assert lambertw(-0.25).ae(-0.357402956181388903) + assert lambertw(-1./10000,0).ae(-0.000100010001500266719) + assert lambertw(-0.25,-1).ae(-2.15329236411034965) + assert lambertw(0.25,-1).ae(-3.00899800997004620-4.07652978899159763j) + assert lambertw(-0.25,-1).ae(-2.15329236411034965) + assert lambertw(0.25,1).ae(-3.00899800997004620+4.07652978899159763j) + assert lambertw(-0.25,1).ae(-3.48973228422959210+7.41405453009603664j) + assert lambertw(-4).ae(0.67881197132094523+1.91195078174339937j) + assert lambertw(-4,1).ae(-0.66743107129800988+7.76827456802783084j) + assert lambertw(-4,-1).ae(0.67881197132094523-1.91195078174339937j) + assert lambertw(1000).ae(5.24960285240159623) + assert lambertw(1000,1).ae(4.91492239981054535+5.44652615979447070j) + assert lambertw(1000,-1).ae(4.91492239981054535-5.44652615979447070j) + assert lambertw(1000,5).ae(3.5010625305312892+29.9614548941181328j) + assert lambertw(3+4j).ae(1.281561806123775878+0.533095222020971071j) + assert lambertw(-0.4+0.4j).ae(-0.10396515323290657+0.61899273315171632j) + assert lambertw(3+4j,1).ae(-0.11691092896595324+5.61888039871282334j) + assert lambertw(3+4j,-1).ae(0.25856740686699742-3.85211668616143559j) + assert lambertw(-0.5,-1).ae(-0.794023632344689368-0.770111750510379110j) + assert lambertw(-1./10000,1).ae(-11.82350837248724344+6.80546081842002101j) + assert lambertw(-1./10000,-1).ae(-11.6671145325663544) + assert lambertw(-1./10000,-2).ae(-11.82350837248724344-6.80546081842002101j) + assert lambertw(-1./100000,4).ae(-14.9186890769540539+26.1856750178782046j) + assert lambertw(-1./100000,5).ae(-15.0931437726379218666+32.5525721210262290086j) + assert lambertw((2+j)/10).ae(0.173704503762911669+0.071781336752835511j) + assert lambertw((2+j)/10,1).ae(-3.21746028349820063+4.56175438896292539j) + assert lambertw((2+j)/10,-1).ae(-3.03781405002993088-3.53946629633505737j) + assert lambertw((2+j)/10,4).ae(-4.6878509692773249+23.8313630697683291j) + assert lambertw(-(2+j)/10).ae(-0.226933772515757933-0.164986470020154580j) + assert lambertw(-(2+j)/10,1).ae(-2.43569517046110001+0.76974067544756289j) + assert lambertw(-(2+j)/10,-1).ae(-3.54858738151989450-6.91627921869943589j) + assert lambertw(-(2+j)/10,4).ae(-4.5500846928118151+20.6672982215434637j) + mp.dps = 50 + assert lambertw(pi).ae('1.073658194796149172092178407024821347547745350410314531') + mp.dps = 15 + # Former bug in generated branch + assert lambertw(-0.5+0.002j).ae(-0.78917138132659918344 + 0.76743539379990327749j) + assert lambertw(-0.5-0.002j).ae(-0.78917138132659918344 - 0.76743539379990327749j) + assert lambertw(-0.448+0.4j).ae(-0.11855133765652382241 + 0.66570534313583423116j) + assert lambertw(-0.448-0.4j).ae(-0.11855133765652382241 - 0.66570534313583423116j) + assert lambertw(-0.65475+0.0001j).ae(-0.61053421111385310898+1.0396534993944097723803j) + # Huge branch index + w = lambertw(1,10**20) + assert w.real.ae(-47.889578926290259164) + assert w.imag.ae(6.2831853071795864769e+20) + +def test_lambertw_hard(): + def check(x,y): + y = convert(y) + type_ok = True + if isinstance(y, mpf): + type_ok = isinstance(x, mpf) + real_ok = abs(x.real-y.real) <= abs(y.real)*8*eps + imag_ok = abs(x.imag-y.imag) <= abs(y.imag)*8*eps + #print x, y, abs(x.real-y.real), abs(x.imag-y.imag) + return real_ok and imag_ok + # Evaluation near 0 + mp.dps = 15 + assert check(lambertw(1e-10), 9.999999999000000000e-11) + assert check(lambertw(-1e-10), -1.000000000100000000e-10) + assert check(lambertw(1e-10j), 9.999999999999999999733e-21 + 9.99999999999999999985e-11j) + assert check(lambertw(-1e-10j), 9.999999999999999999733e-21 - 9.99999999999999999985e-11j) + assert check(lambertw(1e-10,1), -26.303186778379041559 + 3.265093911703828397j) + assert check(lambertw(-1e-10,1), -26.326236166739163892 + 6.526183280686333315j) + assert check(lambertw(1e-10j,1), -26.312931726911421551 + 4.896366881798013421j) + assert check(lambertw(-1e-10j,1), -26.297238779529035066 + 1.632807161345576513j) + assert check(lambertw(1e-10,-1), -26.303186778379041559 - 3.265093911703828397j) + assert check(lambertw(-1e-10,-1), -26.295238819246925694) + assert check(lambertw(1e-10j,-1), -26.297238779529035028 - 1.6328071613455765135j) + assert check(lambertw(-1e-10j,-1), -26.312931726911421551 - 4.896366881798013421j) + # Test evaluation very close to the branch point -1/e + # on the -1, 0, and 1 branches + add = lambda x, y: fadd(x,y,exact=True) + sub = lambda x, y: fsub(x,y,exact=True) + addj = lambda x, y: fadd(x,fmul(y,1j,exact=True),exact=True) + subj = lambda x, y: fadd(x,fmul(y,-1j,exact=True),exact=True) + mp.dps = 1500 + a = -1/e + 10*eps + d3 = mpf('1e-3') + d10 = mpf('1e-10') + d20 = mpf('1e-20') + d40 = mpf('1e-40') + d80 = mpf('1e-80') + d300 = mpf('1e-300') + d1000 = mpf('1e-1000') + mp.dps = 15 + # ---- Branch 0 ---- + # -1/e + eps + assert check(lambertw(add(a,d3)), -0.92802015005456704876) + assert check(lambertw(add(a,d10)), -0.99997668374140088071) + assert check(lambertw(add(a,d20)), -0.99999999976683560186) + assert lambertw(add(a,d40)) == -1 + assert lambertw(add(a,d80)) == -1 + assert lambertw(add(a,d300)) == -1 + assert lambertw(add(a,d1000)) == -1 + # -1/e - eps + assert check(lambertw(sub(a,d3)), -0.99819016149860989001+0.07367191188934638577j) + assert check(lambertw(sub(a,d10)), -0.9999999998187812114595992+0.0000233164398140346109194j) + assert check(lambertw(sub(a,d20)), -0.99999999999999999998187+2.331643981597124203344e-10j) + assert check(lambertw(sub(a,d40)), -1.0+2.33164398159712420336e-20j) + assert check(lambertw(sub(a,d80)), -1.0+2.33164398159712420336e-40j) + assert check(lambertw(sub(a,d300)), -1.0+2.33164398159712420336e-150j) + assert check(lambertw(sub(a,d1000)), mpc(-1,'2.33164398159712420336e-500')) + # -1/e + eps*j + assert check(lambertw(addj(a,d3)), -0.94790387486938526634+0.05036819639190132490j) + assert check(lambertw(addj(a,d10)), -0.9999835127872943680999899+0.0000164870314895821225256j) + assert check(lambertw(addj(a,d20)), -0.999999999835127872929987+1.64872127051890935830e-10j) + assert check(lambertw(addj(a,d40)), -0.9999999999999999999835+1.6487212707001281468305e-20j) + assert check(lambertw(addj(a,d80)), -1.0 + 1.64872127070012814684865e-40j) + assert check(lambertw(addj(a,d300)), -1.0 + 1.64872127070012814684865e-150j) + assert check(lambertw(addj(a,d1000)), mpc(-1.0,'1.64872127070012814684865e-500')) + # -1/e - eps*j + assert check(lambertw(subj(a,d3)), -0.94790387486938526634-0.05036819639190132490j) + assert check(lambertw(subj(a,d10)), -0.9999835127872943680999899-0.0000164870314895821225256j) + assert check(lambertw(subj(a,d20)), -0.999999999835127872929987-1.64872127051890935830e-10j) + assert check(lambertw(subj(a,d40)), -0.9999999999999999999835-1.6487212707001281468305e-20j) + assert check(lambertw(subj(a,d80)), -1.0 - 1.64872127070012814684865e-40j) + assert check(lambertw(subj(a,d300)), -1.0 - 1.64872127070012814684865e-150j) + assert check(lambertw(subj(a,d1000)), mpc(-1.0,'-1.64872127070012814684865e-500')) + # ---- Branch 1 ---- + assert check(lambertw(addj(a,d3),1), -3.088501303219933378005990 + 7.458676867597474813950098j) + assert check(lambertw(addj(a,d80),1), -3.088843015613043855957087 + 7.461489285654254556906117j) + assert check(lambertw(addj(a,d300),1), -3.088843015613043855957087 + 7.461489285654254556906117j) + assert check(lambertw(addj(a,d1000),1), -3.088843015613043855957087 + 7.461489285654254556906117j) + assert check(lambertw(subj(a,d3),1), -1.0520914180450129534365906 + 0.0539925638125450525673175j) + assert check(lambertw(subj(a,d10),1), -1.0000164872127056318529390 + 0.000016487393927159250398333077j) + assert check(lambertw(subj(a,d20),1), -1.0000000001648721270700128 + 1.64872127088134693542628e-10j) + assert check(lambertw(subj(a,d40),1), -1.000000000000000000016487 + 1.64872127070012814686677e-20j) + assert check(lambertw(subj(a,d80),1), -1.0 + 1.64872127070012814684865e-40j) + assert check(lambertw(subj(a,d300),1), -1.0 + 1.64872127070012814684865e-150j) + assert check(lambertw(subj(a,d1000),1), mpc(-1.0, '1.64872127070012814684865e-500')) + # ---- Branch -1 ---- + # -1/e + eps + assert check(lambertw(add(a,d3),-1), -1.075608941186624989414945) + assert check(lambertw(add(a,d10),-1), -1.000023316621036696460620) + assert check(lambertw(add(a,d20),-1), -1.000000000233164398177834) + assert lambertw(add(a,d40),-1) == -1 + assert lambertw(add(a,d80),-1) == -1 + assert lambertw(add(a,d300),-1) == -1 + assert lambertw(add(a,d1000),-1) == -1 + # -1/e - eps + assert check(lambertw(sub(a,d3),-1), -0.99819016149860989001-0.07367191188934638577j) + assert check(lambertw(sub(a,d10),-1), -0.9999999998187812114595992-0.0000233164398140346109194j) + assert check(lambertw(sub(a,d20),-1), -0.99999999999999999998187-2.331643981597124203344e-10j) + assert check(lambertw(sub(a,d40),-1), -1.0-2.33164398159712420336e-20j) + assert check(lambertw(sub(a,d80),-1), -1.0-2.33164398159712420336e-40j) + assert check(lambertw(sub(a,d300),-1), -1.0-2.33164398159712420336e-150j) + assert check(lambertw(sub(a,d1000),-1), mpc(-1,'-2.33164398159712420336e-500')) + # -1/e + eps*j + assert check(lambertw(addj(a,d3),-1), -1.0520914180450129534365906 - 0.0539925638125450525673175j) + assert check(lambertw(addj(a,d10),-1), -1.0000164872127056318529390 - 0.0000164873939271592503983j) + assert check(lambertw(addj(a,d20),-1), -1.0000000001648721270700 - 1.64872127088134693542628e-10j) + assert check(lambertw(addj(a,d40),-1), -1.00000000000000000001648 - 1.6487212707001281468667726e-20j) + assert check(lambertw(addj(a,d80),-1), -1.0 - 1.64872127070012814684865e-40j) + assert check(lambertw(addj(a,d300),-1), -1.0 - 1.64872127070012814684865e-150j) + assert check(lambertw(addj(a,d1000),-1), mpc(-1.0,'-1.64872127070012814684865e-500')) + # -1/e - eps*j + assert check(lambertw(subj(a,d3),-1), -3.088501303219933378005990-7.458676867597474813950098j) + assert check(lambertw(subj(a,d10),-1), -3.088843015579260686911033-7.461489285372968780020716j) + assert check(lambertw(subj(a,d20),-1), -3.088843015613043855953708-7.461489285654254556877988j) + assert check(lambertw(subj(a,d40),-1), -3.088843015613043855957087-7.461489285654254556906117j) + assert check(lambertw(subj(a,d80),-1), -3.088843015613043855957087 - 7.461489285654254556906117j) + assert check(lambertw(subj(a,d300),-1), -3.088843015613043855957087 - 7.461489285654254556906117j) + assert check(lambertw(subj(a,d1000),-1), -3.088843015613043855957087 - 7.461489285654254556906117j) + # One more case, testing higher precision + mp.dps = 500 + x = -1/e + mpf('1e-13') + ans = "-0.99999926266961377166355784455394913638782494543377383"\ + "744978844374498153493943725364881490261187530235150668593869563"\ + "168276697689459394902153960200361935311512317183678882" + mp.dps = 15 + assert lambertw(x).ae(ans) + mp.dps = 50 + assert lambertw(x).ae(ans) + mp.dps = 150 + assert lambertw(x).ae(ans) + +def test_meijerg(): + mp.dps = 15 + assert meijerg([[2,3],[1]],[[0.5,2],[3,4]], 2.5).ae(4.2181028074787439386) + assert meijerg([[],[1+j]],[[1],[1]], 3+4j).ae(271.46290321152464592 - 703.03330399954820169j) + assert meijerg([[0.25],[1]],[[0.5],[2]],0) == 0 + assert meijerg([[0],[]],[[0,0,'1/3','2/3'], []], '2/27').ae(2.2019391389653314120) + # Verify 1/z series being used + assert meijerg([[-3],[-0.5]], [[-1],[-2.5]], -0.5).ae(-1.338096165935754898687431) + assert meijerg([[1-(-1)],[1-(-2.5)]], [[1-(-3)],[1-(-0.5)]], -2.0).ae(-1.338096165935754898687431) + assert meijerg([[-3],[-0.5]], [[-1],[-2.5]], -1).ae(-(pi+4)/(4*pi)) + a = 2.5 + b = 1.25 + for z in [mpf(0.25), mpf(2)]: + x1 = hyp1f1(a,b,z) + x2 = gamma(b)/gamma(a)*meijerg([[1-a],[]],[[0],[1-b]],-z) + x3 = gamma(b)/gamma(a)*meijerg([[1-0],[1-(1-b)]],[[1-(1-a)],[]],-1/z) + assert x1.ae(x2) + assert x1.ae(x3) + +def test_appellf1(): + mp.dps = 15 + assert appellf1(2,-2,1,1,2,3).ae(-1.75) + assert appellf1(2,1,-2,1,2,3).ae(-8) + assert appellf1(2,1,-2,1,0.5,0.25).ae(1.5) + assert appellf1(-2,1,3,2,3,3).ae(19) + assert appellf1(1,2,3,4,0.5,0.125).ae( 1.53843285792549786518) + +def test_coulomb(): + # Note: most tests are doctests + # Test for a bug: + mp.dps = 15 + assert coulombg(mpc(-5,0),2,3).ae(20.087729487721430394) + +def test_hyper_param_accuracy(): + mp.dps = 15 + As = [n+1e-10 for n in range(-5,-1)] + Bs = [n+1e-10 for n in range(-12,-5)] + assert hyper(As,Bs,10).ae(-381757055858.652671927) + assert legenp(0.5, 100, 0.25).ae(-2.4124576567211311755e+144) + assert (hyp1f1(1000,1,-100)*10**24).ae(5.2589445437370169113) + assert (hyp2f1(10, -900, 10.5, 0.99)*10**24).ae(1.9185370579660768203) + assert (hyp2f1(1000,1.5,-3.5,-1.5)*10**385).ae(-2.7367529051334000764) + assert hyp2f1(-5, 10, 3, 0.5, zeroprec=500) == 0 + assert (hyp1f1(-10000, 1000, 100)*10**424).ae(-3.1046080515824859974) + assert (hyp2f1(1000,1.5,-3.5,-0.75,maxterms=100000)*10**231).ae(-4.0534790813913998643) + assert legenp(2, 3, 0.25) == 0 + pytest.raises(ValueError, lambda: hypercomb(lambda a: [([],[],[],[],[a],[-a],0.5)], [3])) + assert hypercomb(lambda a: [([],[],[],[],[a],[-a],0.5)], [3], infprec=200) == inf + assert meijerg([[],[]],[[0,0,0,0],[]],0.1).ae(1.5680822343832351418) + assert (besselk(400,400)*10**94).ae(1.4387057277018550583) + mp.dps = 5 + (hyp1f1(-5000.5, 1500, 100)*10**185).ae(8.5185229673381935522) + (hyp1f1(-5000, 1500, 100)*10**185).ae(9.1501213424563944311) + mp.dps = 15 + (hyp1f1(-5000.5, 1500, 100)*10**185).ae(8.5185229673381935522) + (hyp1f1(-5000, 1500, 100)*10**185).ae(9.1501213424563944311) + assert hyp0f1(fadd(-20,'1e-100',exact=True), 0.25).ae(1.85014429040102783e+49) + assert hyp0f1((-20*10**100+1, 10**100), 0.25).ae(1.85014429040102783e+49) + +def test_hypercomb_zero_pow(): + # check that 0^0 = 1 + assert hypercomb(lambda a: (([0],[a],[],[],[],[],0),), [0]) == 1 + assert meijerg([[-1.5],[]],[[0],[-0.75]],0).ae(1.4464090846320771425) + +def test_spherharm(): + mp.dps = 15 + t = 0.5; r = 0.25 + assert spherharm(0,0,t,r).ae(0.28209479177387814347) + assert spherharm(1,-1,t,r).ae(0.16048941205971996369 - 0.04097967481096344271j) + assert spherharm(1,0,t,r).ae(0.42878904414183579379) + assert spherharm(1,1,t,r).ae(-0.16048941205971996369 - 0.04097967481096344271j) + assert spherharm(2,-2,t,r).ae(0.077915886919031181734 - 0.042565643022253962264j) + assert spherharm(2,-1,t,r).ae(0.31493387233497459884 - 0.08041582001959297689j) + assert spherharm(2,0,t,r).ae(0.41330596756220761898) + assert spherharm(2,1,t,r).ae(-0.31493387233497459884 - 0.08041582001959297689j) + assert spherharm(2,2,t,r).ae(0.077915886919031181734 + 0.042565643022253962264j) + assert spherharm(3,-3,t,r).ae(0.033640236589690881646 - 0.031339125318637082197j) + assert spherharm(3,-2,t,r).ae(0.18091018743101461963 - 0.09883168583167010241j) + assert spherharm(3,-1,t,r).ae(0.42796713930907320351 - 0.10927795157064962317j) + assert spherharm(3,0,t,r).ae(0.27861659336351639787) + assert spherharm(3,1,t,r).ae(-0.42796713930907320351 - 0.10927795157064962317j) + assert spherharm(3,2,t,r).ae(0.18091018743101461963 + 0.09883168583167010241j) + assert spherharm(3,3,t,r).ae(-0.033640236589690881646 - 0.031339125318637082197j) + assert spherharm(0,-1,t,r) == 0 + assert spherharm(0,-2,t,r) == 0 + assert spherharm(0,1,t,r) == 0 + assert spherharm(0,2,t,r) == 0 + assert spherharm(1,2,t,r) == 0 + assert spherharm(1,3,t,r) == 0 + assert spherharm(1,-2,t,r) == 0 + assert spherharm(1,-3,t,r) == 0 + assert spherharm(2,3,t,r) == 0 + assert spherharm(2,4,t,r) == 0 + assert spherharm(2,-3,t,r) == 0 + assert spherharm(2,-4,t,r) == 0 + assert spherharm(3,4.5,0.5,0.25).ae(-22.831053442240790148 + 10.910526059510013757j) + assert spherharm(2+3j, 1-j, 1+j, 3+4j).ae(-2.6582752037810116935 - 1.0909214905642160211j) + assert spherharm(-6,2.5,t,r).ae(0.39383644983851448178 + 0.28414687085358299021j) + assert spherharm(-3.5, 3, 0.5, 0.25).ae(0.014516852987544698924 - 0.015582769591477628495j) + assert spherharm(-3, 3, 0.5, 0.25) == 0 + assert spherharm(-6, 3, 0.5, 0.25).ae(-0.16544349818782275459 - 0.15412657723253924562j) + assert spherharm(-6, 1.5, 0.5, 0.25).ae(0.032208193499767402477 + 0.012678000924063664921j) + assert spherharm(3,0,0,1).ae(0.74635266518023078283) + assert spherharm(3,-2,0,1) == 0 + assert spherharm(3,-2,1,1).ae(-0.16270707338254028971 - 0.35552144137546777097j) + +def test_qfunctions(): + mp.dps = 15 + assert qp(2,3,100).ae('2.7291482267247332183e2391') + +def test_issue_239(): + mp.prec = 150 + x = ldexp(2476979795053773,-52) + assert betainc(206, 385, 0, 0.55, 1).ae('0.99999999999999999999996570910644857895771110649954') + mp.dps = 15 + pytest.raises(ValueError, lambda: hyp2f1(-5,5,0.5,0.5)) + +# Extra stress testing for Bessel functions +# Reference zeros generated with the aid of scipy.special +# jn_zero, jnp_zero, yn_zero, ynp_zero + +V = 15 +M = 15 + +jn_small_zeros = \ +[[2.4048255576957728, + 5.5200781102863106, + 8.6537279129110122, + 11.791534439014282, + 14.930917708487786, + 18.071063967910923, + 21.211636629879259, + 24.352471530749303, + 27.493479132040255, + 30.634606468431975, + 33.775820213573569, + 36.917098353664044, + 40.058425764628239, + 43.19979171317673, + 46.341188371661814], + [3.8317059702075123, + 7.0155866698156188, + 10.173468135062722, + 13.323691936314223, + 16.470630050877633, + 19.615858510468242, + 22.760084380592772, + 25.903672087618383, + 29.046828534916855, + 32.189679910974404, + 35.332307550083865, + 38.474766234771615, + 41.617094212814451, + 44.759318997652822, + 47.901460887185447], + [5.1356223018406826, + 8.4172441403998649, + 11.619841172149059, + 14.795951782351261, + 17.959819494987826, + 21.116997053021846, + 24.270112313573103, + 27.420573549984557, + 30.569204495516397, + 33.7165195092227, + 36.86285651128381, + 40.008446733478192, + 43.153453778371463, + 46.297996677236919, + 49.442164110416873], + [6.3801618959239835, + 9.7610231299816697, + 13.015200721698434, + 16.223466160318768, + 19.409415226435012, + 22.582729593104442, + 25.748166699294978, + 28.908350780921758, + 32.064852407097709, + 35.218670738610115, + 38.370472434756944, + 41.520719670406776, + 44.669743116617253, + 47.817785691533302, + 50.965029906205183], + [7.5883424345038044, + 11.064709488501185, + 14.37253667161759, + 17.615966049804833, + 20.826932956962388, + 24.01901952477111, + 27.199087765981251, + 30.371007667117247, + 33.537137711819223, + 36.699001128744649, + 39.857627302180889, + 43.01373772335443, + 46.167853512924375, + 49.320360686390272, + 52.471551398458023], + [8.771483815959954, + 12.338604197466944, + 15.700174079711671, + 18.980133875179921, + 22.217799896561268, + 25.430341154222704, + 28.626618307291138, + 31.811716724047763, + 34.988781294559295, + 38.159868561967132, + 41.326383254047406, + 44.489319123219673, + 47.649399806697054, + 50.80716520300633, + 53.963026558378149], + [9.9361095242176849, + 13.589290170541217, + 17.003819667816014, + 20.320789213566506, + 23.58608443558139, + 26.820151983411405, + 30.033722386570469, + 33.233041762847123, + 36.422019668258457, + 39.603239416075404, + 42.778481613199507, + 45.949015998042603, + 49.11577372476426, + 52.279453903601052, + 55.440592068853149], + [11.086370019245084, + 14.821268727013171, + 18.287582832481726, + 21.641541019848401, + 24.934927887673022, + 28.191188459483199, + 31.42279419226558, + 34.637089352069324, + 37.838717382853611, + 41.030773691585537, + 44.21540850526126, + 47.394165755570512, + 50.568184679795566, + 53.738325371963291, + 56.905249991978781], + [12.225092264004655, + 16.037774190887709, + 19.554536430997055, + 22.94517313187462, + 26.266814641176644, + 29.54565967099855, + 32.795800037341462, + 36.025615063869571, + 39.240447995178135, + 42.443887743273558, + 45.638444182199141, + 48.825930381553857, + 52.007691456686903, + 55.184747939289049, + 58.357889025269694], + [13.354300477435331, + 17.241220382489128, + 20.807047789264107, + 24.233885257750552, + 27.583748963573006, + 30.885378967696675, + 34.154377923855096, + 37.400099977156589, + 40.628553718964528, + 43.843801420337347, + 47.048700737654032, + 50.245326955305383, + 53.435227157042058, + 56.619580266508436, + 59.799301630960228], + [14.475500686554541, + 18.433463666966583, + 22.046985364697802, + 25.509450554182826, + 28.887375063530457, + 32.211856199712731, + 35.499909205373851, + 38.761807017881651, + 42.004190236671805, + 45.231574103535045, + 48.447151387269394, + 51.653251668165858, + 54.851619075963349, + 58.043587928232478, + 61.230197977292681], + [15.589847884455485, + 19.61596690396692, + 23.275853726263409, + 26.773322545509539, + 30.17906117878486, + 33.526364075588624, + 36.833571341894905, + 40.111823270954241, + 43.368360947521711, + 46.608132676274944, + 49.834653510396724, + 53.050498959135054, + 56.257604715114484, + 59.457456908388002, + 62.651217388202912], + [16.698249933848246, + 20.789906360078443, + 24.494885043881354, + 28.026709949973129, + 31.45996003531804, + 34.829986990290238, + 38.156377504681354, + 41.451092307939681, + 44.721943543191147, + 47.974293531269048, + 51.211967004101068, + 54.437776928325074, + 57.653844811906946, + 60.8618046824805, + 64.062937824850136], + [17.801435153282442, + 21.95624406783631, + 25.705103053924724, + 29.270630441874802, + 32.731053310978403, + 36.123657666448762, + 39.469206825243883, + 42.780439265447158, + 46.06571091157561, + 49.330780096443524, + 52.579769064383396, + 55.815719876305778, + 59.040934037249271, + 62.257189393731728, + 65.465883797232125], + [18.899997953174024, + 23.115778347252756, + 26.907368976182104, + 30.505950163896036, + 33.993184984781542, + 37.408185128639695, + 40.772827853501868, + 44.100590565798301, + 47.400347780543231, + 50.678236946479898, + 53.93866620912693, + 57.184898598119301, + 60.419409852130297, + 63.644117508962281, + 66.860533012260103]] + +jnp_small_zeros = \ +[[0.0, + 3.8317059702075123, + 7.0155866698156188, + 10.173468135062722, + 13.323691936314223, + 16.470630050877633, + 19.615858510468242, + 22.760084380592772, + 25.903672087618383, + 29.046828534916855, + 32.189679910974404, + 35.332307550083865, + 38.474766234771615, + 41.617094212814451, + 44.759318997652822], + [1.8411837813406593, + 5.3314427735250326, + 8.5363163663462858, + 11.706004902592064, + 14.863588633909033, + 18.015527862681804, + 21.16436985918879, + 24.311326857210776, + 27.457050571059246, + 30.601922972669094, + 33.746182898667383, + 36.889987409236811, + 40.033444053350675, + 43.176628965448822, + 46.319597561173912], + [3.0542369282271403, + 6.7061331941584591, + 9.9694678230875958, + 13.170370856016123, + 16.347522318321783, + 19.512912782488205, + 22.671581772477426, + 25.826037141785263, + 28.977672772993679, + 32.127327020443474, + 35.275535050674691, + 38.422654817555906, + 41.568934936074314, + 44.714553532819734, + 47.859641607992093], + [4.2011889412105285, + 8.0152365983759522, + 11.345924310743006, + 14.585848286167028, + 17.78874786606647, + 20.9724769365377, + 24.144897432909265, + 27.310057930204349, + 30.470268806290424, + 33.626949182796679, + 36.781020675464386, + 39.933108623659488, + 43.083652662375079, + 46.232971081836478, + 49.381300092370349], + [5.3175531260839944, + 9.2823962852416123, + 12.681908442638891, + 15.964107037731551, + 19.196028800048905, + 22.401032267689004, + 25.589759681386733, + 28.767836217666503, + 31.938539340972783, + 35.103916677346764, + 38.265316987088158, + 41.423666498500732, + 44.579623137359257, + 47.733667523865744, + 50.886159153182682], + [6.4156163757002403, + 10.519860873772308, + 13.9871886301403, + 17.312842487884625, + 20.575514521386888, + 23.803581476593863, + 27.01030789777772, + 30.20284907898166, + 33.385443901010121, + 36.560777686880356, + 39.730640230067416, + 42.896273163494417, + 46.058566273567043, + 49.218174614666636, + 52.375591529563596], + [7.501266144684147, + 11.734935953042708, + 15.268181461097873, + 18.637443009666202, + 21.931715017802236, + 25.183925599499626, + 28.409776362510085, + 31.617875716105035, + 34.81339298429743, + 37.999640897715301, + 41.178849474321413, + 44.352579199070217, + 47.521956905768113, + 50.687817781723741, + 53.85079463676896], + [8.5778364897140741, + 12.932386237089576, + 16.529365884366944, + 19.941853366527342, + 23.268052926457571, + 26.545032061823576, + 29.790748583196614, + 33.015178641375142, + 36.224380548787162, + 39.422274578939259, + 42.611522172286684, + 45.793999658055002, + 48.971070951900596, + 52.143752969301988, + 55.312820330403446], + [9.6474216519972168, + 14.115518907894618, + 17.774012366915256, + 21.229062622853124, + 24.587197486317681, + 27.889269427955092, + 31.155326556188325, + 34.39662855427218, + 37.620078044197086, + 40.830178681822041, + 44.030010337966153, + 47.221758471887113, + 50.407020967034367, + 53.586995435398319, + 56.762598475105272], + [10.711433970699945, + 15.28673766733295, + 19.004593537946053, + 22.501398726777283, + 25.891277276839136, + 29.218563499936081, + 32.505247352375523, + 35.763792928808799, + 39.001902811514218, + 42.224638430753279, + 45.435483097475542, + 48.636922645305525, + 51.830783925834728, + 55.01844255063594, + 58.200955824859509], + [11.770876674955582, + 16.447852748486498, + 20.223031412681701, + 23.760715860327448, + 27.182021527190532, + 30.534504754007074, + 33.841965775135715, + 37.118000423665604, + 40.371068905333891, + 43.606764901379516, + 46.828959446564562, + 50.040428970943456, + 53.243223214220535, + 56.438892058982552, + 59.628631306921512], + [12.826491228033465, + 17.600266557468326, + 21.430854238060294, + 25.008518704644261, + 28.460857279654847, + 31.838424458616998, + 35.166714427392629, + 38.460388720328256, + 41.728625562624312, + 44.977526250903469, + 48.211333836373288, + 51.433105171422278, + 54.645106240447105, + 57.849056857839799, + 61.046288512821078], + [13.878843069697276, + 18.745090916814406, + 22.629300302835503, + 26.246047773946584, + 29.72897816891134, + 33.131449953571661, + 36.480548302231658, + 39.791940718940855, + 43.075486800191012, + 46.337772104541405, + 49.583396417633095, + 52.815686826850452, + 56.037118687012179, + 59.249577075517968, + 62.454525995970462], + [14.928374492964716, + 19.88322436109951, + 23.81938909003628, + 27.474339750968247, + 30.987394331665278, + 34.414545662167183, + 37.784378506209499, + 41.113512376883377, + 44.412454519229281, + 47.688252845993366, + 50.945849245830813, + 54.188831071035124, + 57.419876154678179, + 60.641030026538746, + 63.853885828967512], + [15.975438807484321, + 21.015404934568315, + 25.001971500138194, + 28.694271223110755, + 32.236969407878118, + 35.688544091185301, + 39.078998185245057, + 42.425854432866141, + 45.740236776624833, + 49.029635055514276, + 52.299319390331728, + 55.553127779547459, + 58.793933759028134, + 62.02393848337554, + 65.244860767043859]] + +yn_small_zeros = \ +[[0.89357696627916752, + 3.9576784193148579, + 7.0860510603017727, + 10.222345043496417, + 13.361097473872763, + 16.500922441528091, + 19.64130970088794, + 22.782028047291559, + 25.922957653180923, + 29.064030252728398, + 32.205204116493281, + 35.346452305214321, + 38.487756653081537, + 41.629104466213808, + 44.770486607221993], + [2.197141326031017, + 5.4296810407941351, + 8.5960058683311689, + 11.749154830839881, + 14.897442128336725, + 18.043402276727856, + 21.188068934142213, + 24.331942571356912, + 27.475294980449224, + 30.618286491641115, + 33.761017796109326, + 36.90355531614295, + 40.045944640266876, + 43.188218097393211, + 46.330399250701687], + [3.3842417671495935, + 6.7938075132682675, + 10.023477979360038, + 13.209986710206416, + 16.378966558947457, + 19.539039990286384, + 22.69395593890929, + 25.845613720902269, + 28.995080395650151, + 32.143002257627551, + 35.289793869635804, + 38.435733485446343, + 41.581014867297885, + 44.725777117640461, + 47.870122696676504], + [4.5270246611496439, + 8.0975537628604907, + 11.396466739595867, + 14.623077742393873, + 17.81845523294552, + 20.997284754187761, + 24.166235758581828, + 27.328799850405162, + 30.486989604098659, + 33.642049384702463, + 36.794791029185579, + 39.945767226378749, + 43.095367507846703, + 46.2438744334407, + 49.391498015725107], + [5.6451478942208959, + 9.3616206152445429, + 12.730144474090465, + 15.999627085382479, + 19.22442895931681, + 22.424810599698521, + 25.610267054939328, + 28.785893657666548, + 31.954686680031668, + 35.118529525584828, + 38.278668089521758, + 41.435960629910073, + 44.591018225353424, + 47.744288086361052, + 50.896105199722123], + [6.7471838248710219, + 10.597176726782031, + 14.033804104911233, + 17.347086393228382, + 20.602899017175335, + 23.826536030287532, + 27.030134937138834, + 30.220335654231385, + 33.401105611047908, + 36.574972486670962, + 39.743627733020277, + 42.908248189569535, + 46.069679073215439, + 49.228543693445843, + 52.385312123112282], + [7.8377378223268716, + 11.811037107609447, + 15.313615118517857, + 18.670704965906724, + 21.958290897126571, + 25.206207715021249, + 28.429037095235496, + 31.634879502950644, + 34.828638524084437, + 38.013473399691765, + 41.19151880917741, + 44.364272633271975, + 47.53281875312084, + 50.697961822183806, + 53.860312300118388], + [8.919605734873789, + 13.007711435388313, + 16.573915129085334, + 19.974342312352426, + 23.293972585596648, + 26.5667563757203, + 29.809531451608321, + 33.031769327150685, + 36.239265816598239, + 39.435790312675323, + 42.623910919472727, + 45.805442883111651, + 48.981708325514764, + 52.153694518185572, + 55.322154420959698], + [9.9946283820824834, + 14.190361295800141, + 17.817887841179873, + 21.26093227125945, + 24.612576377421522, + 27.910524883974868, + 31.173701563441602, + 34.412862242025045, + 37.634648706110989, + 40.843415321050884, + 44.04214994542435, + 47.232978012841169, + 50.417456447370186, + 53.596753874948731, + 56.771765754432457], + [11.064090256031013, + 15.361301343575925, + 19.047949646361388, + 22.532765416313869, + 25.91620496332662, + 29.2394205079349, + 32.523270869465881, + 35.779715464475261, + 39.016196664616095, + 42.237627509803703, + 45.4474001519274, + 48.647941127433196, + 51.841036928216499, + 55.028034667184916, + 58.209970905250097], + [12.128927704415439, + 16.522284394784426, + 20.265984501212254, + 23.791669719454272, + 27.206568881574774, + 30.555020011020762, + 33.859683872746356, + 37.133649760307504, + 40.385117593813002, + 43.619533085646856, + 46.840676630553575, + 50.051265851897857, + 53.253310556711732, + 56.448332488918971, + 59.637507005589829], + [13.189846995683845, + 17.674674253171487, + 21.473493977824902, + 25.03913093040942, + 28.485081336558058, + 31.858644293774859, + 35.184165245422787, + 38.475796636190897, + 41.742455848758449, + 44.990096293791186, + 48.222870660068338, + 51.443777308699826, + 54.655042589416311, + 57.858358441436511, + 61.055036135780528], + [14.247395665073945, + 18.819555894710682, + 22.671697117872794, + 26.276375544903892, + 29.752925495549038, + 33.151412708998983, + 36.497763772987645, + 39.807134090704376, + 43.089121522203808, + 46.350163579538652, + 49.594769786270069, + 52.82620892320143, + 56.046916910756961, + 59.258751140598783, + 62.463155567737854], + [15.30200785858925, + 19.957808654258601, + 23.861599172945054, + 27.504429642227545, + 31.011103429019229, + 34.434283425782942, + 37.801385632318459, + 41.128514139788358, + 44.425913324440663, + 47.700482714581842, + 50.957073905278458, + 54.199216028087261, + 57.429547607017405, + 60.65008661807661, + 63.862406280068586], + [16.354034360047551, + 21.090156519983806, + 25.044040298785627, + 28.724161640881914, + 32.260472459522644, + 35.708083982611664, + 39.095820003878235, + 42.440684315990936, + 45.75353669045622, + 49.041718113283529, + 52.310408280968073, + 55.56338698149062, + 58.803488508906895, + 62.032886550960831, + 65.253280088312461]] + +ynp_small_zeros = \ +[[2.197141326031017, + 5.4296810407941351, + 8.5960058683311689, + 11.749154830839881, + 14.897442128336725, + 18.043402276727856, + 21.188068934142213, + 24.331942571356912, + 27.475294980449224, + 30.618286491641115, + 33.761017796109326, + 36.90355531614295, + 40.045944640266876, + 43.188218097393211, + 46.330399250701687], + [3.6830228565851777, + 6.9414999536541757, + 10.123404655436613, + 13.285758156782854, + 16.440058007293282, + 19.590241756629495, + 22.738034717396327, + 25.884314618788867, + 29.029575819372535, + 32.174118233366201, + 35.318134458192094, + 38.461753870997549, + 41.605066618873108, + 44.74813744908079, + 47.891014070791065], + [5.0025829314460639, + 8.3507247014130795, + 11.574195465217647, + 14.760909306207676, + 17.931285939466855, + 21.092894504412739, + 24.249231678519058, + 27.402145837145258, + 30.552708880564553, + 33.70158627151572, + 36.849213419846257, + 39.995887376143356, + 43.141817835750686, + 46.287157097544201, + 49.432018469138281], + [6.2536332084598136, + 9.6987879841487711, + 12.972409052292216, + 16.19044719506921, + 19.38238844973613, + 22.559791857764261, + 25.728213194724094, + 28.890678419054777, + 32.048984005266337, + 35.204266606440635, + 38.357281675961019, + 41.508551443818436, + 44.658448731963676, + 47.807246956681162, + 50.95515126455207], + [7.4649217367571329, + 11.005169149809189, + 14.3317235192331, + 17.58443601710272, + 20.801062338411128, + 23.997004122902644, + 27.179886689853435, + 30.353960608554323, + 33.521797098666792, + 36.685048382072301, + 39.844826969405863, + 43.001910515625288, + 46.15685955107263, + 49.310088614282257, + 52.461911043685864], + [8.6495562436971983, + 12.280868725807848, + 15.660799304540377, + 18.949739756016503, + 22.192841809428241, + 25.409072788867674, + 28.608039283077593, + 31.795195353138159, + 34.973890634255288, + 38.14630522169358, + 41.313923188794905, + 44.477791768537617, + 47.638672065035628, + 50.797131066967842, + 53.953600129601663], + [9.8147970120105779, + 13.532811875789828, + 16.965526446046053, + 20.291285512443867, + 23.56186260680065, + 26.799499736027237, + 30.015665481543419, + 33.216968050039509, + 36.407516858984748, + 39.590015243560459, + 42.766320595957378, + 45.937754257017323, + 49.105283450953203, + 52.269633324547373, + 55.431358715604255], + [10.965152105242974, + 14.765687379508912, + 18.250123150217555, + 21.612750053384621, + 24.911310600813573, + 28.171051927637585, + 31.40518108895689, + 34.621401012564177, + 37.824552065973114, + 41.017847386464902, + 44.203512240871601, + 47.3831408366063, + 50.557907466622796, + 53.728697478957026, + 56.896191727313342], + [12.103641941939539, + 15.982840905145284, + 19.517731005559611, + 22.916962141504605, + 26.243700855690533, + 29.525960140695407, + 32.778568197561124, + 36.010261572392516, + 39.226578757802172, + 42.43122493258747, + 45.626783824134354, + 48.815117837929515, + 51.997606404328863, + 55.175294723956816, + 58.348990221754937], + [13.232403808592215, + 17.186756572616758, + 20.770762917490496, + 24.206152448722253, + 27.561059462697153, + 30.866053571250639, + 34.137476603379774, + 37.385039772270268, + 40.614946085165892, + 43.831373184731238, + 47.037251786726299, + 50.234705848765229, + 53.425316228549359, + 56.610286079882087, + 59.790548623216652], + [14.35301374369987, + 18.379337301642568, + 22.011118775283494, + 25.482116178696707, + 28.865046588695164, + 32.192853922166294, + 35.483296655830277, + 38.747005493021857, + 41.990815194320955, + 45.219355876831731, + 48.435892856078888, + 51.642803925173029, + 54.84186659475857, + 58.034439083840155, + 61.221578745109862], + [15.466672066554263, + 19.562077985759503, + 23.240325531101082, + 26.746322986645901, + 30.157042415639891, + 33.507642948240263, + 36.817212798512775, + 40.097251300178642, + 43.355193847719752, + 46.596103410173672, + 49.823567279972794, + 53.040208868780832, + 56.247996968470062, + 59.448441365714251, + 62.642721301357187], + [16.574317035530872, + 20.73617763753932, + 24.459631728238804, + 27.999993668839644, + 31.438208790267783, + 34.811512070805535, + 38.140243708611251, + 41.436725143893739, + 44.708963264433333, + 47.962435051891027, + 51.201037321915983, + 54.427630745992975, + 57.644369734615238, + 60.852911791989989, + 64.054555435720397], + [17.676697936439624, + 21.9026148697762, + 25.670073356263225, + 29.244155124266438, + 32.709534477396028, + 36.105399554497548, + 39.453272918267025, + 42.766255701958017, + 46.052899215578358, + 49.319076602061401, + 52.568982147952547, + 55.805705507386287, + 59.031580956740466, + 62.248409689597653, + 65.457606670836759], + [18.774423978290318, + 23.06220035979272, + 26.872520985976736, + 30.479680663499762, + 33.971869047372436, + 37.390118854896324, + 40.757072537673599, + 44.086572292170345, + 47.387688809191869, + 50.66667461073936, + 53.928009929563275, + 57.175005343085052, + 60.410169281219877, + 63.635442539153021, + 66.85235358587768]] + +@pytest.mark.slow +def test_bessel_zeros_extra(): + mp.dps = 15 + for v in range(V): + for m in range(1,M+1): + print(v, m, "of", V, M) + # Twice to test cache (if used) + assert besseljzero(v,m).ae(jn_small_zeros[v][m-1]) + assert besseljzero(v,m).ae(jn_small_zeros[v][m-1]) + assert besseljzero(v,m,1).ae(jnp_small_zeros[v][m-1]) + assert besseljzero(v,m,1).ae(jnp_small_zeros[v][m-1]) + assert besselyzero(v,m).ae(yn_small_zeros[v][m-1]) + assert besselyzero(v,m).ae(yn_small_zeros[v][m-1]) + assert besselyzero(v,m,1).ae(ynp_small_zeros[v][m-1]) + assert besselyzero(v,m,1).ae(ynp_small_zeros[v][m-1]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_ode.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6dbffa79cfd4ca6dbf14f8591296ee48b16682 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_ode.py @@ -0,0 +1,73 @@ +#from mpmath.calculus import ODE_step_euler, ODE_step_rk4, odeint, arange +from mpmath import odefun, cos, sin, mpf, sinc, mp + +''' +solvers = [ODE_step_euler, ODE_step_rk4] + +def test_ode1(): + """ + Let's solve: + + x'' + w**2 * x = 0 + + i.e. x1 = x, x2 = x1': + + x1' = x2 + x2' = -x1 + """ + def derivs((x1, x2), t): + return x2, -x1 + + for solver in solvers: + t = arange(0, 3.1415926, 0.005) + sol = odeint(derivs, (0., 1.), t, solver) + x1 = [a[0] for a in sol] + x2 = [a[1] for a in sol] + # the result is x1 = sin(t), x2 = cos(t) + # let's just check the end points for t = pi + assert abs(x1[-1]) < 1e-2 + assert abs(x2[-1] - (-1)) < 1e-2 + +def test_ode2(): + """ + Let's solve: + + x' - x = 0 + + i.e. x = exp(x) + + """ + def derivs((x), t): + return x + + for solver in solvers: + t = arange(0, 1, 1e-3) + sol = odeint(derivs, (1.,), t, solver) + x = [a[0] for a in sol] + # the result is x = exp(t) + # let's just check the end point for t = 1, i.e. x = e + assert abs(x[-1] - 2.718281828) < 1e-2 +''' + +def test_odefun_rational(): + mp.dps = 15 + # A rational function + f = lambda t: 1/(1+mpf(t)**2) + g = odefun(lambda x, y: [-2*x*y[0]**2], 0, [f(0)]) + assert f(2).ae(g(2)[0]) + +def test_odefun_sinc_large(): + mp.dps = 15 + # Sinc function; test for large x + f = sinc + g = odefun(lambda x, y: [(cos(x)-y[0])/x], 1, [f(1)], tol=0.01, degree=5) + assert abs(f(100) - g(100)[0])/f(100) < 0.01 + +def test_odefun_harmonic(): + mp.dps = 15 + # Harmonic oscillator + f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0]) + for x in [0, 1, 2.5, 8, 3.7]: # we go back to 3.7 to check caching + c, s = f(x) + assert c.ae(cos(x)) + assert s.ae(sin(x)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_quad.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_quad.py new file mode 100644 index 0000000000000000000000000000000000000000..fc71c5f5ef9c0ecd876c988e7d033b321f065cdc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_quad.py @@ -0,0 +1,95 @@ +import pytest +from mpmath import * + +def ae(a, b): + return abs(a-b) < 10**(-mp.dps+5) + +def test_basic_integrals(): + for prec in [15, 30, 100]: + mp.dps = prec + assert ae(quadts(lambda x: x**3 - 3*x**2, [-2, 4]), -12) + assert ae(quadgl(lambda x: x**3 - 3*x**2, [-2, 4]), -12) + assert ae(quadts(sin, [0, pi]), 2) + assert ae(quadts(sin, [0, 2*pi]), 0) + assert ae(quadts(exp, [-inf, -1]), 1/e) + assert ae(quadts(lambda x: exp(-x), [0, inf]), 1) + assert ae(quadts(lambda x: exp(-x*x), [-inf, inf]), sqrt(pi)) + assert ae(quadts(lambda x: 1/(1+x*x), [-1, 1]), pi/2) + assert ae(quadts(lambda x: 1/(1+x*x), [-inf, inf]), pi) + assert ae(quadts(lambda x: 2*sqrt(1-x*x), [-1, 1]), pi) + mp.dps = 15 + +def test_multiple_intervals(): + y,err = quad(lambda x: sign(x), [-0.5, 0.9, 1], maxdegree=2, error=True) + assert abs(y-0.5) < 2*err + +def test_quad_symmetry(): + assert quadts(sin, [-1, 1]) == 0 + assert quadgl(sin, [-1, 1]) == 0 + +def test_quad_infinite_mirror(): + # Check mirrored infinite interval + assert ae(quad(lambda x: exp(-x*x), [inf,-inf]), -sqrt(pi)) + assert ae(quad(lambda x: exp(x), [0,-inf]), -1) + +def test_quadgl_linear(): + assert quadgl(lambda x: x, [0, 1], maxdegree=1).ae(0.5) + +def test_complex_integration(): + assert quadts(lambda x: x, [0, 1+j]).ae(j) + +def test_quadosc(): + mp.dps = 15 + assert quadosc(lambda x: sin(x)/x, [0, inf], period=2*pi).ae(pi/2) + +# Double integrals +def test_double_trivial(): + assert ae(quadts(lambda x, y: x, [0, 1], [0, 1]), 0.5) + assert ae(quadts(lambda x, y: x, [-1, 1], [-1, 1]), 0.0) + +def test_double_1(): + assert ae(quadts(lambda x, y: cos(x+y/2), [-pi/2, pi/2], [0, pi]), 4) + +def test_double_2(): + assert ae(quadts(lambda x, y: (x-1)/((1-x*y)*log(x*y)), [0, 1], [0, 1]), euler) + +def test_double_3(): + assert ae(quadts(lambda x, y: 1/sqrt(1+x*x+y*y), [-1, 1], [-1, 1]), 4*log(2+sqrt(3))-2*pi/3) + +def test_double_4(): + assert ae(quadts(lambda x, y: 1/(1-x*x * y*y), [0, 1], [0, 1]), pi**2 / 8) + +def test_double_5(): + assert ae(quadts(lambda x, y: 1/(1-x*y), [0, 1], [0, 1]), pi**2 / 6) + +def test_double_6(): + assert ae(quadts(lambda x, y: exp(-(x+y)), [0, inf], [0, inf]), 1) + +def test_double_7(): + assert ae(quadts(lambda x, y: exp(-x*x-y*y), [-inf, inf], [-inf, inf]), pi) + + +# Test integrals from "Experimentation in Mathematics" by Borwein, +# Bailey & Girgensohn +def test_expmath_integrals(): + for prec in [15, 30, 50]: + mp.dps = prec + assert ae(quadts(lambda x: x/sinh(x), [0, inf]), pi**2 / 4) + assert ae(quadts(lambda x: log(x)**2 / (1+x**2), [0, inf]), pi**3 / 8) + assert ae(quadts(lambda x: (1+x**2)/(1+x**4), [0, inf]), pi/sqrt(2)) + assert ae(quadts(lambda x: log(x)/cosh(x)**2, [0, inf]), log(pi)-2*log(2)-euler) + assert ae(quadts(lambda x: log(1+x**3)/(1-x+x**2), [0, inf]), 2*pi*log(3)/sqrt(3)) + assert ae(quadts(lambda x: log(x)**2 / (x**2+x+1), [0, 1]), 8*pi**3 / (81*sqrt(3))) + assert ae(quadts(lambda x: log(cos(x))**2, [0, pi/2]), pi/2 * (log(2)**2+pi**2/12)) + assert ae(quadts(lambda x: x**2 / sin(x)**2, [0, pi/2]), pi*log(2)) + assert ae(quadts(lambda x: x**2/sqrt(exp(x)-1), [0, inf]), 4*pi*(log(2)**2 + pi**2/12)) + assert ae(quadts(lambda x: x*exp(-x)*sqrt(1-exp(-2*x)), [0, inf]), pi*(1+2*log(2))/8) + mp.dps = 15 + +# Do not reach full accuracy +@pytest.mark.xfail +def test_expmath_fail(): + assert ae(quadts(lambda x: sqrt(tan(x)), [0, pi/2]), pi*sqrt(2)/2) + assert ae(quadts(lambda x: atan(x)/(x*sqrt(1-x**2)), [0, 1]), pi*log(1+sqrt(2))/2) + assert ae(quadts(lambda x: log(1+x**2)/x**2, [0, 1]), pi/2-log(2)) + assert ae(quadts(lambda x: x**2/((1+x**4)*sqrt(1-x**4)), [0, 1]), pi/8) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_rootfinding.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_rootfinding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3c06463682eb1fd60efeb75b809bbb932a241c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_rootfinding.py @@ -0,0 +1,91 @@ +import pytest +from mpmath import * +from mpmath.calculus.optimization import Secant, Muller, Bisection, Illinois, \ + Pegasus, Anderson, Ridder, ANewton, Newton, MNewton, MDNewton + +def test_findroot(): + # old tests, assuming secant + mp.dps = 15 + assert findroot(lambda x: 4*x-3, mpf(5)).ae(0.75) + assert findroot(sin, mpf(3)).ae(pi) + assert findroot(sin, (mpf(3), mpf(3.14))).ae(pi) + assert findroot(lambda x: x*x+1, mpc(2+2j)).ae(1j) + # test all solvers with 1 starting point + f = lambda x: cos(x) + for solver in [Newton, Secant, MNewton, Muller, ANewton]: + x = findroot(f, 2., solver=solver) + assert abs(f(x)) < eps + # test all solvers with interval of 2 points + for solver in [Secant, Muller, Bisection, Illinois, Pegasus, Anderson, + Ridder]: + x = findroot(f, (1., 2.), solver=solver) + assert abs(f(x)) < eps + # test types + f = lambda x: (x - 2)**2 + + assert isinstance(findroot(f, 1, tol=1e-10), mpf) + assert isinstance(iv.findroot(f, 1., tol=1e-10), iv.mpf) + assert isinstance(fp.findroot(f, 1, tol=1e-10), float) + assert isinstance(fp.findroot(f, 1+0j, tol=1e-10), complex) + + # issue 401 + with pytest.raises(ValueError): + with workprec(2): + findroot(lambda x: x**2 - 4456178*x + 60372201703370, + mpc(real='5.278e+13', imag='-5.278e+13')) + + # issue 192 + with pytest.raises(ValueError): + findroot(lambda x: -1, 0) + + # issue 387 + with pytest.raises(ValueError): + findroot(lambda p: (1 - p)**30 - 1, 0.9) + +def test_bisection(): + # issue 273 + assert findroot(lambda x: x**2-1,(0,2),solver='bisect') == 1 + +def test_mnewton(): + f = lambda x: polyval([1,3,3,1],x) + x = findroot(f, -0.9, solver='mnewton') + assert abs(f(x)) < eps + +def test_anewton(): + f = lambda x: (x - 2)**100 + x = findroot(f, 1., solver=ANewton) + assert abs(f(x)) < eps + +def test_muller(): + f = lambda x: (2 + x)**3 + 2 + x = findroot(f, 1., solver=Muller) + assert abs(f(x)) < eps + +def test_multiplicity(): + for i in range(1, 5): + assert multiplicity(lambda x: (x - 1)**i, 1) == i + assert multiplicity(lambda x: x**2, 1) == 0 + +def test_multidimensional(): + def f(*x): + return [3*x[0]**2-2*x[1]**2-1, x[0]**2-2*x[0]+x[1]**2+2*x[1]-8] + assert mnorm(jacobian(f, (1,-2)) - matrix([[6,8],[0,-2]]),1) < 1.e-7 + for x, error in MDNewton(mp, f, (1,-2), verbose=0, + norm=lambda x: norm(x, inf)): + pass + assert norm(f(*x), 2) < 1e-14 + # The Chinese mathematician Zhu Shijie was the very first to solve this + # nonlinear system 700 years ago + f1 = lambda x, y: -x + 2*y + f2 = lambda x, y: (x**2 + x*(y**2 - 2) - 4*y) / (x + 4) + f3 = lambda x, y: sqrt(x**2 + y**2) + def f(x, y): + f1x = f1(x, y) + return (f2(x, y) - f1x, f3(x, y) - f1x) + x = findroot(f, (10, 10)) + assert [int(round(i)) for i in x] == [3, 4] + +def test_trivial(): + assert findroot(lambda x: 0, 1) == 1 + assert findroot(lambda x: x, 0) == 0 + #assert findroot(lambda x, y: x + y, (1, -1)) == (1, -1) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h new file mode 100644 index 0000000000000000000000000000000000000000..ad79838d00753056d131cad719ddc1d1ff4e9561 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverDn.h @@ -0,0 +1,4874 @@ +/* + * Copyright 2014 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +/* cuSolverDN : Dense Linear Algebra Library + +*/ + +#if !defined(CUSOLVERDN_H_) + #define CUSOLVERDN_H_ + +struct cusolverDnContext; +typedef struct cusolverDnContext *cusolverDnHandle_t; + +struct syevjInfo; +typedef struct syevjInfo *syevjInfo_t; + +struct gesvdjInfo; +typedef struct gesvdjInfo *gesvdjInfo_t; + +//------------------------------------------------------ +// opaque cusolverDnIRS structure for IRS solver +struct cusolverDnIRSParams; +typedef struct cusolverDnIRSParams *cusolverDnIRSParams_t; + +struct cusolverDnIRSInfos; +typedef struct cusolverDnIRSInfos *cusolverDnIRSInfos_t; +//------------------------------------------------------ + +struct cusolverDnParams; +typedef struct cusolverDnParams *cusolverDnParams_t; + +typedef enum { + CUSOLVERDN_GETRF = 0, + CUSOLVERDN_POTRF = 1 +} cusolverDnFunction_t; + + #include + + #include "cuComplex.h" /* import complex data type */ + #include "cublas_v2.h" + #include "cusolver_common.h" + + /*******************************************************************************/ + #ifdef __cplusplus +extern "C" { + #endif + + cusolverStatus_t CUSOLVERAPI cusolverDnCreate(cusolverDnHandle_t *handle); + cusolverStatus_t CUSOLVERAPI cusolverDnDestroy(cusolverDnHandle_t handle); + cusolverStatus_t CUSOLVERAPI + cusolverDnSetStream(cusolverDnHandle_t handle, cudaStream_t streamId); + cusolverStatus_t CUSOLVERAPI + cusolverDnGetStream(cusolverDnHandle_t handle, cudaStream_t *streamId); + + //============================================================ + // IRS headers + //============================================================ + + // ============================================================================= + // IRS helper function API + // ============================================================================= + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsCreate(cusolverDnIRSParams_t *params_ptr); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsDestroy(cusolverDnIRSParams_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetRefinementSolver( + cusolverDnIRSParams_t params, + cusolverIRSRefinement_t refinement_solver); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverMainPrecision( + cusolverDnIRSParams_t params, + cusolverPrecType_t solver_main_precision); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverLowestPrecision( + cusolverDnIRSParams_t params, + cusolverPrecType_t solver_lowest_precision); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverPrecisions( + cusolverDnIRSParams_t params, + cusolverPrecType_t solver_main_precision, + cusolverPrecType_t solver_lowest_precision); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsSetTol(cusolverDnIRSParams_t params, double val); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsSetTolInner(cusolverDnIRSParams_t params, double val); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxIters( + cusolverDnIRSParams_t params, + cusolver_int_t maxiters); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxItersInner( + cusolverDnIRSParams_t params, + cusolver_int_t maxiters_inner); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsGetMaxIters( + cusolverDnIRSParams_t params, + cusolver_int_t * maxiters); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsEnableFallback(cusolverDnIRSParams_t params); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSParamsDisableFallback(cusolverDnIRSParams_t params); + + // ============================================================================= + // cusolverDnIRSInfos prototypes + // ============================================================================= + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSInfosDestroy(cusolverDnIRSInfos_t infos); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSInfosCreate(cusolverDnIRSInfos_t *infos_ptr); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetNiters( + cusolverDnIRSInfos_t infos, + cusolver_int_t * niters); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetOuterNiters( + cusolverDnIRSInfos_t infos, + cusolver_int_t * outer_niters); + + cusolverStatus_t CUSOLVERAPI + cusolverDnIRSInfosRequestResidual(cusolverDnIRSInfos_t infos); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetResidualHistory( + cusolverDnIRSInfos_t infos, + void ** residual_history); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetMaxIters( + cusolverDnIRSInfos_t infos, + cusolver_int_t * maxiters); + + //============================================================ + // IRS functions API + //============================================================ + + /*******************************************************************************/ /* + * [ZZ, ZC, ZK, ZE, ZY, CC, CK, CE, CY, DD, DS, DH, DB, DX, SS, SH, SB, SX]gesv + * users API Prototypes */ + /*******************************************************************************/ + cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZEgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZYgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCEgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCYgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDBgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDXgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSBgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSXgesv( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + /*******************************************************************************/ + + /*******************************************************************************/ /* + * [ZZ, ZC, ZK, ZE, ZY, CC, CK, CE, CY, DD, DS, DH, DB, DX, SS, SH, SB, SX]gesv_bufferSize + * users API Prototypes */ + /*******************************************************************************/ + cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZEgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZYgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCEgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCYgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDBgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDXgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSBgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSXgesv_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + cusolver_int_t * dipiv, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + /*******************************************************************************/ + + /*******************************************************************************/ /* + * [ZZ, ZC, ZK, ZE, ZY, CC, CK, CE, CY, DD, DS, DH, DB, DX, SS, SH, SB, SX]gels + * users API Prototypes */ + /*******************************************************************************/ + cusolverStatus_t CUSOLVERAPI cusolverDnZZgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZCgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZKgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZEgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZYgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCCgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCKgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCEgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCYgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDDgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDSgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDHgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDBgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDXgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSSgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSHgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSBgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnSXgels( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * iter, + cusolver_int_t * d_info); + /*******************************************************************************/ + + /*******************************************************************************/ /* + * [ZZ, ZC, ZK, ZE, ZY, CC, CK, CE, CY, DD, DS, DH, DB, DX, SS, SH, SB, SX]gels_bufferSize + * API prototypes */ + /*******************************************************************************/ + cusolverStatus_t CUSOLVERAPI cusolverDnZZgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZCgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZKgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZEgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnZYgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuDoubleComplex * dA, + cusolver_int_t ldda, + cuDoubleComplex * dB, + cusolver_int_t lddb, + cuDoubleComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCCgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCKgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCEgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnCYgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + cuComplex * dA, + cusolver_int_t ldda, + cuComplex * dB, + cusolver_int_t lddb, + cuComplex * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDDgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDSgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDHgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDBgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnDXgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + double * dA, + cusolver_int_t ldda, + double * dB, + cusolver_int_t lddb, + double * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSSgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSHgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSBgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnSXgels_bufferSize( + cusolverDnHandle_t handle, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + float * dA, + cusolver_int_t ldda, + float * dB, + cusolver_int_t lddb, + float * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t * lwork_bytes); + /*******************************************************************************/ + + /*******************************************************************************/ /* + * expert users API for IRS Prototypes + * */ + /*******************************************************************************/ + cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv( + cusolverDnHandle_t handle, + cusolverDnIRSParams_t gesv_irs_params, + cusolverDnIRSInfos_t gesv_irs_infos, + cusolver_int_t n, + cusolver_int_t nrhs, + void * dA, + cusolver_int_t ldda, + void * dB, + cusolver_int_t lddb, + void * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * niters, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv_bufferSize( + cusolverDnHandle_t handle, + cusolverDnIRSParams_t params, + cusolver_int_t n, + cusolver_int_t nrhs, + size_t * lwork_bytes); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgels( + cusolverDnHandle_t handle, + cusolverDnIRSParams_t gels_irs_params, + cusolverDnIRSInfos_t gels_irs_infos, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + void * dA, + cusolver_int_t ldda, + void * dB, + cusolver_int_t lddb, + void * dX, + cusolver_int_t lddx, + void * dWorkspace, + size_t lwork_bytes, + cusolver_int_t * niters, + cusolver_int_t * d_info); + + cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgels_bufferSize( + cusolverDnHandle_t handle, + cusolverDnIRSParams_t params, + cusolver_int_t m, + cusolver_int_t n, + cusolver_int_t nrhs, + size_t * lwork_bytes); + /*******************************************************************************/ + + /* Cholesky factorization and its solver */ + cusolverStatus_t CUSOLVERAPI cusolverDnSpotrf_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotrf_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotrf_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotrf_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSpotrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnSpotrs( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, + const float * A, + int lda, + float * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotrs( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, + const double * A, + int lda, + double * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotrs( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, + const cuComplex * A, + int lda, + cuComplex * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotrs( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, + const cuDoubleComplex *A, + int lda, + cuDoubleComplex * B, + int ldb, + int * devInfo); + + /* batched Cholesky factorization and its solver */ + cusolverStatus_t CUSOLVERAPI cusolverDnSpotrfBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * Aarray[], + int lda, + int * infoArray, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotrfBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * Aarray[], + int lda, + int * infoArray, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotrfBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * Aarray[], + int lda, + int * infoArray, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotrfBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * Aarray[], + int lda, + int * infoArray, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSpotrsBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, /* only support rhs = 1*/ + float * A[], + int lda, + float * B[], + int ldb, + int * d_info, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotrsBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, /* only support rhs = 1*/ + double * A[], + int lda, + double * B[], + int ldb, + int * d_info, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotrsBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, /* only support rhs = 1*/ + cuComplex * A[], + int lda, + cuComplex * B[], + int ldb, + int * d_info, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotrsBatched( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + int nrhs, /* only support rhs = 1*/ + cuDoubleComplex * A[], + int lda, + cuDoubleComplex * B[], + int ldb, + int * d_info, + int batchSize); + + /* s.p.d. matrix inversion (POTRI) and auxiliary routines (TRTRI and LAUUM) */ + cusolverStatus_t CUSOLVERAPI cusolverDnSpotri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSpotri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDpotri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCpotri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZpotri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnXtrtri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXtrtri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * devInfo); + + /* lauum, auxiliar routine for s.p.d matrix inversion */ + cusolverStatus_t CUSOLVERAPI cusolverDnSlauum_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDlauum_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnClauum_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZlauum_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSlauum( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDlauum( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnClauum( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZlauum( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * work, + int lwork, + int * devInfo); + + /* LU Factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + float * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + double * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgetrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + cuComplex * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgetrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + cuDoubleComplex * A, + int lda, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf( + cusolverDnHandle_t handle, + int m, + int n, + float * A, + int lda, + float * Workspace, + int * devIpiv, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf( + cusolverDnHandle_t handle, + int m, + int n, + double * A, + int lda, + double * Workspace, + int * devIpiv, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgetrf( + cusolverDnHandle_t handle, + int m, + int n, + cuComplex * A, + int lda, + cuComplex * Workspace, + int * devIpiv, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgetrf( + cusolverDnHandle_t handle, + int m, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * Workspace, + int * devIpiv, + int * devInfo); + + /* Row pivoting */ + cusolverStatus_t CUSOLVERAPI cusolverDnSlaswp( + cusolverDnHandle_t handle, + int n, + float * A, + int lda, + int k1, + int k2, + const int * devIpiv, + int incx); + + cusolverStatus_t CUSOLVERAPI cusolverDnDlaswp( + cusolverDnHandle_t handle, + int n, + double * A, + int lda, + int k1, + int k2, + const int * devIpiv, + int incx); + + cusolverStatus_t CUSOLVERAPI cusolverDnClaswp( + cusolverDnHandle_t handle, + int n, + cuComplex * A, + int lda, + int k1, + int k2, + const int * devIpiv, + int incx); + + cusolverStatus_t CUSOLVERAPI cusolverDnZlaswp( + cusolverDnHandle_t handle, + int n, + cuDoubleComplex * A, + int lda, + int k1, + int k2, + const int * devIpiv, + int incx); + + /* LU solve */ + cusolverStatus_t CUSOLVERAPI cusolverDnSgetrs( + cusolverDnHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const float * A, + int lda, + const int * devIpiv, + float * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgetrs( + cusolverDnHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const double * A, + int lda, + const int * devIpiv, + double * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgetrs( + cusolverDnHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const cuComplex * A, + int lda, + const int * devIpiv, + cuComplex * B, + int ldb, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgetrs( + cusolverDnHandle_t handle, + cublasOperation_t trans, + int n, + int nrhs, + const cuDoubleComplex *A, + int lda, + const int * devIpiv, + cuDoubleComplex * B, + int ldb, + int * devInfo); + + /* QR factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + float * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + double * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgeqrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + cuComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgeqrf_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + cuDoubleComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf( + cusolverDnHandle_t handle, + int m, + int n, + float * A, + int lda, + float * TAU, + float * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf( + cusolverDnHandle_t handle, + int m, + int n, + double * A, + int lda, + double * TAU, + double * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgeqrf( + cusolverDnHandle_t handle, + int m, + int n, + cuComplex * A, + int lda, + cuComplex * TAU, + cuComplex * Workspace, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgeqrf( + cusolverDnHandle_t handle, + int m, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * TAU, + cuDoubleComplex * Workspace, + int Lwork, + int * devInfo); + + /* generate unitary matrix Q from QR factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int k, + const float * A, + int lda, + const float * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int k, + const double * A, + int lda, + const double * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungqr_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int k, + const cuComplex * A, + int lda, + const cuComplex * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungqr_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int k, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr( + cusolverDnHandle_t handle, + int m, + int n, + int k, + float * A, + int lda, + const float * tau, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr( + cusolverDnHandle_t handle, + int m, + int n, + int k, + double * A, + int lda, + const double * tau, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungqr( + cusolverDnHandle_t handle, + int m, + int n, + int k, + cuComplex * A, + int lda, + const cuComplex * tau, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungqr( + cusolverDnHandle_t handle, + int m, + int n, + int k, + cuDoubleComplex * A, + int lda, + const cuDoubleComplex *tau, + cuDoubleComplex * work, + int lwork, + int * info); + + /* compute Q**T*b in solve min||A*x = b|| */ + cusolverStatus_t CUSOLVERAPI cusolverDnSormqr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const float * A, + int lda, + const float * tau, + const float * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDormqr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const double * A, + int lda, + const double * tau, + const double * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const cuComplex * A, + int lda, + const cuComplex * tau, + const cuComplex * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + const cuDoubleComplex *C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSormqr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const float * A, + int lda, + const float * tau, + float * C, + int ldc, + float * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDormqr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const double * A, + int lda, + const double * tau, + double * C, + int ldc, + double * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const cuComplex * A, + int lda, + const cuComplex * tau, + cuComplex * C, + int ldc, + cuComplex * work, + int lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasOperation_t trans, + int m, + int n, + int k, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + cuDoubleComplex * C, + int ldc, + cuDoubleComplex * work, + int lwork, + int * devInfo); + + /* L*D*L**T,U*D*U**T factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf_bufferSize( + cusolverDnHandle_t handle, + int n, + float * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf_bufferSize( + cusolverDnHandle_t handle, + int n, + double * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf_bufferSize( + cusolverDnHandle_t handle, + int n, + cuComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf_bufferSize( + cusolverDnHandle_t handle, + int n, + cuDoubleComplex * A, + int lda, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + int * ipiv, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + int * ipiv, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + int * ipiv, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + int * ipiv, + cuDoubleComplex * work, + int lwork, + int * info); + + /* Symmetric indefinite solve (SYTRS) */ + cusolverStatus_t CUSOLVERAPI cusolverDnXsytrs_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + const int64_t * ipiv, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXsytrs( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + const int64_t * ipiv, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* Symmetric indefinite inversion (sytri) */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsytri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + const int * ipiv, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + const int * ipiv, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCsytri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + const int * ipiv, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZsytri_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + const int * ipiv, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsytri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + const int * ipiv, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + const int * ipiv, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCsytri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + const int * ipiv, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZsytri( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + const int * ipiv, + cuDoubleComplex * work, + int lwork, + int * info); + + /* bidiagonal factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * Lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd( + cusolverDnHandle_t handle, + int m, + int n, + float * A, + int lda, + float * D, + float * E, + float * TAUQ, + float * TAUP, + float * Work, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd( + cusolverDnHandle_t handle, + int m, + int n, + double * A, + int lda, + double * D, + double * E, + double * TAUQ, + double * TAUP, + double * Work, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd( + cusolverDnHandle_t handle, + int m, + int n, + cuComplex * A, + int lda, + float * D, + float * E, + cuComplex * TAUQ, + cuComplex * TAUP, + cuComplex * Work, + int Lwork, + int * devInfo); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd( + cusolverDnHandle_t handle, + int m, + int n, + cuDoubleComplex * A, + int lda, + double * D, + double * E, + cuDoubleComplex * TAUQ, + cuDoubleComplex * TAUP, + cuDoubleComplex * Work, + int Lwork, + int * devInfo); + + /* generates one of the unitary matrices Q or P**T determined by GEBRD*/ + cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + const float * A, + int lda, + const float * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + const double * A, + int lda, + const double * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungbr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + const cuComplex * A, + int lda, + const cuComplex * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungbr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + float * A, + int lda, + const float * tau, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + double * A, + int lda, + const double * tau, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungbr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + cuComplex * A, + int lda, + const cuComplex * tau, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungbr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + int m, + int n, + int k, + cuDoubleComplex * A, + int lda, + const cuDoubleComplex *tau, + cuDoubleComplex * work, + int lwork, + int * info); + + /* tridiagonal factorization */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * d, + const float * e, + const float * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * d, + const double * e, + const double * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnChetrd_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const float * d, + const float * e, + const cuComplex * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const double * d, + const double * e, + const cuDoubleComplex *tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * d, + float * e, + float * tau, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * d, + double * e, + double * tau, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnChetrd( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + float * d, + float * e, + cuComplex * tau, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + double * d, + double * e, + cuDoubleComplex * tau, + cuDoubleComplex * work, + int lwork, + int * info); + + /* generate unitary Q comes from sytrd */ + cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungtr_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const cuComplex * tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungtr_bufferSize( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + const float * tau, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + const double * tau, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCungtr( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + const cuComplex * tau, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZungtr( + cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + const cuDoubleComplex *tau, + cuDoubleComplex * work, + int lwork, + int * info); + + /* compute op(Q)*C or C*op(Q) where Q comes from sytrd */ + cusolverStatus_t CUSOLVERAPI cusolverDnSormtr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + const float * A, + int lda, + const float * tau, + const float * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDormtr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + const double * A, + int lda, + const double * tau, + const double * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCunmtr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + const cuComplex * A, + int lda, + const cuComplex * tau, + const cuComplex * C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr_bufferSize( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *tau, + const cuDoubleComplex *C, + int ldc, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSormtr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + float * A, + int lda, + float * tau, + float * C, + int ldc, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDormtr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + double * A, + int lda, + double * tau, + double * C, + int ldc, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCunmtr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + cuComplex * A, + int lda, + cuComplex * tau, + cuComplex * C, + int ldc, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr( + cusolverDnHandle_t handle, + cublasSideMode_t side, + cublasFillMode_t uplo, + cublasOperation_t trans, + int m, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * tau, + cuDoubleComplex * C, + int ldc, + cuDoubleComplex * work, + int lwork, + int * info); + + /* singular value decomposition, A = U * Sigma * V^H */ + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvd_bufferSize( + cusolverDnHandle_t handle, + int m, + int n, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd( + cusolverDnHandle_t handle, + signed char jobu, + signed char jobvt, + int m, + int n, + float * A, + int lda, + float * S, + float * U, + int ldu, + float * VT, + int ldvt, + float * work, + int lwork, + float * rwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd( + cusolverDnHandle_t handle, + signed char jobu, + signed char jobvt, + int m, + int n, + double * A, + int lda, + double * S, + double * U, + int ldu, + double * VT, + int ldvt, + double * work, + int lwork, + double * rwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvd( + cusolverDnHandle_t handle, + signed char jobu, + signed char jobvt, + int m, + int n, + cuComplex * A, + int lda, + float * S, + cuComplex * U, + int ldu, + cuComplex * VT, + int ldvt, + cuComplex * work, + int lwork, + float * rwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvd( + cusolverDnHandle_t handle, + signed char jobu, + signed char jobvt, + int m, + int n, + cuDoubleComplex * A, + int lda, + double * S, + cuDoubleComplex * U, + int ldu, + cuDoubleComplex * VT, + int ldvt, + cuDoubleComplex * work, + int lwork, + double * rwork, + int * info); + + /* standard symmetric eigenvalue solver, A*x = lambda*x, by divide-and-conquer + */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * W, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * W, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevd( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + float * W, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevd( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + double * W, + cuDoubleComplex * work, + int lwork, + int * info); + + /* standard selective symmetric eigenvalue solver, A*x = lambda*x, by + * divide-and-conquer */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + float vl, + float vu, + int il, + int iu, + int * meig, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + double vl, + double vu, + int il, + int iu, + int * meig, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + float vl, + float vu, + int il, + int iu, + int * meig, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + double vl, + double vu, + int il, + int iu, + int * meig, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float vl, + float vu, + int il, + int iu, + int * meig, + float * W, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double vl, + double vu, + int il, + int iu, + int * meig, + double * W, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevdx( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + float vl, + float vu, + int il, + int iu, + int * meig, + float * W, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + double vl, + double vu, + int il, + int iu, + int * meig, + double * W, + cuDoubleComplex * work, + int lwork, + int * info); + + /* selective generalized symmetric eigenvalue solver, A*x = lambda*B*x, by + * divide-and-conquer */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * B, + int ldb, + float vl, + float vu, + int il, + int iu, + int * meig, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * B, + int ldb, + double vl, + double vu, + int il, + int iu, + int * meig, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const cuComplex * B, + int ldb, + float vl, + float vu, + int il, + int iu, + int * meig, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + double vl, + double vu, + int il, + int iu, + int * meig, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * B, + int ldb, + float vl, + float vu, + int il, + int iu, + int * meig, + float * W, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * B, + int ldb, + double vl, + double vu, + int il, + int iu, + int * meig, + double * W, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * B, + int ldb, + float vl, + float vu, + int il, + int iu, + int * meig, + float * W, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * B, + int ldb, + double vl, + double vu, + int il, + int iu, + int * meig, + double * W, + cuDoubleComplex * work, + int lwork, + int * info); + + /* generalized symmetric eigenvalue solver, A*x = lambda*B*x, by + * divide-and-conquer */ + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * B, + int ldb, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * B, + int ldb, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const cuComplex * B, + int ldb, + const float * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvd_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const double * W, + int * lwork); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * B, + int ldb, + float * W, + float * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * B, + int ldb, + double * W, + double * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvd( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * B, + int ldb, + float * W, + cuComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvd( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * B, + int ldb, + double * W, + cuDoubleComplex * work, + int lwork, + int * info); + + cusolverStatus_t CUSOLVERAPI cusolverDnCreateSyevjInfo(syevjInfo_t *info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDestroySyevjInfo(syevjInfo_t info); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXsyevjSetTolerance(syevjInfo_t info, double tolerance); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXsyevjSetMaxSweeps(syevjInfo_t info, int max_sweeps); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXsyevjSetSortEig(syevjInfo_t info, int sort_eig); + + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetResidual( + cusolverDnHandle_t handle, + syevjInfo_t info, + double * residual); + + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetSweeps( + cusolverDnHandle_t handle, + syevjInfo_t info, + int * executed_sweeps); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * W, + int * lwork, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * W, + int * lwork, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const float * W, + int * lwork, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const double * W, + int * lwork, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * W, + float * work, + int lwork, + int * info, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * W, + double * work, + int lwork, + int * info, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + float * W, + cuComplex * work, + int lwork, + int * info, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + double * W, + cuDoubleComplex * work, + int lwork, + int * info, + syevjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const float * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const double * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * W, + float * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * W, + double * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnCheevj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + float * W, + cuComplex * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZheevj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + double * W, + cuDoubleComplex * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const float * A, + int lda, + const float * B, + int ldb, + const float * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const double * A, + int lda, + const double * B, + int ldb, + const double * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuComplex * A, + int lda, + const cuComplex * B, + int ldb, + const float * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const cuDoubleComplex *A, + int lda, + const cuDoubleComplex *B, + int ldb, + const double * W, + int * lwork, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + float * A, + int lda, + float * B, + int ldb, + float * W, + float * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + double * A, + int lda, + double * B, + int ldb, + double * W, + double * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnChegvj( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuComplex * A, + int lda, + cuComplex * B, + int ldb, + float * W, + cuComplex * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj( + cusolverDnHandle_t handle, + cusolverEigType_t itype, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + cuDoubleComplex * A, + int lda, + cuDoubleComplex * B, + int ldb, + double * W, + cuDoubleComplex * work, + int lwork, + int * info, + syevjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnCreateGesvdjInfo(gesvdjInfo_t *info); + + cusolverStatus_t CUSOLVERAPI cusolverDnDestroyGesvdjInfo(gesvdjInfo_t info); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXgesvdjSetTolerance(gesvdjInfo_t info, double tolerance); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXgesvdjSetMaxSweeps(gesvdjInfo_t info, int max_sweeps); + + cusolverStatus_t CUSOLVERAPI + cusolverDnXgesvdjSetSortEig(gesvdjInfo_t info, int sort_svd); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetResidual( + cusolverDnHandle_t handle, + gesvdjInfo_t info, + double * residual); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetSweeps( + cusolverDnHandle_t handle, + gesvdjInfo_t info, + int * executed_sweeps); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + const float * A, + int lda, + const float * S, + const float * U, + int ldu, + const float * V, + int ldv, + int * lwork, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + const double * A, + int lda, + const double * S, + const double * U, + int ldu, + const double * V, + int ldv, + int * lwork, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + const cuComplex * A, + int lda, + const float * S, + const cuComplex * U, + int ldu, + const cuComplex * V, + int ldv, + int * lwork, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + const cuDoubleComplex *A, + int lda, + const double * S, + const cuDoubleComplex *U, + int ldu, + const cuDoubleComplex *V, + int ldv, + int * lwork, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + float * A, + int lda, + float * S, + float * U, + int ldu, + float * V, + int ldv, + float * work, + int lwork, + int * info, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + double * A, + int lda, + double * S, + double * U, + int ldu, + double * V, + int ldv, + double * work, + int lwork, + int * info, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + cuComplex * A, + int lda, + float * S, + cuComplex * U, + int ldu, + cuComplex * V, + int ldv, + cuComplex * work, + int lwork, + int * info, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int m, + int n, + cuDoubleComplex * A, + int lda, + double * S, + cuDoubleComplex * U, + int ldu, + cuDoubleComplex * V, + int ldv, + cuDoubleComplex * work, + int lwork, + int * info, + gesvdjInfo_t params, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + const float * A, + int lda, + const float * S, + const float * U, + int ldu, + const float * V, + int ldv, + int * lwork, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + const double * A, + int lda, + const double * S, + const double * U, + int ldu, + const double * V, + int ldv, + int * lwork, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + const cuComplex * A, + int lda, + const float * S, + const cuComplex * U, + int ldu, + const cuComplex * V, + int ldv, + int * lwork, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + const cuDoubleComplex *A, + int lda, + const double * S, + const cuDoubleComplex *U, + int ldu, + const cuDoubleComplex *V, + int ldv, + int * lwork, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + float * A, + int lda, + float * S, + float * U, + int ldu, + float * V, + int ldv, + float * work, + int lwork, + int * info, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + double * A, + int lda, + double * S, + double * U, + int ldu, + double * V, + int ldv, + double * work, + int lwork, + int * info, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + cuComplex * A, + int lda, + float * S, + cuComplex * U, + int ldu, + cuComplex * V, + int ldv, + cuComplex * work, + int lwork, + int * info, + gesvdjInfo_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int econ, + int m, + int n, + cuDoubleComplex * A, + int lda, + double * S, + cuDoubleComplex * U, + int ldu, + cuDoubleComplex * V, + int ldv, + cuDoubleComplex * work, + int lwork, + int * info, + gesvdjInfo_t params); + + /* batched approximate SVD */ + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const float * d_A, + int lda, + long long int strideA, + const float * d_S, + long long int strideS, + const float * d_U, + int ldu, + long long int strideU, + const float * d_V, + int ldv, + long long int strideV, + int * lwork, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const double * d_A, + int lda, + long long int strideA, + const double * d_S, + long long int strideS, + const double * d_U, + int ldu, + long long int strideU, + const double * d_V, + int ldv, + long long int strideV, + int * lwork, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const cuComplex * d_A, + int lda, + long long int strideA, + const float * d_S, + long long int strideS, + const cuComplex * d_U, + int ldu, + long long int strideU, + const cuComplex * d_V, + int ldv, + long long int strideV, + int * lwork, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const cuDoubleComplex *d_A, + int lda, + long long int strideA, + const double * d_S, + long long int strideS, + const cuDoubleComplex *d_U, + int ldu, + long long int strideU, + const cuDoubleComplex *d_V, + int ldv, + long long int strideV, + int * lwork, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const float * d_A, + int lda, + long long int strideA, + float * d_S, + long long int strideS, + float * d_U, + int ldu, + long long int strideU, + float * d_V, + int ldv, + long long int strideV, + float * d_work, + int lwork, + int * d_info, + double * h_R_nrmF, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const double * d_A, + int lda, + long long int strideA, + double * d_S, + long long int strideS, + double * d_U, + int ldu, + long long int strideU, + double * d_V, + int ldv, + long long int strideV, + double * d_work, + int lwork, + int * d_info, + double * h_R_nrmF, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const cuComplex * d_A, + int lda, + long long int strideA, + float * d_S, + long long int strideS, + cuComplex * d_U, + int ldu, + long long int strideU, + cuComplex * d_V, + int ldv, + long long int strideV, + cuComplex * d_work, + int lwork, + int * d_info, + double * h_R_nrmF, + int batchSize); + + cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched( + cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + int rank, + int m, + int n, + const cuDoubleComplex *d_A, + int lda, + long long int strideA, + double * d_S, + long long int strideS, + cuDoubleComplex * d_U, + int ldu, + long long int strideU, + cuDoubleComplex * d_V, + int ldv, + long long int strideV, + cuDoubleComplex * d_work, + int lwork, + int * d_info, + double * h_R_nrmF, + int batchSize); + + cusolverStatus_t CUSOLVERAPI + cusolverDnCreateParams(cusolverDnParams_t *params); + + cusolverStatus_t CUSOLVERAPI + cusolverDnDestroyParams(cusolverDnParams_t params); + + cusolverStatus_t CUSOLVERAPI cusolverDnSetAdvOptions( + cusolverDnParams_t params, + cusolverDnFunction_t function, + cusolverAlgMode_t algo); + + /* 64-bit API for POTRF */ + CUSOLVER_DEPRECATED(cusolverDnXpotrf_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnPotrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXpotrf) + cusolverStatus_t CUSOLVERAPI cusolverDnPotrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* 64-bit API for POTRS */ + CUSOLVER_DEPRECATED(cusolverDnXpotrs) + cusolverStatus_t CUSOLVERAPI cusolverDnPotrs( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + int * info); + + /* 64-bit API for GEQRF */ + CUSOLVER_DEPRECATED(cusolverDnXgeqrf_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnGeqrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeTau, + const void * tau, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXgeqrf) + cusolverStatus_t CUSOLVERAPI cusolverDnGeqrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeTau, + void * tau, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* 64-bit API for GETRF */ + CUSOLVER_DEPRECATED(cusolverDnXgetrf_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnGetrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXgetrf) + cusolverStatus_t CUSOLVERAPI cusolverDnGetrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + int64_t * ipiv, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* 64-bit API for GETRS */ + CUSOLVER_DEPRECATED(cusolverDnXgetrs) + cusolverStatus_t CUSOLVERAPI cusolverDnGetrs( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasOperation_t trans, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + const int64_t * ipiv, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + int * info); + + /* 64-bit API for SYEVD */ + CUSOLVER_DEPRECATED(cusolverDnXsyevd_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnSyevd_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeW, + const void * W, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXsyevd) + cusolverStatus_t CUSOLVERAPI cusolverDnSyevd( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeW, + void * W, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* 64-bit API for SYEVDX */ + CUSOLVER_DEPRECATED(cusolverDnXsyevdx_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnSyevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + void * vl, + void * vu, + int64_t il, + int64_t iu, + int64_t * h_meig, + cudaDataType dataTypeW, + const void * W, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXsyevdx) + cusolverStatus_t CUSOLVERAPI cusolverDnSyevdx( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + void * vl, + void * vu, + int64_t il, + int64_t iu, + int64_t * meig64, + cudaDataType dataTypeW, + void * W, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* 64-bit API for GESVD */ + CUSOLVER_DEPRECATED(cusolverDnXgesvd_bufferSize) + cusolverStatus_t CUSOLVERAPI cusolverDnGesvd_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobvt, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeS, + const void * S, + cudaDataType dataTypeU, + const void * U, + int64_t ldu, + cudaDataType dataTypeVT, + const void * VT, + int64_t ldvt, + cudaDataType computeType, + size_t * workspaceInBytes); + + CUSOLVER_DEPRECATED(cusolverDnXgesvd) + cusolverStatus_t CUSOLVERAPI cusolverDnGesvd( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobvt, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeS, + void * S, + cudaDataType dataTypeU, + void * U, + int64_t ldu, + cudaDataType dataTypeVT, + void * VT, + int64_t ldvt, + cudaDataType computeType, + void * pBuffer, + size_t workspaceInBytes, + int * info); + + /* + * new 64-bit API + */ + /* 64-bit API for POTRF */ + cusolverStatus_t CUSOLVERAPI cusolverDnXpotrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXpotrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for POTRS */ + cusolverStatus_t CUSOLVERAPI cusolverDnXpotrs( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasFillMode_t uplo, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + int * info); + + /* 64-bit API for GEQRF */ + cusolverStatus_t CUSOLVERAPI cusolverDnXgeqrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeTau, + const void * tau, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgeqrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeTau, + void * tau, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for GETRF */ + cusolverStatus_t CUSOLVERAPI cusolverDnXgetrf_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgetrf( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + int64_t * ipiv, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for GETRS */ + cusolverStatus_t CUSOLVERAPI cusolverDnXgetrs( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cublasOperation_t trans, + int64_t n, + int64_t nrhs, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + const int64_t * ipiv, + cudaDataType dataTypeB, + void * B, + int64_t ldb, + int * info); + + /* 64-bit API for SYEVD */ + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevd_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeW, + const void * W, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevd( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeW, + void * W, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for SYEVDX */ + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevdx_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + void * vl, + void * vu, + int64_t il, + int64_t iu, + int64_t * h_meig, + cudaDataType dataTypeW, + const void * W, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXsyevdx( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cusolverEigRange_t range, + cublasFillMode_t uplo, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + void * vl, + void * vu, + int64_t il, + int64_t iu, + int64_t * meig64, + cudaDataType dataTypeW, + void * W, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for GESVD */ + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvd_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobvt, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeS, + const void * S, + cudaDataType dataTypeU, + const void * U, + int64_t ldu, + cudaDataType dataTypeVT, + const void * VT, + int64_t ldvt, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvd( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobvt, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeS, + void * S, + cudaDataType dataTypeU, + void * U, + int64_t ldu, + cudaDataType dataTypeVT, + void * VT, + int64_t ldvt, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * info); + + /* 64-bit API for GESVDP */ + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdp_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + int econ, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeS, + const void * S, + cudaDataType dataTypeU, + const void * U, + int64_t ldu, + cudaDataType dataTypeV, + const void * V, + int64_t ldv, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdp( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + int econ, + int64_t m, + int64_t n, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeS, + void * S, + cudaDataType dataTypeU, + void * U, + int64_t ldu, + cudaDataType dataTypeV, + void * V, + int64_t ldv, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * d_info, + double * h_err_sigma); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdr_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobv, + int64_t m, + int64_t n, + int64_t k, + int64_t p, + int64_t niters, + cudaDataType dataTypeA, + const void * A, + int64_t lda, + cudaDataType dataTypeSrand, + const void * Srand, + cudaDataType dataTypeUrand, + const void * Urand, + int64_t ldUrand, + cudaDataType dataTypeVrand, + const void * Vrand, + int64_t ldVrand, + cudaDataType computeType, + size_t * workspaceInBytesOnDevice, + size_t * workspaceInBytesOnHost); + + cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdr( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + signed char jobu, + signed char jobv, + int64_t m, + int64_t n, + int64_t k, + int64_t p, + int64_t niters, + cudaDataType dataTypeA, + void * A, + int64_t lda, + cudaDataType dataTypeSrand, + void * Srand, + cudaDataType dataTypeUrand, + void * Urand, + int64_t ldUrand, + cudaDataType dataTypeVrand, + void * Vrand, + int64_t ldVrand, + cudaDataType computeType, + void * bufferOnDevice, + size_t workspaceInBytesOnDevice, + void * bufferOnHost, + size_t workspaceInBytesOnHost, + int * d_info); + + typedef void (*cusolverDnLoggerCallback_t)( + int logLevel, + const char *functionName, + const char *message); + + cusolverStatus_t CUSOLVERAPI + cusolverDnLoggerSetCallback(cusolverDnLoggerCallback_t callback); + + cusolverStatus_t CUSOLVERAPI cusolverDnLoggerSetFile(FILE *file); + + cusolverStatus_t CUSOLVERAPI cusolverDnLoggerOpenFile(const char *logFile); + + cusolverStatus_t CUSOLVERAPI cusolverDnLoggerSetLevel(int level); + + cusolverStatus_t CUSOLVERAPI cusolverDnLoggerSetMask(int mask); + + cusolverStatus_t CUSOLVERAPI cusolverDnLoggerForceDisable(); + + #if defined(__cplusplus) +} + #endif /* __cplusplus */ + +#endif /* !defined(CUDENSE_H_) */ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolver_common.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolver_common.h new file mode 100644 index 0000000000000000000000000000000000000000..148be384bcabf2eec1b65f8e7a16c3efb4f3da66 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolver_common.h @@ -0,0 +1,266 @@ +/* + * Copyright 2014 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#if !defined(CUSOLVER_COMMON_H_) + #define CUSOLVER_COMMON_H_ + + #include "library_types.h" + + #ifndef CUSOLVERAPI + #ifdef _WIN32 + #define CUSOLVERAPI __stdcall + #else + #define CUSOLVERAPI + #endif + #endif + + #if defined(_MSC_VER) +typedef __int64 int64_t; + #else + #include + #endif + +typedef int cusolver_int_t; + + #define CUSOLVER_VER_MAJOR 11 + #define CUSOLVER_VER_MINOR 4 + #define CUSOLVER_VER_PATCH 1 + #define CUSOLVER_VER_BUILD 48 + #define CUSOLVER_VERSION \ + (CUSOLVER_VER_MAJOR * 1000 + CUSOLVER_VER_MINOR * 100 + CUSOLVER_VER_PATCH) + + /* + * disable this macro to proceed old API + */ + #define DISABLE_CUSOLVER_DEPRECATED + +//------------------------------------------------------------------------------ + + #if !defined(_MSC_VER) + #define CUSOLVER_CPP_VERSION __cplusplus + #elif _MSC_FULL_VER >= 190024210 // Visual Studio 2015 Update 3 + #define CUSOLVER_CPP_VERSION _MSVC_LANG + #else + #define CUSOLVER_CPP_VERSION 0 + #endif + +//------------------------------------------------------------------------------ + + #if !defined(DISABLE_CUSOLVER_DEPRECATED) + + #if CUSOLVER_CPP_VERSION >= 201402L + + #define CUSOLVER_DEPRECATED(new_func) \ + [[deprecated("please use " #new_func " instead")]] + + #elif defined(_MSC_VER) + + #define CUSOLVER_DEPRECATED(new_func) \ + __declspec(deprecated("please use " #new_func " instead")) + + #elif defined(__INTEL_COMPILER) || defined(__clang__) || \ + (defined(__GNUC__) && \ + (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5))) + + #define CUSOLVER_DEPRECATED(new_func) \ + __attribute__((deprecated("please use " #new_func " instead"))) + + #elif defined(__GNUC__) || defined(__xlc__) + + #define CUSOLVER_DEPRECATED(new_func) __attribute__((deprecated)) + + #else + + #define CUSOLVER_DEPRECATED(new_func) + + #endif // defined(__cplusplus) && __cplusplus >= 201402L + //------------------------------------------------------------------------------ + + #if CUSOLVER_CPP_VERSION >= 201703L + + #define CUSOLVER_DEPRECATED_ENUM(new_enum) \ + [[deprecated("please use " #new_enum " instead")]] + + #elif defined(__clang__) || \ + (defined(__GNUC__) && __GNUC__ >= 6 && !defined(__PGI)) + + #define CUSOLVER_DEPRECATED_ENUM(new_enum) \ + __attribute__((deprecated("please use " #new_enum " instead"))) + + #else + + #define CUSOLVER_DEPRECATED_ENUM(new_enum) + + #endif // defined(__cplusplus) && __cplusplus >= 201402L + + #else // defined(DISABLE_CUSOLVER_DEPRECATED) + + #define CUSOLVER_DEPRECATED(new_func) + #define CUSOLVER_DEPRECATED_ENUM(new_enum) + + #endif // !defined(DISABLE_CUSOLVER_DEPRECATED) + + #undef CUSOLVER_CPP_VERSION + + #if defined(__cplusplus) +extern "C" { + #endif /* __cplusplus */ + + typedef enum { + CUSOLVER_STATUS_SUCCESS = 0, + CUSOLVER_STATUS_NOT_INITIALIZED = 1, + CUSOLVER_STATUS_ALLOC_FAILED = 2, + CUSOLVER_STATUS_INVALID_VALUE = 3, + CUSOLVER_STATUS_ARCH_MISMATCH = 4, + CUSOLVER_STATUS_MAPPING_ERROR = 5, + CUSOLVER_STATUS_EXECUTION_FAILED = 6, + CUSOLVER_STATUS_INTERNAL_ERROR = 7, + CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED = 8, + CUSOLVER_STATUS_NOT_SUPPORTED = 9, + CUSOLVER_STATUS_ZERO_PIVOT = 10, + CUSOLVER_STATUS_INVALID_LICENSE = 11, + CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED = 12, + CUSOLVER_STATUS_IRS_PARAMS_INVALID = 13, + CUSOLVER_STATUS_IRS_PARAMS_INVALID_PREC = 14, + CUSOLVER_STATUS_IRS_PARAMS_INVALID_REFINE = 15, + CUSOLVER_STATUS_IRS_PARAMS_INVALID_MAXITER = 16, + CUSOLVER_STATUS_IRS_INTERNAL_ERROR = 20, + CUSOLVER_STATUS_IRS_NOT_SUPPORTED = 21, + CUSOLVER_STATUS_IRS_OUT_OF_RANGE = 22, + CUSOLVER_STATUS_IRS_NRHS_NOT_SUPPORTED_FOR_REFINE_GMRES = 23, + CUSOLVER_STATUS_IRS_INFOS_NOT_INITIALIZED = 25, + CUSOLVER_STATUS_IRS_INFOS_NOT_DESTROYED = 26, + CUSOLVER_STATUS_IRS_MATRIX_SINGULAR = 30, + CUSOLVER_STATUS_INVALID_WORKSPACE = 31 + } cusolverStatus_t; + + typedef enum { + CUSOLVER_EIG_TYPE_1 = 1, + CUSOLVER_EIG_TYPE_2 = 2, + CUSOLVER_EIG_TYPE_3 = 3 + } cusolverEigType_t; + + typedef enum { + CUSOLVER_EIG_MODE_NOVECTOR = 0, + CUSOLVER_EIG_MODE_VECTOR = 1 + } cusolverEigMode_t; + + typedef enum { + CUSOLVER_EIG_RANGE_ALL = 1001, + CUSOLVER_EIG_RANGE_I = 1002, + CUSOLVER_EIG_RANGE_V = 1003, + } cusolverEigRange_t; + + typedef enum { + CUSOLVER_INF_NORM = 104, + CUSOLVER_MAX_NORM = 105, + CUSOLVER_ONE_NORM = 106, + CUSOLVER_FRO_NORM = 107, + } cusolverNorm_t; + + typedef enum { + CUSOLVER_IRS_REFINE_NOT_SET = 1100, + CUSOLVER_IRS_REFINE_NONE = 1101, + CUSOLVER_IRS_REFINE_CLASSICAL = 1102, + CUSOLVER_IRS_REFINE_CLASSICAL_GMRES = 1103, + CUSOLVER_IRS_REFINE_GMRES = 1104, + CUSOLVER_IRS_REFINE_GMRES_GMRES = 1105, + CUSOLVER_IRS_REFINE_GMRES_NOPCOND = 1106, + + CUSOLVER_PREC_DD = 1150, + CUSOLVER_PREC_SS = 1151, + CUSOLVER_PREC_SHT = 1152, + + } cusolverIRSRefinement_t; + + typedef enum { + CUSOLVER_R_8I = 1201, + CUSOLVER_R_8U = 1202, + CUSOLVER_R_64F = 1203, + CUSOLVER_R_32F = 1204, + CUSOLVER_R_16F = 1205, + CUSOLVER_R_16BF = 1206, + CUSOLVER_R_TF32 = 1207, + CUSOLVER_R_AP = 1208, + CUSOLVER_C_8I = 1211, + CUSOLVER_C_8U = 1212, + CUSOLVER_C_64F = 1213, + CUSOLVER_C_32F = 1214, + CUSOLVER_C_16F = 1215, + CUSOLVER_C_16BF = 1216, + CUSOLVER_C_TF32 = 1217, + CUSOLVER_C_AP = 1218, + } cusolverPrecType_t; + + typedef enum { + CUSOLVER_ALG_0 = 0, /* default algorithm */ + CUSOLVER_ALG_1 = 1, + CUSOLVER_ALG_2 = 2 + } cusolverAlgMode_t; + + typedef enum { + CUBLAS_STOREV_COLUMNWISE = 0, + CUBLAS_STOREV_ROWWISE = 1 + } cusolverStorevMode_t; + + typedef enum { + CUBLAS_DIRECT_FORWARD = 0, + CUBLAS_DIRECT_BACKWARD = 1 + } cusolverDirectMode_t; + + cusolverStatus_t CUSOLVERAPI + cusolverGetProperty(libraryPropertyType type, int *value); + + cusolverStatus_t CUSOLVERAPI cusolverGetVersion(int *version); + + #if defined(__cplusplus) +} + #endif /* __cplusplus */ + +#endif // CUSOLVER_COMMON_H_ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b487de38a55b7c1bb628e6bbd748257bb7b4594 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66929086120aa815f7037183ec9c82dba2b3efcb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa207d18aeca7371e9bc277dba246963aaae9a93 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/cusparse_v2.h b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/cusparse_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..f889e1f569d46d1116fe6e302429b3855de43c21 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/include/cusparse_v2.h @@ -0,0 +1,54 @@ +/* + * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ +#if !defined(CUSPARSE_V2_H_) +#define CUSPARSE_V2_H_ + +#include "cusparse.h" + +#endif diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/lib/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusparse/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75a92d3cf6dcf5e7b4622944b690b60dab095eb8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/lib/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa232b6cabdbfc835d76a77ea26250a8b7ef0783 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__init__.py @@ -0,0 +1,128 @@ +import contextlib +import functools +import os +import sys +from typing import TYPE_CHECKING, List, Optional, Type, cast + +from pip._internal.utils.misc import strtobool + +from .base import BaseDistribution, BaseEnvironment, FilesystemWheel, MemoryWheel, Wheel + +if TYPE_CHECKING: + from typing import Literal, Protocol +else: + Protocol = object + +__all__ = [ + "BaseDistribution", + "BaseEnvironment", + "FilesystemWheel", + "MemoryWheel", + "Wheel", + "get_default_environment", + "get_environment", + "get_wheel_distribution", + "select_backend", +] + + +def _should_use_importlib_metadata() -> bool: + """Whether to use the ``importlib.metadata`` or ``pkg_resources`` backend. + + By default, pip uses ``importlib.metadata`` on Python 3.11+, and + ``pkg_resourcess`` otherwise. This can be overridden by a couple of ways: + + * If environment variable ``_PIP_USE_IMPORTLIB_METADATA`` is set, it + dictates whether ``importlib.metadata`` is used, regardless of Python + version. + * On Python 3.11+, Python distributors can patch ``importlib.metadata`` + to add a global constant ``_PIP_USE_IMPORTLIB_METADATA = False``. This + makes pip use ``pkg_resources`` (unless the user set the aforementioned + environment variable to *True*). + """ + with contextlib.suppress(KeyError, ValueError): + return bool(strtobool(os.environ["_PIP_USE_IMPORTLIB_METADATA"])) + if sys.version_info < (3, 11): + return False + import importlib.metadata + + return bool(getattr(importlib.metadata, "_PIP_USE_IMPORTLIB_METADATA", True)) + + +class Backend(Protocol): + NAME: 'Literal["importlib", "pkg_resources"]' + Distribution: Type[BaseDistribution] + Environment: Type[BaseEnvironment] + + +@functools.lru_cache(maxsize=None) +def select_backend() -> Backend: + if _should_use_importlib_metadata(): + from . import importlib + + return cast(Backend, importlib) + from . import pkg_resources + + return cast(Backend, pkg_resources) + + +def get_default_environment() -> BaseEnvironment: + """Get the default representation for the current environment. + + This returns an Environment instance from the chosen backend. The default + Environment instance should be built from ``sys.path`` and may use caching + to share instance state accorss calls. + """ + return select_backend().Environment.default() + + +def get_environment(paths: Optional[List[str]]) -> BaseEnvironment: + """Get a representation of the environment specified by ``paths``. + + This returns an Environment instance from the chosen backend based on the + given import paths. The backend must build a fresh instance representing + the state of installed distributions when this function is called. + """ + return select_backend().Environment.from_paths(paths) + + +def get_directory_distribution(directory: str) -> BaseDistribution: + """Get the distribution metadata representation in the specified directory. + + This returns a Distribution instance from the chosen backend based on + the given on-disk ``.dist-info`` directory. + """ + return select_backend().Distribution.from_directory(directory) + + +def get_wheel_distribution(wheel: Wheel, canonical_name: str) -> BaseDistribution: + """Get the representation of the specified wheel's distribution metadata. + + This returns a Distribution instance from the chosen backend based on + the given wheel's ``.dist-info`` directory. + + :param canonical_name: Normalized project name of the given wheel. + """ + return select_backend().Distribution.from_wheel(wheel, canonical_name) + + +def get_metadata_distribution( + metadata_contents: bytes, + filename: str, + canonical_name: str, +) -> BaseDistribution: + """Get the dist representation of the specified METADATA file contents. + + This returns a Distribution instance from the chosen backend sourced from the data + in `metadata_contents`. + + :param metadata_contents: Contents of a METADATA file within a dist, or one served + via PEP 658. + :param filename: Filename for the dist this metadata represents. + :param canonical_name: Normalized project name of the given dist. + """ + return select_backend().Distribution.from_metadata_file_contents( + metadata_contents, + filename, + canonical_name, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69255cd410bdb8639743f122c84f8e4d34a105f0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/_json.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/_json.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2255eb48041f95118c0cada4fc7281050b6cd5a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/_json.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/base.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88d0f4522e8ec1c5bb233dfb30a0ff98cd0bd3e7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/__pycache__/base.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/_json.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/_json.py new file mode 100644 index 0000000000000000000000000000000000000000..9097dd58590549f2b5389f1303cd35961526815b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/_json.py @@ -0,0 +1,84 @@ +# Extracted from https://github.com/pfmoore/pkg_metadata + +from email.header import Header, decode_header, make_header +from email.message import Message +from typing import Any, Dict, List, Union, cast + +METADATA_FIELDS = [ + # Name, Multiple-Use + ("Metadata-Version", False), + ("Name", False), + ("Version", False), + ("Dynamic", True), + ("Platform", True), + ("Supported-Platform", True), + ("Summary", False), + ("Description", False), + ("Description-Content-Type", False), + ("Keywords", False), + ("Home-page", False), + ("Download-URL", False), + ("Author", False), + ("Author-email", False), + ("Maintainer", False), + ("Maintainer-email", False), + ("License", False), + ("Classifier", True), + ("Requires-Dist", True), + ("Requires-Python", False), + ("Requires-External", True), + ("Project-URL", True), + ("Provides-Extra", True), + ("Provides-Dist", True), + ("Obsoletes-Dist", True), +] + + +def json_name(field: str) -> str: + return field.lower().replace("-", "_") + + +def msg_to_json(msg: Message) -> Dict[str, Any]: + """Convert a Message object into a JSON-compatible dictionary.""" + + def sanitise_header(h: Union[Header, str]) -> str: + if isinstance(h, Header): + chunks = [] + for bytes, encoding in decode_header(h): + if encoding == "unknown-8bit": + try: + # See if UTF-8 works + bytes.decode("utf-8") + encoding = "utf-8" + except UnicodeDecodeError: + # If not, latin1 at least won't fail + encoding = "latin1" + chunks.append((bytes, encoding)) + return str(make_header(chunks)) + return str(h) + + result = {} + for field, multi in METADATA_FIELDS: + if field not in msg: + continue + key = json_name(field) + if multi: + value: Union[str, List[str]] = [ + sanitise_header(v) for v in msg.get_all(field) # type: ignore + ] + else: + value = sanitise_header(msg.get(field)) # type: ignore + if key == "keywords": + # Accept both comma-separated and space-separated + # forms, for better compatibility with old data. + if "," in value: + value = [v.strip() for v in value.split(",")] + else: + value = value.split() + result[key] = value + + payload = cast(str, msg.get_payload()) + if payload: + result["description"] = payload + + return result diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/importlib/_dists.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/importlib/_dists.py new file mode 100644 index 0000000000000000000000000000000000000000..36cd326232e932f0dc8b38a733fd3ef38658cf94 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/importlib/_dists.py @@ -0,0 +1,221 @@ +import email.message +import importlib.metadata +import pathlib +import zipfile +from typing import ( + Collection, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Sequence, + cast, +) + +from pip._vendor.packaging.requirements import Requirement +from pip._vendor.packaging.utils import NormalizedName, canonicalize_name +from pip._vendor.packaging.version import Version +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.exceptions import InvalidWheel, UnsupportedWheel +from pip._internal.metadata.base import ( + BaseDistribution, + BaseEntryPoint, + InfoPath, + Wheel, +) +from pip._internal.utils.misc import normalize_path +from pip._internal.utils.packaging import get_requirement +from pip._internal.utils.temp_dir import TempDirectory +from pip._internal.utils.wheel import parse_wheel, read_wheel_metadata_file + +from ._compat import ( + BasePath, + get_dist_canonical_name, + parse_name_and_version_from_info_directory, +) + + +class WheelDistribution(importlib.metadata.Distribution): + """An ``importlib.metadata.Distribution`` read from a wheel. + + Although ``importlib.metadata.PathDistribution`` accepts ``zipfile.Path``, + its implementation is too "lazy" for pip's needs (we can't keep the ZipFile + handle open for the entire lifetime of the distribution object). + + This implementation eagerly reads the entire metadata directory into the + memory instead, and operates from that. + """ + + def __init__( + self, + files: Mapping[pathlib.PurePosixPath, bytes], + info_location: pathlib.PurePosixPath, + ) -> None: + self._files = files + self.info_location = info_location + + @classmethod + def from_zipfile( + cls, + zf: zipfile.ZipFile, + name: str, + location: str, + ) -> "WheelDistribution": + info_dir, _ = parse_wheel(zf, name) + paths = ( + (name, pathlib.PurePosixPath(name.split("/", 1)[-1])) + for name in zf.namelist() + if name.startswith(f"{info_dir}/") + ) + files = { + relpath: read_wheel_metadata_file(zf, fullpath) + for fullpath, relpath in paths + } + info_location = pathlib.PurePosixPath(location, info_dir) + return cls(files, info_location) + + def iterdir(self, path: InfoPath) -> Iterator[pathlib.PurePosixPath]: + # Only allow iterating through the metadata directory. + if pathlib.PurePosixPath(str(path)) in self._files: + return iter(self._files) + raise FileNotFoundError(path) + + def read_text(self, filename: str) -> Optional[str]: + try: + data = self._files[pathlib.PurePosixPath(filename)] + except KeyError: + return None + try: + text = data.decode("utf-8") + except UnicodeDecodeError as e: + wheel = self.info_location.parent + error = f"Error decoding metadata for {wheel}: {e} in {filename} file" + raise UnsupportedWheel(error) + return text + + +class Distribution(BaseDistribution): + def __init__( + self, + dist: importlib.metadata.Distribution, + info_location: Optional[BasePath], + installed_location: Optional[BasePath], + ) -> None: + self._dist = dist + self._info_location = info_location + self._installed_location = installed_location + + @classmethod + def from_directory(cls, directory: str) -> BaseDistribution: + info_location = pathlib.Path(directory) + dist = importlib.metadata.Distribution.at(info_location) + return cls(dist, info_location, info_location.parent) + + @classmethod + def from_metadata_file_contents( + cls, + metadata_contents: bytes, + filename: str, + project_name: str, + ) -> BaseDistribution: + # Generate temp dir to contain the metadata file, and write the file contents. + temp_dir = pathlib.Path( + TempDirectory(kind="metadata", globally_managed=True).path + ) + metadata_path = temp_dir / "METADATA" + metadata_path.write_bytes(metadata_contents) + # Construct dist pointing to the newly created directory. + dist = importlib.metadata.Distribution.at(metadata_path.parent) + return cls(dist, metadata_path.parent, None) + + @classmethod + def from_wheel(cls, wheel: Wheel, name: str) -> BaseDistribution: + try: + with wheel.as_zipfile() as zf: + dist = WheelDistribution.from_zipfile(zf, name, wheel.location) + except zipfile.BadZipFile as e: + raise InvalidWheel(wheel.location, name) from e + return cls(dist, dist.info_location, pathlib.PurePosixPath(wheel.location)) + + @property + def location(self) -> Optional[str]: + if self._info_location is None: + return None + return str(self._info_location.parent) + + @property + def info_location(self) -> Optional[str]: + if self._info_location is None: + return None + return str(self._info_location) + + @property + def installed_location(self) -> Optional[str]: + if self._installed_location is None: + return None + return normalize_path(str(self._installed_location)) + + @property + def canonical_name(self) -> NormalizedName: + return get_dist_canonical_name(self._dist) + + @property + def version(self) -> Version: + if version := parse_name_and_version_from_info_directory(self._dist)[1]: + return parse_version(version) + return parse_version(self._dist.version) + + @property + def raw_version(self) -> str: + return self._dist.version + + def is_file(self, path: InfoPath) -> bool: + return self._dist.read_text(str(path)) is not None + + def iter_distutils_script_names(self) -> Iterator[str]: + # A distutils installation is always "flat" (not in e.g. egg form), so + # if this distribution's info location is NOT a pathlib.Path (but e.g. + # zipfile.Path), it can never contain any distutils scripts. + if not isinstance(self._info_location, pathlib.Path): + return + for child in self._info_location.joinpath("scripts").iterdir(): + yield child.name + + def read_text(self, path: InfoPath) -> str: + content = self._dist.read_text(str(path)) + if content is None: + raise FileNotFoundError(path) + return content + + def iter_entry_points(self) -> Iterable[BaseEntryPoint]: + # importlib.metadata's EntryPoint structure sasitfies BaseEntryPoint. + return self._dist.entry_points + + def _metadata_impl(self) -> email.message.Message: + # From Python 3.10+, importlib.metadata declares PackageMetadata as the + # return type. This protocol is unfortunately a disaster now and misses + # a ton of fields that we need, including get() and get_payload(). We + # rely on the implementation that the object is actually a Message now, + # until upstream can improve the protocol. (python/cpython#94952) + return cast(email.message.Message, self._dist.metadata) + + def iter_provided_extras(self) -> Iterable[NormalizedName]: + return [ + canonicalize_name(extra) + for extra in self.metadata.get_all("Provides-Extra", []) + ] + + def iter_dependencies(self, extras: Collection[str] = ()) -> Iterable[Requirement]: + contexts: Sequence[Dict[str, str]] = [{"extra": e} for e in extras] + for req_string in self.metadata.get_all("Requires-Dist", []): + # strip() because email.message.Message.get_all() may return a leading \n + # in case a long header was wrapped. + req = get_requirement(req_string.strip()) + if not req.marker: + yield req + elif not extras and req.marker.evaluate({"extra": ""}): + yield req + elif any(req.marker.evaluate(context) for context in contexts): + yield req diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/pkg_resources.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/pkg_resources.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea84f93a6fb8f2d04230d70491eac7809672031 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/metadata/pkg_resources.py @@ -0,0 +1,301 @@ +import email.message +import email.parser +import logging +import os +import zipfile +from typing import ( + Collection, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, +) + +from pip._vendor import pkg_resources +from pip._vendor.packaging.requirements import Requirement +from pip._vendor.packaging.utils import NormalizedName, canonicalize_name +from pip._vendor.packaging.version import Version +from pip._vendor.packaging.version import parse as parse_version + +from pip._internal.exceptions import InvalidWheel, NoneMetadataError, UnsupportedWheel +from pip._internal.utils.egg_link import egg_link_path_from_location +from pip._internal.utils.misc import display_path, normalize_path +from pip._internal.utils.wheel import parse_wheel, read_wheel_metadata_file + +from .base import ( + BaseDistribution, + BaseEntryPoint, + BaseEnvironment, + InfoPath, + Wheel, +) + +__all__ = ["NAME", "Distribution", "Environment"] + +logger = logging.getLogger(__name__) + +NAME = "pkg_resources" + + +class EntryPoint(NamedTuple): + name: str + value: str + group: str + + +class InMemoryMetadata: + """IMetadataProvider that reads metadata files from a dictionary. + + This also maps metadata decoding exceptions to our internal exception type. + """ + + def __init__(self, metadata: Mapping[str, bytes], wheel_name: str) -> None: + self._metadata = metadata + self._wheel_name = wheel_name + + def has_metadata(self, name: str) -> bool: + return name in self._metadata + + def get_metadata(self, name: str) -> str: + try: + return self._metadata[name].decode() + except UnicodeDecodeError as e: + # Augment the default error with the origin of the file. + raise UnsupportedWheel( + f"Error decoding metadata for {self._wheel_name}: {e} in {name} file" + ) + + def get_metadata_lines(self, name: str) -> Iterable[str]: + return pkg_resources.yield_lines(self.get_metadata(name)) + + def metadata_isdir(self, name: str) -> bool: + return False + + def metadata_listdir(self, name: str) -> List[str]: + return [] + + def run_script(self, script_name: str, namespace: str) -> None: + pass + + +class Distribution(BaseDistribution): + def __init__(self, dist: pkg_resources.Distribution) -> None: + self._dist = dist + # This is populated lazily, to avoid loading metadata for all possible + # distributions eagerly. + self.__extra_mapping: Optional[Mapping[NormalizedName, str]] = None + + @property + def _extra_mapping(self) -> Mapping[NormalizedName, str]: + if self.__extra_mapping is None: + self.__extra_mapping = { + canonicalize_name(extra): extra for extra in self._dist.extras + } + + return self.__extra_mapping + + @classmethod + def from_directory(cls, directory: str) -> BaseDistribution: + dist_dir = directory.rstrip(os.sep) + + # Build a PathMetadata object, from path to metadata. :wink: + base_dir, dist_dir_name = os.path.split(dist_dir) + metadata = pkg_resources.PathMetadata(base_dir, dist_dir) + + # Determine the correct Distribution object type. + if dist_dir.endswith(".egg-info"): + dist_cls = pkg_resources.Distribution + dist_name = os.path.splitext(dist_dir_name)[0] + else: + assert dist_dir.endswith(".dist-info") + dist_cls = pkg_resources.DistInfoDistribution + dist_name = os.path.splitext(dist_dir_name)[0].split("-")[0] + + dist = dist_cls(base_dir, project_name=dist_name, metadata=metadata) + return cls(dist) + + @classmethod + def from_metadata_file_contents( + cls, + metadata_contents: bytes, + filename: str, + project_name: str, + ) -> BaseDistribution: + metadata_dict = { + "METADATA": metadata_contents, + } + dist = pkg_resources.DistInfoDistribution( + location=filename, + metadata=InMemoryMetadata(metadata_dict, filename), + project_name=project_name, + ) + return cls(dist) + + @classmethod + def from_wheel(cls, wheel: Wheel, name: str) -> BaseDistribution: + try: + with wheel.as_zipfile() as zf: + info_dir, _ = parse_wheel(zf, name) + metadata_dict = { + path.split("/", 1)[-1]: read_wheel_metadata_file(zf, path) + for path in zf.namelist() + if path.startswith(f"{info_dir}/") + } + except zipfile.BadZipFile as e: + raise InvalidWheel(wheel.location, name) from e + except UnsupportedWheel as e: + raise UnsupportedWheel(f"{name} has an invalid wheel, {e}") + dist = pkg_resources.DistInfoDistribution( + location=wheel.location, + metadata=InMemoryMetadata(metadata_dict, wheel.location), + project_name=name, + ) + return cls(dist) + + @property + def location(self) -> Optional[str]: + return self._dist.location + + @property + def installed_location(self) -> Optional[str]: + egg_link = egg_link_path_from_location(self.raw_name) + if egg_link: + location = egg_link + elif self.location: + location = self.location + else: + return None + return normalize_path(location) + + @property + def info_location(self) -> Optional[str]: + return self._dist.egg_info + + @property + def installed_by_distutils(self) -> bool: + # A distutils-installed distribution is provided by FileMetadata. This + # provider has a "path" attribute not present anywhere else. Not the + # best introspection logic, but pip has been doing this for a long time. + try: + return bool(self._dist._provider.path) + except AttributeError: + return False + + @property + def canonical_name(self) -> NormalizedName: + return canonicalize_name(self._dist.project_name) + + @property + def version(self) -> Version: + return parse_version(self._dist.version) + + @property + def raw_version(self) -> str: + return self._dist.version + + def is_file(self, path: InfoPath) -> bool: + return self._dist.has_metadata(str(path)) + + def iter_distutils_script_names(self) -> Iterator[str]: + yield from self._dist.metadata_listdir("scripts") + + def read_text(self, path: InfoPath) -> str: + name = str(path) + if not self._dist.has_metadata(name): + raise FileNotFoundError(name) + content = self._dist.get_metadata(name) + if content is None: + raise NoneMetadataError(self, name) + return content + + def iter_entry_points(self) -> Iterable[BaseEntryPoint]: + for group, entries in self._dist.get_entry_map().items(): + for name, entry_point in entries.items(): + name, _, value = str(entry_point).partition("=") + yield EntryPoint(name=name.strip(), value=value.strip(), group=group) + + def _metadata_impl(self) -> email.message.Message: + """ + :raises NoneMetadataError: if the distribution reports `has_metadata()` + True but `get_metadata()` returns None. + """ + if isinstance(self._dist, pkg_resources.DistInfoDistribution): + metadata_name = "METADATA" + else: + metadata_name = "PKG-INFO" + try: + metadata = self.read_text(metadata_name) + except FileNotFoundError: + if self.location: + displaying_path = display_path(self.location) + else: + displaying_path = repr(self.location) + logger.warning("No metadata found in %s", displaying_path) + metadata = "" + feed_parser = email.parser.FeedParser() + feed_parser.feed(metadata) + return feed_parser.close() + + def iter_dependencies(self, extras: Collection[str] = ()) -> Iterable[Requirement]: + if extras: + relevant_extras = set(self._extra_mapping) & set( + map(canonicalize_name, extras) + ) + extras = [self._extra_mapping[extra] for extra in relevant_extras] + return self._dist.requires(extras) + + def iter_provided_extras(self) -> Iterable[NormalizedName]: + return self._extra_mapping.keys() + + +class Environment(BaseEnvironment): + def __init__(self, ws: pkg_resources.WorkingSet) -> None: + self._ws = ws + + @classmethod + def default(cls) -> BaseEnvironment: + return cls(pkg_resources.working_set) + + @classmethod + def from_paths(cls, paths: Optional[List[str]]) -> BaseEnvironment: + return cls(pkg_resources.WorkingSet(paths)) + + def _iter_distributions(self) -> Iterator[BaseDistribution]: + for dist in self._ws: + yield Distribution(dist) + + def _search_distribution(self, name: str) -> Optional[BaseDistribution]: + """Find a distribution matching the ``name`` in the environment. + + This searches from *all* distributions available in the environment, to + match the behavior of ``pkg_resources.get_distribution()``. + """ + canonical_name = canonicalize_name(name) + for dist in self.iter_all_distributions(): + if dist.canonical_name == canonical_name: + return dist + return None + + def get_distribution(self, name: str) -> Optional[BaseDistribution]: + # Search the distribution by looking through the working set. + dist = self._search_distribution(name) + if dist: + return dist + + # If distribution could not be found, call working_set.require to + # update the working set, and try to find the distribution again. + # This might happen for e.g. when you install a package twice, once + # using setup.py develop and again using setup.py install. Now when + # running pip uninstall twice, the package gets removed from the + # working set in the first uninstall, so we have to populate the + # working set again so that pip knows about it and the packages gets + # picked up and is successfully uninstalled the second time too. + try: + # We didn't pass in any version specifiers, so this can never + # raise pkg_resources.VersionConflict. + self._ws.require(name) + except pkg_resources.DistributionNotFound: + return None + return self._search_distribution(name) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1699c72a72f3d3c56cc93dab1ab68eca2fae8569 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/download.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f99afd12817cb950a2cd76dc5b22d8ba3eb1059a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/download.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/session.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/session.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965256c3fa79000686aef265b1739ad170ec688b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/session.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..820d521e413e054263160bc01120fcde132a2c6f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/xmlrpc.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/xmlrpc.py new file mode 100644 index 0000000000000000000000000000000000000000..22ec8d2f4a6b276aedf2249d50ff10a21922513c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_internal/network/xmlrpc.py @@ -0,0 +1,62 @@ +"""xmlrpclib.Transport implementation +""" + +import logging +import urllib.parse +import xmlrpc.client +from typing import TYPE_CHECKING, Tuple + +from pip._internal.exceptions import NetworkConnectionError +from pip._internal.network.session import PipSession +from pip._internal.network.utils import raise_for_status + +if TYPE_CHECKING: + from xmlrpc.client import _HostType, _Marshallable + + from _typeshed import SizedBuffer + +logger = logging.getLogger(__name__) + + +class PipXmlrpcTransport(xmlrpc.client.Transport): + """Provide a `xmlrpclib.Transport` implementation via a `PipSession` + object. + """ + + def __init__( + self, index_url: str, session: PipSession, use_datetime: bool = False + ) -> None: + super().__init__(use_datetime) + index_parts = urllib.parse.urlparse(index_url) + self._scheme = index_parts.scheme + self._session = session + + def request( + self, + host: "_HostType", + handler: str, + request_body: "SizedBuffer", + verbose: bool = False, + ) -> Tuple["_Marshallable", ...]: + assert isinstance(host, str) + parts = (self._scheme, host, handler, None, None, None) + url = urllib.parse.urlunparse(parts) + try: + headers = {"Content-Type": "text/xml"} + response = self._session.post( + url, + data=request_body, + headers=headers, + stream=True, + ) + raise_for_status(response) + self.verbose = verbose + return self.parse_response(response.raw) + except NetworkConnectionError as exc: + assert exc.response + logger.critical( + "HTTP error %s while getting %s", + exc.response.status_code, + url, + ) + raise diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0d6c6d30e3747ae08789e6c94d31d009f4bbf2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2012-2023 Vinay Sajip. +# Licensed to the Python Software Foundation under a contributor agreement. +# See LICENSE.txt and CONTRIBUTORS.txt. +# +import logging + +__version__ = '0.3.9' + + +class DistlibException(Exception): + pass + + +try: + from logging import NullHandler +except ImportError: # pragma: no cover + + class NullHandler(logging.Handler): + + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +logger = logging.getLogger(__name__) +logger.addHandler(NullHandler()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f2efd9ac2485ac8f3801c63f7e01b82a605f190 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/compat.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/compat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11ff6e5aa7e9facbe2bb38063cad5ef0722038ed Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/compat.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/database.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/database.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb34c5bfd0363fcb4eb0460dbe29ba51aa09a3d8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/database.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/manifest.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/manifest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f337752ce6581757ba6ed94034d820b57ff8808a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/manifest.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/database.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/database.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f896a7d8597f279837961f64e85e962afb026d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/database.py @@ -0,0 +1,1329 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2012-2023 The Python Software Foundation. +# See LICENSE.txt and CONTRIBUTORS.txt. +# +"""PEP 376 implementation.""" + +from __future__ import unicode_literals + +import base64 +import codecs +import contextlib +import hashlib +import logging +import os +import posixpath +import sys +import zipimport + +from . import DistlibException, resources +from .compat import StringIO +from .version import get_scheme, UnsupportedVersionError +from .metadata import (Metadata, METADATA_FILENAME, WHEEL_METADATA_FILENAME, LEGACY_METADATA_FILENAME) +from .util import (parse_requirement, cached_property, parse_name_and_version, read_exports, write_exports, CSVReader, + CSVWriter) + +__all__ = [ + 'Distribution', 'BaseInstalledDistribution', 'InstalledDistribution', 'EggInfoDistribution', 'DistributionPath' +] + +logger = logging.getLogger(__name__) + +EXPORTS_FILENAME = 'pydist-exports.json' +COMMANDS_FILENAME = 'pydist-commands.json' + +DIST_FILES = ('INSTALLER', METADATA_FILENAME, 'RECORD', 'REQUESTED', 'RESOURCES', EXPORTS_FILENAME, 'SHARED') + +DISTINFO_EXT = '.dist-info' + + +class _Cache(object): + """ + A simple cache mapping names and .dist-info paths to distributions + """ + + def __init__(self): + """ + Initialise an instance. There is normally one for each DistributionPath. + """ + self.name = {} + self.path = {} + self.generated = False + + def clear(self): + """ + Clear the cache, setting it to its initial state. + """ + self.name.clear() + self.path.clear() + self.generated = False + + def add(self, dist): + """ + Add a distribution to the cache. + :param dist: The distribution to add. + """ + if dist.path not in self.path: + self.path[dist.path] = dist + self.name.setdefault(dist.key, []).append(dist) + + +class DistributionPath(object): + """ + Represents a set of distributions installed on a path (typically sys.path). + """ + + def __init__(self, path=None, include_egg=False): + """ + Create an instance from a path, optionally including legacy (distutils/ + setuptools/distribute) distributions. + :param path: The path to use, as a list of directories. If not specified, + sys.path is used. + :param include_egg: If True, this instance will look for and return legacy + distributions as well as those based on PEP 376. + """ + if path is None: + path = sys.path + self.path = path + self._include_dist = True + self._include_egg = include_egg + + self._cache = _Cache() + self._cache_egg = _Cache() + self._cache_enabled = True + self._scheme = get_scheme('default') + + def _get_cache_enabled(self): + return self._cache_enabled + + def _set_cache_enabled(self, value): + self._cache_enabled = value + + cache_enabled = property(_get_cache_enabled, _set_cache_enabled) + + def clear_cache(self): + """ + Clears the internal cache. + """ + self._cache.clear() + self._cache_egg.clear() + + def _yield_distributions(self): + """ + Yield .dist-info and/or .egg(-info) distributions. + """ + # We need to check if we've seen some resources already, because on + # some Linux systems (e.g. some Debian/Ubuntu variants) there are + # symlinks which alias other files in the environment. + seen = set() + for path in self.path: + finder = resources.finder_for_path(path) + if finder is None: + continue + r = finder.find('') + if not r or not r.is_container: + continue + rset = sorted(r.resources) + for entry in rset: + r = finder.find(entry) + if not r or r.path in seen: + continue + try: + if self._include_dist and entry.endswith(DISTINFO_EXT): + possible_filenames = [METADATA_FILENAME, WHEEL_METADATA_FILENAME, LEGACY_METADATA_FILENAME] + for metadata_filename in possible_filenames: + metadata_path = posixpath.join(entry, metadata_filename) + pydist = finder.find(metadata_path) + if pydist: + break + else: + continue + + with contextlib.closing(pydist.as_stream()) as stream: + metadata = Metadata(fileobj=stream, scheme='legacy') + logger.debug('Found %s', r.path) + seen.add(r.path) + yield new_dist_class(r.path, metadata=metadata, env=self) + elif self._include_egg and entry.endswith(('.egg-info', '.egg')): + logger.debug('Found %s', r.path) + seen.add(r.path) + yield old_dist_class(r.path, self) + except Exception as e: + msg = 'Unable to read distribution at %s, perhaps due to bad metadata: %s' + logger.warning(msg, r.path, e) + import warnings + warnings.warn(msg % (r.path, e), stacklevel=2) + + def _generate_cache(self): + """ + Scan the path for distributions and populate the cache with + those that are found. + """ + gen_dist = not self._cache.generated + gen_egg = self._include_egg and not self._cache_egg.generated + if gen_dist or gen_egg: + for dist in self._yield_distributions(): + if isinstance(dist, InstalledDistribution): + self._cache.add(dist) + else: + self._cache_egg.add(dist) + + if gen_dist: + self._cache.generated = True + if gen_egg: + self._cache_egg.generated = True + + @classmethod + def distinfo_dirname(cls, name, version): + """ + The *name* and *version* parameters are converted into their + filename-escaped form, i.e. any ``'-'`` characters are replaced + with ``'_'`` other than the one in ``'dist-info'`` and the one + separating the name from the version number. + + :parameter name: is converted to a standard distribution name by replacing + any runs of non- alphanumeric characters with a single + ``'-'``. + :type name: string + :parameter version: is converted to a standard version string. Spaces + become dots, and all other non-alphanumeric characters + (except dots) become dashes, with runs of multiple + dashes condensed to a single dash. + :type version: string + :returns: directory name + :rtype: string""" + name = name.replace('-', '_') + return '-'.join([name, version]) + DISTINFO_EXT + + def get_distributions(self): + """ + Provides an iterator that looks for distributions and returns + :class:`InstalledDistribution` or + :class:`EggInfoDistribution` instances for each one of them. + + :rtype: iterator of :class:`InstalledDistribution` and + :class:`EggInfoDistribution` instances + """ + if not self._cache_enabled: + for dist in self._yield_distributions(): + yield dist + else: + self._generate_cache() + + for dist in self._cache.path.values(): + yield dist + + if self._include_egg: + for dist in self._cache_egg.path.values(): + yield dist + + def get_distribution(self, name): + """ + Looks for a named distribution on the path. + + This function only returns the first result found, as no more than one + value is expected. If nothing is found, ``None`` is returned. + + :rtype: :class:`InstalledDistribution`, :class:`EggInfoDistribution` + or ``None`` + """ + result = None + name = name.lower() + if not self._cache_enabled: + for dist in self._yield_distributions(): + if dist.key == name: + result = dist + break + else: + self._generate_cache() + + if name in self._cache.name: + result = self._cache.name[name][0] + elif self._include_egg and name in self._cache_egg.name: + result = self._cache_egg.name[name][0] + return result + + def provides_distribution(self, name, version=None): + """ + Iterates over all distributions to find which distributions provide *name*. + If a *version* is provided, it will be used to filter the results. + + This function only returns the first result found, since no more than + one values are expected. If the directory is not found, returns ``None``. + + :parameter version: a version specifier that indicates the version + required, conforming to the format in ``PEP-345`` + + :type name: string + :type version: string + """ + matcher = None + if version is not None: + try: + matcher = self._scheme.matcher('%s (%s)' % (name, version)) + except ValueError: + raise DistlibException('invalid name or version: %r, %r' % (name, version)) + + for dist in self.get_distributions(): + # We hit a problem on Travis where enum34 was installed and doesn't + # have a provides attribute ... + if not hasattr(dist, 'provides'): + logger.debug('No "provides": %s', dist) + else: + provided = dist.provides + + for p in provided: + p_name, p_ver = parse_name_and_version(p) + if matcher is None: + if p_name == name: + yield dist + break + else: + if p_name == name and matcher.match(p_ver): + yield dist + break + + def get_file_path(self, name, relative_path): + """ + Return the path to a resource file. + """ + dist = self.get_distribution(name) + if dist is None: + raise LookupError('no distribution named %r found' % name) + return dist.get_resource_path(relative_path) + + def get_exported_entries(self, category, name=None): + """ + Return all of the exported entries in a particular category. + + :param category: The category to search for entries. + :param name: If specified, only entries with that name are returned. + """ + for dist in self.get_distributions(): + r = dist.exports + if category in r: + d = r[category] + if name is not None: + if name in d: + yield d[name] + else: + for v in d.values(): + yield v + + +class Distribution(object): + """ + A base class for distributions, whether installed or from indexes. + Either way, it must have some metadata, so that's all that's needed + for construction. + """ + + build_time_dependency = False + """ + Set to True if it's known to be only a build-time dependency (i.e. + not needed after installation). + """ + + requested = False + """A boolean that indicates whether the ``REQUESTED`` metadata file is + present (in other words, whether the package was installed by user + request or it was installed as a dependency).""" + + def __init__(self, metadata): + """ + Initialise an instance. + :param metadata: The instance of :class:`Metadata` describing this + distribution. + """ + self.metadata = metadata + self.name = metadata.name + self.key = self.name.lower() # for case-insensitive comparisons + self.version = metadata.version + self.locator = None + self.digest = None + self.extras = None # additional features requested + self.context = None # environment marker overrides + self.download_urls = set() + self.digests = {} + + @property + def source_url(self): + """ + The source archive download URL for this distribution. + """ + return self.metadata.source_url + + download_url = source_url # Backward compatibility + + @property + def name_and_version(self): + """ + A utility property which displays the name and version in parentheses. + """ + return '%s (%s)' % (self.name, self.version) + + @property + def provides(self): + """ + A set of distribution names and versions provided by this distribution. + :return: A set of "name (version)" strings. + """ + plist = self.metadata.provides + s = '%s (%s)' % (self.name, self.version) + if s not in plist: + plist.append(s) + return plist + + def _get_requirements(self, req_attr): + md = self.metadata + reqts = getattr(md, req_attr) + logger.debug('%s: got requirements %r from metadata: %r', self.name, req_attr, reqts) + return set(md.get_requirements(reqts, extras=self.extras, env=self.context)) + + @property + def run_requires(self): + return self._get_requirements('run_requires') + + @property + def meta_requires(self): + return self._get_requirements('meta_requires') + + @property + def build_requires(self): + return self._get_requirements('build_requires') + + @property + def test_requires(self): + return self._get_requirements('test_requires') + + @property + def dev_requires(self): + return self._get_requirements('dev_requires') + + def matches_requirement(self, req): + """ + Say if this instance matches (fulfills) a requirement. + :param req: The requirement to match. + :rtype req: str + :return: True if it matches, else False. + """ + # Requirement may contain extras - parse to lose those + # from what's passed to the matcher + r = parse_requirement(req) + scheme = get_scheme(self.metadata.scheme) + try: + matcher = scheme.matcher(r.requirement) + except UnsupportedVersionError: + # XXX compat-mode if cannot read the version + logger.warning('could not read version %r - using name only', req) + name = req.split()[0] + matcher = scheme.matcher(name) + + name = matcher.key # case-insensitive + + result = False + for p in self.provides: + p_name, p_ver = parse_name_and_version(p) + if p_name != name: + continue + try: + result = matcher.match(p_ver) + break + except UnsupportedVersionError: + pass + return result + + def __repr__(self): + """ + Return a textual representation of this instance, + """ + if self.source_url: + suffix = ' [%s]' % self.source_url + else: + suffix = '' + return '' % (self.name, self.version, suffix) + + def __eq__(self, other): + """ + See if this distribution is the same as another. + :param other: The distribution to compare with. To be equal to one + another. distributions must have the same type, name, + version and source_url. + :return: True if it is the same, else False. + """ + if type(other) is not type(self): + result = False + else: + result = (self.name == other.name and self.version == other.version and self.source_url == other.source_url) + return result + + def __hash__(self): + """ + Compute hash in a way which matches the equality test. + """ + return hash(self.name) + hash(self.version) + hash(self.source_url) + + +class BaseInstalledDistribution(Distribution): + """ + This is the base class for installed distributions (whether PEP 376 or + legacy). + """ + + hasher = None + + def __init__(self, metadata, path, env=None): + """ + Initialise an instance. + :param metadata: An instance of :class:`Metadata` which describes the + distribution. This will normally have been initialised + from a metadata file in the ``path``. + :param path: The path of the ``.dist-info`` or ``.egg-info`` + directory for the distribution. + :param env: This is normally the :class:`DistributionPath` + instance where this distribution was found. + """ + super(BaseInstalledDistribution, self).__init__(metadata) + self.path = path + self.dist_path = env + + def get_hash(self, data, hasher=None): + """ + Get the hash of some data, using a particular hash algorithm, if + specified. + + :param data: The data to be hashed. + :type data: bytes + :param hasher: The name of a hash implementation, supported by hashlib, + or ``None``. Examples of valid values are ``'sha1'``, + ``'sha224'``, ``'sha384'``, '``sha256'``, ``'md5'`` and + ``'sha512'``. If no hasher is specified, the ``hasher`` + attribute of the :class:`InstalledDistribution` instance + is used. If the hasher is determined to be ``None``, MD5 + is used as the hashing algorithm. + :returns: The hash of the data. If a hasher was explicitly specified, + the returned hash will be prefixed with the specified hasher + followed by '='. + :rtype: str + """ + if hasher is None: + hasher = self.hasher + if hasher is None: + hasher = hashlib.md5 + prefix = '' + else: + hasher = getattr(hashlib, hasher) + prefix = '%s=' % self.hasher + digest = hasher(data).digest() + digest = base64.urlsafe_b64encode(digest).rstrip(b'=').decode('ascii') + return '%s%s' % (prefix, digest) + + +class InstalledDistribution(BaseInstalledDistribution): + """ + Created with the *path* of the ``.dist-info`` directory provided to the + constructor. It reads the metadata contained in ``pydist.json`` when it is + instantiated., or uses a passed in Metadata instance (useful for when + dry-run mode is being used). + """ + + hasher = 'sha256' + + def __init__(self, path, metadata=None, env=None): + self.modules = [] + self.finder = finder = resources.finder_for_path(path) + if finder is None: + raise ValueError('finder unavailable for %s' % path) + if env and env._cache_enabled and path in env._cache.path: + metadata = env._cache.path[path].metadata + elif metadata is None: + r = finder.find(METADATA_FILENAME) + # Temporary - for Wheel 0.23 support + if r is None: + r = finder.find(WHEEL_METADATA_FILENAME) + # Temporary - for legacy support + if r is None: + r = finder.find(LEGACY_METADATA_FILENAME) + if r is None: + raise ValueError('no %s found in %s' % (METADATA_FILENAME, path)) + with contextlib.closing(r.as_stream()) as stream: + metadata = Metadata(fileobj=stream, scheme='legacy') + + super(InstalledDistribution, self).__init__(metadata, path, env) + + if env and env._cache_enabled: + env._cache.add(self) + + r = finder.find('REQUESTED') + self.requested = r is not None + p = os.path.join(path, 'top_level.txt') + if os.path.exists(p): + with open(p, 'rb') as f: + data = f.read().decode('utf-8') + self.modules = data.splitlines() + + def __repr__(self): + return '' % (self.name, self.version, self.path) + + def __str__(self): + return "%s %s" % (self.name, self.version) + + def _get_records(self): + """ + Get the list of installed files for the distribution + :return: A list of tuples of path, hash and size. Note that hash and + size might be ``None`` for some entries. The path is exactly + as stored in the file (which is as in PEP 376). + """ + results = [] + r = self.get_distinfo_resource('RECORD') + with contextlib.closing(r.as_stream()) as stream: + with CSVReader(stream=stream) as record_reader: + # Base location is parent dir of .dist-info dir + # base_location = os.path.dirname(self.path) + # base_location = os.path.abspath(base_location) + for row in record_reader: + missing = [None for i in range(len(row), 3)] + path, checksum, size = row + missing + # if not os.path.isabs(path): + # path = path.replace('/', os.sep) + # path = os.path.join(base_location, path) + results.append((path, checksum, size)) + return results + + @cached_property + def exports(self): + """ + Return the information exported by this distribution. + :return: A dictionary of exports, mapping an export category to a dict + of :class:`ExportEntry` instances describing the individual + export entries, and keyed by name. + """ + result = {} + r = self.get_distinfo_resource(EXPORTS_FILENAME) + if r: + result = self.read_exports() + return result + + def read_exports(self): + """ + Read exports data from a file in .ini format. + + :return: A dictionary of exports, mapping an export category to a list + of :class:`ExportEntry` instances describing the individual + export entries. + """ + result = {} + r = self.get_distinfo_resource(EXPORTS_FILENAME) + if r: + with contextlib.closing(r.as_stream()) as stream: + result = read_exports(stream) + return result + + def write_exports(self, exports): + """ + Write a dictionary of exports to a file in .ini format. + :param exports: A dictionary of exports, mapping an export category to + a list of :class:`ExportEntry` instances describing the + individual export entries. + """ + rf = self.get_distinfo_file(EXPORTS_FILENAME) + with open(rf, 'w') as f: + write_exports(exports, f) + + def get_resource_path(self, relative_path): + """ + NOTE: This API may change in the future. + + Return the absolute path to a resource file with the given relative + path. + + :param relative_path: The path, relative to .dist-info, of the resource + of interest. + :return: The absolute path where the resource is to be found. + """ + r = self.get_distinfo_resource('RESOURCES') + with contextlib.closing(r.as_stream()) as stream: + with CSVReader(stream=stream) as resources_reader: + for relative, destination in resources_reader: + if relative == relative_path: + return destination + raise KeyError('no resource file with relative path %r ' + 'is installed' % relative_path) + + def list_installed_files(self): + """ + Iterates over the ``RECORD`` entries and returns a tuple + ``(path, hash, size)`` for each line. + + :returns: iterator of (path, hash, size) + """ + for result in self._get_records(): + yield result + + def write_installed_files(self, paths, prefix, dry_run=False): + """ + Writes the ``RECORD`` file, using the ``paths`` iterable passed in. Any + existing ``RECORD`` file is silently overwritten. + + prefix is used to determine when to write absolute paths. + """ + prefix = os.path.join(prefix, '') + base = os.path.dirname(self.path) + base_under_prefix = base.startswith(prefix) + base = os.path.join(base, '') + record_path = self.get_distinfo_file('RECORD') + logger.info('creating %s', record_path) + if dry_run: + return None + with CSVWriter(record_path) as writer: + for path in paths: + if os.path.isdir(path) or path.endswith(('.pyc', '.pyo')): + # do not put size and hash, as in PEP-376 + hash_value = size = '' + else: + size = '%d' % os.path.getsize(path) + with open(path, 'rb') as fp: + hash_value = self.get_hash(fp.read()) + if path.startswith(base) or (base_under_prefix and path.startswith(prefix)): + path = os.path.relpath(path, base) + writer.writerow((path, hash_value, size)) + + # add the RECORD file itself + if record_path.startswith(base): + record_path = os.path.relpath(record_path, base) + writer.writerow((record_path, '', '')) + return record_path + + def check_installed_files(self): + """ + Checks that the hashes and sizes of the files in ``RECORD`` are + matched by the files themselves. Returns a (possibly empty) list of + mismatches. Each entry in the mismatch list will be a tuple consisting + of the path, 'exists', 'size' or 'hash' according to what didn't match + (existence is checked first, then size, then hash), the expected + value and the actual value. + """ + mismatches = [] + base = os.path.dirname(self.path) + record_path = self.get_distinfo_file('RECORD') + for path, hash_value, size in self.list_installed_files(): + if not os.path.isabs(path): + path = os.path.join(base, path) + if path == record_path: + continue + if not os.path.exists(path): + mismatches.append((path, 'exists', True, False)) + elif os.path.isfile(path): + actual_size = str(os.path.getsize(path)) + if size and actual_size != size: + mismatches.append((path, 'size', size, actual_size)) + elif hash_value: + if '=' in hash_value: + hasher = hash_value.split('=', 1)[0] + else: + hasher = None + + with open(path, 'rb') as f: + actual_hash = self.get_hash(f.read(), hasher) + if actual_hash != hash_value: + mismatches.append((path, 'hash', hash_value, actual_hash)) + return mismatches + + @cached_property + def shared_locations(self): + """ + A dictionary of shared locations whose keys are in the set 'prefix', + 'purelib', 'platlib', 'scripts', 'headers', 'data' and 'namespace'. + The corresponding value is the absolute path of that category for + this distribution, and takes into account any paths selected by the + user at installation time (e.g. via command-line arguments). In the + case of the 'namespace' key, this would be a list of absolute paths + for the roots of namespace packages in this distribution. + + The first time this property is accessed, the relevant information is + read from the SHARED file in the .dist-info directory. + """ + result = {} + shared_path = os.path.join(self.path, 'SHARED') + if os.path.isfile(shared_path): + with codecs.open(shared_path, 'r', encoding='utf-8') as f: + lines = f.read().splitlines() + for line in lines: + key, value = line.split('=', 1) + if key == 'namespace': + result.setdefault(key, []).append(value) + else: + result[key] = value + return result + + def write_shared_locations(self, paths, dry_run=False): + """ + Write shared location information to the SHARED file in .dist-info. + :param paths: A dictionary as described in the documentation for + :meth:`shared_locations`. + :param dry_run: If True, the action is logged but no file is actually + written. + :return: The path of the file written to. + """ + shared_path = os.path.join(self.path, 'SHARED') + logger.info('creating %s', shared_path) + if dry_run: + return None + lines = [] + for key in ('prefix', 'lib', 'headers', 'scripts', 'data'): + path = paths[key] + if os.path.isdir(paths[key]): + lines.append('%s=%s' % (key, path)) + for ns in paths.get('namespace', ()): + lines.append('namespace=%s' % ns) + + with codecs.open(shared_path, 'w', encoding='utf-8') as f: + f.write('\n'.join(lines)) + return shared_path + + def get_distinfo_resource(self, path): + if path not in DIST_FILES: + raise DistlibException('invalid path for a dist-info file: ' + '%r at %r' % (path, self.path)) + finder = resources.finder_for_path(self.path) + if finder is None: + raise DistlibException('Unable to get a finder for %s' % self.path) + return finder.find(path) + + def get_distinfo_file(self, path): + """ + Returns a path located under the ``.dist-info`` directory. Returns a + string representing the path. + + :parameter path: a ``'/'``-separated path relative to the + ``.dist-info`` directory or an absolute path; + If *path* is an absolute path and doesn't start + with the ``.dist-info`` directory path, + a :class:`DistlibException` is raised + :type path: str + :rtype: str + """ + # Check if it is an absolute path # XXX use relpath, add tests + if path.find(os.sep) >= 0: + # it's an absolute path? + distinfo_dirname, path = path.split(os.sep)[-2:] + if distinfo_dirname != self.path.split(os.sep)[-1]: + raise DistlibException('dist-info file %r does not belong to the %r %s ' + 'distribution' % (path, self.name, self.version)) + + # The file must be relative + if path not in DIST_FILES: + raise DistlibException('invalid path for a dist-info file: ' + '%r at %r' % (path, self.path)) + + return os.path.join(self.path, path) + + def list_distinfo_files(self): + """ + Iterates over the ``RECORD`` entries and returns paths for each line if + the path is pointing to a file located in the ``.dist-info`` directory + or one of its subdirectories. + + :returns: iterator of paths + """ + base = os.path.dirname(self.path) + for path, checksum, size in self._get_records(): + # XXX add separator or use real relpath algo + if not os.path.isabs(path): + path = os.path.join(base, path) + if path.startswith(self.path): + yield path + + def __eq__(self, other): + return (isinstance(other, InstalledDistribution) and self.path == other.path) + + # See http://docs.python.org/reference/datamodel#object.__hash__ + __hash__ = object.__hash__ + + +class EggInfoDistribution(BaseInstalledDistribution): + """Created with the *path* of the ``.egg-info`` directory or file provided + to the constructor. It reads the metadata contained in the file itself, or + if the given path happens to be a directory, the metadata is read from the + file ``PKG-INFO`` under that directory.""" + + requested = True # as we have no way of knowing, assume it was + shared_locations = {} + + def __init__(self, path, env=None): + + def set_name_and_version(s, n, v): + s.name = n + s.key = n.lower() # for case-insensitive comparisons + s.version = v + + self.path = path + self.dist_path = env + if env and env._cache_enabled and path in env._cache_egg.path: + metadata = env._cache_egg.path[path].metadata + set_name_and_version(self, metadata.name, metadata.version) + else: + metadata = self._get_metadata(path) + + # Need to be set before caching + set_name_and_version(self, metadata.name, metadata.version) + + if env and env._cache_enabled: + env._cache_egg.add(self) + super(EggInfoDistribution, self).__init__(metadata, path, env) + + def _get_metadata(self, path): + requires = None + + def parse_requires_data(data): + """Create a list of dependencies from a requires.txt file. + + *data*: the contents of a setuptools-produced requires.txt file. + """ + reqs = [] + lines = data.splitlines() + for line in lines: + line = line.strip() + # sectioned files have bare newlines (separating sections) + if not line: # pragma: no cover + continue + if line.startswith('['): # pragma: no cover + logger.warning('Unexpected line: quitting requirement scan: %r', line) + break + r = parse_requirement(line) + if not r: # pragma: no cover + logger.warning('Not recognised as a requirement: %r', line) + continue + if r.extras: # pragma: no cover + logger.warning('extra requirements in requires.txt are ' + 'not supported') + if not r.constraints: + reqs.append(r.name) + else: + cons = ', '.join('%s%s' % c for c in r.constraints) + reqs.append('%s (%s)' % (r.name, cons)) + return reqs + + def parse_requires_path(req_path): + """Create a list of dependencies from a requires.txt file. + + *req_path*: the path to a setuptools-produced requires.txt file. + """ + + reqs = [] + try: + with codecs.open(req_path, 'r', 'utf-8') as fp: + reqs = parse_requires_data(fp.read()) + except IOError: + pass + return reqs + + tl_path = tl_data = None + if path.endswith('.egg'): + if os.path.isdir(path): + p = os.path.join(path, 'EGG-INFO') + meta_path = os.path.join(p, 'PKG-INFO') + metadata = Metadata(path=meta_path, scheme='legacy') + req_path = os.path.join(p, 'requires.txt') + tl_path = os.path.join(p, 'top_level.txt') + requires = parse_requires_path(req_path) + else: + # FIXME handle the case where zipfile is not available + zipf = zipimport.zipimporter(path) + fileobj = StringIO(zipf.get_data('EGG-INFO/PKG-INFO').decode('utf8')) + metadata = Metadata(fileobj=fileobj, scheme='legacy') + try: + data = zipf.get_data('EGG-INFO/requires.txt') + tl_data = zipf.get_data('EGG-INFO/top_level.txt').decode('utf-8') + requires = parse_requires_data(data.decode('utf-8')) + except IOError: + requires = None + elif path.endswith('.egg-info'): + if os.path.isdir(path): + req_path = os.path.join(path, 'requires.txt') + requires = parse_requires_path(req_path) + path = os.path.join(path, 'PKG-INFO') + tl_path = os.path.join(path, 'top_level.txt') + metadata = Metadata(path=path, scheme='legacy') + else: + raise DistlibException('path must end with .egg-info or .egg, ' + 'got %r' % path) + + if requires: + metadata.add_requirements(requires) + # look for top-level modules in top_level.txt, if present + if tl_data is None: + if tl_path is not None and os.path.exists(tl_path): + with open(tl_path, 'rb') as f: + tl_data = f.read().decode('utf-8') + if not tl_data: + tl_data = [] + else: + tl_data = tl_data.splitlines() + self.modules = tl_data + return metadata + + def __repr__(self): + return '' % (self.name, self.version, self.path) + + def __str__(self): + return "%s %s" % (self.name, self.version) + + def check_installed_files(self): + """ + Checks that the hashes and sizes of the files in ``RECORD`` are + matched by the files themselves. Returns a (possibly empty) list of + mismatches. Each entry in the mismatch list will be a tuple consisting + of the path, 'exists', 'size' or 'hash' according to what didn't match + (existence is checked first, then size, then hash), the expected + value and the actual value. + """ + mismatches = [] + record_path = os.path.join(self.path, 'installed-files.txt') + if os.path.exists(record_path): + for path, _, _ in self.list_installed_files(): + if path == record_path: + continue + if not os.path.exists(path): + mismatches.append((path, 'exists', True, False)) + return mismatches + + def list_installed_files(self): + """ + Iterates over the ``installed-files.txt`` entries and returns a tuple + ``(path, hash, size)`` for each line. + + :returns: a list of (path, hash, size) + """ + + def _md5(path): + f = open(path, 'rb') + try: + content = f.read() + finally: + f.close() + return hashlib.md5(content).hexdigest() + + def _size(path): + return os.stat(path).st_size + + record_path = os.path.join(self.path, 'installed-files.txt') + result = [] + if os.path.exists(record_path): + with codecs.open(record_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + p = os.path.normpath(os.path.join(self.path, line)) + # "./" is present as a marker between installed files + # and installation metadata files + if not os.path.exists(p): + logger.warning('Non-existent file: %s', p) + if p.endswith(('.pyc', '.pyo')): + continue + # otherwise fall through and fail + if not os.path.isdir(p): + result.append((p, _md5(p), _size(p))) + result.append((record_path, None, None)) + return result + + def list_distinfo_files(self, absolute=False): + """ + Iterates over the ``installed-files.txt`` entries and returns paths for + each line if the path is pointing to a file located in the + ``.egg-info`` directory or one of its subdirectories. + + :parameter absolute: If *absolute* is ``True``, each returned path is + transformed into a local absolute path. Otherwise the + raw value from ``installed-files.txt`` is returned. + :type absolute: boolean + :returns: iterator of paths + """ + record_path = os.path.join(self.path, 'installed-files.txt') + if os.path.exists(record_path): + skip = True + with codecs.open(record_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line == './': + skip = False + continue + if not skip: + p = os.path.normpath(os.path.join(self.path, line)) + if p.startswith(self.path): + if absolute: + yield p + else: + yield line + + def __eq__(self, other): + return (isinstance(other, EggInfoDistribution) and self.path == other.path) + + # See http://docs.python.org/reference/datamodel#object.__hash__ + __hash__ = object.__hash__ + + +new_dist_class = InstalledDistribution +old_dist_class = EggInfoDistribution + + +class DependencyGraph(object): + """ + Represents a dependency graph between distributions. + + The dependency relationships are stored in an ``adjacency_list`` that maps + distributions to a list of ``(other, label)`` tuples where ``other`` + is a distribution and the edge is labeled with ``label`` (i.e. the version + specifier, if such was provided). Also, for more efficient traversal, for + every distribution ``x``, a list of predecessors is kept in + ``reverse_list[x]``. An edge from distribution ``a`` to + distribution ``b`` means that ``a`` depends on ``b``. If any missing + dependencies are found, they are stored in ``missing``, which is a + dictionary that maps distributions to a list of requirements that were not + provided by any other distributions. + """ + + def __init__(self): + self.adjacency_list = {} + self.reverse_list = {} + self.missing = {} + + def add_distribution(self, distribution): + """Add the *distribution* to the graph. + + :type distribution: :class:`distutils2.database.InstalledDistribution` + or :class:`distutils2.database.EggInfoDistribution` + """ + self.adjacency_list[distribution] = [] + self.reverse_list[distribution] = [] + # self.missing[distribution] = [] + + def add_edge(self, x, y, label=None): + """Add an edge from distribution *x* to distribution *y* with the given + *label*. + + :type x: :class:`distutils2.database.InstalledDistribution` or + :class:`distutils2.database.EggInfoDistribution` + :type y: :class:`distutils2.database.InstalledDistribution` or + :class:`distutils2.database.EggInfoDistribution` + :type label: ``str`` or ``None`` + """ + self.adjacency_list[x].append((y, label)) + # multiple edges are allowed, so be careful + if x not in self.reverse_list[y]: + self.reverse_list[y].append(x) + + def add_missing(self, distribution, requirement): + """ + Add a missing *requirement* for the given *distribution*. + + :type distribution: :class:`distutils2.database.InstalledDistribution` + or :class:`distutils2.database.EggInfoDistribution` + :type requirement: ``str`` + """ + logger.debug('%s missing %r', distribution, requirement) + self.missing.setdefault(distribution, []).append(requirement) + + def _repr_dist(self, dist): + return '%s %s' % (dist.name, dist.version) + + def repr_node(self, dist, level=1): + """Prints only a subgraph""" + output = [self._repr_dist(dist)] + for other, label in self.adjacency_list[dist]: + dist = self._repr_dist(other) + if label is not None: + dist = '%s [%s]' % (dist, label) + output.append(' ' * level + str(dist)) + suboutput = self.repr_node(other, level + 1) + subs = suboutput.split('\n') + output.extend(subs[1:]) + return '\n'.join(output) + + def to_dot(self, f, skip_disconnected=True): + """Writes a DOT output for the graph to the provided file *f*. + + If *skip_disconnected* is set to ``True``, then all distributions + that are not dependent on any other distribution are skipped. + + :type f: has to support ``file``-like operations + :type skip_disconnected: ``bool`` + """ + disconnected = [] + + f.write("digraph dependencies {\n") + for dist, adjs in self.adjacency_list.items(): + if len(adjs) == 0 and not skip_disconnected: + disconnected.append(dist) + for other, label in adjs: + if label is not None: + f.write('"%s" -> "%s" [label="%s"]\n' % (dist.name, other.name, label)) + else: + f.write('"%s" -> "%s"\n' % (dist.name, other.name)) + if not skip_disconnected and len(disconnected) > 0: + f.write('subgraph disconnected {\n') + f.write('label = "Disconnected"\n') + f.write('bgcolor = red\n') + + for dist in disconnected: + f.write('"%s"' % dist.name) + f.write('\n') + f.write('}\n') + f.write('}\n') + + def topological_sort(self): + """ + Perform a topological sort of the graph. + :return: A tuple, the first element of which is a topologically sorted + list of distributions, and the second element of which is a + list of distributions that cannot be sorted because they have + circular dependencies and so form a cycle. + """ + result = [] + # Make a shallow copy of the adjacency list + alist = {} + for k, v in self.adjacency_list.items(): + alist[k] = v[:] + while True: + # See what we can remove in this run + to_remove = [] + for k, v in list(alist.items())[:]: + if not v: + to_remove.append(k) + del alist[k] + if not to_remove: + # What's left in alist (if anything) is a cycle. + break + # Remove from the adjacency list of others + for k, v in alist.items(): + alist[k] = [(d, r) for d, r in v if d not in to_remove] + logger.debug('Moving to result: %s', ['%s (%s)' % (d.name, d.version) for d in to_remove]) + result.extend(to_remove) + return result, list(alist.keys()) + + def __repr__(self): + """Representation of the graph""" + output = [] + for dist, adjs in self.adjacency_list.items(): + output.append(self.repr_node(dist)) + return '\n'.join(output) + + +def make_graph(dists, scheme='default'): + """Makes a dependency graph from the given distributions. + + :parameter dists: a list of distributions + :type dists: list of :class:`distutils2.database.InstalledDistribution` and + :class:`distutils2.database.EggInfoDistribution` instances + :rtype: a :class:`DependencyGraph` instance + """ + scheme = get_scheme(scheme) + graph = DependencyGraph() + provided = {} # maps names to lists of (version, dist) tuples + + # first, build the graph and find out what's provided + for dist in dists: + graph.add_distribution(dist) + + for p in dist.provides: + name, version = parse_name_and_version(p) + logger.debug('Add to provided: %s, %s, %s', name, version, dist) + provided.setdefault(name, []).append((version, dist)) + + # now make the edges + for dist in dists: + requires = (dist.run_requires | dist.meta_requires | dist.build_requires | dist.dev_requires) + for req in requires: + try: + matcher = scheme.matcher(req) + except UnsupportedVersionError: + # XXX compat-mode if cannot read the version + logger.warning('could not read version %r - using name only', req) + name = req.split()[0] + matcher = scheme.matcher(name) + + name = matcher.key # case-insensitive + + matched = False + if name in provided: + for version, provider in provided[name]: + try: + match = matcher.match(version) + except UnsupportedVersionError: + match = False + + if match: + graph.add_edge(dist, provider, req) + matched = True + break + if not matched: + graph.add_missing(dist, req) + return graph + + +def get_dependent_dists(dists, dist): + """Recursively generate a list of distributions from *dists* that are + dependent on *dist*. + + :param dists: a list of distributions + :param dist: a distribution, member of *dists* for which we are interested + """ + if dist not in dists: + raise DistlibException('given distribution %r is not a member ' + 'of the list' % dist.name) + graph = make_graph(dists) + + dep = [dist] # dependent distributions + todo = graph.reverse_list[dist] # list of nodes we should inspect + + while todo: + d = todo.pop() + dep.append(d) + for succ in graph.reverse_list[d]: + if succ not in dep: + todo.append(succ) + + dep.pop(0) # remove dist from dep, was there to prevent infinite loops + return dep + + +def get_required_dists(dists, dist): + """Recursively generate a list of distributions from *dists* that are + required by *dist*. + + :param dists: a list of distributions + :param dist: a distribution, member of *dists* for which we are interested + in finding the dependencies. + """ + if dist not in dists: + raise DistlibException('given distribution %r is not a member ' + 'of the list' % dist.name) + graph = make_graph(dists) + + req = set() # required distributions + todo = graph.adjacency_list[dist] # list of nodes we should inspect + seen = set(t[0] for t in todo) # already added to todo + + while todo: + d = todo.pop()[0] + req.add(d) + pred_list = graph.adjacency_list[d] + for pred in pred_list: + d = pred[0] + if d not in req and d not in seen: + seen.add(d) + todo.append(pred) + return req + + +def make_dist(name, version, **kwargs): + """ + A convenience method for making a dist given just a name and version. + """ + summary = kwargs.pop('summary', 'Placeholder for summary') + md = Metadata(**kwargs) + md.name = name + md.version = version + md.summary = summary or 'Placeholder for summary' + return Distribution(md) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/locators.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/locators.py new file mode 100644 index 0000000000000000000000000000000000000000..222c1bf3e90da230fbd2e0891cf2724e6f23a943 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/locators.py @@ -0,0 +1,1295 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2012-2023 Vinay Sajip. +# Licensed to the Python Software Foundation under a contributor agreement. +# See LICENSE.txt and CONTRIBUTORS.txt. +# + +import gzip +from io import BytesIO +import json +import logging +import os +import posixpath +import re +try: + import threading +except ImportError: # pragma: no cover + import dummy_threading as threading +import zlib + +from . import DistlibException +from .compat import (urljoin, urlparse, urlunparse, url2pathname, pathname2url, queue, quote, unescape, build_opener, + HTTPRedirectHandler as BaseRedirectHandler, text_type, Request, HTTPError, URLError) +from .database import Distribution, DistributionPath, make_dist +from .metadata import Metadata, MetadataInvalidError +from .util import (cached_property, ensure_slash, split_filename, get_project_data, parse_requirement, + parse_name_and_version, ServerProxy, normalize_name) +from .version import get_scheme, UnsupportedVersionError +from .wheel import Wheel, is_compatible + +logger = logging.getLogger(__name__) + +HASHER_HASH = re.compile(r'^(\w+)=([a-f0-9]+)') +CHARSET = re.compile(r';\s*charset\s*=\s*(.*)\s*$', re.I) +HTML_CONTENT_TYPE = re.compile('text/html|application/x(ht)?ml') +DEFAULT_INDEX = 'https://pypi.org/pypi' + + +def get_all_distribution_names(url=None): + """ + Return all distribution names known by an index. + :param url: The URL of the index. + :return: A list of all known distribution names. + """ + if url is None: + url = DEFAULT_INDEX + client = ServerProxy(url, timeout=3.0) + try: + return client.list_packages() + finally: + client('close')() + + +class RedirectHandler(BaseRedirectHandler): + """ + A class to work around a bug in some Python 3.2.x releases. + """ + + # There's a bug in the base version for some 3.2.x + # (e.g. 3.2.2 on Ubuntu Oneiric). If a Location header + # returns e.g. /abc, it bails because it says the scheme '' + # is bogus, when actually it should use the request's + # URL for the scheme. See Python issue #13696. + def http_error_302(self, req, fp, code, msg, headers): + # Some servers (incorrectly) return multiple Location headers + # (so probably same goes for URI). Use first header. + newurl = None + for key in ('location', 'uri'): + if key in headers: + newurl = headers[key] + break + if newurl is None: # pragma: no cover + return + urlparts = urlparse(newurl) + if urlparts.scheme == '': + newurl = urljoin(req.get_full_url(), newurl) + if hasattr(headers, 'replace_header'): + headers.replace_header(key, newurl) + else: + headers[key] = newurl + return BaseRedirectHandler.http_error_302(self, req, fp, code, msg, headers) + + http_error_301 = http_error_303 = http_error_307 = http_error_302 + + +class Locator(object): + """ + A base class for locators - things that locate distributions. + """ + source_extensions = ('.tar.gz', '.tar.bz2', '.tar', '.zip', '.tgz', '.tbz') + binary_extensions = ('.egg', '.exe', '.whl') + excluded_extensions = ('.pdf', ) + + # A list of tags indicating which wheels you want to match. The default + # value of None matches against the tags compatible with the running + # Python. If you want to match other values, set wheel_tags on a locator + # instance to a list of tuples (pyver, abi, arch) which you want to match. + wheel_tags = None + + downloadable_extensions = source_extensions + ('.whl', ) + + def __init__(self, scheme='default'): + """ + Initialise an instance. + :param scheme: Because locators look for most recent versions, they + need to know the version scheme to use. This specifies + the current PEP-recommended scheme - use ``'legacy'`` + if you need to support existing distributions on PyPI. + """ + self._cache = {} + self.scheme = scheme + # Because of bugs in some of the handlers on some of the platforms, + # we use our own opener rather than just using urlopen. + self.opener = build_opener(RedirectHandler()) + # If get_project() is called from locate(), the matcher instance + # is set from the requirement passed to locate(). See issue #18 for + # why this can be useful to know. + self.matcher = None + self.errors = queue.Queue() + + def get_errors(self): + """ + Return any errors which have occurred. + """ + result = [] + while not self.errors.empty(): # pragma: no cover + try: + e = self.errors.get(False) + result.append(e) + except self.errors.Empty: + continue + self.errors.task_done() + return result + + def clear_errors(self): + """ + Clear any errors which may have been logged. + """ + # Just get the errors and throw them away + self.get_errors() + + def clear_cache(self): + self._cache.clear() + + def _get_scheme(self): + return self._scheme + + def _set_scheme(self, value): + self._scheme = value + + scheme = property(_get_scheme, _set_scheme) + + def _get_project(self, name): + """ + For a given project, get a dictionary mapping available versions to Distribution + instances. + + This should be implemented in subclasses. + + If called from a locate() request, self.matcher will be set to a + matcher for the requirement to satisfy, otherwise it will be None. + """ + raise NotImplementedError('Please implement in the subclass') + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + raise NotImplementedError('Please implement in the subclass') + + def get_project(self, name): + """ + For a given project, get a dictionary mapping available versions to Distribution + instances. + + This calls _get_project to do all the work, and just implements a caching layer on top. + """ + if self._cache is None: # pragma: no cover + result = self._get_project(name) + elif name in self._cache: + result = self._cache[name] + else: + self.clear_errors() + result = self._get_project(name) + self._cache[name] = result + return result + + def score_url(self, url): + """ + Give an url a score which can be used to choose preferred URLs + for a given project release. + """ + t = urlparse(url) + basename = posixpath.basename(t.path) + compatible = True + is_wheel = basename.endswith('.whl') + is_downloadable = basename.endswith(self.downloadable_extensions) + if is_wheel: + compatible = is_compatible(Wheel(basename), self.wheel_tags) + return (t.scheme == 'https', 'pypi.org' in t.netloc, is_downloadable, is_wheel, compatible, basename) + + def prefer_url(self, url1, url2): + """ + Choose one of two URLs where both are candidates for distribution + archives for the same version of a distribution (for example, + .tar.gz vs. zip). + + The current implementation favours https:// URLs over http://, archives + from PyPI over those from other locations, wheel compatibility (if a + wheel) and then the archive name. + """ + result = url2 + if url1: + s1 = self.score_url(url1) + s2 = self.score_url(url2) + if s1 > s2: + result = url1 + if result != url2: + logger.debug('Not replacing %r with %r', url1, url2) + else: + logger.debug('Replacing %r with %r', url1, url2) + return result + + def split_filename(self, filename, project_name): + """ + Attempt to split a filename in project name, version and Python version. + """ + return split_filename(filename, project_name) + + def convert_url_to_download_info(self, url, project_name): + """ + See if a URL is a candidate for a download URL for a project (the URL + has typically been scraped from an HTML page). + + If it is, a dictionary is returned with keys "name", "version", + "filename" and "url"; otherwise, None is returned. + """ + + def same_project(name1, name2): + return normalize_name(name1) == normalize_name(name2) + + result = None + scheme, netloc, path, params, query, frag = urlparse(url) + if frag.lower().startswith('egg='): # pragma: no cover + logger.debug('%s: version hint in fragment: %r', project_name, frag) + m = HASHER_HASH.match(frag) + if m: + algo, digest = m.groups() + else: + algo, digest = None, None + origpath = path + if path and path[-1] == '/': # pragma: no cover + path = path[:-1] + if path.endswith('.whl'): + try: + wheel = Wheel(path) + if not is_compatible(wheel, self.wheel_tags): + logger.debug('Wheel not compatible: %s', path) + else: + if project_name is None: + include = True + else: + include = same_project(wheel.name, project_name) + if include: + result = { + 'name': wheel.name, + 'version': wheel.version, + 'filename': wheel.filename, + 'url': urlunparse((scheme, netloc, origpath, params, query, '')), + 'python-version': ', '.join(['.'.join(list(v[2:])) for v in wheel.pyver]), + } + except Exception: # pragma: no cover + logger.warning('invalid path for wheel: %s', path) + elif not path.endswith(self.downloadable_extensions): # pragma: no cover + logger.debug('Not downloadable: %s', path) + else: # downloadable extension + path = filename = posixpath.basename(path) + for ext in self.downloadable_extensions: + if path.endswith(ext): + path = path[:-len(ext)] + t = self.split_filename(path, project_name) + if not t: # pragma: no cover + logger.debug('No match for project/version: %s', path) + else: + name, version, pyver = t + if not project_name or same_project(project_name, name): + result = { + 'name': name, + 'version': version, + 'filename': filename, + 'url': urlunparse((scheme, netloc, origpath, params, query, '')), + } + if pyver: # pragma: no cover + result['python-version'] = pyver + break + if result and algo: + result['%s_digest' % algo] = digest + return result + + def _get_digest(self, info): + """ + Get a digest from a dictionary by looking at a "digests" dictionary + or keys of the form 'algo_digest'. + + Returns a 2-tuple (algo, digest) if found, else None. Currently + looks only for SHA256, then MD5. + """ + result = None + if 'digests' in info: + digests = info['digests'] + for algo in ('sha256', 'md5'): + if algo in digests: + result = (algo, digests[algo]) + break + if not result: + for algo in ('sha256', 'md5'): + key = '%s_digest' % algo + if key in info: + result = (algo, info[key]) + break + return result + + def _update_version_data(self, result, info): + """ + Update a result dictionary (the final result from _get_project) with a + dictionary for a specific version, which typically holds information + gleaned from a filename or URL for an archive for the distribution. + """ + name = info.pop('name') + version = info.pop('version') + if version in result: + dist = result[version] + md = dist.metadata + else: + dist = make_dist(name, version, scheme=self.scheme) + md = dist.metadata + dist.digest = digest = self._get_digest(info) + url = info['url'] + result['digests'][url] = digest + if md.source_url != info['url']: + md.source_url = self.prefer_url(md.source_url, url) + result['urls'].setdefault(version, set()).add(url) + dist.locator = self + result[version] = dist + + def locate(self, requirement, prereleases=False): + """ + Find the most recent distribution which matches the given + requirement. + + :param requirement: A requirement of the form 'foo (1.0)' or perhaps + 'foo (>= 1.0, < 2.0, != 1.3)' + :param prereleases: If ``True``, allow pre-release versions + to be located. Otherwise, pre-release versions + are not returned. + :return: A :class:`Distribution` instance, or ``None`` if no such + distribution could be located. + """ + result = None + r = parse_requirement(requirement) + if r is None: # pragma: no cover + raise DistlibException('Not a valid requirement: %r' % requirement) + scheme = get_scheme(self.scheme) + self.matcher = matcher = scheme.matcher(r.requirement) + logger.debug('matcher: %s (%s)', matcher, type(matcher).__name__) + versions = self.get_project(r.name) + if len(versions) > 2: # urls and digests keys are present + # sometimes, versions are invalid + slist = [] + vcls = matcher.version_class + for k in versions: + if k in ('urls', 'digests'): + continue + try: + if not matcher.match(k): + pass # logger.debug('%s did not match %r', matcher, k) + else: + if prereleases or not vcls(k).is_prerelease: + slist.append(k) + except Exception: # pragma: no cover + logger.warning('error matching %s with %r', matcher, k) + pass # slist.append(k) + if len(slist) > 1: + slist = sorted(slist, key=scheme.key) + if slist: + logger.debug('sorted list: %s', slist) + version = slist[-1] + result = versions[version] + if result: + if r.extras: + result.extras = r.extras + result.download_urls = versions.get('urls', {}).get(version, set()) + d = {} + sd = versions.get('digests', {}) + for url in result.download_urls: + if url in sd: # pragma: no cover + d[url] = sd[url] + result.digests = d + self.matcher = None + return result + + +class PyPIRPCLocator(Locator): + """ + This locator uses XML-RPC to locate distributions. It therefore + cannot be used with simple mirrors (that only mirror file content). + """ + + def __init__(self, url, **kwargs): + """ + Initialise an instance. + + :param url: The URL to use for XML-RPC. + :param kwargs: Passed to the superclass constructor. + """ + super(PyPIRPCLocator, self).__init__(**kwargs) + self.base_url = url + self.client = ServerProxy(url, timeout=3.0) + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + return set(self.client.list_packages()) + + def _get_project(self, name): + result = {'urls': {}, 'digests': {}} + versions = self.client.package_releases(name, True) + for v in versions: + urls = self.client.release_urls(name, v) + data = self.client.release_data(name, v) + metadata = Metadata(scheme=self.scheme) + metadata.name = data['name'] + metadata.version = data['version'] + metadata.license = data.get('license') + metadata.keywords = data.get('keywords', []) + metadata.summary = data.get('summary') + dist = Distribution(metadata) + if urls: + info = urls[0] + metadata.source_url = info['url'] + dist.digest = self._get_digest(info) + dist.locator = self + result[v] = dist + for info in urls: + url = info['url'] + digest = self._get_digest(info) + result['urls'].setdefault(v, set()).add(url) + result['digests'][url] = digest + return result + + +class PyPIJSONLocator(Locator): + """ + This locator uses PyPI's JSON interface. It's very limited in functionality + and probably not worth using. + """ + + def __init__(self, url, **kwargs): + super(PyPIJSONLocator, self).__init__(**kwargs) + self.base_url = ensure_slash(url) + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + raise NotImplementedError('Not available from this locator') + + def _get_project(self, name): + result = {'urls': {}, 'digests': {}} + url = urljoin(self.base_url, '%s/json' % quote(name)) + try: + resp = self.opener.open(url) + data = resp.read().decode() # for now + d = json.loads(data) + md = Metadata(scheme=self.scheme) + data = d['info'] + md.name = data['name'] + md.version = data['version'] + md.license = data.get('license') + md.keywords = data.get('keywords', []) + md.summary = data.get('summary') + dist = Distribution(md) + dist.locator = self + # urls = d['urls'] + result[md.version] = dist + for info in d['urls']: + url = info['url'] + dist.download_urls.add(url) + dist.digests[url] = self._get_digest(info) + result['urls'].setdefault(md.version, set()).add(url) + result['digests'][url] = self._get_digest(info) + # Now get other releases + for version, infos in d['releases'].items(): + if version == md.version: + continue # already done + omd = Metadata(scheme=self.scheme) + omd.name = md.name + omd.version = version + odist = Distribution(omd) + odist.locator = self + result[version] = odist + for info in infos: + url = info['url'] + odist.download_urls.add(url) + odist.digests[url] = self._get_digest(info) + result['urls'].setdefault(version, set()).add(url) + result['digests'][url] = self._get_digest(info) + + +# for info in urls: +# md.source_url = info['url'] +# dist.digest = self._get_digest(info) +# dist.locator = self +# for info in urls: +# url = info['url'] +# result['urls'].setdefault(md.version, set()).add(url) +# result['digests'][url] = self._get_digest(info) + except Exception as e: + self.errors.put(text_type(e)) + logger.exception('JSON fetch failed: %s', e) + return result + + +class Page(object): + """ + This class represents a scraped HTML page. + """ + # The following slightly hairy-looking regex just looks for the contents of + # an anchor link, which has an attribute "href" either immediately preceded + # or immediately followed by a "rel" attribute. The attribute values can be + # declared with double quotes, single quotes or no quotes - which leads to + # the length of the expression. + _href = re.compile( + """ +(rel\\s*=\\s*(?:"(?P[^"]*)"|'(?P[^']*)'|(?P[^>\\s\n]*))\\s+)? +href\\s*=\\s*(?:"(?P[^"]*)"|'(?P[^']*)'|(?P[^>\\s\n]*)) +(\\s+rel\\s*=\\s*(?:"(?P[^"]*)"|'(?P[^']*)'|(?P[^>\\s\n]*)))? +""", re.I | re.S | re.X) + _base = re.compile(r"""]+)""", re.I | re.S) + + def __init__(self, data, url): + """ + Initialise an instance with the Unicode page contents and the URL they + came from. + """ + self.data = data + self.base_url = self.url = url + m = self._base.search(self.data) + if m: + self.base_url = m.group(1) + + _clean_re = re.compile(r'[^a-z0-9$&+,/:;=?@.#%_\\|-]', re.I) + + @cached_property + def links(self): + """ + Return the URLs of all the links on a page together with information + about their "rel" attribute, for determining which ones to treat as + downloads and which ones to queue for further scraping. + """ + + def clean(url): + "Tidy up an URL." + scheme, netloc, path, params, query, frag = urlparse(url) + return urlunparse((scheme, netloc, quote(path), params, query, frag)) + + result = set() + for match in self._href.finditer(self.data): + d = match.groupdict('') + rel = (d['rel1'] or d['rel2'] or d['rel3'] or d['rel4'] or d['rel5'] or d['rel6']) + url = d['url1'] or d['url2'] or d['url3'] + url = urljoin(self.base_url, url) + url = unescape(url) + url = self._clean_re.sub(lambda m: '%%%2x' % ord(m.group(0)), url) + result.add((url, rel)) + # We sort the result, hoping to bring the most recent versions + # to the front + result = sorted(result, key=lambda t: t[0], reverse=True) + return result + + +class SimpleScrapingLocator(Locator): + """ + A locator which scrapes HTML pages to locate downloads for a distribution. + This runs multiple threads to do the I/O; performance is at least as good + as pip's PackageFinder, which works in an analogous fashion. + """ + + # These are used to deal with various Content-Encoding schemes. + decoders = { + 'deflate': zlib.decompress, + 'gzip': lambda b: gzip.GzipFile(fileobj=BytesIO(b)).read(), + 'none': lambda b: b, + } + + def __init__(self, url, timeout=None, num_workers=10, **kwargs): + """ + Initialise an instance. + :param url: The root URL to use for scraping. + :param timeout: The timeout, in seconds, to be applied to requests. + This defaults to ``None`` (no timeout specified). + :param num_workers: The number of worker threads you want to do I/O, + This defaults to 10. + :param kwargs: Passed to the superclass. + """ + super(SimpleScrapingLocator, self).__init__(**kwargs) + self.base_url = ensure_slash(url) + self.timeout = timeout + self._page_cache = {} + self._seen = set() + self._to_fetch = queue.Queue() + self._bad_hosts = set() + self.skip_externals = False + self.num_workers = num_workers + self._lock = threading.RLock() + # See issue #45: we need to be resilient when the locator is used + # in a thread, e.g. with concurrent.futures. We can't use self._lock + # as it is for coordinating our internal threads - the ones created + # in _prepare_threads. + self._gplock = threading.RLock() + self.platform_check = False # See issue #112 + + def _prepare_threads(self): + """ + Threads are created only when get_project is called, and terminate + before it returns. They are there primarily to parallelise I/O (i.e. + fetching web pages). + """ + self._threads = [] + for i in range(self.num_workers): + t = threading.Thread(target=self._fetch) + t.daemon = True + t.start() + self._threads.append(t) + + def _wait_threads(self): + """ + Tell all the threads to terminate (by sending a sentinel value) and + wait for them to do so. + """ + # Note that you need two loops, since you can't say which + # thread will get each sentinel + for t in self._threads: + self._to_fetch.put(None) # sentinel + for t in self._threads: + t.join() + self._threads = [] + + def _get_project(self, name): + result = {'urls': {}, 'digests': {}} + with self._gplock: + self.result = result + self.project_name = name + url = urljoin(self.base_url, '%s/' % quote(name)) + self._seen.clear() + self._page_cache.clear() + self._prepare_threads() + try: + logger.debug('Queueing %s', url) + self._to_fetch.put(url) + self._to_fetch.join() + finally: + self._wait_threads() + del self.result + return result + + platform_dependent = re.compile(r'\b(linux_(i\d86|x86_64|arm\w+)|' + r'win(32|_amd64)|macosx_?\d+)\b', re.I) + + def _is_platform_dependent(self, url): + """ + Does an URL refer to a platform-specific download? + """ + return self.platform_dependent.search(url) + + def _process_download(self, url): + """ + See if an URL is a suitable download for a project. + + If it is, register information in the result dictionary (for + _get_project) about the specific version it's for. + + Note that the return value isn't actually used other than as a boolean + value. + """ + if self.platform_check and self._is_platform_dependent(url): + info = None + else: + info = self.convert_url_to_download_info(url, self.project_name) + logger.debug('process_download: %s -> %s', url, info) + if info: + with self._lock: # needed because self.result is shared + self._update_version_data(self.result, info) + return info + + def _should_queue(self, link, referrer, rel): + """ + Determine whether a link URL from a referring page and with a + particular "rel" attribute should be queued for scraping. + """ + scheme, netloc, path, _, _, _ = urlparse(link) + if path.endswith(self.source_extensions + self.binary_extensions + self.excluded_extensions): + result = False + elif self.skip_externals and not link.startswith(self.base_url): + result = False + elif not referrer.startswith(self.base_url): + result = False + elif rel not in ('homepage', 'download'): + result = False + elif scheme not in ('http', 'https', 'ftp'): + result = False + elif self._is_platform_dependent(link): + result = False + else: + host = netloc.split(':', 1)[0] + if host.lower() == 'localhost': + result = False + else: + result = True + logger.debug('should_queue: %s (%s) from %s -> %s', link, rel, referrer, result) + return result + + def _fetch(self): + """ + Get a URL to fetch from the work queue, get the HTML page, examine its + links for download candidates and candidates for further scraping. + + This is a handy method to run in a thread. + """ + while True: + url = self._to_fetch.get() + try: + if url: + page = self.get_page(url) + if page is None: # e.g. after an error + continue + for link, rel in page.links: + if link not in self._seen: + try: + self._seen.add(link) + if (not self._process_download(link) and self._should_queue(link, url, rel)): + logger.debug('Queueing %s from %s', link, url) + self._to_fetch.put(link) + except MetadataInvalidError: # e.g. invalid versions + pass + except Exception as e: # pragma: no cover + self.errors.put(text_type(e)) + finally: + # always do this, to avoid hangs :-) + self._to_fetch.task_done() + if not url: + # logger.debug('Sentinel seen, quitting.') + break + + def get_page(self, url): + """ + Get the HTML for an URL, possibly from an in-memory cache. + + XXX TODO Note: this cache is never actually cleared. It's assumed that + the data won't get stale over the lifetime of a locator instance (not + necessarily true for the default_locator). + """ + # http://peak.telecommunity.com/DevCenter/EasyInstall#package-index-api + scheme, netloc, path, _, _, _ = urlparse(url) + if scheme == 'file' and os.path.isdir(url2pathname(path)): + url = urljoin(ensure_slash(url), 'index.html') + + if url in self._page_cache: + result = self._page_cache[url] + logger.debug('Returning %s from cache: %s', url, result) + else: + host = netloc.split(':', 1)[0] + result = None + if host in self._bad_hosts: + logger.debug('Skipping %s due to bad host %s', url, host) + else: + req = Request(url, headers={'Accept-encoding': 'identity'}) + try: + logger.debug('Fetching %s', url) + resp = self.opener.open(req, timeout=self.timeout) + logger.debug('Fetched %s', url) + headers = resp.info() + content_type = headers.get('Content-Type', '') + if HTML_CONTENT_TYPE.match(content_type): + final_url = resp.geturl() + data = resp.read() + encoding = headers.get('Content-Encoding') + if encoding: + decoder = self.decoders[encoding] # fail if not found + data = decoder(data) + encoding = 'utf-8' + m = CHARSET.search(content_type) + if m: + encoding = m.group(1) + try: + data = data.decode(encoding) + except UnicodeError: # pragma: no cover + data = data.decode('latin-1') # fallback + result = Page(data, final_url) + self._page_cache[final_url] = result + except HTTPError as e: + if e.code != 404: + logger.exception('Fetch failed: %s: %s', url, e) + except URLError as e: # pragma: no cover + logger.exception('Fetch failed: %s: %s', url, e) + with self._lock: + self._bad_hosts.add(host) + except Exception as e: # pragma: no cover + logger.exception('Fetch failed: %s: %s', url, e) + finally: + self._page_cache[url] = result # even if None (failure) + return result + + _distname_re = re.compile(']*>([^<]+)<') + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + result = set() + page = self.get_page(self.base_url) + if not page: + raise DistlibException('Unable to get %s' % self.base_url) + for match in self._distname_re.finditer(page.data): + result.add(match.group(1)) + return result + + +class DirectoryLocator(Locator): + """ + This class locates distributions in a directory tree. + """ + + def __init__(self, path, **kwargs): + """ + Initialise an instance. + :param path: The root of the directory tree to search. + :param kwargs: Passed to the superclass constructor, + except for: + * recursive - if True (the default), subdirectories are + recursed into. If False, only the top-level directory + is searched, + """ + self.recursive = kwargs.pop('recursive', True) + super(DirectoryLocator, self).__init__(**kwargs) + path = os.path.abspath(path) + if not os.path.isdir(path): # pragma: no cover + raise DistlibException('Not a directory: %r' % path) + self.base_dir = path + + def should_include(self, filename, parent): + """ + Should a filename be considered as a candidate for a distribution + archive? As well as the filename, the directory which contains it + is provided, though not used by the current implementation. + """ + return filename.endswith(self.downloadable_extensions) + + def _get_project(self, name): + result = {'urls': {}, 'digests': {}} + for root, dirs, files in os.walk(self.base_dir): + for fn in files: + if self.should_include(fn, root): + fn = os.path.join(root, fn) + url = urlunparse(('file', '', pathname2url(os.path.abspath(fn)), '', '', '')) + info = self.convert_url_to_download_info(url, name) + if info: + self._update_version_data(result, info) + if not self.recursive: + break + return result + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + result = set() + for root, dirs, files in os.walk(self.base_dir): + for fn in files: + if self.should_include(fn, root): + fn = os.path.join(root, fn) + url = urlunparse(('file', '', pathname2url(os.path.abspath(fn)), '', '', '')) + info = self.convert_url_to_download_info(url, None) + if info: + result.add(info['name']) + if not self.recursive: + break + return result + + +class JSONLocator(Locator): + """ + This locator uses special extended metadata (not available on PyPI) and is + the basis of performant dependency resolution in distlib. Other locators + require archive downloads before dependencies can be determined! As you + might imagine, that can be slow. + """ + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + raise NotImplementedError('Not available from this locator') + + def _get_project(self, name): + result = {'urls': {}, 'digests': {}} + data = get_project_data(name) + if data: + for info in data.get('files', []): + if info['ptype'] != 'sdist' or info['pyversion'] != 'source': + continue + # We don't store summary in project metadata as it makes + # the data bigger for no benefit during dependency + # resolution + dist = make_dist(data['name'], + info['version'], + summary=data.get('summary', 'Placeholder for summary'), + scheme=self.scheme) + md = dist.metadata + md.source_url = info['url'] + # TODO SHA256 digest + if 'digest' in info and info['digest']: + dist.digest = ('md5', info['digest']) + md.dependencies = info.get('requirements', {}) + dist.exports = info.get('exports', {}) + result[dist.version] = dist + result['urls'].setdefault(dist.version, set()).add(info['url']) + return result + + +class DistPathLocator(Locator): + """ + This locator finds installed distributions in a path. It can be useful for + adding to an :class:`AggregatingLocator`. + """ + + def __init__(self, distpath, **kwargs): + """ + Initialise an instance. + + :param distpath: A :class:`DistributionPath` instance to search. + """ + super(DistPathLocator, self).__init__(**kwargs) + assert isinstance(distpath, DistributionPath) + self.distpath = distpath + + def _get_project(self, name): + dist = self.distpath.get_distribution(name) + if dist is None: + result = {'urls': {}, 'digests': {}} + else: + result = { + dist.version: dist, + 'urls': { + dist.version: set([dist.source_url]) + }, + 'digests': { + dist.version: set([None]) + } + } + return result + + +class AggregatingLocator(Locator): + """ + This class allows you to chain and/or merge a list of locators. + """ + + def __init__(self, *locators, **kwargs): + """ + Initialise an instance. + + :param locators: The list of locators to search. + :param kwargs: Passed to the superclass constructor, + except for: + * merge - if False (the default), the first successful + search from any of the locators is returned. If True, + the results from all locators are merged (this can be + slow). + """ + self.merge = kwargs.pop('merge', False) + self.locators = locators + super(AggregatingLocator, self).__init__(**kwargs) + + def clear_cache(self): + super(AggregatingLocator, self).clear_cache() + for locator in self.locators: + locator.clear_cache() + + def _set_scheme(self, value): + self._scheme = value + for locator in self.locators: + locator.scheme = value + + scheme = property(Locator.scheme.fget, _set_scheme) + + def _get_project(self, name): + result = {} + for locator in self.locators: + d = locator.get_project(name) + if d: + if self.merge: + files = result.get('urls', {}) + digests = result.get('digests', {}) + # next line could overwrite result['urls'], result['digests'] + result.update(d) + df = result.get('urls') + if files and df: + for k, v in files.items(): + if k in df: + df[k] |= v + else: + df[k] = v + dd = result.get('digests') + if digests and dd: + dd.update(digests) + else: + # See issue #18. If any dists are found and we're looking + # for specific constraints, we only return something if + # a match is found. For example, if a DirectoryLocator + # returns just foo (1.0) while we're looking for + # foo (>= 2.0), we'll pretend there was nothing there so + # that subsequent locators can be queried. Otherwise we + # would just return foo (1.0) which would then lead to a + # failure to find foo (>= 2.0), because other locators + # weren't searched. Note that this only matters when + # merge=False. + if self.matcher is None: + found = True + else: + found = False + for k in d: + if self.matcher.match(k): + found = True + break + if found: + result = d + break + return result + + def get_distribution_names(self): + """ + Return all the distribution names known to this locator. + """ + result = set() + for locator in self.locators: + try: + result |= locator.get_distribution_names() + except NotImplementedError: + pass + return result + + +# We use a legacy scheme simply because most of the dists on PyPI use legacy +# versions which don't conform to PEP 440. +default_locator = AggregatingLocator( + # JSONLocator(), # don't use as PEP 426 is withdrawn + SimpleScrapingLocator('https://pypi.org/simple/', timeout=3.0), + scheme='legacy') + +locate = default_locator.locate + + +class DependencyFinder(object): + """ + Locate dependencies for distributions. + """ + + def __init__(self, locator=None): + """ + Initialise an instance, using the specified locator + to locate distributions. + """ + self.locator = locator or default_locator + self.scheme = get_scheme(self.locator.scheme) + + def add_distribution(self, dist): + """ + Add a distribution to the finder. This will update internal information + about who provides what. + :param dist: The distribution to add. + """ + logger.debug('adding distribution %s', dist) + name = dist.key + self.dists_by_name[name] = dist + self.dists[(name, dist.version)] = dist + for p in dist.provides: + name, version = parse_name_and_version(p) + logger.debug('Add to provided: %s, %s, %s', name, version, dist) + self.provided.setdefault(name, set()).add((version, dist)) + + def remove_distribution(self, dist): + """ + Remove a distribution from the finder. This will update internal + information about who provides what. + :param dist: The distribution to remove. + """ + logger.debug('removing distribution %s', dist) + name = dist.key + del self.dists_by_name[name] + del self.dists[(name, dist.version)] + for p in dist.provides: + name, version = parse_name_and_version(p) + logger.debug('Remove from provided: %s, %s, %s', name, version, dist) + s = self.provided[name] + s.remove((version, dist)) + if not s: + del self.provided[name] + + def get_matcher(self, reqt): + """ + Get a version matcher for a requirement. + :param reqt: The requirement + :type reqt: str + :return: A version matcher (an instance of + :class:`distlib.version.Matcher`). + """ + try: + matcher = self.scheme.matcher(reqt) + except UnsupportedVersionError: # pragma: no cover + # XXX compat-mode if cannot read the version + name = reqt.split()[0] + matcher = self.scheme.matcher(name) + return matcher + + def find_providers(self, reqt): + """ + Find the distributions which can fulfill a requirement. + + :param reqt: The requirement. + :type reqt: str + :return: A set of distribution which can fulfill the requirement. + """ + matcher = self.get_matcher(reqt) + name = matcher.key # case-insensitive + result = set() + provided = self.provided + if name in provided: + for version, provider in provided[name]: + try: + match = matcher.match(version) + except UnsupportedVersionError: + match = False + + if match: + result.add(provider) + break + return result + + def try_to_replace(self, provider, other, problems): + """ + Attempt to replace one provider with another. This is typically used + when resolving dependencies from multiple sources, e.g. A requires + (B >= 1.0) while C requires (B >= 1.1). + + For successful replacement, ``provider`` must meet all the requirements + which ``other`` fulfills. + + :param provider: The provider we are trying to replace with. + :param other: The provider we're trying to replace. + :param problems: If False is returned, this will contain what + problems prevented replacement. This is currently + a tuple of the literal string 'cantreplace', + ``provider``, ``other`` and the set of requirements + that ``provider`` couldn't fulfill. + :return: True if we can replace ``other`` with ``provider``, else + False. + """ + rlist = self.reqts[other] + unmatched = set() + for s in rlist: + matcher = self.get_matcher(s) + if not matcher.match(provider.version): + unmatched.add(s) + if unmatched: + # can't replace other with provider + problems.add(('cantreplace', provider, other, frozenset(unmatched))) + result = False + else: + # can replace other with provider + self.remove_distribution(other) + del self.reqts[other] + for s in rlist: + self.reqts.setdefault(provider, set()).add(s) + self.add_distribution(provider) + result = True + return result + + def find(self, requirement, meta_extras=None, prereleases=False): + """ + Find a distribution and all distributions it depends on. + + :param requirement: The requirement specifying the distribution to + find, or a Distribution instance. + :param meta_extras: A list of meta extras such as :test:, :build: and + so on. + :param prereleases: If ``True``, allow pre-release versions to be + returned - otherwise, don't return prereleases + unless they're all that's available. + + Return a set of :class:`Distribution` instances and a set of + problems. + + The distributions returned should be such that they have the + :attr:`required` attribute set to ``True`` if they were + from the ``requirement`` passed to ``find()``, and they have the + :attr:`build_time_dependency` attribute set to ``True`` unless they + are post-installation dependencies of the ``requirement``. + + The problems should be a tuple consisting of the string + ``'unsatisfied'`` and the requirement which couldn't be satisfied + by any distribution known to the locator. + """ + + self.provided = {} + self.dists = {} + self.dists_by_name = {} + self.reqts = {} + + meta_extras = set(meta_extras or []) + if ':*:' in meta_extras: + meta_extras.remove(':*:') + # :meta: and :run: are implicitly included + meta_extras |= set([':test:', ':build:', ':dev:']) + + if isinstance(requirement, Distribution): + dist = odist = requirement + logger.debug('passed %s as requirement', odist) + else: + dist = odist = self.locator.locate(requirement, prereleases=prereleases) + if dist is None: + raise DistlibException('Unable to locate %r' % requirement) + logger.debug('located %s', odist) + dist.requested = True + problems = set() + todo = set([dist]) + install_dists = set([odist]) + while todo: + dist = todo.pop() + name = dist.key # case-insensitive + if name not in self.dists_by_name: + self.add_distribution(dist) + else: + # import pdb; pdb.set_trace() + other = self.dists_by_name[name] + if other != dist: + self.try_to_replace(dist, other, problems) + + ireqts = dist.run_requires | dist.meta_requires + sreqts = dist.build_requires + ereqts = set() + if meta_extras and dist in install_dists: + for key in ('test', 'build', 'dev'): + e = ':%s:' % key + if e in meta_extras: + ereqts |= getattr(dist, '%s_requires' % key) + all_reqts = ireqts | sreqts | ereqts + for r in all_reqts: + providers = self.find_providers(r) + if not providers: + logger.debug('No providers found for %r', r) + provider = self.locator.locate(r, prereleases=prereleases) + # If no provider is found and we didn't consider + # prereleases, consider them now. + if provider is None and not prereleases: + provider = self.locator.locate(r, prereleases=True) + if provider is None: + logger.debug('Cannot satisfy %r', r) + problems.add(('unsatisfied', r)) + else: + n, v = provider.key, provider.version + if (n, v) not in self.dists: + todo.add(provider) + providers.add(provider) + if r in ireqts and dist in install_dists: + install_dists.add(provider) + logger.debug('Adding %s to install_dists', provider.name_and_version) + for p in providers: + name = p.key + if name not in self.dists_by_name: + self.reqts.setdefault(p, set()).add(r) + else: + other = self.dists_by_name[name] + if other != p: + # see if other can be replaced by p + self.try_to_replace(p, other, problems) + + dists = set(self.dists.values()) + for dist in dists: + dist.build_time_dependency = dist not in install_dists + if dist.build_time_dependency: + logger.debug('%s is a build-time dependency only.', dist.name_and_version) + logger.debug('find done for %s', odist) + return dists, problems diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t32.exe b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t32.exe new file mode 100644 index 0000000000000000000000000000000000000000..52154f0be32cc2bdbf98af131d477900667d0abd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t32.exe differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__main__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..efb7fb79bf64bbe58d3a92cd14ac67562dba26d4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__main__.py @@ -0,0 +1,273 @@ +import colorsys +import io +from time import process_time + +from pip._vendor.rich import box +from pip._vendor.rich.color import Color +from pip._vendor.rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult +from pip._vendor.rich.markdown import Markdown +from pip._vendor.rich.measure import Measurement +from pip._vendor.rich.pretty import Pretty +from pip._vendor.rich.segment import Segment +from pip._vendor.rich.style import Style +from pip._vendor.rich.syntax import Syntax +from pip._vendor.rich.table import Table +from pip._vendor.rich.text import Text + + +class ColorBox: + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + for y in range(0, 5): + for x in range(options.max_width): + h = x / options.max_width + l = 0.1 + ((y / 5) * 0.7) + r1, g1, b1 = colorsys.hls_to_rgb(h, l, 1.0) + r2, g2, b2 = colorsys.hls_to_rgb(h, l + 0.7 / 10, 1.0) + bgcolor = Color.from_rgb(r1 * 255, g1 * 255, b1 * 255) + color = Color.from_rgb(r2 * 255, g2 * 255, b2 * 255) + yield Segment("▄", Style(color=color, bgcolor=bgcolor)) + yield Segment.line() + + def __rich_measure__( + self, console: "Console", options: ConsoleOptions + ) -> Measurement: + return Measurement(1, options.max_width) + + +def make_test_card() -> Table: + """Get a renderable that demonstrates a number of features.""" + table = Table.grid(padding=1, pad_edge=True) + table.title = "Rich features" + table.add_column("Feature", no_wrap=True, justify="center", style="bold red") + table.add_column("Demonstration") + + color_table = Table( + box=None, + expand=False, + show_header=False, + show_edge=False, + pad_edge=False, + ) + color_table.add_row( + ( + "✓ [bold green]4-bit color[/]\n" + "✓ [bold blue]8-bit color[/]\n" + "✓ [bold magenta]Truecolor (16.7 million)[/]\n" + "✓ [bold yellow]Dumb terminals[/]\n" + "✓ [bold cyan]Automatic color conversion" + ), + ColorBox(), + ) + + table.add_row("Colors", color_table) + + table.add_row( + "Styles", + "All ansi styles: [bold]bold[/], [dim]dim[/], [italic]italic[/italic], [underline]underline[/], [strike]strikethrough[/], [reverse]reverse[/], and even [blink]blink[/].", + ) + + lorem = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque in metus sed sapien ultricies pretium a at justo. Maecenas luctus velit et auctor maximus." + lorem_table = Table.grid(padding=1, collapse_padding=True) + lorem_table.pad_edge = False + lorem_table.add_row( + Text(lorem, justify="left", style="green"), + Text(lorem, justify="center", style="yellow"), + Text(lorem, justify="right", style="blue"), + Text(lorem, justify="full", style="red"), + ) + table.add_row( + "Text", + Group( + Text.from_markup( + """Word wrap text. Justify [green]left[/], [yellow]center[/], [blue]right[/] or [red]full[/].\n""" + ), + lorem_table, + ), + ) + + def comparison(renderable1: RenderableType, renderable2: RenderableType) -> Table: + table = Table(show_header=False, pad_edge=False, box=None, expand=True) + table.add_column("1", ratio=1) + table.add_column("2", ratio=1) + table.add_row(renderable1, renderable2) + return table + + table.add_row( + "Asian\nlanguage\nsupport", + ":flag_for_china: 该库支持中文,日文和韩文文本!\n:flag_for_japan: ライブラリは中国語、日本語、韓国語のテキストをサポートしています\n:flag_for_south_korea: 이 라이브러리는 중국어, 일본어 및 한국어 텍스트를 지원합니다", + ) + + markup_example = ( + "[bold magenta]Rich[/] supports a simple [i]bbcode[/i]-like [b]markup[/b] for [yellow]color[/], [underline]style[/], and emoji! " + ":+1: :apple: :ant: :bear: :baguette_bread: :bus: " + ) + table.add_row("Markup", markup_example) + + example_table = Table( + show_edge=False, + show_header=True, + expand=False, + row_styles=["none", "dim"], + box=box.SIMPLE, + ) + example_table.add_column("[green]Date", style="green", no_wrap=True) + example_table.add_column("[blue]Title", style="blue") + example_table.add_column( + "[cyan]Production Budget", + style="cyan", + justify="right", + no_wrap=True, + ) + example_table.add_column( + "[magenta]Box Office", + style="magenta", + justify="right", + no_wrap=True, + ) + example_table.add_row( + "Dec 20, 2019", + "Star Wars: The Rise of Skywalker", + "$275,000,000", + "$375,126,118", + ) + example_table.add_row( + "May 25, 2018", + "[b]Solo[/]: A Star Wars Story", + "$275,000,000", + "$393,151,347", + ) + example_table.add_row( + "Dec 15, 2017", + "Star Wars Ep. VIII: The Last Jedi", + "$262,000,000", + "[bold]$1,332,539,889[/bold]", + ) + example_table.add_row( + "May 19, 1999", + "Star Wars Ep. [b]I[/b]: [i]The phantom Menace", + "$115,000,000", + "$1,027,044,677", + ) + + table.add_row("Tables", example_table) + + code = '''\ +def iter_last(values: Iterable[T]) -> Iterable[Tuple[bool, T]]: + """Iterate and generate a tuple with a flag for last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + for value in iter_values: + yield False, previous_value + previous_value = value + yield True, previous_value''' + + pretty_data = { + "foo": [ + 3.1427, + ( + "Paul Atreides", + "Vladimir Harkonnen", + "Thufir Hawat", + ), + ], + "atomic": (False, True, None), + } + table.add_row( + "Syntax\nhighlighting\n&\npretty\nprinting", + comparison( + Syntax(code, "python3", line_numbers=True, indent_guides=True), + Pretty(pretty_data, indent_guides=True), + ), + ) + + markdown_example = """\ +# Markdown + +Supports much of the *markdown* __syntax__! + +- Headers +- Basic formatting: **bold**, *italic*, `code` +- Block quotes +- Lists, and more... + """ + table.add_row( + "Markdown", comparison("[cyan]" + markdown_example, Markdown(markdown_example)) + ) + + table.add_row( + "+more!", + """Progress bars, columns, styled logging handler, tracebacks, etc...""", + ) + return table + + +if __name__ == "__main__": # pragma: no cover + console = Console( + file=io.StringIO(), + force_terminal=True, + ) + test_card = make_test_card() + + # Print once to warm cache + start = process_time() + console.print(test_card) + pre_cache_taken = round((process_time() - start) * 1000.0, 1) + + console.file = io.StringIO() + + start = process_time() + console.print(test_card) + taken = round((process_time() - start) * 1000.0, 1) + + c = Console(record=True) + c.print(test_card) + + print(f"rendered in {pre_cache_taken}ms (cold cache)") + print(f"rendered in {taken}ms (warm cache)") + + from pip._vendor.rich.panel import Panel + + console = Console() + + sponsor_message = Table.grid(padding=1) + sponsor_message.add_column(style="green", justify="right") + sponsor_message.add_column(no_wrap=True) + + sponsor_message.add_row( + "Textualize", + "[u blue link=https://github.com/textualize]https://github.com/textualize", + ) + sponsor_message.add_row( + "Twitter", + "[u blue link=https://twitter.com/willmcgugan]https://twitter.com/willmcgugan", + ) + + intro_message = Text.from_markup( + """\ +We hope you enjoy using Rich! + +Rich is maintained with [red]:heart:[/] by [link=https://www.textualize.io]Textualize.io[/] + +- Will McGugan""" + ) + + message = Table.grid(padding=2) + message.add_column() + message.add_column(no_wrap=True) + message.add_row(intro_message, sponsor_message) + + console.print( + Panel.fit( + message, + box=box.ROUNDED, + padding=(1, 2), + title="[b red]Thanks for trying out Rich!", + border_style="bright_blue", + ), + justify="center", + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_timer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ca6be03c43054caaa3660998273ebf704345dd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_timer.py @@ -0,0 +1,19 @@ +""" +Timer context manager, only used in debug. + +""" + +from time import time + +import contextlib +from typing import Generator + + +@contextlib.contextmanager +def timer(subject: str = "time") -> Generator[None, None, None]: + """print the elapsed time. (only used in debugging)""" + start = time() + yield + elapsed = time() - start + elapsed_ms = elapsed * 1000 + print(f"{subject} elapsed {elapsed_ms:.1f}ms") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/align.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/align.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b734fd728b9e3758379c6812eeadcf36174b69 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/align.py @@ -0,0 +1,311 @@ +import sys +from itertools import chain +from typing import TYPE_CHECKING, Iterable, Optional + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from pip._vendor.typing_extensions import Literal # pragma: no cover + +from .constrain import Constrain +from .jupyter import JupyterMixin +from .measure import Measurement +from .segment import Segment +from .style import StyleType + +if TYPE_CHECKING: + from .console import Console, ConsoleOptions, RenderableType, RenderResult + +AlignMethod = Literal["left", "center", "right"] +VerticalAlignMethod = Literal["top", "middle", "bottom"] + + +class Align(JupyterMixin): + """Align a renderable by adding spaces if necessary. + + Args: + renderable (RenderableType): A console renderable. + align (AlignMethod): One of "left", "center", or "right"" + style (StyleType, optional): An optional style to apply to the background. + vertical (Optional[VerticalAlignMethod], optional): Optional vertical align, one of "top", "middle", or "bottom". Defaults to None. + pad (bool, optional): Pad the right with spaces. Defaults to True. + width (int, optional): Restrict contents to given width, or None to use default width. Defaults to None. + height (int, optional): Set height of align renderable, or None to fit to contents. Defaults to None. + + Raises: + ValueError: if ``align`` is not one of the expected values. + """ + + def __init__( + self, + renderable: "RenderableType", + align: AlignMethod = "left", + style: Optional[StyleType] = None, + *, + vertical: Optional[VerticalAlignMethod] = None, + pad: bool = True, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + if align not in ("left", "center", "right"): + raise ValueError( + f'invalid value for align, expected "left", "center", or "right" (not {align!r})' + ) + if vertical is not None and vertical not in ("top", "middle", "bottom"): + raise ValueError( + f'invalid value for vertical, expected "top", "middle", or "bottom" (not {vertical!r})' + ) + self.renderable = renderable + self.align = align + self.style = style + self.vertical = vertical + self.pad = pad + self.width = width + self.height = height + + def __repr__(self) -> str: + return f"Align({self.renderable!r}, {self.align!r})" + + @classmethod + def left( + cls, + renderable: "RenderableType", + style: Optional[StyleType] = None, + *, + vertical: Optional[VerticalAlignMethod] = None, + pad: bool = True, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> "Align": + """Align a renderable to the left.""" + return cls( + renderable, + "left", + style=style, + vertical=vertical, + pad=pad, + width=width, + height=height, + ) + + @classmethod + def center( + cls, + renderable: "RenderableType", + style: Optional[StyleType] = None, + *, + vertical: Optional[VerticalAlignMethod] = None, + pad: bool = True, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> "Align": + """Align a renderable to the center.""" + return cls( + renderable, + "center", + style=style, + vertical=vertical, + pad=pad, + width=width, + height=height, + ) + + @classmethod + def right( + cls, + renderable: "RenderableType", + style: Optional[StyleType] = None, + *, + vertical: Optional[VerticalAlignMethod] = None, + pad: bool = True, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> "Align": + """Align a renderable to the right.""" + return cls( + renderable, + "right", + style=style, + vertical=vertical, + pad=pad, + width=width, + height=height, + ) + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + align = self.align + width = console.measure(self.renderable, options=options).maximum + rendered = console.render( + Constrain( + self.renderable, width if self.width is None else min(width, self.width) + ), + options.update(height=None), + ) + lines = list(Segment.split_lines(rendered)) + width, height = Segment.get_shape(lines) + lines = Segment.set_shape(lines, width, height) + new_line = Segment.line() + excess_space = options.max_width - width + style = console.get_style(self.style) if self.style is not None else None + + def generate_segments() -> Iterable[Segment]: + if excess_space <= 0: + # Exact fit + for line in lines: + yield from line + yield new_line + + elif align == "left": + # Pad on the right + pad = Segment(" " * excess_space, style) if self.pad else None + for line in lines: + yield from line + if pad: + yield pad + yield new_line + + elif align == "center": + # Pad left and right + left = excess_space // 2 + pad = Segment(" " * left, style) + pad_right = ( + Segment(" " * (excess_space - left), style) if self.pad else None + ) + for line in lines: + if left: + yield pad + yield from line + if pad_right: + yield pad_right + yield new_line + + elif align == "right": + # Padding on left + pad = Segment(" " * excess_space, style) + for line in lines: + yield pad + yield from line + yield new_line + + blank_line = ( + Segment(f"{' ' * (self.width or options.max_width)}\n", style) + if self.pad + else Segment("\n") + ) + + def blank_lines(count: int) -> Iterable[Segment]: + if count > 0: + for _ in range(count): + yield blank_line + + vertical_height = self.height or options.height + iter_segments: Iterable[Segment] + if self.vertical and vertical_height is not None: + if self.vertical == "top": + bottom_space = vertical_height - height + iter_segments = chain(generate_segments(), blank_lines(bottom_space)) + elif self.vertical == "middle": + top_space = (vertical_height - height) // 2 + bottom_space = vertical_height - top_space - height + iter_segments = chain( + blank_lines(top_space), + generate_segments(), + blank_lines(bottom_space), + ) + else: # self.vertical == "bottom": + top_space = vertical_height - height + iter_segments = chain(blank_lines(top_space), generate_segments()) + else: + iter_segments = generate_segments() + if self.style: + style = console.get_style(self.style) + iter_segments = Segment.apply_style(iter_segments, style) + yield from iter_segments + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + measurement = Measurement.get(console, options, self.renderable) + return measurement + + +class VerticalCenter(JupyterMixin): + """Vertically aligns a renderable. + + Warn: + This class is deprecated and may be removed in a future version. Use Align class with + `vertical="middle"`. + + Args: + renderable (RenderableType): A renderable object. + """ + + def __init__( + self, + renderable: "RenderableType", + style: Optional[StyleType] = None, + ) -> None: + self.renderable = renderable + self.style = style + + def __repr__(self) -> str: + return f"VerticalCenter({self.renderable!r})" + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + style = console.get_style(self.style) if self.style is not None else None + lines = console.render_lines( + self.renderable, options.update(height=None), pad=False + ) + width, _height = Segment.get_shape(lines) + new_line = Segment.line() + height = options.height or options.size.height + top_space = (height - len(lines)) // 2 + bottom_space = height - top_space - len(lines) + blank_line = Segment(f"{' ' * width}", style) + + def blank_lines(count: int) -> Iterable[Segment]: + for _ in range(count): + yield blank_line + yield new_line + + if top_space > 0: + yield from blank_lines(top_space) + for line in lines: + yield from line + yield new_line + if bottom_space > 0: + yield from blank_lines(bottom_space) + + def __rich_measure__( + self, console: "Console", options: "ConsoleOptions" + ) -> Measurement: + measurement = Measurement.get(console, options, self.renderable) + return measurement + + +if __name__ == "__main__": # pragma: no cover + from pip._vendor.rich.console import Console, Group + from pip._vendor.rich.highlighter import ReprHighlighter + from pip._vendor.rich.panel import Panel + + highlighter = ReprHighlighter() + console = Console() + + panel = Panel( + Group( + Align.left(highlighter("align='left'")), + Align.center(highlighter("align='center'")), + Align.right(highlighter("align='right'")), + ), + width=60, + style="on dark_blue", + title="Align", + ) + + console.print( + Align.center(panel, vertical="middle", style="on red", height=console.height) + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/box.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/box.py new file mode 100644 index 0000000000000000000000000000000000000000..0511a9e48bae3df44289f5ab51dd711db42fe722 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/box.py @@ -0,0 +1,480 @@ +import sys +from typing import TYPE_CHECKING, Iterable, List + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from pip._vendor.typing_extensions import Literal # pragma: no cover + + +from ._loop import loop_last + +if TYPE_CHECKING: + from pip._vendor.rich.console import ConsoleOptions + + +class Box: + """Defines characters to render boxes. + + ┌─┬┐ top + │ ││ head + ├─┼┤ head_row + │ ││ mid + ├─┼┤ row + ├─┼┤ foot_row + │ ││ foot + └─┴┘ bottom + + Args: + box (str): Characters making up box. + ascii (bool, optional): True if this box uses ascii characters only. Default is False. + """ + + def __init__(self, box: str, *, ascii: bool = False) -> None: + self._box = box + self.ascii = ascii + line1, line2, line3, line4, line5, line6, line7, line8 = box.splitlines() + # top + self.top_left, self.top, self.top_divider, self.top_right = iter(line1) + # head + self.head_left, _, self.head_vertical, self.head_right = iter(line2) + # head_row + ( + self.head_row_left, + self.head_row_horizontal, + self.head_row_cross, + self.head_row_right, + ) = iter(line3) + + # mid + self.mid_left, _, self.mid_vertical, self.mid_right = iter(line4) + # row + self.row_left, self.row_horizontal, self.row_cross, self.row_right = iter(line5) + # foot_row + ( + self.foot_row_left, + self.foot_row_horizontal, + self.foot_row_cross, + self.foot_row_right, + ) = iter(line6) + # foot + self.foot_left, _, self.foot_vertical, self.foot_right = iter(line7) + # bottom + self.bottom_left, self.bottom, self.bottom_divider, self.bottom_right = iter( + line8 + ) + + def __repr__(self) -> str: + return "Box(...)" + + def __str__(self) -> str: + return self._box + + def substitute(self, options: "ConsoleOptions", safe: bool = True) -> "Box": + """Substitute this box for another if it won't render due to platform issues. + + Args: + options (ConsoleOptions): Console options used in rendering. + safe (bool, optional): Substitute this for another Box if there are known problems + displaying on the platform (currently only relevant on Windows). Default is True. + + Returns: + Box: A different Box or the same Box. + """ + box = self + if options.legacy_windows and safe: + box = LEGACY_WINDOWS_SUBSTITUTIONS.get(box, box) + if options.ascii_only and not box.ascii: + box = ASCII + return box + + def get_plain_headed_box(self) -> "Box": + """If this box uses special characters for the borders of the header, then + return the equivalent box that does not. + + Returns: + Box: The most similar Box that doesn't use header-specific box characters. + If the current Box already satisfies this criterion, then it's returned. + """ + return PLAIN_HEADED_SUBSTITUTIONS.get(self, self) + + def get_top(self, widths: Iterable[int]) -> str: + """Get the top of a simple box. + + Args: + widths (List[int]): Widths of columns. + + Returns: + str: A string of box characters. + """ + + parts: List[str] = [] + append = parts.append + append(self.top_left) + for last, width in loop_last(widths): + append(self.top * width) + if not last: + append(self.top_divider) + append(self.top_right) + return "".join(parts) + + def get_row( + self, + widths: Iterable[int], + level: Literal["head", "row", "foot", "mid"] = "row", + edge: bool = True, + ) -> str: + """Get the top of a simple box. + + Args: + width (List[int]): Widths of columns. + + Returns: + str: A string of box characters. + """ + if level == "head": + left = self.head_row_left + horizontal = self.head_row_horizontal + cross = self.head_row_cross + right = self.head_row_right + elif level == "row": + left = self.row_left + horizontal = self.row_horizontal + cross = self.row_cross + right = self.row_right + elif level == "mid": + left = self.mid_left + horizontal = " " + cross = self.mid_vertical + right = self.mid_right + elif level == "foot": + left = self.foot_row_left + horizontal = self.foot_row_horizontal + cross = self.foot_row_cross + right = self.foot_row_right + else: + raise ValueError("level must be 'head', 'row' or 'foot'") + + parts: List[str] = [] + append = parts.append + if edge: + append(left) + for last, width in loop_last(widths): + append(horizontal * width) + if not last: + append(cross) + if edge: + append(right) + return "".join(parts) + + def get_bottom(self, widths: Iterable[int]) -> str: + """Get the bottom of a simple box. + + Args: + widths (List[int]): Widths of columns. + + Returns: + str: A string of box characters. + """ + + parts: List[str] = [] + append = parts.append + append(self.bottom_left) + for last, width in loop_last(widths): + append(self.bottom * width) + if not last: + append(self.bottom_divider) + append(self.bottom_right) + return "".join(parts) + + +# fmt: off +ASCII: Box = Box( + "+--+\n" + "| ||\n" + "|-+|\n" + "| ||\n" + "|-+|\n" + "|-+|\n" + "| ||\n" + "+--+\n", + ascii=True, +) + +ASCII2: Box = Box( + "+-++\n" + "| ||\n" + "+-++\n" + "| ||\n" + "+-++\n" + "+-++\n" + "| ||\n" + "+-++\n", + ascii=True, +) + +ASCII_DOUBLE_HEAD: Box = Box( + "+-++\n" + "| ||\n" + "+=++\n" + "| ||\n" + "+-++\n" + "+-++\n" + "| ||\n" + "+-++\n", + ascii=True, +) + +SQUARE: Box = Box( + "┌─┬┐\n" + "│ ││\n" + "├─┼┤\n" + "│ ││\n" + "├─┼┤\n" + "├─┼┤\n" + "│ ││\n" + "└─┴┘\n" +) + +SQUARE_DOUBLE_HEAD: Box = Box( + "┌─┬┐\n" + "│ ││\n" + "╞═╪╡\n" + "│ ││\n" + "├─┼┤\n" + "├─┼┤\n" + "│ ││\n" + "└─┴┘\n" +) + +MINIMAL: Box = Box( + " ╷ \n" + " │ \n" + "╶─┼╴\n" + " │ \n" + "╶─┼╴\n" + "╶─┼╴\n" + " │ \n" + " ╵ \n" +) + + +MINIMAL_HEAVY_HEAD: Box = Box( + " ╷ \n" + " │ \n" + "╺━┿╸\n" + " │ \n" + "╶─┼╴\n" + "╶─┼╴\n" + " │ \n" + " ╵ \n" +) + +MINIMAL_DOUBLE_HEAD: Box = Box( + " ╷ \n" + " │ \n" + " ═╪ \n" + " │ \n" + " ─┼ \n" + " ─┼ \n" + " │ \n" + " ╵ \n" +) + + +SIMPLE: Box = Box( + " \n" + " \n" + " ── \n" + " \n" + " \n" + " ── \n" + " \n" + " \n" +) + +SIMPLE_HEAD: Box = Box( + " \n" + " \n" + " ── \n" + " \n" + " \n" + " \n" + " \n" + " \n" +) + + +SIMPLE_HEAVY: Box = Box( + " \n" + " \n" + " ━━ \n" + " \n" + " \n" + " ━━ \n" + " \n" + " \n" +) + + +HORIZONTALS: Box = Box( + " ── \n" + " \n" + " ── \n" + " \n" + " ── \n" + " ── \n" + " \n" + " ── \n" +) + +ROUNDED: Box = Box( + "╭─┬╮\n" + "│ ││\n" + "├─┼┤\n" + "│ ││\n" + "├─┼┤\n" + "├─┼┤\n" + "│ ││\n" + "╰─┴╯\n" +) + +HEAVY: Box = Box( + "┏━┳┓\n" + "┃ ┃┃\n" + "┣━╋┫\n" + "┃ ┃┃\n" + "┣━╋┫\n" + "┣━╋┫\n" + "┃ ┃┃\n" + "┗━┻┛\n" +) + +HEAVY_EDGE: Box = Box( + "┏━┯┓\n" + "┃ │┃\n" + "┠─┼┨\n" + "┃ │┃\n" + "┠─┼┨\n" + "┠─┼┨\n" + "┃ │┃\n" + "┗━┷┛\n" +) + +HEAVY_HEAD: Box = Box( + "┏━┳┓\n" + "┃ ┃┃\n" + "┡━╇┩\n" + "│ ││\n" + "├─┼┤\n" + "├─┼┤\n" + "│ ││\n" + "└─┴┘\n" +) + +DOUBLE: Box = Box( + "╔═╦╗\n" + "║ ║║\n" + "╠═╬╣\n" + "║ ║║\n" + "╠═╬╣\n" + "╠═╬╣\n" + "║ ║║\n" + "╚═╩╝\n" +) + +DOUBLE_EDGE: Box = Box( + "╔═╤╗\n" + "║ │║\n" + "╟─┼╢\n" + "║ │║\n" + "╟─┼╢\n" + "╟─┼╢\n" + "║ │║\n" + "╚═╧╝\n" +) + +MARKDOWN: Box = Box( + " \n" + "| ||\n" + "|-||\n" + "| ||\n" + "|-||\n" + "|-||\n" + "| ||\n" + " \n", + ascii=True, +) +# fmt: on + +# Map Boxes that don't render with raster fonts on to equivalent that do +LEGACY_WINDOWS_SUBSTITUTIONS = { + ROUNDED: SQUARE, + MINIMAL_HEAVY_HEAD: MINIMAL, + SIMPLE_HEAVY: SIMPLE, + HEAVY: SQUARE, + HEAVY_EDGE: SQUARE, + HEAVY_HEAD: SQUARE, +} + +# Map headed boxes to their headerless equivalents +PLAIN_HEADED_SUBSTITUTIONS = { + HEAVY_HEAD: SQUARE, + SQUARE_DOUBLE_HEAD: SQUARE, + MINIMAL_DOUBLE_HEAD: MINIMAL, + MINIMAL_HEAVY_HEAD: MINIMAL, + ASCII_DOUBLE_HEAD: ASCII2, +} + + +if __name__ == "__main__": # pragma: no cover + from pip._vendor.rich.columns import Columns + from pip._vendor.rich.panel import Panel + + from . import box as box + from .console import Console + from .table import Table + from .text import Text + + console = Console(record=True) + + BOXES = [ + "ASCII", + "ASCII2", + "ASCII_DOUBLE_HEAD", + "SQUARE", + "SQUARE_DOUBLE_HEAD", + "MINIMAL", + "MINIMAL_HEAVY_HEAD", + "MINIMAL_DOUBLE_HEAD", + "SIMPLE", + "SIMPLE_HEAD", + "SIMPLE_HEAVY", + "HORIZONTALS", + "ROUNDED", + "HEAVY", + "HEAVY_EDGE", + "HEAVY_HEAD", + "DOUBLE", + "DOUBLE_EDGE", + "MARKDOWN", + ] + + console.print(Panel("[bold green]Box Constants", style="green"), justify="center") + console.print() + + columns = Columns(expand=True, padding=2) + for box_name in sorted(BOXES): + table = Table( + show_footer=True, style="dim", border_style="not dim", expand=True + ) + table.add_column("Header 1", "Footer 1") + table.add_column("Header 2", "Footer 2") + table.add_row("Cell", "Cell") + table.add_row("Cell", "Cell") + table.box = getattr(box, box_name) + table.title = Text(f"box.{box_name}", style="magenta") + columns.add_renderable(table) + console.print(columns) + + # console.save_svg("box.svg") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/jupyter.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/jupyter.py new file mode 100644 index 0000000000000000000000000000000000000000..22f4d716ac9764ee18005b9b852946d614152375 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/jupyter.py @@ -0,0 +1,101 @@ +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Sequence + +if TYPE_CHECKING: + from pip._vendor.rich.console import ConsoleRenderable + +from . import get_console +from .segment import Segment +from .terminal_theme import DEFAULT_TERMINAL_THEME + +if TYPE_CHECKING: + from pip._vendor.rich.console import ConsoleRenderable + +JUPYTER_HTML_FORMAT = """\ +
{code}
+""" + + +class JupyterRenderable: + """A shim to write html to Jupyter notebook.""" + + def __init__(self, html: str, text: str) -> None: + self.html = html + self.text = text + + def _repr_mimebundle_( + self, include: Sequence[str], exclude: Sequence[str], **kwargs: Any + ) -> Dict[str, str]: + data = {"text/plain": self.text, "text/html": self.html} + if include: + data = {k: v for (k, v) in data.items() if k in include} + if exclude: + data = {k: v for (k, v) in data.items() if k not in exclude} + return data + + +class JupyterMixin: + """Add to an Rich renderable to make it render in Jupyter notebook.""" + + __slots__ = () + + def _repr_mimebundle_( + self: "ConsoleRenderable", + include: Sequence[str], + exclude: Sequence[str], + **kwargs: Any, + ) -> Dict[str, str]: + console = get_console() + segments = list(console.render(self, console.options)) + html = _render_segments(segments) + text = console._render_buffer(segments) + data = {"text/plain": text, "text/html": html} + if include: + data = {k: v for (k, v) in data.items() if k in include} + if exclude: + data = {k: v for (k, v) in data.items() if k not in exclude} + return data + + +def _render_segments(segments: Iterable[Segment]) -> str: + def escape(text: str) -> str: + """Escape html.""" + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + fragments: List[str] = [] + append_fragment = fragments.append + theme = DEFAULT_TERMINAL_THEME + for text, style, control in Segment.simplify(segments): + if control: + continue + text = escape(text) + if style: + rule = style.get_html_style(theme) + text = f'{text}' if rule else text + if style.link: + text = f'
{text}' + append_fragment(text) + + code = "".join(fragments) + html = JUPYTER_HTML_FORMAT.format(code=code) + + return html + + +def display(segments: Iterable[Segment], text: str) -> None: + """Render segments to Jupyter.""" + html = _render_segments(segments) + jupyter_renderable = JupyterRenderable(html, text) + try: + from IPython.display import display as ipython_display + + ipython_display(jupyter_renderable) + except ModuleNotFoundError: + # Handle the case where the Console has force_jupyter=True, + # but IPython is not installed. + pass + + +def print(*args: Any, **kwargs: Any) -> None: + """Proxy for Console print.""" + console = get_console() + return console.print(*args, **kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/live.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/live.py new file mode 100644 index 0000000000000000000000000000000000000000..f0529a781cf8c247eadb994fd251c0d00433dcac --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/live.py @@ -0,0 +1,375 @@ +import sys +from threading import Event, RLock, Thread +from types import TracebackType +from typing import IO, Any, Callable, List, Optional, TextIO, Type, cast + +from . import get_console +from .console import Console, ConsoleRenderable, RenderableType, RenderHook +from .control import Control +from .file_proxy import FileProxy +from .jupyter import JupyterMixin +from .live_render import LiveRender, VerticalOverflowMethod +from .screen import Screen +from .text import Text + + +class _RefreshThread(Thread): + """A thread that calls refresh() at regular intervals.""" + + def __init__(self, live: "Live", refresh_per_second: float) -> None: + self.live = live + self.refresh_per_second = refresh_per_second + self.done = Event() + super().__init__(daemon=True) + + def stop(self) -> None: + self.done.set() + + def run(self) -> None: + while not self.done.wait(1 / self.refresh_per_second): + with self.live._lock: + if not self.done.is_set(): + self.live.refresh() + + +class Live(JupyterMixin, RenderHook): + """Renders an auto-updating live display of any given renderable. + + Args: + renderable (RenderableType, optional): The renderable to live display. Defaults to displaying nothing. + console (Console, optional): Optional Console instance. Default will an internal Console instance writing to stdout. + screen (bool, optional): Enable alternate screen mode. Defaults to False. + auto_refresh (bool, optional): Enable auto refresh. If disabled, you will need to call `refresh()` or `update()` with refresh flag. Defaults to True + refresh_per_second (float, optional): Number of times per second to refresh the live display. Defaults to 4. + transient (bool, optional): Clear the renderable on exit (has no effect when screen=True). Defaults to False. + redirect_stdout (bool, optional): Enable redirection of stdout, so ``print`` may be used. Defaults to True. + redirect_stderr (bool, optional): Enable redirection of stderr. Defaults to True. + vertical_overflow (VerticalOverflowMethod, optional): How to handle renderable when it is too tall for the console. Defaults to "ellipsis". + get_renderable (Callable[[], RenderableType], optional): Optional callable to get renderable. Defaults to None. + """ + + def __init__( + self, + renderable: Optional[RenderableType] = None, + *, + console: Optional[Console] = None, + screen: bool = False, + auto_refresh: bool = True, + refresh_per_second: float = 4, + transient: bool = False, + redirect_stdout: bool = True, + redirect_stderr: bool = True, + vertical_overflow: VerticalOverflowMethod = "ellipsis", + get_renderable: Optional[Callable[[], RenderableType]] = None, + ) -> None: + assert refresh_per_second > 0, "refresh_per_second must be > 0" + self._renderable = renderable + self.console = console if console is not None else get_console() + self._screen = screen + self._alt_screen = False + + self._redirect_stdout = redirect_stdout + self._redirect_stderr = redirect_stderr + self._restore_stdout: Optional[IO[str]] = None + self._restore_stderr: Optional[IO[str]] = None + + self._lock = RLock() + self.ipy_widget: Optional[Any] = None + self.auto_refresh = auto_refresh + self._started: bool = False + self.transient = True if screen else transient + + self._refresh_thread: Optional[_RefreshThread] = None + self.refresh_per_second = refresh_per_second + + self.vertical_overflow = vertical_overflow + self._get_renderable = get_renderable + self._live_render = LiveRender( + self.get_renderable(), vertical_overflow=vertical_overflow + ) + + @property + def is_started(self) -> bool: + """Check if live display has been started.""" + return self._started + + def get_renderable(self) -> RenderableType: + renderable = ( + self._get_renderable() + if self._get_renderable is not None + else self._renderable + ) + return renderable or "" + + def start(self, refresh: bool = False) -> None: + """Start live rendering display. + + Args: + refresh (bool, optional): Also refresh. Defaults to False. + """ + with self._lock: + if self._started: + return + self.console.set_live(self) + self._started = True + if self._screen: + self._alt_screen = self.console.set_alt_screen(True) + self.console.show_cursor(False) + self._enable_redirect_io() + self.console.push_render_hook(self) + if refresh: + try: + self.refresh() + except Exception: + # If refresh fails, we want to stop the redirection of sys.stderr, + # so the error stacktrace is properly displayed in the terminal. + # (or, if the code that calls Rich captures the exception and wants to display something, + # let this be displayed in the terminal). + self.stop() + raise + if self.auto_refresh: + self._refresh_thread = _RefreshThread(self, self.refresh_per_second) + self._refresh_thread.start() + + def stop(self) -> None: + """Stop live rendering display.""" + with self._lock: + if not self._started: + return + self.console.clear_live() + self._started = False + + if self.auto_refresh and self._refresh_thread is not None: + self._refresh_thread.stop() + self._refresh_thread = None + # allow it to fully render on the last even if overflow + self.vertical_overflow = "visible" + with self.console: + try: + if not self._alt_screen and not self.console.is_jupyter: + self.refresh() + finally: + self._disable_redirect_io() + self.console.pop_render_hook() + if not self._alt_screen and self.console.is_terminal: + self.console.line() + self.console.show_cursor(True) + if self._alt_screen: + self.console.set_alt_screen(False) + + if self.transient and not self._alt_screen: + self.console.control(self._live_render.restore_cursor()) + if self.ipy_widget is not None and self.transient: + self.ipy_widget.close() # pragma: no cover + + def __enter__(self) -> "Live": + self.start(refresh=self._renderable is not None) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.stop() + + def _enable_redirect_io(self) -> None: + """Enable redirecting of stdout / stderr.""" + if self.console.is_terminal or self.console.is_jupyter: + if self._redirect_stdout and not isinstance(sys.stdout, FileProxy): + self._restore_stdout = sys.stdout + sys.stdout = cast("TextIO", FileProxy(self.console, sys.stdout)) + if self._redirect_stderr and not isinstance(sys.stderr, FileProxy): + self._restore_stderr = sys.stderr + sys.stderr = cast("TextIO", FileProxy(self.console, sys.stderr)) + + def _disable_redirect_io(self) -> None: + """Disable redirecting of stdout / stderr.""" + if self._restore_stdout: + sys.stdout = cast("TextIO", self._restore_stdout) + self._restore_stdout = None + if self._restore_stderr: + sys.stderr = cast("TextIO", self._restore_stderr) + self._restore_stderr = None + + @property + def renderable(self) -> RenderableType: + """Get the renderable that is being displayed + + Returns: + RenderableType: Displayed renderable. + """ + renderable = self.get_renderable() + return Screen(renderable) if self._alt_screen else renderable + + def update(self, renderable: RenderableType, *, refresh: bool = False) -> None: + """Update the renderable that is being displayed + + Args: + renderable (RenderableType): New renderable to use. + refresh (bool, optional): Refresh the display. Defaults to False. + """ + if isinstance(renderable, str): + renderable = self.console.render_str(renderable) + with self._lock: + self._renderable = renderable + if refresh: + self.refresh() + + def refresh(self) -> None: + """Update the display of the Live Render.""" + with self._lock: + self._live_render.set_renderable(self.renderable) + if self.console.is_jupyter: # pragma: no cover + try: + from IPython.display import display + from ipywidgets import Output + except ImportError: + import warnings + + warnings.warn('install "ipywidgets" for Jupyter support') + else: + if self.ipy_widget is None: + self.ipy_widget = Output() + display(self.ipy_widget) + + with self.ipy_widget: + self.ipy_widget.clear_output(wait=True) + self.console.print(self._live_render.renderable) + elif self.console.is_terminal and not self.console.is_dumb_terminal: + with self.console: + self.console.print(Control()) + elif ( + not self._started and not self.transient + ): # if it is finished allow files or dumb-terminals to see final result + with self.console: + self.console.print(Control()) + + def process_renderables( + self, renderables: List[ConsoleRenderable] + ) -> List[ConsoleRenderable]: + """Process renderables to restore cursor and display progress.""" + self._live_render.vertical_overflow = self.vertical_overflow + if self.console.is_interactive: + # lock needs acquiring as user can modify live_render renderable at any time unlike in Progress. + with self._lock: + reset = ( + Control.home() + if self._alt_screen + else self._live_render.position_cursor() + ) + renderables = [reset, *renderables, self._live_render] + elif ( + not self._started and not self.transient + ): # if it is finished render the final output for files or dumb_terminals + renderables = [*renderables, self._live_render] + + return renderables + + +if __name__ == "__main__": # pragma: no cover + import random + import time + from itertools import cycle + from typing import Dict, List, Tuple + + from .align import Align + from .console import Console + from .live import Live as Live + from .panel import Panel + from .rule import Rule + from .syntax import Syntax + from .table import Table + + console = Console() + + syntax = Syntax( + '''def loop_last(values: Iterable[T]) -> Iterable[Tuple[bool, T]]: + """Iterate and generate a tuple with a flag for last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + for value in iter_values: + yield False, previous_value + previous_value = value + yield True, previous_value''', + "python", + line_numbers=True, + ) + + table = Table("foo", "bar", "baz") + table.add_row("1", "2", "3") + + progress_renderables = [ + "You can make the terminal shorter and taller to see the live table hide" + "Text may be printed while the progress bars are rendering.", + Panel("In fact, [i]any[/i] renderable will work"), + "Such as [magenta]tables[/]...", + table, + "Pretty printed structures...", + {"type": "example", "text": "Pretty printed"}, + "Syntax...", + syntax, + Rule("Give it a try!"), + ] + + examples = cycle(progress_renderables) + + exchanges = [ + "SGD", + "MYR", + "EUR", + "USD", + "AUD", + "JPY", + "CNH", + "HKD", + "CAD", + "INR", + "DKK", + "GBP", + "RUB", + "NZD", + "MXN", + "IDR", + "TWD", + "THB", + "VND", + ] + with Live(console=console) as live_table: + exchange_rate_dict: Dict[Tuple[str, str], float] = {} + + for index in range(100): + select_exchange = exchanges[index % len(exchanges)] + + for exchange in exchanges: + if exchange == select_exchange: + continue + time.sleep(0.4) + if random.randint(0, 10) < 1: + console.log(next(examples)) + exchange_rate_dict[(select_exchange, exchange)] = 200 / ( + (random.random() * 320) + 1 + ) + if len(exchange_rate_dict) > len(exchanges) - 1: + exchange_rate_dict.pop(list(exchange_rate_dict.keys())[0]) + table = Table(title="Exchange Rates") + + table.add_column("Source Currency") + table.add_column("Destination Currency") + table.add_column("Exchange Rate") + + for (source, dest), exchange_rate in exchange_rate_dict.items(): + table.add_row( + source, + dest, + Text( + f"{exchange_rate:.4f}", + style="red" if exchange_rate < 1.0 else "green", + ), + ) + + live_table.update(Align.center(table)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/logging.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..91368dda78aad590837aa12023dee67e224709ba --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/logging.py @@ -0,0 +1,289 @@ +import logging +from datetime import datetime +from logging import Handler, LogRecord +from pathlib import Path +from types import ModuleType +from typing import ClassVar, Iterable, List, Optional, Type, Union + +from pip._vendor.rich._null_file import NullFile + +from . import get_console +from ._log_render import FormatTimeCallable, LogRender +from .console import Console, ConsoleRenderable +from .highlighter import Highlighter, ReprHighlighter +from .text import Text +from .traceback import Traceback + + +class RichHandler(Handler): + """A logging handler that renders output with Rich. The time / level / message and file are displayed in columns. + The level is color coded, and the message is syntax highlighted. + + Note: + Be careful when enabling console markup in log messages if you have configured logging for libraries not + under your control. If a dependency writes messages containing square brackets, it may not produce the intended output. + + Args: + level (Union[int, str], optional): Log level. Defaults to logging.NOTSET. + console (:class:`~rich.console.Console`, optional): Optional console instance to write logs. + Default will use a global console instance writing to stdout. + show_time (bool, optional): Show a column for the time. Defaults to True. + omit_repeated_times (bool, optional): Omit repetition of the same time. Defaults to True. + show_level (bool, optional): Show a column for the level. Defaults to True. + show_path (bool, optional): Show the path to the original log call. Defaults to True. + enable_link_path (bool, optional): Enable terminal link of path column to file. Defaults to True. + highlighter (Highlighter, optional): Highlighter to style log messages, or None to use ReprHighlighter. Defaults to None. + markup (bool, optional): Enable console markup in log messages. Defaults to False. + rich_tracebacks (bool, optional): Enable rich tracebacks with syntax highlighting and formatting. Defaults to False. + tracebacks_width (Optional[int], optional): Number of characters used to render tracebacks, or None for full width. Defaults to None. + tracebacks_extra_lines (int, optional): Additional lines of code to render tracebacks, or None for full width. Defaults to None. + tracebacks_theme (str, optional): Override pygments theme used in traceback. + tracebacks_word_wrap (bool, optional): Enable word wrapping of long tracebacks lines. Defaults to True. + tracebacks_show_locals (bool, optional): Enable display of locals in tracebacks. Defaults to False. + tracebacks_suppress (Sequence[Union[str, ModuleType]]): Optional sequence of modules or paths to exclude from traceback. + locals_max_length (int, optional): Maximum length of containers before abbreviating, or None for no abbreviation. + Defaults to 10. + locals_max_string (int, optional): Maximum length of string before truncating, or None to disable. Defaults to 80. + log_time_format (Union[str, TimeFormatterCallable], optional): If ``log_time`` is enabled, either string for strftime or callable that formats the time. Defaults to "[%x %X] ". + keywords (List[str], optional): List of words to highlight instead of ``RichHandler.KEYWORDS``. + """ + + KEYWORDS: ClassVar[Optional[List[str]]] = [ + "GET", + "POST", + "HEAD", + "PUT", + "DELETE", + "OPTIONS", + "TRACE", + "PATCH", + ] + HIGHLIGHTER_CLASS: ClassVar[Type[Highlighter]] = ReprHighlighter + + def __init__( + self, + level: Union[int, str] = logging.NOTSET, + console: Optional[Console] = None, + *, + show_time: bool = True, + omit_repeated_times: bool = True, + show_level: bool = True, + show_path: bool = True, + enable_link_path: bool = True, + highlighter: Optional[Highlighter] = None, + markup: bool = False, + rich_tracebacks: bool = False, + tracebacks_width: Optional[int] = None, + tracebacks_extra_lines: int = 3, + tracebacks_theme: Optional[str] = None, + tracebacks_word_wrap: bool = True, + tracebacks_show_locals: bool = False, + tracebacks_suppress: Iterable[Union[str, ModuleType]] = (), + locals_max_length: int = 10, + locals_max_string: int = 80, + log_time_format: Union[str, FormatTimeCallable] = "[%x %X]", + keywords: Optional[List[str]] = None, + ) -> None: + super().__init__(level=level) + self.console = console or get_console() + self.highlighter = highlighter or self.HIGHLIGHTER_CLASS() + self._log_render = LogRender( + show_time=show_time, + show_level=show_level, + show_path=show_path, + time_format=log_time_format, + omit_repeated_times=omit_repeated_times, + level_width=None, + ) + self.enable_link_path = enable_link_path + self.markup = markup + self.rich_tracebacks = rich_tracebacks + self.tracebacks_width = tracebacks_width + self.tracebacks_extra_lines = tracebacks_extra_lines + self.tracebacks_theme = tracebacks_theme + self.tracebacks_word_wrap = tracebacks_word_wrap + self.tracebacks_show_locals = tracebacks_show_locals + self.tracebacks_suppress = tracebacks_suppress + self.locals_max_length = locals_max_length + self.locals_max_string = locals_max_string + self.keywords = keywords + + def get_level_text(self, record: LogRecord) -> Text: + """Get the level name from the record. + + Args: + record (LogRecord): LogRecord instance. + + Returns: + Text: A tuple of the style and level name. + """ + level_name = record.levelname + level_text = Text.styled( + level_name.ljust(8), f"logging.level.{level_name.lower()}" + ) + return level_text + + def emit(self, record: LogRecord) -> None: + """Invoked by logging.""" + message = self.format(record) + traceback = None + if ( + self.rich_tracebacks + and record.exc_info + and record.exc_info != (None, None, None) + ): + exc_type, exc_value, exc_traceback = record.exc_info + assert exc_type is not None + assert exc_value is not None + traceback = Traceback.from_exception( + exc_type, + exc_value, + exc_traceback, + width=self.tracebacks_width, + extra_lines=self.tracebacks_extra_lines, + theme=self.tracebacks_theme, + word_wrap=self.tracebacks_word_wrap, + show_locals=self.tracebacks_show_locals, + locals_max_length=self.locals_max_length, + locals_max_string=self.locals_max_string, + suppress=self.tracebacks_suppress, + ) + message = record.getMessage() + if self.formatter: + record.message = record.getMessage() + formatter = self.formatter + if hasattr(formatter, "usesTime") and formatter.usesTime(): + record.asctime = formatter.formatTime(record, formatter.datefmt) + message = formatter.formatMessage(record) + + message_renderable = self.render_message(record, message) + log_renderable = self.render( + record=record, traceback=traceback, message_renderable=message_renderable + ) + if isinstance(self.console.file, NullFile): + # Handles pythonw, where stdout/stderr are null, and we return NullFile + # instance from Console.file. In this case, we still want to make a log record + # even though we won't be writing anything to a file. + self.handleError(record) + else: + try: + self.console.print(log_renderable) + except Exception: + self.handleError(record) + + def render_message(self, record: LogRecord, message: str) -> "ConsoleRenderable": + """Render message text in to Text. + + Args: + record (LogRecord): logging Record. + message (str): String containing log message. + + Returns: + ConsoleRenderable: Renderable to display log message. + """ + use_markup = getattr(record, "markup", self.markup) + message_text = Text.from_markup(message) if use_markup else Text(message) + + highlighter = getattr(record, "highlighter", self.highlighter) + if highlighter: + message_text = highlighter(message_text) + + if self.keywords is None: + self.keywords = self.KEYWORDS + + if self.keywords: + message_text.highlight_words(self.keywords, "logging.keyword") + + return message_text + + def render( + self, + *, + record: LogRecord, + traceback: Optional[Traceback], + message_renderable: "ConsoleRenderable", + ) -> "ConsoleRenderable": + """Render log for display. + + Args: + record (LogRecord): logging Record. + traceback (Optional[Traceback]): Traceback instance or None for no Traceback. + message_renderable (ConsoleRenderable): Renderable (typically Text) containing log message contents. + + Returns: + ConsoleRenderable: Renderable to display log. + """ + path = Path(record.pathname).name + level = self.get_level_text(record) + time_format = None if self.formatter is None else self.formatter.datefmt + log_time = datetime.fromtimestamp(record.created) + + log_renderable = self._log_render( + self.console, + [message_renderable] if not traceback else [message_renderable, traceback], + log_time=log_time, + time_format=time_format, + level=level, + path=path, + line_no=record.lineno, + link_path=record.pathname if self.enable_link_path else None, + ) + return log_renderable + + +if __name__ == "__main__": # pragma: no cover + from time import sleep + + FORMAT = "%(message)s" + # FORMAT = "%(asctime)-15s - %(levelname)s - %(message)s" + logging.basicConfig( + level="NOTSET", + format=FORMAT, + datefmt="[%X]", + handlers=[RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)], + ) + log = logging.getLogger("rich") + + log.info("Server starting...") + log.info("Listening on http://127.0.0.1:8080") + sleep(1) + + log.info("GET /index.html 200 1298") + log.info("GET /imgs/backgrounds/back1.jpg 200 54386") + log.info("GET /css/styles.css 200 54386") + log.warning("GET /favicon.ico 404 242") + sleep(1) + + log.debug( + "JSONRPC request\n--> %r\n<-- %r", + { + "version": "1.1", + "method": "confirmFruitPurchase", + "params": [["apple", "orange", "mangoes", "pomelo"], 1.123], + "id": "194521489", + }, + {"version": "1.1", "result": True, "error": None, "id": "194521489"}, + ) + log.debug( + "Loading configuration file /adasd/asdasd/qeqwe/qwrqwrqwr/sdgsdgsdg/werwerwer/dfgerert/ertertert/ertetert/werwerwer" + ) + log.error("Unable to find 'pomelo' in database!") + log.info("POST /jsonrpc/ 200 65532") + log.info("POST /admin/ 401 42234") + log.warning("password was rejected for admin site.") + + def divide() -> None: + number = 1 + divisor = 0 + foos = ["foo"] * 100 + log.debug("in divide") + try: + number / divisor + except: + log.exception("An error of some kind occurred!") + + divide() + sleep(1) + log.critical("Out of memory!") + log.info("Server exited with code=-1") + log.info("[bold]EXITING...[/bold]", extra=dict(markup=True)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/py.typed b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/segment.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..93edbbdeb72a49c6ebd2dad3c0e03210c2b23bc7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/segment.py @@ -0,0 +1,738 @@ +from enum import IntEnum +from functools import lru_cache +from itertools import filterfalse +from logging import getLogger +from operator import attrgetter +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from .cells import ( + _is_single_cell_widths, + cached_cell_len, + cell_len, + get_character_cell_size, + set_cell_size, +) +from .repr import Result, rich_repr +from .style import Style + +if TYPE_CHECKING: + from .console import Console, ConsoleOptions, RenderResult + +log = getLogger("rich") + + +class ControlType(IntEnum): + """Non-printable control codes which typically translate to ANSI codes.""" + + BELL = 1 + CARRIAGE_RETURN = 2 + HOME = 3 + CLEAR = 4 + SHOW_CURSOR = 5 + HIDE_CURSOR = 6 + ENABLE_ALT_SCREEN = 7 + DISABLE_ALT_SCREEN = 8 + CURSOR_UP = 9 + CURSOR_DOWN = 10 + CURSOR_FORWARD = 11 + CURSOR_BACKWARD = 12 + CURSOR_MOVE_TO_COLUMN = 13 + CURSOR_MOVE_TO = 14 + ERASE_IN_LINE = 15 + SET_WINDOW_TITLE = 16 + + +ControlCode = Union[ + Tuple[ControlType], + Tuple[ControlType, Union[int, str]], + Tuple[ControlType, int, int], +] + + +@rich_repr() +class Segment(NamedTuple): + """A piece of text with associated style. Segments are produced by the Console render process and + are ultimately converted in to strings to be written to the terminal. + + Args: + text (str): A piece of text. + style (:class:`~rich.style.Style`, optional): An optional style to apply to the text. + control (Tuple[ControlCode], optional): Optional sequence of control codes. + + Attributes: + cell_length (int): The cell length of this Segment. + """ + + text: str + style: Optional[Style] = None + control: Optional[Sequence[ControlCode]] = None + + @property + def cell_length(self) -> int: + """The number of terminal cells required to display self.text. + + Returns: + int: A number of cells. + """ + text, _style, control = self + return 0 if control else cell_len(text) + + def __rich_repr__(self) -> Result: + yield self.text + if self.control is None: + if self.style is not None: + yield self.style + else: + yield self.style + yield self.control + + def __bool__(self) -> bool: + """Check if the segment contains text.""" + return bool(self.text) + + @property + def is_control(self) -> bool: + """Check if the segment contains control codes.""" + return self.control is not None + + @classmethod + @lru_cache(1024 * 16) + def _split_cells(cls, segment: "Segment", cut: int) -> Tuple["Segment", "Segment"]: + text, style, control = segment + _Segment = Segment + + cell_length = segment.cell_length + if cut >= cell_length: + return segment, _Segment("", style, control) + + cell_size = get_character_cell_size + + pos = int((cut / cell_length) * (len(text) - 1)) + + before = text[:pos] + cell_pos = cell_len(before) + if cell_pos == cut: + return ( + _Segment(before, style, control), + _Segment(text[pos:], style, control), + ) + while pos < len(text): + char = text[pos] + pos += 1 + cell_pos += cell_size(char) + before = text[:pos] + if cell_pos == cut: + return ( + _Segment(before, style, control), + _Segment(text[pos:], style, control), + ) + if cell_pos > cut: + return ( + _Segment(before[: pos - 1] + " ", style, control), + _Segment(" " + text[pos:], style, control), + ) + + raise AssertionError("Will never reach here") + + def split_cells(self, cut: int) -> Tuple["Segment", "Segment"]: + """Split segment in to two segments at the specified column. + + If the cut point falls in the middle of a 2-cell wide character then it is replaced + by two spaces, to preserve the display width of the parent segment. + + Returns: + Tuple[Segment, Segment]: Two segments. + """ + text, style, control = self + + if _is_single_cell_widths(text): + # Fast path with all 1 cell characters + if cut >= len(text): + return self, Segment("", style, control) + return ( + Segment(text[:cut], style, control), + Segment(text[cut:], style, control), + ) + + return self._split_cells(self, cut) + + @classmethod + def line(cls) -> "Segment": + """Make a new line segment.""" + return cls("\n") + + @classmethod + def apply_style( + cls, + segments: Iterable["Segment"], + style: Optional[Style] = None, + post_style: Optional[Style] = None, + ) -> Iterable["Segment"]: + """Apply style(s) to an iterable of segments. + + Returns an iterable of segments where the style is replaced by ``style + segment.style + post_style``. + + Args: + segments (Iterable[Segment]): Segments to process. + style (Style, optional): Base style. Defaults to None. + post_style (Style, optional): Style to apply on top of segment style. Defaults to None. + + Returns: + Iterable[Segments]: A new iterable of segments (possibly the same iterable). + """ + result_segments = segments + if style: + apply = style.__add__ + result_segments = ( + cls(text, None if control else apply(_style), control) + for text, _style, control in result_segments + ) + if post_style: + result_segments = ( + cls( + text, + ( + None + if control + else (_style + post_style if _style else post_style) + ), + control, + ) + for text, _style, control in result_segments + ) + return result_segments + + @classmethod + def filter_control( + cls, segments: Iterable["Segment"], is_control: bool = False + ) -> Iterable["Segment"]: + """Filter segments by ``is_control`` attribute. + + Args: + segments (Iterable[Segment]): An iterable of Segment instances. + is_control (bool, optional): is_control flag to match in search. + + Returns: + Iterable[Segment]: And iterable of Segment instances. + + """ + if is_control: + return filter(attrgetter("control"), segments) + else: + return filterfalse(attrgetter("control"), segments) + + @classmethod + def split_lines(cls, segments: Iterable["Segment"]) -> Iterable[List["Segment"]]: + """Split a sequence of segments in to a list of lines. + + Args: + segments (Iterable[Segment]): Segments potentially containing line feeds. + + Yields: + Iterable[List[Segment]]: Iterable of segment lists, one per line. + """ + line: List[Segment] = [] + append = line.append + + for segment in segments: + if "\n" in segment.text and not segment.control: + text, style, _ = segment + while text: + _text, new_line, text = text.partition("\n") + if _text: + append(cls(_text, style)) + if new_line: + yield line + line = [] + append = line.append + else: + append(segment) + if line: + yield line + + @classmethod + def split_and_crop_lines( + cls, + segments: Iterable["Segment"], + length: int, + style: Optional[Style] = None, + pad: bool = True, + include_new_lines: bool = True, + ) -> Iterable[List["Segment"]]: + """Split segments in to lines, and crop lines greater than a given length. + + Args: + segments (Iterable[Segment]): An iterable of segments, probably + generated from console.render. + length (int): Desired line length. + style (Style, optional): Style to use for any padding. + pad (bool): Enable padding of lines that are less than `length`. + + Returns: + Iterable[List[Segment]]: An iterable of lines of segments. + """ + line: List[Segment] = [] + append = line.append + + adjust_line_length = cls.adjust_line_length + new_line_segment = cls("\n") + + for segment in segments: + if "\n" in segment.text and not segment.control: + text, segment_style, _ = segment + while text: + _text, new_line, text = text.partition("\n") + if _text: + append(cls(_text, segment_style)) + if new_line: + cropped_line = adjust_line_length( + line, length, style=style, pad=pad + ) + if include_new_lines: + cropped_line.append(new_line_segment) + yield cropped_line + line.clear() + else: + append(segment) + if line: + yield adjust_line_length(line, length, style=style, pad=pad) + + @classmethod + def adjust_line_length( + cls, + line: List["Segment"], + length: int, + style: Optional[Style] = None, + pad: bool = True, + ) -> List["Segment"]: + """Adjust a line to a given width (cropping or padding as required). + + Args: + segments (Iterable[Segment]): A list of segments in a single line. + length (int): The desired width of the line. + style (Style, optional): The style of padding if used (space on the end). Defaults to None. + pad (bool, optional): Pad lines with spaces if they are shorter than `length`. Defaults to True. + + Returns: + List[Segment]: A line of segments with the desired length. + """ + line_length = sum(segment.cell_length for segment in line) + new_line: List[Segment] + + if line_length < length: + if pad: + new_line = line + [cls(" " * (length - line_length), style)] + else: + new_line = line[:] + elif line_length > length: + new_line = [] + append = new_line.append + line_length = 0 + for segment in line: + segment_length = segment.cell_length + if line_length + segment_length < length or segment.control: + append(segment) + line_length += segment_length + else: + text, segment_style, _ = segment + text = set_cell_size(text, length - line_length) + append(cls(text, segment_style)) + break + else: + new_line = line[:] + return new_line + + @classmethod + def get_line_length(cls, line: List["Segment"]) -> int: + """Get the length of list of segments. + + Args: + line (List[Segment]): A line encoded as a list of Segments (assumes no '\\\\n' characters), + + Returns: + int: The length of the line. + """ + _cell_len = cell_len + return sum(_cell_len(text) for text, style, control in line if not control) + + @classmethod + def get_shape(cls, lines: List[List["Segment"]]) -> Tuple[int, int]: + """Get the shape (enclosing rectangle) of a list of lines. + + Args: + lines (List[List[Segment]]): A list of lines (no '\\\\n' characters). + + Returns: + Tuple[int, int]: Width and height in characters. + """ + get_line_length = cls.get_line_length + max_width = max(get_line_length(line) for line in lines) if lines else 0 + return (max_width, len(lines)) + + @classmethod + def set_shape( + cls, + lines: List[List["Segment"]], + width: int, + height: Optional[int] = None, + style: Optional[Style] = None, + new_lines: bool = False, + ) -> List[List["Segment"]]: + """Set the shape of a list of lines (enclosing rectangle). + + Args: + lines (List[List[Segment]]): A list of lines. + width (int): Desired width. + height (int, optional): Desired height or None for no change. + style (Style, optional): Style of any padding added. + new_lines (bool, optional): Padded lines should include "\n". Defaults to False. + + Returns: + List[List[Segment]]: New list of lines. + """ + _height = height or len(lines) + + blank = ( + [cls(" " * width + "\n", style)] if new_lines else [cls(" " * width, style)] + ) + + adjust_line_length = cls.adjust_line_length + shaped_lines = lines[:_height] + shaped_lines[:] = [ + adjust_line_length(line, width, style=style) for line in lines + ] + if len(shaped_lines) < _height: + shaped_lines.extend([blank] * (_height - len(shaped_lines))) + return shaped_lines + + @classmethod + def align_top( + cls: Type["Segment"], + lines: List[List["Segment"]], + width: int, + height: int, + style: Style, + new_lines: bool = False, + ) -> List[List["Segment"]]: + """Aligns lines to top (adds extra lines to bottom as required). + + Args: + lines (List[List[Segment]]): A list of lines. + width (int): Desired width. + height (int, optional): Desired height or None for no change. + style (Style): Style of any padding added. + new_lines (bool, optional): Padded lines should include "\n". Defaults to False. + + Returns: + List[List[Segment]]: New list of lines. + """ + extra_lines = height - len(lines) + if not extra_lines: + return lines[:] + lines = lines[:height] + blank = cls(" " * width + "\n", style) if new_lines else cls(" " * width, style) + lines = lines + [[blank]] * extra_lines + return lines + + @classmethod + def align_bottom( + cls: Type["Segment"], + lines: List[List["Segment"]], + width: int, + height: int, + style: Style, + new_lines: bool = False, + ) -> List[List["Segment"]]: + """Aligns render to bottom (adds extra lines above as required). + + Args: + lines (List[List[Segment]]): A list of lines. + width (int): Desired width. + height (int, optional): Desired height or None for no change. + style (Style): Style of any padding added. Defaults to None. + new_lines (bool, optional): Padded lines should include "\n". Defaults to False. + + Returns: + List[List[Segment]]: New list of lines. + """ + extra_lines = height - len(lines) + if not extra_lines: + return lines[:] + lines = lines[:height] + blank = cls(" " * width + "\n", style) if new_lines else cls(" " * width, style) + lines = [[blank]] * extra_lines + lines + return lines + + @classmethod + def align_middle( + cls: Type["Segment"], + lines: List[List["Segment"]], + width: int, + height: int, + style: Style, + new_lines: bool = False, + ) -> List[List["Segment"]]: + """Aligns lines to middle (adds extra lines to above and below as required). + + Args: + lines (List[List[Segment]]): A list of lines. + width (int): Desired width. + height (int, optional): Desired height or None for no change. + style (Style): Style of any padding added. + new_lines (bool, optional): Padded lines should include "\n". Defaults to False. + + Returns: + List[List[Segment]]: New list of lines. + """ + extra_lines = height - len(lines) + if not extra_lines: + return lines[:] + lines = lines[:height] + blank = cls(" " * width + "\n", style) if new_lines else cls(" " * width, style) + top_lines = extra_lines // 2 + bottom_lines = extra_lines - top_lines + lines = [[blank]] * top_lines + lines + [[blank]] * bottom_lines + return lines + + @classmethod + def simplify(cls, segments: Iterable["Segment"]) -> Iterable["Segment"]: + """Simplify an iterable of segments by combining contiguous segments with the same style. + + Args: + segments (Iterable[Segment]): An iterable of segments. + + Returns: + Iterable[Segment]: A possibly smaller iterable of segments that will render the same way. + """ + iter_segments = iter(segments) + try: + last_segment = next(iter_segments) + except StopIteration: + return + + _Segment = Segment + for segment in iter_segments: + if last_segment.style == segment.style and not segment.control: + last_segment = _Segment( + last_segment.text + segment.text, last_segment.style + ) + else: + yield last_segment + last_segment = segment + yield last_segment + + @classmethod + def strip_links(cls, segments: Iterable["Segment"]) -> Iterable["Segment"]: + """Remove all links from an iterable of styles. + + Args: + segments (Iterable[Segment]): An iterable segments. + + Yields: + Segment: Segments with link removed. + """ + for segment in segments: + if segment.control or segment.style is None: + yield segment + else: + text, style, _control = segment + yield cls(text, style.update_link(None) if style else None) + + @classmethod + def strip_styles(cls, segments: Iterable["Segment"]) -> Iterable["Segment"]: + """Remove all styles from an iterable of segments. + + Args: + segments (Iterable[Segment]): An iterable segments. + + Yields: + Segment: Segments with styles replace with None + """ + for text, _style, control in segments: + yield cls(text, None, control) + + @classmethod + def remove_color(cls, segments: Iterable["Segment"]) -> Iterable["Segment"]: + """Remove all color from an iterable of segments. + + Args: + segments (Iterable[Segment]): An iterable segments. + + Yields: + Segment: Segments with colorless style. + """ + + cache: Dict[Style, Style] = {} + for text, style, control in segments: + if style: + colorless_style = cache.get(style) + if colorless_style is None: + colorless_style = style.without_color + cache[style] = colorless_style + yield cls(text, colorless_style, control) + else: + yield cls(text, None, control) + + @classmethod + def divide( + cls, segments: Iterable["Segment"], cuts: Iterable[int] + ) -> Iterable[List["Segment"]]: + """Divides an iterable of segments in to portions. + + Args: + cuts (Iterable[int]): Cell positions where to divide. + + Yields: + [Iterable[List[Segment]]]: An iterable of Segments in List. + """ + split_segments: List["Segment"] = [] + add_segment = split_segments.append + + iter_cuts = iter(cuts) + + while True: + cut = next(iter_cuts, -1) + if cut == -1: + return [] + if cut != 0: + break + yield [] + pos = 0 + + segments_clear = split_segments.clear + segments_copy = split_segments.copy + + _cell_len = cached_cell_len + for segment in segments: + text, _style, control = segment + while text: + end_pos = pos if control else pos + _cell_len(text) + if end_pos < cut: + add_segment(segment) + pos = end_pos + break + + if end_pos == cut: + add_segment(segment) + yield segments_copy() + segments_clear() + pos = end_pos + + cut = next(iter_cuts, -1) + if cut == -1: + if split_segments: + yield segments_copy() + return + + break + + else: + before, segment = segment.split_cells(cut - pos) + text, _style, control = segment + add_segment(before) + yield segments_copy() + segments_clear() + pos = cut + + cut = next(iter_cuts, -1) + if cut == -1: + if split_segments: + yield segments_copy() + return + + yield segments_copy() + + +class Segments: + """A simple renderable to render an iterable of segments. This class may be useful if + you want to print segments outside of a __rich_console__ method. + + Args: + segments (Iterable[Segment]): An iterable of segments. + new_lines (bool, optional): Add new lines between segments. Defaults to False. + """ + + def __init__(self, segments: Iterable[Segment], new_lines: bool = False) -> None: + self.segments = list(segments) + self.new_lines = new_lines + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + if self.new_lines: + line = Segment.line() + for segment in self.segments: + yield segment + yield line + else: + yield from self.segments + + +class SegmentLines: + def __init__(self, lines: Iterable[List[Segment]], new_lines: bool = False) -> None: + """A simple renderable containing a number of lines of segments. May be used as an intermediate + in rendering process. + + Args: + lines (Iterable[List[Segment]]): Lists of segments forming lines. + new_lines (bool, optional): Insert new lines after each line. Defaults to False. + """ + self.lines = list(lines) + self.new_lines = new_lines + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + if self.new_lines: + new_line = Segment.line() + for line in self.lines: + yield from line + yield new_line + else: + for line in self.lines: + yield from line + + +if __name__ == "__main__": # pragma: no cover + from pip._vendor.rich.console import Console + from pip._vendor.rich.syntax import Syntax + from pip._vendor.rich.text import Text + + code = """from rich.console import Console +console = Console() +text = Text.from_markup("Hello, [bold magenta]World[/]!") +console.print(text)""" + + text = Text.from_markup("Hello, [bold magenta]World[/]!") + + console = Console() + + console.rule("rich.Segment") + console.print( + "A Segment is the last step in the Rich render process before generating text with ANSI codes." + ) + console.print("\nConsider the following code:\n") + console.print(Syntax(code, "python", line_numbers=True)) + console.print() + console.print( + "When you call [b]print()[/b], Rich [i]renders[/i] the object in to the following:\n" + ) + fragments = list(console.render(text)) + console.print(fragments) + console.print() + console.print("The Segments are then processed to produce the following output:\n") + console.print(text) + console.print( + "\nYou will only need to know this if you are implementing your own Rich renderables." + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/terminal_theme.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/terminal_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..565e9d960f8604c487e063ad9ed3f6f63027f3b4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/terminal_theme.py @@ -0,0 +1,153 @@ +from typing import List, Optional, Tuple + +from .color_triplet import ColorTriplet +from .palette import Palette + +_ColorTuple = Tuple[int, int, int] + + +class TerminalTheme: + """A color theme used when exporting console content. + + Args: + background (Tuple[int, int, int]): The background color. + foreground (Tuple[int, int, int]): The foreground (text) color. + normal (List[Tuple[int, int, int]]): A list of 8 normal intensity colors. + bright (List[Tuple[int, int, int]], optional): A list of 8 bright colors, or None + to repeat normal intensity. Defaults to None. + """ + + def __init__( + self, + background: _ColorTuple, + foreground: _ColorTuple, + normal: List[_ColorTuple], + bright: Optional[List[_ColorTuple]] = None, + ) -> None: + self.background_color = ColorTriplet(*background) + self.foreground_color = ColorTriplet(*foreground) + self.ansi_colors = Palette(normal + (bright or normal)) + + +DEFAULT_TERMINAL_THEME = TerminalTheme( + (255, 255, 255), + (0, 0, 0), + [ + (0, 0, 0), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (128, 0, 128), + (0, 128, 128), + (192, 192, 192), + ], + [ + (128, 128, 128), + (255, 0, 0), + (0, 255, 0), + (255, 255, 0), + (0, 0, 255), + (255, 0, 255), + (0, 255, 255), + (255, 255, 255), + ], +) + +MONOKAI = TerminalTheme( + (12, 12, 12), + (217, 217, 217), + [ + (26, 26, 26), + (244, 0, 95), + (152, 224, 36), + (253, 151, 31), + (157, 101, 255), + (244, 0, 95), + (88, 209, 235), + (196, 197, 181), + (98, 94, 76), + ], + [ + (244, 0, 95), + (152, 224, 36), + (224, 213, 97), + (157, 101, 255), + (244, 0, 95), + (88, 209, 235), + (246, 246, 239), + ], +) +DIMMED_MONOKAI = TerminalTheme( + (25, 25, 25), + (185, 188, 186), + [ + (58, 61, 67), + (190, 63, 72), + (135, 154, 59), + (197, 166, 53), + (79, 118, 161), + (133, 92, 141), + (87, 143, 164), + (185, 188, 186), + (136, 137, 135), + ], + [ + (251, 0, 31), + (15, 114, 47), + (196, 112, 51), + (24, 109, 227), + (251, 0, 103), + (46, 112, 109), + (253, 255, 185), + ], +) +NIGHT_OWLISH = TerminalTheme( + (255, 255, 255), + (64, 63, 83), + [ + (1, 22, 39), + (211, 66, 62), + (42, 162, 152), + (218, 170, 1), + (72, 118, 214), + (64, 63, 83), + (8, 145, 106), + (122, 129, 129), + (122, 129, 129), + ], + [ + (247, 110, 110), + (73, 208, 197), + (218, 194, 107), + (92, 167, 228), + (105, 112, 152), + (0, 201, 144), + (152, 159, 177), + ], +) + +SVG_EXPORT_THEME = TerminalTheme( + (41, 41, 41), + (197, 200, 198), + [ + (75, 78, 85), + (204, 85, 90), + (152, 168, 75), + (208, 179, 68), + (96, 138, 177), + (152, 114, 159), + (104, 160, 179), + (197, 200, 198), + (154, 155, 153), + ], + [ + (255, 38, 39), + (0, 130, 61), + (208, 132, 66), + (25, 132, 233), + (255, 44, 122), + (57, 130, 128), + (253, 253, 197), + ], +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_ops.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fe396da3fb90e621542fb98b6a48d1eebc7df138 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_ops.py @@ -0,0 +1,322 @@ +import inspect + +from torch._custom_op.impl import ( + _custom_op_with_schema, + _find_custom_op, + infer_schema, + parse_qualname, + validate_namespace, +) +from torch.library import get_ctx + +__all__ = [ + "custom_op", + "impl", + "impl_abstract", + "get_ctx", + "impl_save_for_backward", + "impl_backward", +] + + +def custom_op(qualname, func_or_schema=None): + r"""Register a new custom operator + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define the op (by providing an operator name and schema) + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the custom operator (the first step) + you must then perform the second step by calling various + ``impl_*`` APIs. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + qualname (str): Should be a string that looks like + "namespace::operator_name". Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. + func_or_schema (Union[Callable, str]): Each PyTorch operator needs a + schema that tells PyTorch the types of the inputs/outputs. + If this is a Callable, we will automatically infer the schema from + the type annotations on the function (see examples). Otherwise, + if you don't want to use type annotations, you may provide us the + schema string. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_sin + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu") + >>> def numpy_sin_impl_cpu(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda") + >>> def numpy_sin_impl_cuda(x): + >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda + + """ + ns, name = parse_qualname(qualname) + validate_namespace(ns) + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func) + _custom_op_with_schema(qualname, schema) + return func + + if func_or_schema is None: + return inner + if isinstance(func_or_schema, str): + _custom_op_with_schema(qualname, func_or_schema) + else: + return inner(func_or_schema) + + +def impl(qualname, *, device_types=("cpu", "cuda"), func=None): + r"""Register an implementation for a device type for this custom op. + + If the op is passed multiple Tensor inputs with different device + types, it will dispatch to the registered implementation for the highest + priority device type among those present. + The supported device types, in order of priority, are {'cuda', 'cpu'}. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + device_types (str or Iterable[str]): the device type(s) to register the function for. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos") + >>> def numpy_cos(x: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_cos + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu") + >>> def numpy_cos_impl_cpu(x): + >>> return torch.from_numpy(np.cos(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda") + >>> def numpy_cos_impl_cuda(x): + >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl(device_types, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_abstract(qualname, *, func=None): + r"""Register an abstract implementation for this operator. + + An "abstract implementation" specifies the behavior of this operator on + Tensors that carry no data. Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + The abstract implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write an abstract + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The abstract implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Examples:: + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch._custom_ops.custom_op("mylibrary::custom_linear") + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear") + >>> def custom_linear_abstract(x, weight): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero') + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> ... + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero") + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch._custom_ops.get_ctx() + >>> nnz = ctx.create_unbacked_symint() + >>> shape = [x.dim(), nnz] + >>> result = x.new_empty(shape, dtype=torch.long) + >>> return result + >>> + >>> @torch._custom_ops.impl("mylibrary::custom_nonzero") + >>> def custom_nonzero_impl(x): + >>> x_np = to_numpy(x) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> # unbacked symbolic ints in PyTorch must be >= 2, so we + >>> # constrain the range to at least 2 + >>> if res.shape[0] <= 1: + >>> raise RuntimeError("not supported") + >>> return torch.tensor(res, device=x.device) + + """ + import torch.library + + return torch.library.impl_abstract(qualname, func, _stacklevel=2) + + +def impl_save_for_backward(qualname, *, func=None): + r"""Register a function that tells us what to save for backward. + + Please see :func:`impl_backward` for more details. + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_save_for_backward(_stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_backward(qualname, output_differentiability=None, *, func=None): + r"""Registers a backward formula for an operator. + + In order for an operator to work with autograd, you need to register + a backward formula. There are two pieces to this: + 1. You must give us a function to specify what to save for backward. + Call this the "save for backward" function. + 2. You must give us a function that computes gradients. Call this the + "backward" function. + + Use `impl_save_for_backward` to define a "save for backward" function + that specifies what gets saved for backward. The function should accept + two arguments ``(inputs, output)`` and return the quantities to be saved + for backward. + + During runtime, when you call the operator in a forwards pass, PyTorch + will invoke the "save for backward" function with the inputs and output + of the operator. + + Use `impl_backward` to define the "backward" function. The backward + function must accept ``(ctx, saved, *grads)``: + - ``ctx`` is a context object where we may provide information + - ``saved`` is exactly what gets returned from the "save for backward" + function + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + + The backward function must return a dict that maps the name of + an input to the operator to its corresponding gradient. All inputs that + were declared to be Tensors in the operator definition must be accounted + for in the dict. The gradient may be a Tensor or None. + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_backward(output_differentiability, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def _destroy(qualname): + """De-registers a custom op. For testing purposes only""" + custom_op = _find_custom_op(qualname) + custom_op._destroy() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..7c03a9f6e36af4d78dc3a3be990176eb978be3c4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,418 @@ +from typing import Any, Dict, List, Optional +import torch +from collections import defaultdict +from torch import nn +import copy +from ...sparsifier.utils import fqn_to_module, module_to_fqn +import warnings + +__all__ = ['ActivationSparsifier'] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> # xdoctest: +SKIP + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn) + >>> + >>> # start training process + >>> for _ in [...]: + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + def __init__(self, model: nn.Module, aggregate_fn=None, reduce_fn=None, mask_fn=None, + features=None, feature_dim=None, **sparse_config): + self.model = model + self.defaults: Dict[str, Any] = defaultdict() + self.defaults['sparse_config'] = sparse_config + + # functions + self.defaults['aggregate_fn'] = aggregate_fn + self.defaults['reduce_fn'] = reduce_fn + self.defaults['mask_fn'] = mask_fn + + # default feature and feature_dim + self.defaults['features'] = features + self.defaults['feature_dim'] = feature_dim + + self.data_groups: Dict[str, Dict] = defaultdict(dict) # contains all relevant info w.r.t each registered layer + + self.state: Dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly + """ + + # if features are not None, then feature_dim must not be None + features, feature_dim = args['features'], args['feature_dim'] + if features is not None: + assert feature_dim is not None, "need feature dim to select features" + + # all the *_fns should be callable + fn_keys = ['aggregate_fn', 'reduce_fn', 'mask_fn'] + for key in fn_keys: + fn = args[key] + assert callable(fn), 'function should be callable' + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through. + """ + + # gather some data + feature_dim = self.data_groups[name]['feature_dim'] + features = self.data_groups[name]['features'] + agg_fn = self.data_groups[name]['aggregate_fn'] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get('data') # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]['mask'] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [0 for _ in range(0, len(features))] # create one incase of 1st forward + self.state[name]['mask'] = [0 for _ in range(0, len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = torch.Tensor([features[feature_idx]]).long().to(input_data.device) + data_feature = torch.index_select(input_data, feature_dim, feature_tensor) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]['mask'][feature_idx] = torch.ones_like(data_feature) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]['data'] = out_data + return hook + + def register_layer(self, layer: nn.Module, aggregate_fn=None, reduce_fn=None, + mask_fn=None, features=None, feature_dim=None, **sparse_config): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the model" # satisfy mypy + + if name in self.data_groups: # unregister layer if already present + warnings.warn("layer already attached to the sparsifier, deregistering the layer and registering with new config") + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + 'aggregate_fn': aggregate_fn, + 'reduce_fn': reduce_fn, + 'mask_fn': mask_fn, + 'features': features, + 'feature_dim': feature_dim, + 'layer': layer + } + local_args.update((arg, val) for arg, val in update_dict.items() if val is not None) + local_args['sparse_config'].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]['mask'] = None # mask will be created when model forward is called. + + # attach agg hook + self.data_groups[name]['hook'] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + assert name is not None or layer is not None, "Need at least name or layer obj to retrieve mask" + + if name is None: + assert layer is not None + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the specified model" + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get('mask', None) + + if mask is None: + raise ValueError("Error: shape unknown, call layer() routine at least once to infer mask") + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer + """ + + # detach any hooks attached + self.data_groups[name]['hook'].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer + """ + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs['data'] + self.update_mask(name, data, configs) + + self.data_groups[name].pop('data') # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs['sparse_config'] + features = configs['features'] + reduce_fn = configs['reduce_fn'] + mask_fn = configs['mask_fn'] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer + """ + mask = self.get_mask(name) + features = self.data_groups[name]['features'] + feature_dim = self.data_groups[name]['feature_dim'] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(0, len(features)): + feature = torch.Tensor([features[feature_idx]]).long().to(input_data.device) + sparsified = torch.index_select(input_data, feature_dim, feature) * mask[feature_idx] + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggregate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs['hook'].remove() + configs.pop('hook') + self.data_groups[name]['hook_state'] = "None" + if attach_sparsify_hook: + configs['hook'] = configs['layer'].register_forward_pre_hook(self._sparsify_hook(name)) + configs['hook_state'] = "sparsify" # signals that sparsify hook is now attached + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: Dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = {key: value for key, value in config.items() if key not in ['hook', 'layer']} + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for state in states.values(): + if state['mask'] is not None: + if isinstance(state['mask'], List): + for idx in range(len(state['mask'])): + if sparse_coo: + state['mask'][idx] = state['mask'][idx].to_sparse_coo() + else: + state['mask'][idx] = state['mask'][idx].to_dense() + else: + if sparse_coo: + state['mask'] = state['mask'].to_sparse_coo() + else: + state['mask'] = state['mask'].to_dense() + return states + + def state_dict(self) -> Dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + 'state': state, + 'data_groups': data_groups, + 'defaults': self.defaults + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict['state'] + data_groups, defaults = state_dict['data_groups'], state_dict['defaults'] + + self.__set_state__({'state': state, 'data_groups': data_groups, 'defaults': defaults}) + + def __get_state__(self) -> Dict[str, Any]: + + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + 'defaults': self.defaults, + 'state': state, + 'data_groups': data_groups, + } + + def __set_state__(self, state: Dict[str, Any]) -> None: + state['state'] = self._convert_mask(state['state'], sparse_coo=False) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + assert layer is not None # satisfy mypy + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config['hook_state'] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config['layer'] = layer + config['hook'] = hook # type: ignore[possibly-undefined] + + def __repr__(self): + format_string = self.__class__.__name__ + ' (' + for name, config in self.data_groups.items(): + format_string += '\n' + format_string += '\tData Group\n' + format_string += f'\t name: {name}\n' + for key in sorted(config.keys()): + if key in ['data', 'hook', 'reduce_fn', 'mask_fn', 'aggregate_fn']: + continue + format_string += f'\t {key}: {config[key]}\n' + format_string += ')' + return format_string diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bce0b123bcd3172b88b7179c83cb89ab1fc25064 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/return_types.pyi b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/return_types.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b6fbf20793518de3e942211b9bca2935f25f3711 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/return_types.pyi @@ -0,0 +1,437 @@ +# @generated from torch/_C/return_types.pyi + +from typing import ( + Any, + Callable, + ContextManager, + Iterator, + List, + Literal, + NamedTuple, + NoReturn, + Optional, + overload, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor, SymInt +from torch.types import ( + _bool, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Number, +) + +class _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tuple[Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def mask(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _fused_moving_avg_obs_fq_helper(Tuple[Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def mask(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_det(Tuple[Tensor, Tensor, Tensor]): + @property + def result(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor]): ... + n_fields: _int = 3 + n_sequeunce_fields: _int = 3 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_eigh(Tuple[Tensor, Tensor]): + @property + def eigenvalues(self) -> Tensor: ... + @property + def eigenvectors(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_slogdet(Tuple[Tensor, Tensor, Tensor, Tensor]): + @property + def sign(self) -> Tensor: ... + @property + def logabsdet(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor, Tensor]): ... + n_fields: _int = 4 + n_sequeunce_fields: _int = 4 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_solve_ex(Tuple[Tensor, Tensor, Tensor, Tensor]): + @property + def result(self) -> Tensor: ... + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + @property + def info(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor, Tensor]): ... + n_fields: _int = 4 + n_sequeunce_fields: _int = 4 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _linalg_svd(Tuple[Tensor, Tensor, Tensor]): + @property + def U(self) -> Tensor: ... + @property + def S(self) -> Tensor: ... + @property + def Vh(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor]): ... + n_fields: _int = 3 + n_sequeunce_fields: _int = 3 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _lu_with_info(Tuple[Tensor, Tensor, Tensor]): + @property + def LU(self) -> Tensor: ... + @property + def pivots(self) -> Tensor: ... + @property + def info(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor]): ... + n_fields: _int = 3 + n_sequeunce_fields: _int = 3 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_cudnn_attention(Tuple[Tensor, Tensor, Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + @property + def philox_seed(self) -> Tensor: ... + @property + def philox_offset(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor, Tensor]): ... + n_fields: _int = 4 + n_sequeunce_fields: _int = 4 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_efficient_attention(Tuple[Tensor, Tensor, Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def log_sumexp(self) -> Tensor: ... + @property + def philox_seed(self) -> Tensor: ... + @property + def philox_offset(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor, Tensor]): ... + n_fields: _int = 4 + n_sequeunce_fields: _int = 4 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_flash_attention(Tuple[Tensor, Tensor, Tensor, Tensor, Union[_int, SymInt], Union[_int, SymInt], Tensor, Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + @property + def cum_seq_q(self) -> Tensor: ... + @property + def cum_seq_k(self) -> Tensor: ... + @property + def max_q(self) -> Union[_int, SymInt]: ... + @property + def max_k(self) -> Union[_int, SymInt]: ... + @property + def philox_seed(self) -> Tensor: ... + @property + def philox_offset(self) -> Tensor: ... + @property + def debug_attn_mask(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor, Tensor, Union[_int, SymInt], Union[_int, SymInt], Tensor, Tensor, Tensor]): ... + n_fields: _int = 9 + n_sequeunce_fields: _int = 9 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _scaled_dot_product_flash_attention_for_cpu(Tuple[Tensor, Tensor]): + @property + def output(self) -> Tensor: ... + @property + def logsumexp(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class _unpack_dual(Tuple[Tensor, Tensor]): + @property + def primal(self) -> Tensor: ... + @property + def tangent(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class aminmax(Tuple[Tensor, Tensor]): + @property + def min(self) -> Tensor: ... + @property + def max(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class cummax(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class cummin(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class frexp(Tuple[Tensor, Tensor]): + @property + def mantissa(self) -> Tensor: ... + @property + def exponent(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class geqrf(Tuple[Tensor, Tensor]): + @property + def a(self) -> Tensor: ... + @property + def tau(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class histogram(Tuple[Tensor, Tensor]): + @property + def hist(self) -> Tensor: ... + @property + def bin_edges(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class histogramdd(Tuple[Tensor, Tuple[Tensor, ...]]): + @property + def hist(self) -> Tensor: ... + @property + def bin_edges(self) -> Tuple[Tensor, ...]: ... + def __new__(cls, sequence: Tuple[Tensor, Tuple[Tensor, ...]]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class kthvalue(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class lu_unpack(Tuple[Tensor, Tensor, Tensor]): + @property + def P(self) -> Tensor: ... + @property + def L(self) -> Tensor: ... + @property + def U(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor]): ... + n_fields: _int = 3 + n_sequeunce_fields: _int = 3 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class max(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class median(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class min(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class mode(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class nanmedian(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class qr(Tuple[Tensor, Tensor]): + @property + def Q(self) -> Tensor: ... + @property + def R(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class slogdet(Tuple[Tensor, Tensor]): + @property + def sign(self) -> Tensor: ... + @property + def logabsdet(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class sort(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class svd(Tuple[Tensor, Tensor, Tensor]): + @property + def U(self) -> Tensor: ... + @property + def S(self) -> Tensor: ... + @property + def V(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor, Tensor]): ... + n_fields: _int = 3 + n_sequeunce_fields: _int = 3 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class topk(Tuple[Tensor, Tensor]): + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +class triangular_solve(Tuple[Tensor, Tensor]): + @property + def solution(self) -> Tensor: ... + @property + def cloned_coefficient(self) -> Tensor: ... + def __new__(cls, sequence: Tuple[Tensor, Tensor]): ... + n_fields: _int = 2 + n_sequeunce_fields: _int = 2 + n_unnamed_fields: _int = 0 + def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing + +all_return_types: List[Type] = [] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/types.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/types.py new file mode 100644 index 0000000000000000000000000000000000000000..22c01e3bb9795ec2ca23d6149ebbbfc0ab19bb7e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/types.py @@ -0,0 +1,79 @@ +import torch +from typing import Any, List, Optional, Sequence, Tuple, Union + +import builtins + +# Convenience aliases for common composite types that we need +# to talk about in PyTorch + +_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]] +_TensorOrTensorsOrGradEdge = Union[ + torch.Tensor, Sequence[torch.Tensor], + "torch.autograd.graph.GradientEdge", + Sequence["torch.autograd.graph.GradientEdge"]] + +# In some cases, these basic types are shadowed by corresponding +# top-level values. The underscore variants let us refer to these +# types. See https://github.com/python/mypy/issues/4146 for why these +# workarounds is necessary +_int = builtins.int +_float = builtins.float +_bool = builtins.bool +_complex = builtins.complex + +_dtype = torch.dtype +_device = torch.device +_qscheme = torch.qscheme +_size = Union[torch.Size, List[_int], Tuple[_int, ...]] +_layout = torch.layout +_dispatchkey = Union[str, torch._C.DispatchKey] + +# Meta-type for "numeric" things; matches our docs +Number = Union[builtins.int, builtins.float, builtins.bool] + +# Meta-type for "device-like" things. Not to be confused with 'device' (a +# literal device object). This nomenclature is consistent with PythonArgParser. +# None means use the default device (typically CPU) +Device = Optional[Union[_device, str, _int]] +del Optional + +# Storage protocol implemented by ${Type}StorageBase classes + +class Storage: + _cdata: int + device: torch.device + dtype: torch.dtype + _torch_load_uninitialized: bool + + def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body] + ... + + def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body] + ... + + def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None: + ... + + def element_size(self) -> int: # type: ignore[empty-body] + ... + + def is_shared(self) -> bool: # type: ignore[empty-body] + ... + + def share_memory_(self) -> 'Storage': # type: ignore[empty-body] + ... + + def nbytes(self) -> int: # type: ignore[empty-body] + ... + + def cpu(self) -> 'Storage': # type: ignore[empty-body] + ... + + def data_ptr(self) -> int: # type: ignore[empty-body] + ... + + def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body] + ... + + def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body] + ... diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/triton-2.3.0.dist-info/INSTALLER b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/triton-2.3.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/triton-2.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip