diff --git a/.gitattributes b/.gitattributes index 754a5a1108104dcf55ac998358e29e8406d88011..d1633cccc4f25e42dddc72fb3ad125848b7171f4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -397,3 +397,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/ .venv/lib/python3.11/site-packages/mistral_common/data/tokenizer.model.v1 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a85527f929d0c6562e450f9c6f288c26959a0055 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07faac212a7a262c6ea1fffc03378750fe6bb57a142cd64278e8827f652c7424 +size 390546 diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__init__.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e7baf2c683e27fca27f81e72c348fe8d225089 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/__init__.py @@ -0,0 +1,185 @@ +""" +A sub-package for efficiently dealing with polynomials. + +Within the documentation for this sub-package, a "finite power series," +i.e., a polynomial (also referred to simply as a "series") is represented +by a 1-D numpy array of the polynomial's coefficients, ordered from lowest +order term to highest. For example, array([1,2,3]) represents +``P_0 + 2*P_1 + 3*P_2``, where P_n is the n-th order basis polynomial +applicable to the specific module in question, e.g., `polynomial` (which +"wraps" the "standard" basis) or `chebyshev`. For optimal performance, +all operations on polynomials, including evaluation at an argument, are +implemented as operations on the coefficients. Additional (module-specific) +information can be found in the docstring for the module of interest. + +This package provides *convenience classes* for each of six different kinds +of polynomials: + + ======================== ================ + **Name** **Provides** + ======================== ================ + `~polynomial.Polynomial` Power series + `~chebyshev.Chebyshev` Chebyshev series + `~legendre.Legendre` Legendre series + `~laguerre.Laguerre` Laguerre series + `~hermite.Hermite` Hermite series + `~hermite_e.HermiteE` HermiteE series + ======================== ================ + +These *convenience classes* provide a consistent interface for creating, +manipulating, and fitting data with polynomials of different bases. +The convenience classes are the preferred interface for the `~numpy.polynomial` +package, and are available from the ``numpy.polynomial`` namespace. +This eliminates the need to navigate to the corresponding submodules, e.g. +``np.polynomial.Polynomial`` or ``np.polynomial.Chebyshev`` instead of +``np.polynomial.polynomial.Polynomial`` or +``np.polynomial.chebyshev.Chebyshev``, respectively. +The classes provide a more consistent and concise interface than the +type-specific functions defined in the submodules for each type of polynomial. +For example, to fit a Chebyshev polynomial with degree ``1`` to data given +by arrays ``xdata`` and ``ydata``, the +`~chebyshev.Chebyshev.fit` class method:: + + >>> from numpy.polynomial import Chebyshev + >>> c = Chebyshev.fit(xdata, ydata, deg=1) + +is preferred over the `chebyshev.chebfit` function from the +``np.polynomial.chebyshev`` module:: + + >>> from numpy.polynomial.chebyshev import chebfit + >>> c = chebfit(xdata, ydata, deg=1) + +See :doc:`routines.polynomials.classes` for more details. + +Convenience Classes +=================== + +The following lists the various constants and methods common to all of +the classes representing the various kinds of polynomials. In the following, +the term ``Poly`` represents any one of the convenience classes (e.g. +`~polynomial.Polynomial`, `~chebyshev.Chebyshev`, `~hermite.Hermite`, etc.) +while the lowercase ``p`` represents an **instance** of a polynomial class. + +Constants +--------- + +- ``Poly.domain`` -- Default domain +- ``Poly.window`` -- Default window +- ``Poly.basis_name`` -- String used to represent the basis +- ``Poly.maxpower`` -- Maximum value ``n`` such that ``p**n`` is allowed +- ``Poly.nickname`` -- String used in printing + +Creation +-------- + +Methods for creating polynomial instances. + +- ``Poly.basis(degree)`` -- Basis polynomial of given degree +- ``Poly.identity()`` -- ``p`` where ``p(x) = x`` for all ``x`` +- ``Poly.fit(x, y, deg)`` -- ``p`` of degree ``deg`` with coefficients + determined by the least-squares fit to the data ``x``, ``y`` +- ``Poly.fromroots(roots)`` -- ``p`` with specified roots +- ``p.copy()`` -- Create a copy of ``p`` + +Conversion +---------- + +Methods for converting a polynomial instance of one kind to another. + +- ``p.cast(Poly)`` -- Convert ``p`` to instance of kind ``Poly`` +- ``p.convert(Poly)`` -- Convert ``p`` to instance of kind ``Poly`` or map + between ``domain`` and ``window`` + +Calculus +-------- +- ``p.deriv()`` -- Take the derivative of ``p`` +- ``p.integ()`` -- Integrate ``p`` + +Validation +---------- +- ``Poly.has_samecoef(p1, p2)`` -- Check if coefficients match +- ``Poly.has_samedomain(p1, p2)`` -- Check if domains match +- ``Poly.has_sametype(p1, p2)`` -- Check if types match +- ``Poly.has_samewindow(p1, p2)`` -- Check if windows match + +Misc +---- +- ``p.linspace()`` -- Return ``x, p(x)`` at equally-spaced points in ``domain`` +- ``p.mapparms()`` -- Return the parameters for the linear mapping between + ``domain`` and ``window``. +- ``p.roots()`` -- Return the roots of `p`. +- ``p.trim()`` -- Remove trailing coefficients. +- ``p.cutdeg(degree)`` -- Truncate p to given degree +- ``p.truncate(size)`` -- Truncate p to given size + +""" +from .polynomial import Polynomial +from .chebyshev import Chebyshev +from .legendre import Legendre +from .hermite import Hermite +from .hermite_e import HermiteE +from .laguerre import Laguerre + +__all__ = [ + "set_default_printstyle", + "polynomial", "Polynomial", + "chebyshev", "Chebyshev", + "legendre", "Legendre", + "hermite", "Hermite", + "hermite_e", "HermiteE", + "laguerre", "Laguerre", +] + + +def set_default_printstyle(style): + """ + Set the default format for the string representation of polynomials. + + Values for ``style`` must be valid inputs to ``__format__``, i.e. 'ascii' + or 'unicode'. + + Parameters + ---------- + style : str + Format string for default printing style. Must be either 'ascii' or + 'unicode'. + + Notes + ----- + The default format depends on the platform: 'unicode' is used on + Unix-based systems and 'ascii' on Windows. This determination is based on + default font support for the unicode superscript and subscript ranges. + + Examples + -------- + >>> p = np.polynomial.Polynomial([1, 2, 3]) + >>> c = np.polynomial.Chebyshev([1, 2, 3]) + >>> np.polynomial.set_default_printstyle('unicode') + >>> print(p) + 1.0 + 2.0·x + 3.0·x² + >>> print(c) + 1.0 + 2.0·T₁(x) + 3.0·T₂(x) + >>> np.polynomial.set_default_printstyle('ascii') + >>> print(p) + 1.0 + 2.0 x + 3.0 x**2 + >>> print(c) + 1.0 + 2.0 T_1(x) + 3.0 T_2(x) + >>> # Formatting supersedes all class/package-level defaults + >>> print(f"{p:unicode}") + 1.0 + 2.0·x + 3.0·x² + """ + if style not in ('unicode', 'ascii'): + raise ValueError( + f"Unsupported format string '{style}'. Valid options are 'ascii' " + f"and 'unicode'" + ) + _use_unicode = True + if style == 'ascii': + _use_unicode = False + from ._polybase import ABCPolyBase + ABCPolyBase._use_unicode = _use_unicode + + +from numpy._pytesttester import PytestTester +test = PytestTester(__name__) +del PytestTester diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2f3fbe868cff83069ebffc27916977c03ef6126 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/_polybase.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/_polybase.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85c2fefd782041f7f6c5f77e57c1abb8333c174c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/_polybase.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/chebyshev.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/chebyshev.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bc5523cc2495b09852b49e5bb8371adc2ef1900 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/chebyshev.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eee601a605e9a6016aa7872943b199458dd8b64c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite_e.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite_e.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..104446c8d3a9c639fdaffb7bf8057b6b5247daa8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite_e.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/laguerre.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/laguerre.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faeddcdc242f9d5e7067538bf25095835e50df8d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/laguerre.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/legendre.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/legendre.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fdf0e455a78ea30abc9d86b57fe999c932bac2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/legendre.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polynomial.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polynomial.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a285b3d2aa13a545cf26d77113a4f6ab039377 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polynomial.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polyutils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polyutils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e5bd57a2dd589d2957a18e61464482eef6c1da Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polyutils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/setup.cpython-311.pyc b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/setup.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..257fed74c3036316f3855adf5d54b4276d353239 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/setup.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.py new file mode 100644 index 0000000000000000000000000000000000000000..9730574cf22e22823aaa0c77be9e630425cb2f79 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.py @@ -0,0 +1,1206 @@ +""" +Abstract base class for the various polynomial Classes. + +The ABCPolyBase class provides the methods needed to implement the common API +for the various polynomial classes. It operates as a mixin, but uses the +abc module from the stdlib, hence it is only available for Python >= 2.6. + +""" +import os +import abc +import numbers + +import numpy as np +from . import polyutils as pu + +__all__ = ['ABCPolyBase'] + +class ABCPolyBase(abc.ABC): + """An abstract base class for immutable series classes. + + ABCPolyBase provides the standard Python numerical methods + '+', '-', '*', '//', '%', 'divmod', '**', and '()' along with the + methods listed below. + + .. versionadded:: 1.9.0 + + Parameters + ---------- + coef : array_like + Series coefficients in order of increasing degree, i.e., + ``(1, 2, 3)`` gives ``1*P_0(x) + 2*P_1(x) + 3*P_2(x)``, where + ``P_i`` is the basis polynomials of degree ``i``. + domain : (2,) array_like, optional + Domain to use. The interval ``[domain[0], domain[1]]`` is mapped + to the interval ``[window[0], window[1]]`` by shifting and scaling. + The default value is the derived class domain. + window : (2,) array_like, optional + Window, see domain for its use. The default value is the + derived class window. + symbol : str, optional + Symbol used to represent the independent variable in string + representations of the polynomial expression, e.g. for printing. + The symbol must be a valid Python identifier. Default value is 'x'. + + .. versionadded:: 1.24 + + Attributes + ---------- + coef : (N,) ndarray + Series coefficients in order of increasing degree. + domain : (2,) ndarray + Domain that is mapped to window. + window : (2,) ndarray + Window that domain is mapped to. + symbol : str + Symbol representing the independent variable. + + Class Attributes + ---------------- + maxpower : int + Maximum power allowed, i.e., the largest number ``n`` such that + ``p(x)**n`` is allowed. This is to limit runaway polynomial size. + domain : (2,) ndarray + Default domain of the class. + window : (2,) ndarray + Default window of the class. + + """ + + # Not hashable + __hash__ = None + + # Opt out of numpy ufuncs and Python ops with ndarray subclasses. + __array_ufunc__ = None + + # Limit runaway size. T_n^m has degree n*m + maxpower = 100 + + # Unicode character mappings for improved __str__ + _superscript_mapping = str.maketrans({ + "0": "⁰", + "1": "¹", + "2": "²", + "3": "³", + "4": "⁴", + "5": "⁵", + "6": "⁶", + "7": "⁷", + "8": "⁸", + "9": "⁹" + }) + _subscript_mapping = str.maketrans({ + "0": "₀", + "1": "₁", + "2": "₂", + "3": "₃", + "4": "₄", + "5": "₅", + "6": "₆", + "7": "₇", + "8": "₈", + "9": "₉" + }) + # Some fonts don't support full unicode character ranges necessary for + # the full set of superscripts and subscripts, including common/default + # fonts in Windows shells/terminals. Therefore, default to ascii-only + # printing on windows. + _use_unicode = not os.name == 'nt' + + @property + def symbol(self): + return self._symbol + + @property + @abc.abstractmethod + def domain(self): + pass + + @property + @abc.abstractmethod + def window(self): + pass + + @property + @abc.abstractmethod + def basis_name(self): + pass + + @staticmethod + @abc.abstractmethod + def _add(c1, c2): + pass + + @staticmethod + @abc.abstractmethod + def _sub(c1, c2): + pass + + @staticmethod + @abc.abstractmethod + def _mul(c1, c2): + pass + + @staticmethod + @abc.abstractmethod + def _div(c1, c2): + pass + + @staticmethod + @abc.abstractmethod + def _pow(c, pow, maxpower=None): + pass + + @staticmethod + @abc.abstractmethod + def _val(x, c): + pass + + @staticmethod + @abc.abstractmethod + def _int(c, m, k, lbnd, scl): + pass + + @staticmethod + @abc.abstractmethod + def _der(c, m, scl): + pass + + @staticmethod + @abc.abstractmethod + def _fit(x, y, deg, rcond, full): + pass + + @staticmethod + @abc.abstractmethod + def _line(off, scl): + pass + + @staticmethod + @abc.abstractmethod + def _roots(c): + pass + + @staticmethod + @abc.abstractmethod + def _fromroots(r): + pass + + def has_samecoef(self, other): + """Check if coefficients match. + + .. versionadded:: 1.6.0 + + Parameters + ---------- + other : class instance + The other class must have the ``coef`` attribute. + + Returns + ------- + bool : boolean + True if the coefficients are the same, False otherwise. + + """ + if len(self.coef) != len(other.coef): + return False + elif not np.all(self.coef == other.coef): + return False + else: + return True + + def has_samedomain(self, other): + """Check if domains match. + + .. versionadded:: 1.6.0 + + Parameters + ---------- + other : class instance + The other class must have the ``domain`` attribute. + + Returns + ------- + bool : boolean + True if the domains are the same, False otherwise. + + """ + return np.all(self.domain == other.domain) + + def has_samewindow(self, other): + """Check if windows match. + + .. versionadded:: 1.6.0 + + Parameters + ---------- + other : class instance + The other class must have the ``window`` attribute. + + Returns + ------- + bool : boolean + True if the windows are the same, False otherwise. + + """ + return np.all(self.window == other.window) + + def has_sametype(self, other): + """Check if types match. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + other : object + Class instance. + + Returns + ------- + bool : boolean + True if other is same class as self + + """ + return isinstance(other, self.__class__) + + def _get_coefficients(self, other): + """Interpret other as polynomial coefficients. + + The `other` argument is checked to see if it is of the same + class as self with identical domain and window. If so, + return its coefficients, otherwise return `other`. + + .. versionadded:: 1.9.0 + + Parameters + ---------- + other : anything + Object to be checked. + + Returns + ------- + coef + The coefficients of`other` if it is a compatible instance, + of ABCPolyBase, otherwise `other`. + + Raises + ------ + TypeError + When `other` is an incompatible instance of ABCPolyBase. + + """ + if isinstance(other, ABCPolyBase): + if not isinstance(other, self.__class__): + raise TypeError("Polynomial types differ") + elif not np.all(self.domain == other.domain): + raise TypeError("Domains differ") + elif not np.all(self.window == other.window): + raise TypeError("Windows differ") + elif self.symbol != other.symbol: + raise ValueError("Polynomial symbols differ") + return other.coef + return other + + def __init__(self, coef, domain=None, window=None, symbol='x'): + [coef] = pu.as_series([coef], trim=False) + self.coef = coef + + if domain is not None: + [domain] = pu.as_series([domain], trim=False) + if len(domain) != 2: + raise ValueError("Domain has wrong number of elements.") + self.domain = domain + + if window is not None: + [window] = pu.as_series([window], trim=False) + if len(window) != 2: + raise ValueError("Window has wrong number of elements.") + self.window = window + + # Validation for symbol + try: + if not symbol.isidentifier(): + raise ValueError( + "Symbol string must be a valid Python identifier" + ) + # If a user passes in something other than a string, the above + # results in an AttributeError. Catch this and raise a more + # informative exception + except AttributeError: + raise TypeError("Symbol must be a non-empty string") + + self._symbol = symbol + + def __repr__(self): + coef = repr(self.coef)[6:-1] + domain = repr(self.domain)[6:-1] + window = repr(self.window)[6:-1] + name = self.__class__.__name__ + return (f"{name}({coef}, domain={domain}, window={window}, " + f"symbol='{self.symbol}')") + + def __format__(self, fmt_str): + if fmt_str == '': + return self.__str__() + if fmt_str not in ('ascii', 'unicode'): + raise ValueError( + f"Unsupported format string '{fmt_str}' passed to " + f"{self.__class__}.__format__. Valid options are " + f"'ascii' and 'unicode'" + ) + if fmt_str == 'ascii': + return self._generate_string(self._str_term_ascii) + return self._generate_string(self._str_term_unicode) + + def __str__(self): + if self._use_unicode: + return self._generate_string(self._str_term_unicode) + return self._generate_string(self._str_term_ascii) + + def _generate_string(self, term_method): + """ + Generate the full string representation of the polynomial, using + ``term_method`` to generate each polynomial term. + """ + # Get configuration for line breaks + linewidth = np.get_printoptions().get('linewidth', 75) + if linewidth < 1: + linewidth = 1 + out = pu.format_float(self.coef[0]) + for i, coef in enumerate(self.coef[1:]): + out += " " + power = str(i + 1) + # Polynomial coefficient + # The coefficient array can be an object array with elements that + # will raise a TypeError with >= 0 (e.g. strings or Python + # complex). In this case, represent the coefficient as-is. + try: + if coef >= 0: + next_term = f"+ " + pu.format_float(coef, parens=True) + else: + next_term = f"- " + pu.format_float(-coef, parens=True) + except TypeError: + next_term = f"+ {coef}" + # Polynomial term + next_term += term_method(power, self.symbol) + # Length of the current line with next term added + line_len = len(out.split('\n')[-1]) + len(next_term) + # If not the last term in the polynomial, it will be two + # characters longer due to the +/- with the next term + if i < len(self.coef[1:]) - 1: + line_len += 2 + # Handle linebreaking + if line_len >= linewidth: + next_term = next_term.replace(" ", "\n", 1) + out += next_term + return out + + @classmethod + def _str_term_unicode(cls, i, arg_str): + """ + String representation of single polynomial term using unicode + characters for superscripts and subscripts. + """ + if cls.basis_name is None: + raise NotImplementedError( + "Subclasses must define either a basis_name, or override " + "_str_term_unicode(cls, i, arg_str)" + ) + return (f"·{cls.basis_name}{i.translate(cls._subscript_mapping)}" + f"({arg_str})") + + @classmethod + def _str_term_ascii(cls, i, arg_str): + """ + String representation of a single polynomial term using ** and _ to + represent superscripts and subscripts, respectively. + """ + if cls.basis_name is None: + raise NotImplementedError( + "Subclasses must define either a basis_name, or override " + "_str_term_ascii(cls, i, arg_str)" + ) + return f" {cls.basis_name}_{i}({arg_str})" + + @classmethod + def _repr_latex_term(cls, i, arg_str, needs_parens): + if cls.basis_name is None: + raise NotImplementedError( + "Subclasses must define either a basis name, or override " + "_repr_latex_term(i, arg_str, needs_parens)") + # since we always add parens, we don't care if the expression needs them + return f"{{{cls.basis_name}}}_{{{i}}}({arg_str})" + + @staticmethod + def _repr_latex_scalar(x, parens=False): + # TODO: we're stuck with disabling math formatting until we handle + # exponents in this function + return r'\text{{{}}}'.format(pu.format_float(x, parens=parens)) + + def _repr_latex_(self): + # get the scaled argument string to the basis functions + off, scale = self.mapparms() + if off == 0 and scale == 1: + term = self.symbol + needs_parens = False + elif scale == 1: + term = f"{self._repr_latex_scalar(off)} + {self.symbol}" + needs_parens = True + elif off == 0: + term = f"{self._repr_latex_scalar(scale)}{self.symbol}" + needs_parens = True + else: + term = ( + f"{self._repr_latex_scalar(off)} + " + f"{self._repr_latex_scalar(scale)}{self.symbol}" + ) + needs_parens = True + + mute = r"\color{{LightGray}}{{{}}}".format + + parts = [] + for i, c in enumerate(self.coef): + # prevent duplication of + and - signs + if i == 0: + coef_str = f"{self._repr_latex_scalar(c)}" + elif not isinstance(c, numbers.Real): + coef_str = f" + ({self._repr_latex_scalar(c)})" + elif not np.signbit(c): + coef_str = f" + {self._repr_latex_scalar(c, parens=True)}" + else: + coef_str = f" - {self._repr_latex_scalar(-c, parens=True)}" + + # produce the string for the term + term_str = self._repr_latex_term(i, term, needs_parens) + if term_str == '1': + part = coef_str + else: + part = rf"{coef_str}\,{term_str}" + + if c == 0: + part = mute(part) + + parts.append(part) + + if parts: + body = ''.join(parts) + else: + # in case somehow there are no coefficients at all + body = '0' + + return rf"${self.symbol} \mapsto {body}$" + + + + # Pickle and copy + + def __getstate__(self): + ret = self.__dict__.copy() + ret['coef'] = self.coef.copy() + ret['domain'] = self.domain.copy() + ret['window'] = self.window.copy() + ret['symbol'] = self.symbol + return ret + + def __setstate__(self, dict): + self.__dict__ = dict + + # Call + + def __call__(self, arg): + off, scl = pu.mapparms(self.domain, self.window) + arg = off + scl*arg + return self._val(arg, self.coef) + + def __iter__(self): + return iter(self.coef) + + def __len__(self): + return len(self.coef) + + # Numeric properties. + + def __neg__(self): + return self.__class__( + -self.coef, self.domain, self.window, self.symbol + ) + + def __pos__(self): + return self + + def __add__(self, other): + othercoef = self._get_coefficients(other) + try: + coef = self._add(self.coef, othercoef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __sub__(self, other): + othercoef = self._get_coefficients(other) + try: + coef = self._sub(self.coef, othercoef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __mul__(self, other): + othercoef = self._get_coefficients(other) + try: + coef = self._mul(self.coef, othercoef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __truediv__(self, other): + # there is no true divide if the rhs is not a Number, although it + # could return the first n elements of an infinite series. + # It is hard to see where n would come from, though. + if not isinstance(other, numbers.Number) or isinstance(other, bool): + raise TypeError( + f"unsupported types for true division: " + f"'{type(self)}', '{type(other)}'" + ) + return self.__floordiv__(other) + + def __floordiv__(self, other): + res = self.__divmod__(other) + if res is NotImplemented: + return res + return res[0] + + def __mod__(self, other): + res = self.__divmod__(other) + if res is NotImplemented: + return res + return res[1] + + def __divmod__(self, other): + othercoef = self._get_coefficients(other) + try: + quo, rem = self._div(self.coef, othercoef) + except ZeroDivisionError: + raise + except Exception: + return NotImplemented + quo = self.__class__(quo, self.domain, self.window, self.symbol) + rem = self.__class__(rem, self.domain, self.window, self.symbol) + return quo, rem + + def __pow__(self, other): + coef = self._pow(self.coef, other, maxpower=self.maxpower) + res = self.__class__(coef, self.domain, self.window, self.symbol) + return res + + def __radd__(self, other): + try: + coef = self._add(other, self.coef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __rsub__(self, other): + try: + coef = self._sub(other, self.coef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __rmul__(self, other): + try: + coef = self._mul(other, self.coef) + except Exception: + return NotImplemented + return self.__class__(coef, self.domain, self.window, self.symbol) + + def __rdiv__(self, other): + # set to __floordiv__ /. + return self.__rfloordiv__(other) + + def __rtruediv__(self, other): + # An instance of ABCPolyBase is not considered a + # Number. + return NotImplemented + + def __rfloordiv__(self, other): + res = self.__rdivmod__(other) + if res is NotImplemented: + return res + return res[0] + + def __rmod__(self, other): + res = self.__rdivmod__(other) + if res is NotImplemented: + return res + return res[1] + + def __rdivmod__(self, other): + try: + quo, rem = self._div(other, self.coef) + except ZeroDivisionError: + raise + except Exception: + return NotImplemented + quo = self.__class__(quo, self.domain, self.window, self.symbol) + rem = self.__class__(rem, self.domain, self.window, self.symbol) + return quo, rem + + def __eq__(self, other): + res = (isinstance(other, self.__class__) and + np.all(self.domain == other.domain) and + np.all(self.window == other.window) and + (self.coef.shape == other.coef.shape) and + np.all(self.coef == other.coef) and + (self.symbol == other.symbol)) + return res + + def __ne__(self, other): + return not self.__eq__(other) + + # + # Extra methods. + # + + def copy(self): + """Return a copy. + + Returns + ------- + new_series : series + Copy of self. + + """ + return self.__class__(self.coef, self.domain, self.window, self.symbol) + + def degree(self): + """The degree of the series. + + .. versionadded:: 1.5.0 + + Returns + ------- + degree : int + Degree of the series, one less than the number of coefficients. + + Examples + -------- + + Create a polynomial object for ``1 + 7*x + 4*x**2``: + + >>> poly = np.polynomial.Polynomial([1, 7, 4]) + >>> print(poly) + 1.0 + 7.0·x + 4.0·x² + >>> poly.degree() + 2 + + Note that this method does not check for non-zero coefficients. + You must trim the polynomial to remove any trailing zeroes: + + >>> poly = np.polynomial.Polynomial([1, 7, 0]) + >>> print(poly) + 1.0 + 7.0·x + 0.0·x² + >>> poly.degree() + 2 + >>> poly.trim().degree() + 1 + + """ + return len(self) - 1 + + def cutdeg(self, deg): + """Truncate series to the given degree. + + Reduce the degree of the series to `deg` by discarding the + high order terms. If `deg` is greater than the current degree a + copy of the current series is returned. This can be useful in least + squares where the coefficients of the high degree terms may be very + small. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + deg : non-negative int + The series is reduced to degree `deg` by discarding the high + order terms. The value of `deg` must be a non-negative integer. + + Returns + ------- + new_series : series + New instance of series with reduced degree. + + """ + return self.truncate(deg + 1) + + def trim(self, tol=0): + """Remove trailing coefficients + + Remove trailing coefficients until a coefficient is reached whose + absolute value greater than `tol` or the beginning of the series is + reached. If all the coefficients would be removed the series is set + to ``[0]``. A new series instance is returned with the new + coefficients. The current instance remains unchanged. + + Parameters + ---------- + tol : non-negative number. + All trailing coefficients less than `tol` will be removed. + + Returns + ------- + new_series : series + New instance of series with trimmed coefficients. + + """ + coef = pu.trimcoef(self.coef, tol) + return self.__class__(coef, self.domain, self.window, self.symbol) + + def truncate(self, size): + """Truncate series to length `size`. + + Reduce the series to length `size` by discarding the high + degree terms. The value of `size` must be a positive integer. This + can be useful in least squares where the coefficients of the + high degree terms may be very small. + + Parameters + ---------- + size : positive int + The series is reduced to length `size` by discarding the high + degree terms. The value of `size` must be a positive integer. + + Returns + ------- + new_series : series + New instance of series with truncated coefficients. + + """ + isize = int(size) + if isize != size or isize < 1: + raise ValueError("size must be a positive integer") + if isize >= len(self.coef): + coef = self.coef + else: + coef = self.coef[:isize] + return self.__class__(coef, self.domain, self.window, self.symbol) + + def convert(self, domain=None, kind=None, window=None): + """Convert series to a different kind and/or domain and/or window. + + Parameters + ---------- + domain : array_like, optional + The domain of the converted series. If the value is None, + the default domain of `kind` is used. + kind : class, optional + The polynomial series type class to which the current instance + should be converted. If kind is None, then the class of the + current instance is used. + window : array_like, optional + The window of the converted series. If the value is None, + the default window of `kind` is used. + + Returns + ------- + new_series : series + The returned class can be of different type than the current + instance and/or have a different domain and/or different + window. + + Notes + ----- + Conversion between domains and class types can result in + numerically ill defined series. + + """ + if kind is None: + kind = self.__class__ + if domain is None: + domain = kind.domain + if window is None: + window = kind.window + return self(kind.identity(domain, window=window, symbol=self.symbol)) + + def mapparms(self): + """Return the mapping parameters. + + The returned values define a linear map ``off + scl*x`` that is + applied to the input arguments before the series is evaluated. The + map depends on the ``domain`` and ``window``; if the current + ``domain`` is equal to the ``window`` the resulting map is the + identity. If the coefficients of the series instance are to be + used by themselves outside this class, then the linear function + must be substituted for the ``x`` in the standard representation of + the base polynomials. + + Returns + ------- + off, scl : float or complex + The mapping function is defined by ``off + scl*x``. + + Notes + ----- + If the current domain is the interval ``[l1, r1]`` and the window + is ``[l2, r2]``, then the linear mapping function ``L`` is + defined by the equations:: + + L(l1) = l2 + L(r1) = r2 + + """ + return pu.mapparms(self.domain, self.window) + + def integ(self, m=1, k=[], lbnd=None): + """Integrate. + + Return a series instance that is the definite integral of the + current series. + + Parameters + ---------- + m : non-negative int + The number of integrations to perform. + k : array_like + Integration constants. The first constant is applied to the + first integration, the second to the second, and so on. The + list of values must less than or equal to `m` in length and any + missing values are set to zero. + lbnd : Scalar + The lower bound of the definite integral. + + Returns + ------- + new_series : series + A new series representing the integral. The domain is the same + as the domain of the integrated series. + + """ + off, scl = self.mapparms() + if lbnd is None: + lbnd = 0 + else: + lbnd = off + scl*lbnd + coef = self._int(self.coef, m, k, lbnd, 1./scl) + return self.__class__(coef, self.domain, self.window, self.symbol) + + def deriv(self, m=1): + """Differentiate. + + Return a series instance of that is the derivative of the current + series. + + Parameters + ---------- + m : non-negative int + Find the derivative of order `m`. + + Returns + ------- + new_series : series + A new series representing the derivative. The domain is the same + as the domain of the differentiated series. + + """ + off, scl = self.mapparms() + coef = self._der(self.coef, m, scl) + return self.__class__(coef, self.domain, self.window, self.symbol) + + def roots(self): + """Return the roots of the series polynomial. + + Compute the roots for the series. Note that the accuracy of the + roots decreases the further outside the `domain` they lie. + + Returns + ------- + roots : ndarray + Array containing the roots of the series. + + """ + roots = self._roots(self.coef) + return pu.mapdomain(roots, self.window, self.domain) + + def linspace(self, n=100, domain=None): + """Return x, y values at equally spaced points in domain. + + Returns the x, y values at `n` linearly spaced points across the + domain. Here y is the value of the polynomial at the points x. By + default the domain is the same as that of the series instance. + This method is intended mostly as a plotting aid. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + n : int, optional + Number of point pairs to return. The default value is 100. + domain : {None, array_like}, optional + If not None, the specified domain is used instead of that of + the calling instance. It should be of the form ``[beg,end]``. + The default is None which case the class domain is used. + + Returns + ------- + x, y : ndarray + x is equal to linspace(self.domain[0], self.domain[1], n) and + y is the series evaluated at element of x. + + """ + if domain is None: + domain = self.domain + x = np.linspace(domain[0], domain[1], n) + y = self(x) + return x, y + + @classmethod + def fit(cls, x, y, deg, domain=None, rcond=None, full=False, w=None, + window=None, symbol='x'): + """Least squares fit to data. + + Return a series instance that is the least squares fit to the data + `y` sampled at `x`. The domain of the returned instance can be + specified and this will often result in a superior fit with less + chance of ill conditioning. + + Parameters + ---------- + x : array_like, shape (M,) + x-coordinates of the M sample points ``(x[i], y[i])``. + y : array_like, shape (M,) + y-coordinates of the M sample points ``(x[i], y[i])``. + deg : int or 1-D array_like + Degree(s) of the fitting polynomials. If `deg` is a single integer + all terms up to and including the `deg`'th term are included in the + fit. For NumPy versions >= 1.11.0 a list of integers specifying the + degrees of the terms to include may be used instead. + domain : {None, [beg, end], []}, optional + Domain to use for the returned series. If ``None``, + then a minimal domain that covers the points `x` is chosen. If + ``[]`` the class domain is used. The default value was the + class domain in NumPy 1.4 and ``None`` in later versions. + The ``[]`` option was added in numpy 1.5.0. + rcond : float, optional + Relative condition number of the fit. Singular values smaller + than this relative to the largest singular value will be + ignored. The default value is len(x)*eps, where eps is the + relative precision of the float type, about 2e-16 in most + cases. + full : bool, optional + Switch determining nature of return value. When it is False + (the default) just the coefficients are returned, when True + diagnostic information from the singular value decomposition is + also returned. + w : array_like, shape (M,), optional + Weights. If not None, the weight ``w[i]`` applies to the unsquared + residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are + chosen so that the errors of the products ``w[i]*y[i]`` all have + the same variance. When using inverse-variance weighting, use + ``w[i] = 1/sigma(y[i])``. The default value is None. + + .. versionadded:: 1.5.0 + window : {[beg, end]}, optional + Window to use for the returned series. The default + value is the default class domain + + .. versionadded:: 1.6.0 + symbol : str, optional + Symbol representing the independent variable. Default is 'x'. + + Returns + ------- + new_series : series + A series that represents the least squares fit to the data and + has the domain and window specified in the call. If the + coefficients for the unscaled and unshifted basis polynomials are + of interest, do ``new_series.convert().coef``. + + [resid, rank, sv, rcond] : list + These values are only returned if ``full == True`` + + - resid -- sum of squared residuals of the least squares fit + - rank -- the numerical rank of the scaled Vandermonde matrix + - sv -- singular values of the scaled Vandermonde matrix + - rcond -- value of `rcond`. + + For more details, see `linalg.lstsq`. + + """ + if domain is None: + domain = pu.getdomain(x) + elif type(domain) is list and len(domain) == 0: + domain = cls.domain + + if window is None: + window = cls.window + + xnew = pu.mapdomain(x, domain, window) + res = cls._fit(xnew, y, deg, w=w, rcond=rcond, full=full) + if full: + [coef, status] = res + return ( + cls(coef, domain=domain, window=window, symbol=symbol), status + ) + else: + coef = res + return cls(coef, domain=domain, window=window, symbol=symbol) + + @classmethod + def fromroots(cls, roots, domain=[], window=None, symbol='x'): + """Return series instance that has the specified roots. + + Returns a series representing the product + ``(x - r[0])*(x - r[1])*...*(x - r[n-1])``, where ``r`` is a + list of roots. + + Parameters + ---------- + roots : array_like + List of roots. + domain : {[], None, array_like}, optional + Domain for the resulting series. If None the domain is the + interval from the smallest root to the largest. If [] the + domain is the class domain. The default is []. + window : {None, array_like}, optional + Window for the returned series. If None the class window is + used. The default is None. + symbol : str, optional + Symbol representing the independent variable. Default is 'x'. + + Returns + ------- + new_series : series + Series with the specified roots. + + """ + [roots] = pu.as_series([roots], trim=False) + if domain is None: + domain = pu.getdomain(roots) + elif type(domain) is list and len(domain) == 0: + domain = cls.domain + + if window is None: + window = cls.window + + deg = len(roots) + off, scl = pu.mapparms(domain, window) + rnew = off + scl*roots + coef = cls._fromroots(rnew) / scl**deg + return cls(coef, domain=domain, window=window, symbol=symbol) + + @classmethod + def identity(cls, domain=None, window=None, symbol='x'): + """Identity function. + + If ``p`` is the returned series, then ``p(x) == x`` for all + values of x. + + Parameters + ---------- + domain : {None, array_like}, optional + If given, the array must be of the form ``[beg, end]``, where + ``beg`` and ``end`` are the endpoints of the domain. If None is + given then the class domain is used. The default is None. + window : {None, array_like}, optional + If given, the resulting array must be if the form + ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of + the window. If None is given then the class window is used. The + default is None. + symbol : str, optional + Symbol representing the independent variable. Default is 'x'. + + Returns + ------- + new_series : series + Series of representing the identity. + + """ + if domain is None: + domain = cls.domain + if window is None: + window = cls.window + off, scl = pu.mapparms(window, domain) + coef = cls._line(off, scl) + return cls(coef, domain, window, symbol) + + @classmethod + def basis(cls, deg, domain=None, window=None, symbol='x'): + """Series basis polynomial of degree `deg`. + + Returns the series representing the basis polynomial of degree `deg`. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + deg : int + Degree of the basis polynomial for the series. Must be >= 0. + domain : {None, array_like}, optional + If given, the array must be of the form ``[beg, end]``, where + ``beg`` and ``end`` are the endpoints of the domain. If None is + given then the class domain is used. The default is None. + window : {None, array_like}, optional + If given, the resulting array must be if the form + ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of + the window. If None is given then the class window is used. The + default is None. + symbol : str, optional + Symbol representing the independent variable. Default is 'x'. + + Returns + ------- + new_series : series + A series with the coefficient of the `deg` term set to one and + all others zero. + + """ + if domain is None: + domain = cls.domain + if window is None: + window = cls.window + ideg = int(deg) + + if ideg != deg or ideg < 0: + raise ValueError("deg must be non-negative integer") + return cls([0]*ideg + [1], domain, window, symbol) + + @classmethod + def cast(cls, series, domain=None, window=None): + """Convert series to series of this class. + + The `series` is expected to be an instance of some polynomial + series of one of the types supported by by the numpy.polynomial + module, but could be some other class that supports the convert + method. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + series : series + The series instance to be converted. + domain : {None, array_like}, optional + If given, the array must be of the form ``[beg, end]``, where + ``beg`` and ``end`` are the endpoints of the domain. If None is + given then the class domain is used. The default is None. + window : {None, array_like}, optional + If given, the resulting array must be if the form + ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of + the window. If None is given then the class window is used. The + default is None. + + Returns + ------- + new_series : series + A series of the same kind as the calling class and equal to + `series` when evaluated. + + See Also + -------- + convert : similar instance method + + """ + if domain is None: + domain = cls.domain + if window is None: + window = cls.window + return series.convert(domain, cls, window) diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.pyi b/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.pyi new file mode 100644 index 0000000000000000000000000000000000000000..25c740dbedd02ca6c3f6e1beb155876a967cb57c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.pyi @@ -0,0 +1,71 @@ +import abc +from typing import Any, ClassVar + +__all__: list[str] + +class ABCPolyBase(abc.ABC): + __hash__: ClassVar[None] # type: ignore[assignment] + __array_ufunc__: ClassVar[None] + maxpower: ClassVar[int] + coef: Any + @property + def symbol(self) -> str: ... + @property + @abc.abstractmethod + def domain(self): ... + @property + @abc.abstractmethod + def window(self): ... + @property + @abc.abstractmethod + def basis_name(self): ... + def has_samecoef(self, other): ... + def has_samedomain(self, other): ... + def has_samewindow(self, other): ... + def has_sametype(self, other): ... + def __init__(self, coef, domain=..., window=..., symbol: str = ...) -> None: ... + def __format__(self, fmt_str): ... + def __call__(self, arg): ... + def __iter__(self): ... + def __len__(self): ... + def __neg__(self): ... + def __pos__(self): ... + def __add__(self, other): ... + def __sub__(self, other): ... + def __mul__(self, other): ... + def __truediv__(self, other): ... + def __floordiv__(self, other): ... + def __mod__(self, other): ... + def __divmod__(self, other): ... + def __pow__(self, other): ... + def __radd__(self, other): ... + def __rsub__(self, other): ... + def __rmul__(self, other): ... + def __rdiv__(self, other): ... + def __rtruediv__(self, other): ... + def __rfloordiv__(self, other): ... + def __rmod__(self, other): ... + def __rdivmod__(self, other): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def copy(self): ... + def degree(self): ... + def cutdeg(self, deg): ... + def trim(self, tol=...): ... + def truncate(self, size): ... + def convert(self, domain=..., kind=..., window=...): ... + def mapparms(self): ... + def integ(self, m=..., k = ..., lbnd=...): ... + def deriv(self, m=...): ... + def roots(self): ... + def linspace(self, n=..., domain=...): ... + @classmethod + def fit(cls, x, y, deg, domain=..., rcond=..., full=..., w=..., window=...): ... + @classmethod + def fromroots(cls, roots, domain = ..., window=...): ... + @classmethod + def identity(cls, domain=..., window=...): ... + @classmethod + def basis(cls, deg, domain=..., window=...): ... + @classmethod + def cast(cls, series, domain=..., window=...): ... diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/chebyshev.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/chebyshev.py new file mode 100644 index 0000000000000000000000000000000000000000..efbe13e0cadb27e29bea430a858dea5110621a0c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/chebyshev.py @@ -0,0 +1,2082 @@ +""" +==================================================== +Chebyshev Series (:mod:`numpy.polynomial.chebyshev`) +==================================================== + +This module provides a number of objects (mostly functions) useful for +dealing with Chebyshev series, including a `Chebyshev` class that +encapsulates the usual arithmetic operations. (General information +on how this module represents and works with such polynomials is in the +docstring for its "parent" sub-package, `numpy.polynomial`). + +Classes +------- + +.. autosummary:: + :toctree: generated/ + + Chebyshev + + +Constants +--------- + +.. autosummary:: + :toctree: generated/ + + chebdomain + chebzero + chebone + chebx + +Arithmetic +---------- + +.. autosummary:: + :toctree: generated/ + + chebadd + chebsub + chebmulx + chebmul + chebdiv + chebpow + chebval + chebval2d + chebval3d + chebgrid2d + chebgrid3d + +Calculus +-------- + +.. autosummary:: + :toctree: generated/ + + chebder + chebint + +Misc Functions +-------------- + +.. autosummary:: + :toctree: generated/ + + chebfromroots + chebroots + chebvander + chebvander2d + chebvander3d + chebgauss + chebweight + chebcompanion + chebfit + chebpts1 + chebpts2 + chebtrim + chebline + cheb2poly + poly2cheb + chebinterpolate + +See also +-------- +`numpy.polynomial` + +Notes +----- +The implementations of multiplication, division, integration, and +differentiation use the algebraic identities [1]_: + +.. math:: + T_n(x) = \\frac{z^n + z^{-n}}{2} \\\\ + z\\frac{dx}{dz} = \\frac{z - z^{-1}}{2}. + +where + +.. math:: x = \\frac{z + z^{-1}}{2}. + +These identities allow a Chebyshev series to be expressed as a finite, +symmetric Laurent series. In this module, this sort of Laurent series +is referred to as a "z-series." + +References +---------- +.. [1] A. T. Benjamin, et al., "Combinatorial Trigonometry with Chebyshev + Polynomials," *Journal of Statistical Planning and Inference 14*, 2008 + (https://web.archive.org/web/20080221202153/https://www.math.hmc.edu/~benjamin/papers/CombTrig.pdf, pg. 4) + +""" +import numpy as np +import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index + +from . import polyutils as pu +from ._polybase import ABCPolyBase + +__all__ = [ + 'chebzero', 'chebone', 'chebx', 'chebdomain', 'chebline', 'chebadd', + 'chebsub', 'chebmulx', 'chebmul', 'chebdiv', 'chebpow', 'chebval', + 'chebder', 'chebint', 'cheb2poly', 'poly2cheb', 'chebfromroots', + 'chebvander', 'chebfit', 'chebtrim', 'chebroots', 'chebpts1', + 'chebpts2', 'Chebyshev', 'chebval2d', 'chebval3d', 'chebgrid2d', + 'chebgrid3d', 'chebvander2d', 'chebvander3d', 'chebcompanion', + 'chebgauss', 'chebweight', 'chebinterpolate'] + +chebtrim = pu.trimcoef + +# +# A collection of functions for manipulating z-series. These are private +# functions and do minimal error checking. +# + +def _cseries_to_zseries(c): + """Convert Chebyshev series to z-series. + + Convert a Chebyshev series to the equivalent z-series. The result is + never an empty array. The dtype of the return is the same as that of + the input. No checks are run on the arguments as this routine is for + internal use. + + Parameters + ---------- + c : 1-D ndarray + Chebyshev coefficients, ordered from low to high + + Returns + ------- + zs : 1-D ndarray + Odd length symmetric z-series, ordered from low to high. + + """ + n = c.size + zs = np.zeros(2*n-1, dtype=c.dtype) + zs[n-1:] = c/2 + return zs + zs[::-1] + + +def _zseries_to_cseries(zs): + """Convert z-series to a Chebyshev series. + + Convert a z series to the equivalent Chebyshev series. The result is + never an empty array. The dtype of the return is the same as that of + the input. No checks are run on the arguments as this routine is for + internal use. + + Parameters + ---------- + zs : 1-D ndarray + Odd length symmetric z-series, ordered from low to high. + + Returns + ------- + c : 1-D ndarray + Chebyshev coefficients, ordered from low to high. + + """ + n = (zs.size + 1)//2 + c = zs[n-1:].copy() + c[1:n] *= 2 + return c + + +def _zseries_mul(z1, z2): + """Multiply two z-series. + + Multiply two z-series to produce a z-series. + + Parameters + ---------- + z1, z2 : 1-D ndarray + The arrays must be 1-D but this is not checked. + + Returns + ------- + product : 1-D ndarray + The product z-series. + + Notes + ----- + This is simply convolution. If symmetric/anti-symmetric z-series are + denoted by S/A then the following rules apply: + + S*S, A*A -> S + S*A, A*S -> A + + """ + return np.convolve(z1, z2) + + +def _zseries_div(z1, z2): + """Divide the first z-series by the second. + + Divide `z1` by `z2` and return the quotient and remainder as z-series. + Warning: this implementation only applies when both z1 and z2 have the + same symmetry, which is sufficient for present purposes. + + Parameters + ---------- + z1, z2 : 1-D ndarray + The arrays must be 1-D and have the same symmetry, but this is not + checked. + + Returns + ------- + + (quotient, remainder) : 1-D ndarrays + Quotient and remainder as z-series. + + Notes + ----- + This is not the same as polynomial division on account of the desired form + of the remainder. If symmetric/anti-symmetric z-series are denoted by S/A + then the following rules apply: + + S/S -> S,S + A/A -> S,A + + The restriction to types of the same symmetry could be fixed but seems like + unneeded generality. There is no natural form for the remainder in the case + where there is no symmetry. + + """ + z1 = z1.copy() + z2 = z2.copy() + lc1 = len(z1) + lc2 = len(z2) + if lc2 == 1: + z1 /= z2 + return z1, z1[:1]*0 + elif lc1 < lc2: + return z1[:1]*0, z1 + else: + dlen = lc1 - lc2 + scl = z2[0] + z2 /= scl + quo = np.empty(dlen + 1, dtype=z1.dtype) + i = 0 + j = dlen + while i < j: + r = z1[i] + quo[i] = z1[i] + quo[dlen - i] = r + tmp = r*z2 + z1[i:i+lc2] -= tmp + z1[j:j+lc2] -= tmp + i += 1 + j -= 1 + r = z1[i] + quo[i] = r + tmp = r*z2 + z1[i:i+lc2] -= tmp + quo /= scl + rem = z1[i+1:i-1+lc2].copy() + return quo, rem + + +def _zseries_der(zs): + """Differentiate a z-series. + + The derivative is with respect to x, not z. This is achieved using the + chain rule and the value of dx/dz given in the module notes. + + Parameters + ---------- + zs : z-series + The z-series to differentiate. + + Returns + ------- + derivative : z-series + The derivative + + Notes + ----- + The zseries for x (ns) has been multiplied by two in order to avoid + using floats that are incompatible with Decimal and likely other + specialized scalar types. This scaling has been compensated by + multiplying the value of zs by two also so that the two cancels in the + division. + + """ + n = len(zs)//2 + ns = np.array([-1, 0, 1], dtype=zs.dtype) + zs *= np.arange(-n, n+1)*2 + d, r = _zseries_div(zs, ns) + return d + + +def _zseries_int(zs): + """Integrate a z-series. + + The integral is with respect to x, not z. This is achieved by a change + of variable using dx/dz given in the module notes. + + Parameters + ---------- + zs : z-series + The z-series to integrate + + Returns + ------- + integral : z-series + The indefinite integral + + Notes + ----- + The zseries for x (ns) has been multiplied by two in order to avoid + using floats that are incompatible with Decimal and likely other + specialized scalar types. This scaling has been compensated by + dividing the resulting zs by two. + + """ + n = 1 + len(zs)//2 + ns = np.array([-1, 0, 1], dtype=zs.dtype) + zs = _zseries_mul(zs, ns) + div = np.arange(-n, n+1)*2 + zs[:n] /= div[:n] + zs[n+1:] /= div[n+1:] + zs[n] = 0 + return zs + +# +# Chebyshev series functions +# + + +def poly2cheb(pol): + """ + Convert a polynomial to a Chebyshev series. + + Convert an array representing the coefficients of a polynomial (relative + to the "standard" basis) ordered from lowest degree to highest, to an + array of the coefficients of the equivalent Chebyshev series, ordered + from lowest to highest degree. + + Parameters + ---------- + pol : array_like + 1-D array containing the polynomial coefficients + + Returns + ------- + c : ndarray + 1-D array containing the coefficients of the equivalent Chebyshev + series. + + See Also + -------- + cheb2poly + + Notes + ----- + The easy way to do conversions between polynomial basis sets + is to use the convert method of a class instance. + + Examples + -------- + >>> from numpy import polynomial as P + >>> p = P.Polynomial(range(4)) + >>> p + Polynomial([0., 1., 2., 3.], domain=[-1, 1], window=[-1, 1]) + >>> c = p.convert(kind=P.Chebyshev) + >>> c + Chebyshev([1. , 3.25, 1. , 0.75], domain=[-1., 1.], window=[-1., 1.]) + >>> P.chebyshev.poly2cheb(range(4)) + array([1. , 3.25, 1. , 0.75]) + + """ + [pol] = pu.as_series([pol]) + deg = len(pol) - 1 + res = 0 + for i in range(deg, -1, -1): + res = chebadd(chebmulx(res), pol[i]) + return res + + +def cheb2poly(c): + """ + Convert a Chebyshev series to a polynomial. + + Convert an array representing the coefficients of a Chebyshev series, + ordered from lowest degree to highest, to an array of the coefficients + of the equivalent polynomial (relative to the "standard" basis) ordered + from lowest to highest degree. + + Parameters + ---------- + c : array_like + 1-D array containing the Chebyshev series coefficients, ordered + from lowest order term to highest. + + Returns + ------- + pol : ndarray + 1-D array containing the coefficients of the equivalent polynomial + (relative to the "standard" basis) ordered from lowest order term + to highest. + + See Also + -------- + poly2cheb + + Notes + ----- + The easy way to do conversions between polynomial basis sets + is to use the convert method of a class instance. + + Examples + -------- + >>> from numpy import polynomial as P + >>> c = P.Chebyshev(range(4)) + >>> c + Chebyshev([0., 1., 2., 3.], domain=[-1, 1], window=[-1, 1]) + >>> p = c.convert(kind=P.Polynomial) + >>> p + Polynomial([-2., -8., 4., 12.], domain=[-1., 1.], window=[-1., 1.]) + >>> P.chebyshev.cheb2poly(range(4)) + array([-2., -8., 4., 12.]) + + """ + from .polynomial import polyadd, polysub, polymulx + + [c] = pu.as_series([c]) + n = len(c) + if n < 3: + return c + else: + c0 = c[-2] + c1 = c[-1] + # i is the current degree of c1 + for i in range(n - 1, 1, -1): + tmp = c0 + c0 = polysub(c[i - 2], c1) + c1 = polyadd(tmp, polymulx(c1)*2) + return polyadd(c0, polymulx(c1)) + + +# +# These are constant arrays are of integer type so as to be compatible +# with the widest range of other types, such as Decimal. +# + +# Chebyshev default domain. +chebdomain = np.array([-1, 1]) + +# Chebyshev coefficients representing zero. +chebzero = np.array([0]) + +# Chebyshev coefficients representing one. +chebone = np.array([1]) + +# Chebyshev coefficients representing the identity x. +chebx = np.array([0, 1]) + + +def chebline(off, scl): + """ + Chebyshev series whose graph is a straight line. + + Parameters + ---------- + off, scl : scalars + The specified line is given by ``off + scl*x``. + + Returns + ------- + y : ndarray + This module's representation of the Chebyshev series for + ``off + scl*x``. + + See Also + -------- + numpy.polynomial.polynomial.polyline + numpy.polynomial.legendre.legline + numpy.polynomial.laguerre.lagline + numpy.polynomial.hermite.hermline + numpy.polynomial.hermite_e.hermeline + + Examples + -------- + >>> import numpy.polynomial.chebyshev as C + >>> C.chebline(3,2) + array([3, 2]) + >>> C.chebval(-3, C.chebline(3,2)) # should be -3 + -3.0 + + """ + if scl != 0: + return np.array([off, scl]) + else: + return np.array([off]) + + +def chebfromroots(roots): + """ + Generate a Chebyshev series with given roots. + + The function returns the coefficients of the polynomial + + .. math:: p(x) = (x - r_0) * (x - r_1) * ... * (x - r_n), + + in Chebyshev form, where the `r_n` are the roots specified in `roots`. + If a zero has multiplicity n, then it must appear in `roots` n times. + For instance, if 2 is a root of multiplicity three and 3 is a root of + multiplicity 2, then `roots` looks something like [2, 2, 2, 3, 3]. The + roots can appear in any order. + + If the returned coefficients are `c`, then + + .. math:: p(x) = c_0 + c_1 * T_1(x) + ... + c_n * T_n(x) + + The coefficient of the last term is not generally 1 for monic + polynomials in Chebyshev form. + + Parameters + ---------- + roots : array_like + Sequence containing the roots. + + Returns + ------- + out : ndarray + 1-D array of coefficients. If all roots are real then `out` is a + real array, if some of the roots are complex, then `out` is complex + even if all the coefficients in the result are real (see Examples + below). + + See Also + -------- + numpy.polynomial.polynomial.polyfromroots + numpy.polynomial.legendre.legfromroots + numpy.polynomial.laguerre.lagfromroots + numpy.polynomial.hermite.hermfromroots + numpy.polynomial.hermite_e.hermefromroots + + Examples + -------- + >>> import numpy.polynomial.chebyshev as C + >>> C.chebfromroots((-1,0,1)) # x^3 - x relative to the standard basis + array([ 0. , -0.25, 0. , 0.25]) + >>> j = complex(0,1) + >>> C.chebfromroots((-j,j)) # x^2 + 1 relative to the standard basis + array([1.5+0.j, 0. +0.j, 0.5+0.j]) + + """ + return pu._fromroots(chebline, chebmul, roots) + + +def chebadd(c1, c2): + """ + Add one Chebyshev series to another. + + Returns the sum of two Chebyshev series `c1` + `c2`. The arguments + are sequences of coefficients ordered from lowest order term to + highest, i.e., [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of Chebyshev series coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Array representing the Chebyshev series of their sum. + + See Also + -------- + chebsub, chebmulx, chebmul, chebdiv, chebpow + + Notes + ----- + Unlike multiplication, division, etc., the sum of two Chebyshev series + is a Chebyshev series (without having to "reproject" the result onto + the basis set) so addition, just like that of "standard" polynomials, + is simply "component-wise." + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> C.chebadd(c1,c2) + array([4., 4., 4.]) + + """ + return pu._add(c1, c2) + + +def chebsub(c1, c2): + """ + Subtract one Chebyshev series from another. + + Returns the difference of two Chebyshev series `c1` - `c2`. The + sequences of coefficients are from lowest order term to highest, i.e., + [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of Chebyshev series coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Of Chebyshev series coefficients representing their difference. + + See Also + -------- + chebadd, chebmulx, chebmul, chebdiv, chebpow + + Notes + ----- + Unlike multiplication, division, etc., the difference of two Chebyshev + series is a Chebyshev series (without having to "reproject" the result + onto the basis set) so subtraction, just like that of "standard" + polynomials, is simply "component-wise." + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> C.chebsub(c1,c2) + array([-2., 0., 2.]) + >>> C.chebsub(c2,c1) # -C.chebsub(c1,c2) + array([ 2., 0., -2.]) + + """ + return pu._sub(c1, c2) + + +def chebmulx(c): + """Multiply a Chebyshev series by x. + + Multiply the polynomial `c` by x, where x is the independent + variable. + + + Parameters + ---------- + c : array_like + 1-D array of Chebyshev series coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Array representing the result of the multiplication. + + Notes + ----- + + .. versionadded:: 1.5.0 + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> C.chebmulx([1,2,3]) + array([1. , 2.5, 1. , 1.5]) + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + # The zero series needs special treatment + if len(c) == 1 and c[0] == 0: + return c + + prd = np.empty(len(c) + 1, dtype=c.dtype) + prd[0] = c[0]*0 + prd[1] = c[0] + if len(c) > 1: + tmp = c[1:]/2 + prd[2:] = tmp + prd[0:-2] += tmp + return prd + + +def chebmul(c1, c2): + """ + Multiply one Chebyshev series by another. + + Returns the product of two Chebyshev series `c1` * `c2`. The arguments + are sequences of coefficients, from lowest order "term" to highest, + e.g., [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of Chebyshev series coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Of Chebyshev series coefficients representing their product. + + See Also + -------- + chebadd, chebsub, chebmulx, chebdiv, chebpow + + Notes + ----- + In general, the (polynomial) product of two C-series results in terms + that are not in the Chebyshev polynomial basis set. Thus, to express + the product as a C-series, it is typically necessary to "reproject" + the product onto said basis set, which typically produces + "unintuitive live" (but correct) results; see Examples section below. + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> C.chebmul(c1,c2) # multiplication requires "reprojection" + array([ 6.5, 12. , 12. , 4. , 1.5]) + + """ + # c1, c2 are trimmed copies + [c1, c2] = pu.as_series([c1, c2]) + z1 = _cseries_to_zseries(c1) + z2 = _cseries_to_zseries(c2) + prd = _zseries_mul(z1, z2) + ret = _zseries_to_cseries(prd) + return pu.trimseq(ret) + + +def chebdiv(c1, c2): + """ + Divide one Chebyshev series by another. + + Returns the quotient-with-remainder of two Chebyshev series + `c1` / `c2`. The arguments are sequences of coefficients from lowest + order "term" to highest, e.g., [1,2,3] represents the series + ``T_0 + 2*T_1 + 3*T_2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of Chebyshev series coefficients ordered from low to + high. + + Returns + ------- + [quo, rem] : ndarrays + Of Chebyshev series coefficients representing the quotient and + remainder. + + See Also + -------- + chebadd, chebsub, chebmulx, chebmul, chebpow + + Notes + ----- + In general, the (polynomial) division of one C-series by another + results in quotient and remainder terms that are not in the Chebyshev + polynomial basis set. Thus, to express these results as C-series, it + is typically necessary to "reproject" the results onto said basis + set, which typically produces "unintuitive" (but correct) results; + see Examples section below. + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> C.chebdiv(c1,c2) # quotient "intuitive," remainder not + (array([3.]), array([-8., -4.])) + >>> c2 = (0,1,2,3) + >>> C.chebdiv(c2,c1) # neither "intuitive" + (array([0., 2.]), array([-2., -4.])) + + """ + # c1, c2 are trimmed copies + [c1, c2] = pu.as_series([c1, c2]) + if c2[-1] == 0: + raise ZeroDivisionError() + + # note: this is more efficient than `pu._div(chebmul, c1, c2)` + lc1 = len(c1) + lc2 = len(c2) + if lc1 < lc2: + return c1[:1]*0, c1 + elif lc2 == 1: + return c1/c2[-1], c1[:1]*0 + else: + z1 = _cseries_to_zseries(c1) + z2 = _cseries_to_zseries(c2) + quo, rem = _zseries_div(z1, z2) + quo = pu.trimseq(_zseries_to_cseries(quo)) + rem = pu.trimseq(_zseries_to_cseries(rem)) + return quo, rem + + +def chebpow(c, pow, maxpower=16): + """Raise a Chebyshev series to a power. + + Returns the Chebyshev series `c` raised to the power `pow`. The + argument `c` is a sequence of coefficients ordered from low to high. + i.e., [1,2,3] is the series ``T_0 + 2*T_1 + 3*T_2.`` + + Parameters + ---------- + c : array_like + 1-D array of Chebyshev series coefficients ordered from low to + high. + pow : integer + Power to which the series will be raised + maxpower : integer, optional + Maximum power allowed. This is mainly to limit growth of the series + to unmanageable size. Default is 16 + + Returns + ------- + coef : ndarray + Chebyshev series of power. + + See Also + -------- + chebadd, chebsub, chebmulx, chebmul, chebdiv + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> C.chebpow([1, 2, 3, 4], 2) + array([15.5, 22. , 16. , ..., 12.5, 12. , 8. ]) + + """ + # note: this is more efficient than `pu._pow(chebmul, c1, c2)`, as it + # avoids converting between z and c series repeatedly + + # c is a trimmed copy + [c] = pu.as_series([c]) + power = int(pow) + if power != pow or power < 0: + raise ValueError("Power must be a non-negative integer.") + elif maxpower is not None and power > maxpower: + raise ValueError("Power is too large") + elif power == 0: + return np.array([1], dtype=c.dtype) + elif power == 1: + return c + else: + # This can be made more efficient by using powers of two + # in the usual way. + zs = _cseries_to_zseries(c) + prd = zs + for i in range(2, power + 1): + prd = np.convolve(prd, zs) + return _zseries_to_cseries(prd) + + +def chebder(c, m=1, scl=1, axis=0): + """ + Differentiate a Chebyshev series. + + Returns the Chebyshev series coefficients `c` differentiated `m` times + along `axis`. At each iteration the result is multiplied by `scl` (the + scaling factor is for use in a linear change of variable). The argument + `c` is an array of coefficients from low to high degree along each + axis, e.g., [1,2,3] represents the series ``1*T_0 + 2*T_1 + 3*T_2`` + while [[1,2],[1,2]] represents ``1*T_0(x)*T_0(y) + 1*T_1(x)*T_0(y) + + 2*T_0(x)*T_1(y) + 2*T_1(x)*T_1(y)`` if axis=0 is ``x`` and axis=1 is + ``y``. + + Parameters + ---------- + c : array_like + Array of Chebyshev series coefficients. If c is multidimensional + the different axis correspond to different variables with the + degree in each axis given by the corresponding index. + m : int, optional + Number of derivatives taken, must be non-negative. (Default: 1) + scl : scalar, optional + Each differentiation is multiplied by `scl`. The end result is + multiplication by ``scl**m``. This is for use in a linear change of + variable. (Default: 1) + axis : int, optional + Axis over which the derivative is taken. (Default: 0). + + .. versionadded:: 1.7.0 + + Returns + ------- + der : ndarray + Chebyshev series of the derivative. + + See Also + -------- + chebint + + Notes + ----- + In general, the result of differentiating a C-series needs to be + "reprojected" onto the C-series basis set. Thus, typically, the + result of this function is "unintuitive," albeit correct; see Examples + section below. + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c = (1,2,3,4) + >>> C.chebder(c) + array([14., 12., 24.]) + >>> C.chebder(c,3) + array([96.]) + >>> C.chebder(c,scl=-1) + array([-14., -12., -24.]) + >>> C.chebder(c,2,-1) + array([12., 96.]) + + """ + c = np.array(c, ndmin=1, copy=True) + if c.dtype.char in '?bBhHiIlLqQpP': + c = c.astype(np.double) + cnt = pu._deprecate_as_int(m, "the order of derivation") + iaxis = pu._deprecate_as_int(axis, "the axis") + if cnt < 0: + raise ValueError("The order of derivation must be non-negative") + iaxis = normalize_axis_index(iaxis, c.ndim) + + if cnt == 0: + return c + + c = np.moveaxis(c, iaxis, 0) + n = len(c) + if cnt >= n: + c = c[:1]*0 + else: + for i in range(cnt): + n = n - 1 + c *= scl + der = np.empty((n,) + c.shape[1:], dtype=c.dtype) + for j in range(n, 2, -1): + der[j - 1] = (2*j)*c[j] + c[j - 2] += (j*c[j])/(j - 2) + if n > 1: + der[1] = 4*c[2] + der[0] = c[1] + c = der + c = np.moveaxis(c, 0, iaxis) + return c + + +def chebint(c, m=1, k=[], lbnd=0, scl=1, axis=0): + """ + Integrate a Chebyshev series. + + Returns the Chebyshev series coefficients `c` integrated `m` times from + `lbnd` along `axis`. At each iteration the resulting series is + **multiplied** by `scl` and an integration constant, `k`, is added. + The scaling factor is for use in a linear change of variable. ("Buyer + beware": note that, depending on what one is doing, one may want `scl` + to be the reciprocal of what one might expect; for more information, + see the Notes section below.) The argument `c` is an array of + coefficients from low to high degree along each axis, e.g., [1,2,3] + represents the series ``T_0 + 2*T_1 + 3*T_2`` while [[1,2],[1,2]] + represents ``1*T_0(x)*T_0(y) + 1*T_1(x)*T_0(y) + 2*T_0(x)*T_1(y) + + 2*T_1(x)*T_1(y)`` if axis=0 is ``x`` and axis=1 is ``y``. + + Parameters + ---------- + c : array_like + Array of Chebyshev series coefficients. If c is multidimensional + the different axis correspond to different variables with the + degree in each axis given by the corresponding index. + m : int, optional + Order of integration, must be positive. (Default: 1) + k : {[], list, scalar}, optional + Integration constant(s). The value of the first integral at zero + is the first value in the list, the value of the second integral + at zero is the second value, etc. If ``k == []`` (the default), + all constants are set to zero. If ``m == 1``, a single scalar can + be given instead of a list. + lbnd : scalar, optional + The lower bound of the integral. (Default: 0) + scl : scalar, optional + Following each integration the result is *multiplied* by `scl` + before the integration constant is added. (Default: 1) + axis : int, optional + Axis over which the integral is taken. (Default: 0). + + .. versionadded:: 1.7.0 + + Returns + ------- + S : ndarray + C-series coefficients of the integral. + + Raises + ------ + ValueError + If ``m < 1``, ``len(k) > m``, ``np.ndim(lbnd) != 0``, or + ``np.ndim(scl) != 0``. + + See Also + -------- + chebder + + Notes + ----- + Note that the result of each integration is *multiplied* by `scl`. + Why is this important to note? Say one is making a linear change of + variable :math:`u = ax + b` in an integral relative to `x`. Then + :math:`dx = du/a`, so one will need to set `scl` equal to + :math:`1/a`- perhaps not what one would have first thought. + + Also note that, in general, the result of integrating a C-series needs + to be "reprojected" onto the C-series basis set. Thus, typically, + the result of this function is "unintuitive," albeit correct; see + Examples section below. + + Examples + -------- + >>> from numpy.polynomial import chebyshev as C + >>> c = (1,2,3) + >>> C.chebint(c) + array([ 0.5, -0.5, 0.5, 0.5]) + >>> C.chebint(c,3) + array([ 0.03125 , -0.1875 , 0.04166667, -0.05208333, 0.01041667, # may vary + 0.00625 ]) + >>> C.chebint(c, k=3) + array([ 3.5, -0.5, 0.5, 0.5]) + >>> C.chebint(c,lbnd=-2) + array([ 8.5, -0.5, 0.5, 0.5]) + >>> C.chebint(c,scl=-2) + array([-1., 1., -1., -1.]) + + """ + c = np.array(c, ndmin=1, copy=True) + if c.dtype.char in '?bBhHiIlLqQpP': + c = c.astype(np.double) + if not np.iterable(k): + k = [k] + cnt = pu._deprecate_as_int(m, "the order of integration") + iaxis = pu._deprecate_as_int(axis, "the axis") + if cnt < 0: + raise ValueError("The order of integration must be non-negative") + if len(k) > cnt: + raise ValueError("Too many integration constants") + if np.ndim(lbnd) != 0: + raise ValueError("lbnd must be a scalar.") + if np.ndim(scl) != 0: + raise ValueError("scl must be a scalar.") + iaxis = normalize_axis_index(iaxis, c.ndim) + + if cnt == 0: + return c + + c = np.moveaxis(c, iaxis, 0) + k = list(k) + [0]*(cnt - len(k)) + for i in range(cnt): + n = len(c) + c *= scl + if n == 1 and np.all(c[0] == 0): + c[0] += k[i] + else: + tmp = np.empty((n + 1,) + c.shape[1:], dtype=c.dtype) + tmp[0] = c[0]*0 + tmp[1] = c[0] + if n > 1: + tmp[2] = c[1]/4 + for j in range(2, n): + tmp[j + 1] = c[j]/(2*(j + 1)) + tmp[j - 1] -= c[j]/(2*(j - 1)) + tmp[0] += k[i] - chebval(lbnd, tmp) + c = tmp + c = np.moveaxis(c, 0, iaxis) + return c + + +def chebval(x, c, tensor=True): + """ + Evaluate a Chebyshev series at points x. + + If `c` is of length `n + 1`, this function returns the value: + + .. math:: p(x) = c_0 * T_0(x) + c_1 * T_1(x) + ... + c_n * T_n(x) + + The parameter `x` is converted to an array only if it is a tuple or a + list, otherwise it is treated as a scalar. In either case, either `x` + or its elements must support multiplication and addition both with + themselves and with the elements of `c`. + + If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If + `c` is multidimensional, then the shape of the result depends on the + value of `tensor`. If `tensor` is true the shape will be c.shape[1:] + + x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that + scalars have shape (,). + + Trailing zeros in the coefficients will be used in the evaluation, so + they should be avoided if efficiency is a concern. + + Parameters + ---------- + x : array_like, compatible object + If `x` is a list or tuple, it is converted to an ndarray, otherwise + it is left unchanged and treated as a scalar. In either case, `x` + or its elements must support addition and multiplication with + themselves and with the elements of `c`. + c : array_like + Array of coefficients ordered so that the coefficients for terms of + degree n are contained in c[n]. If `c` is multidimensional the + remaining indices enumerate multiple polynomials. In the two + dimensional case the coefficients may be thought of as stored in + the columns of `c`. + tensor : boolean, optional + If True, the shape of the coefficient array is extended with ones + on the right, one for each dimension of `x`. Scalars have dimension 0 + for this action. The result is that every column of coefficients in + `c` is evaluated for every element of `x`. If False, `x` is broadcast + over the columns of `c` for the evaluation. This keyword is useful + when `c` is multidimensional. The default value is True. + + .. versionadded:: 1.7.0 + + Returns + ------- + values : ndarray, algebra_like + The shape of the return value is described above. + + See Also + -------- + chebval2d, chebgrid2d, chebval3d, chebgrid3d + + Notes + ----- + The evaluation uses Clenshaw recursion, aka synthetic division. + + """ + c = np.array(c, ndmin=1, copy=True) + if c.dtype.char in '?bBhHiIlLqQpP': + c = c.astype(np.double) + if isinstance(x, (tuple, list)): + x = np.asarray(x) + if isinstance(x, np.ndarray) and tensor: + c = c.reshape(c.shape + (1,)*x.ndim) + + if len(c) == 1: + c0 = c[0] + c1 = 0 + elif len(c) == 2: + c0 = c[0] + c1 = c[1] + else: + x2 = 2*x + c0 = c[-2] + c1 = c[-1] + for i in range(3, len(c) + 1): + tmp = c0 + c0 = c[-i] - c1 + c1 = tmp + c1*x2 + return c0 + c1*x + + +def chebval2d(x, y, c): + """ + Evaluate a 2-D Chebyshev series at points (x, y). + + This function returns the values: + + .. math:: p(x,y) = \\sum_{i,j} c_{i,j} * T_i(x) * T_j(y) + + The parameters `x` and `y` are converted to arrays only if they are + tuples or a lists, otherwise they are treated as a scalars and they + must have the same shape after conversion. In either case, either `x` + and `y` or their elements must support multiplication and addition both + with themselves and with the elements of `c`. + + If `c` is a 1-D array a one is implicitly appended to its shape to make + it 2-D. The shape of the result will be c.shape[2:] + x.shape. + + Parameters + ---------- + x, y : array_like, compatible objects + The two dimensional series is evaluated at the points `(x, y)`, + where `x` and `y` must have the same shape. If `x` or `y` is a list + or tuple, it is first converted to an ndarray, otherwise it is left + unchanged and if it isn't an ndarray it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficient of the term + of multi-degree i,j is contained in ``c[i,j]``. If `c` has + dimension greater than 2 the remaining indices enumerate multiple + sets of coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional Chebyshev series at points formed + from pairs of corresponding values from `x` and `y`. + + See Also + -------- + chebval, chebgrid2d, chebval3d, chebgrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._valnd(chebval, c, x, y) + + +def chebgrid2d(x, y, c): + """ + Evaluate a 2-D Chebyshev series on the Cartesian product of x and y. + + This function returns the values: + + .. math:: p(a,b) = \\sum_{i,j} c_{i,j} * T_i(a) * T_j(b), + + where the points `(a, b)` consist of all pairs formed by taking + `a` from `x` and `b` from `y`. The resulting points form a grid with + `x` in the first dimension and `y` in the second. + + The parameters `x` and `y` are converted to arrays only if they are + tuples or a lists, otherwise they are treated as a scalars. In either + case, either `x` and `y` or their elements must support multiplication + and addition both with themselves and with the elements of `c`. + + If `c` has fewer than two dimensions, ones are implicitly appended to + its shape to make it 2-D. The shape of the result will be c.shape[2:] + + x.shape + y.shape. + + Parameters + ---------- + x, y : array_like, compatible objects + The two dimensional series is evaluated at the points in the + Cartesian product of `x` and `y`. If `x` or `y` is a list or + tuple, it is first converted to an ndarray, otherwise it is left + unchanged and, if it isn't an ndarray, it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficient of the term of + multi-degree i,j is contained in `c[i,j]`. If `c` has dimension + greater than two the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional Chebyshev series at points in the + Cartesian product of `x` and `y`. + + See Also + -------- + chebval, chebval2d, chebval3d, chebgrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._gridnd(chebval, c, x, y) + + +def chebval3d(x, y, z, c): + """ + Evaluate a 3-D Chebyshev series at points (x, y, z). + + This function returns the values: + + .. math:: p(x,y,z) = \\sum_{i,j,k} c_{i,j,k} * T_i(x) * T_j(y) * T_k(z) + + The parameters `x`, `y`, and `z` are converted to arrays only if + they are tuples or a lists, otherwise they are treated as a scalars and + they must have the same shape after conversion. In either case, either + `x`, `y`, and `z` or their elements must support multiplication and + addition both with themselves and with the elements of `c`. + + If `c` has fewer than 3 dimensions, ones are implicitly appended to its + shape to make it 3-D. The shape of the result will be c.shape[3:] + + x.shape. + + Parameters + ---------- + x, y, z : array_like, compatible object + The three dimensional series is evaluated at the points + `(x, y, z)`, where `x`, `y`, and `z` must have the same shape. If + any of `x`, `y`, or `z` is a list or tuple, it is first converted + to an ndarray, otherwise it is left unchanged and if it isn't an + ndarray it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficient of the term of + multi-degree i,j,k is contained in ``c[i,j,k]``. If `c` has dimension + greater than 3 the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the multidimensional polynomial on points formed with + triples of corresponding values from `x`, `y`, and `z`. + + See Also + -------- + chebval, chebval2d, chebgrid2d, chebgrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._valnd(chebval, c, x, y, z) + + +def chebgrid3d(x, y, z, c): + """ + Evaluate a 3-D Chebyshev series on the Cartesian product of x, y, and z. + + This function returns the values: + + .. math:: p(a,b,c) = \\sum_{i,j,k} c_{i,j,k} * T_i(a) * T_j(b) * T_k(c) + + where the points `(a, b, c)` consist of all triples formed by taking + `a` from `x`, `b` from `y`, and `c` from `z`. The resulting points form + a grid with `x` in the first dimension, `y` in the second, and `z` in + the third. + + The parameters `x`, `y`, and `z` are converted to arrays only if they + are tuples or a lists, otherwise they are treated as a scalars. In + either case, either `x`, `y`, and `z` or their elements must support + multiplication and addition both with themselves and with the elements + of `c`. + + If `c` has fewer than three dimensions, ones are implicitly appended to + its shape to make it 3-D. The shape of the result will be c.shape[3:] + + x.shape + y.shape + z.shape. + + Parameters + ---------- + x, y, z : array_like, compatible objects + The three dimensional series is evaluated at the points in the + Cartesian product of `x`, `y`, and `z`. If `x`,`y`, or `z` is a + list or tuple, it is first converted to an ndarray, otherwise it is + left unchanged and, if it isn't an ndarray, it is treated as a + scalar. + c : array_like + Array of coefficients ordered so that the coefficients for terms of + degree i,j are contained in ``c[i,j]``. If `c` has dimension + greater than two the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional polynomial at points in the Cartesian + product of `x` and `y`. + + See Also + -------- + chebval, chebval2d, chebgrid2d, chebval3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._gridnd(chebval, c, x, y, z) + + +def chebvander(x, deg): + """Pseudo-Vandermonde matrix of given degree. + + Returns the pseudo-Vandermonde matrix of degree `deg` and sample points + `x`. The pseudo-Vandermonde matrix is defined by + + .. math:: V[..., i] = T_i(x), + + where `0 <= i <= deg`. The leading indices of `V` index the elements of + `x` and the last index is the degree of the Chebyshev polynomial. + + If `c` is a 1-D array of coefficients of length `n + 1` and `V` is the + matrix ``V = chebvander(x, n)``, then ``np.dot(V, c)`` and + ``chebval(x, c)`` are the same up to roundoff. This equivalence is + useful both for least squares fitting and for the evaluation of a large + number of Chebyshev series of the same degree and sample points. + + Parameters + ---------- + x : array_like + Array of points. The dtype is converted to float64 or complex128 + depending on whether any of the elements are complex. If `x` is + scalar it is converted to a 1-D array. + deg : int + Degree of the resulting matrix. + + Returns + ------- + vander : ndarray + The pseudo Vandermonde matrix. The shape of the returned matrix is + ``x.shape + (deg + 1,)``, where The last index is the degree of the + corresponding Chebyshev polynomial. The dtype will be the same as + the converted `x`. + + """ + ideg = pu._deprecate_as_int(deg, "deg") + if ideg < 0: + raise ValueError("deg must be non-negative") + + x = np.array(x, copy=False, ndmin=1) + 0.0 + dims = (ideg + 1,) + x.shape + dtyp = x.dtype + v = np.empty(dims, dtype=dtyp) + # Use forward recursion to generate the entries. + v[0] = x*0 + 1 + if ideg > 0: + x2 = 2*x + v[1] = x + for i in range(2, ideg + 1): + v[i] = v[i-1]*x2 - v[i-2] + return np.moveaxis(v, 0, -1) + + +def chebvander2d(x, y, deg): + """Pseudo-Vandermonde matrix of given degrees. + + Returns the pseudo-Vandermonde matrix of degrees `deg` and sample + points `(x, y)`. The pseudo-Vandermonde matrix is defined by + + .. math:: V[..., (deg[1] + 1)*i + j] = T_i(x) * T_j(y), + + where `0 <= i <= deg[0]` and `0 <= j <= deg[1]`. The leading indices of + `V` index the points `(x, y)` and the last index encodes the degrees of + the Chebyshev polynomials. + + If ``V = chebvander2d(x, y, [xdeg, ydeg])``, then the columns of `V` + correspond to the elements of a 2-D coefficient array `c` of shape + (xdeg + 1, ydeg + 1) in the order + + .. math:: c_{00}, c_{01}, c_{02} ... , c_{10}, c_{11}, c_{12} ... + + and ``np.dot(V, c.flat)`` and ``chebval2d(x, y, c)`` will be the same + up to roundoff. This equivalence is useful both for least squares + fitting and for the evaluation of a large number of 2-D Chebyshev + series of the same degrees and sample points. + + Parameters + ---------- + x, y : array_like + Arrays of point coordinates, all of the same shape. The dtypes + will be converted to either float64 or complex128 depending on + whether any of the elements are complex. Scalars are converted to + 1-D arrays. + deg : list of ints + List of maximum degrees of the form [x_deg, y_deg]. + + Returns + ------- + vander2d : ndarray + The shape of the returned matrix is ``x.shape + (order,)``, where + :math:`order = (deg[0]+1)*(deg[1]+1)`. The dtype will be the same + as the converted `x` and `y`. + + See Also + -------- + chebvander, chebvander3d, chebval2d, chebval3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._vander_nd_flat((chebvander, chebvander), (x, y), deg) + + +def chebvander3d(x, y, z, deg): + """Pseudo-Vandermonde matrix of given degrees. + + Returns the pseudo-Vandermonde matrix of degrees `deg` and sample + points `(x, y, z)`. If `l, m, n` are the given degrees in `x, y, z`, + then The pseudo-Vandermonde matrix is defined by + + .. math:: V[..., (m+1)(n+1)i + (n+1)j + k] = T_i(x)*T_j(y)*T_k(z), + + where `0 <= i <= l`, `0 <= j <= m`, and `0 <= j <= n`. The leading + indices of `V` index the points `(x, y, z)` and the last index encodes + the degrees of the Chebyshev polynomials. + + If ``V = chebvander3d(x, y, z, [xdeg, ydeg, zdeg])``, then the columns + of `V` correspond to the elements of a 3-D coefficient array `c` of + shape (xdeg + 1, ydeg + 1, zdeg + 1) in the order + + .. math:: c_{000}, c_{001}, c_{002},... , c_{010}, c_{011}, c_{012},... + + and ``np.dot(V, c.flat)`` and ``chebval3d(x, y, z, c)`` will be the + same up to roundoff. This equivalence is useful both for least squares + fitting and for the evaluation of a large number of 3-D Chebyshev + series of the same degrees and sample points. + + Parameters + ---------- + x, y, z : array_like + Arrays of point coordinates, all of the same shape. The dtypes will + be converted to either float64 or complex128 depending on whether + any of the elements are complex. Scalars are converted to 1-D + arrays. + deg : list of ints + List of maximum degrees of the form [x_deg, y_deg, z_deg]. + + Returns + ------- + vander3d : ndarray + The shape of the returned matrix is ``x.shape + (order,)``, where + :math:`order = (deg[0]+1)*(deg[1]+1)*(deg[2]+1)`. The dtype will + be the same as the converted `x`, `y`, and `z`. + + See Also + -------- + chebvander, chebvander3d, chebval2d, chebval3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._vander_nd_flat((chebvander, chebvander, chebvander), (x, y, z), deg) + + +def chebfit(x, y, deg, rcond=None, full=False, w=None): + """ + Least squares fit of Chebyshev series to data. + + Return the coefficients of a Chebyshev series of degree `deg` that is the + least squares fit to the data values `y` given at points `x`. If `y` is + 1-D the returned coefficients will also be 1-D. If `y` is 2-D multiple + fits are done, one for each column of `y`, and the resulting + coefficients are stored in the corresponding columns of a 2-D return. + The fitted polynomial(s) are in the form + + .. math:: p(x) = c_0 + c_1 * T_1(x) + ... + c_n * T_n(x), + + where `n` is `deg`. + + Parameters + ---------- + x : array_like, shape (M,) + x-coordinates of the M sample points ``(x[i], y[i])``. + y : array_like, shape (M,) or (M, K) + y-coordinates of the sample points. Several data sets of sample + points sharing the same x-coordinates can be fitted at once by + passing in a 2D-array that contains one dataset per column. + deg : int or 1-D array_like + Degree(s) of the fitting polynomials. If `deg` is a single integer, + all terms up to and including the `deg`'th term are included in the + fit. For NumPy versions >= 1.11.0 a list of integers specifying the + degrees of the terms to include may be used instead. + rcond : float, optional + Relative condition number of the fit. Singular values smaller than + this relative to the largest singular value will be ignored. The + default value is len(x)*eps, where eps is the relative precision of + the float type, about 2e-16 in most cases. + full : bool, optional + Switch determining nature of return value. When it is False (the + default) just the coefficients are returned, when True diagnostic + information from the singular value decomposition is also returned. + w : array_like, shape (`M`,), optional + Weights. If not None, the weight ``w[i]`` applies to the unsquared + residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are + chosen so that the errors of the products ``w[i]*y[i]`` all have the + same variance. When using inverse-variance weighting, use + ``w[i] = 1/sigma(y[i])``. The default value is None. + + .. versionadded:: 1.5.0 + + Returns + ------- + coef : ndarray, shape (M,) or (M, K) + Chebyshev coefficients ordered from low to high. If `y` was 2-D, + the coefficients for the data in column k of `y` are in column + `k`. + + [residuals, rank, singular_values, rcond] : list + These values are only returned if ``full == True`` + + - residuals -- sum of squared residuals of the least squares fit + - rank -- the numerical rank of the scaled Vandermonde matrix + - singular_values -- singular values of the scaled Vandermonde matrix + - rcond -- value of `rcond`. + + For more details, see `numpy.linalg.lstsq`. + + Warns + ----- + RankWarning + The rank of the coefficient matrix in the least-squares fit is + deficient. The warning is only raised if ``full == False``. The + warnings can be turned off by + + >>> import warnings + >>> warnings.simplefilter('ignore', np.RankWarning) + + See Also + -------- + numpy.polynomial.polynomial.polyfit + numpy.polynomial.legendre.legfit + numpy.polynomial.laguerre.lagfit + numpy.polynomial.hermite.hermfit + numpy.polynomial.hermite_e.hermefit + chebval : Evaluates a Chebyshev series. + chebvander : Vandermonde matrix of Chebyshev series. + chebweight : Chebyshev weight function. + numpy.linalg.lstsq : Computes a least-squares fit from the matrix. + scipy.interpolate.UnivariateSpline : Computes spline fits. + + Notes + ----- + The solution is the coefficients of the Chebyshev series `p` that + minimizes the sum of the weighted squared errors + + .. math:: E = \\sum_j w_j^2 * |y_j - p(x_j)|^2, + + where :math:`w_j` are the weights. This problem is solved by setting up + as the (typically) overdetermined matrix equation + + .. math:: V(x) * c = w * y, + + where `V` is the weighted pseudo Vandermonde matrix of `x`, `c` are the + coefficients to be solved for, `w` are the weights, and `y` are the + observed values. This equation is then solved using the singular value + decomposition of `V`. + + If some of the singular values of `V` are so small that they are + neglected, then a `RankWarning` will be issued. This means that the + coefficient values may be poorly determined. Using a lower order fit + will usually get rid of the warning. The `rcond` parameter can also be + set to a value smaller than its default, but the resulting fit may be + spurious and have large contributions from roundoff error. + + Fits using Chebyshev series are usually better conditioned than fits + using power series, but much can depend on the distribution of the + sample points and the smoothness of the data. If the quality of the fit + is inadequate splines may be a good alternative. + + References + ---------- + .. [1] Wikipedia, "Curve fitting", + https://en.wikipedia.org/wiki/Curve_fitting + + Examples + -------- + + """ + return pu._fit(chebvander, x, y, deg, rcond, full, w) + + +def chebcompanion(c): + """Return the scaled companion matrix of c. + + The basis polynomials are scaled so that the companion matrix is + symmetric when `c` is a Chebyshev basis polynomial. This provides + better eigenvalue estimates than the unscaled case and for basis + polynomials the eigenvalues are guaranteed to be real if + `numpy.linalg.eigvalsh` is used to obtain them. + + Parameters + ---------- + c : array_like + 1-D array of Chebyshev series coefficients ordered from low to high + degree. + + Returns + ------- + mat : ndarray + Scaled companion matrix of dimensions (deg, deg). + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + if len(c) < 2: + raise ValueError('Series must have maximum degree of at least 1.') + if len(c) == 2: + return np.array([[-c[0]/c[1]]]) + + n = len(c) - 1 + mat = np.zeros((n, n), dtype=c.dtype) + scl = np.array([1.] + [np.sqrt(.5)]*(n-1)) + top = mat.reshape(-1)[1::n+1] + bot = mat.reshape(-1)[n::n+1] + top[0] = np.sqrt(.5) + top[1:] = 1/2 + bot[...] = top + mat[:, -1] -= (c[:-1]/c[-1])*(scl/scl[-1])*.5 + return mat + + +def chebroots(c): + """ + Compute the roots of a Chebyshev series. + + Return the roots (a.k.a. "zeros") of the polynomial + + .. math:: p(x) = \\sum_i c[i] * T_i(x). + + Parameters + ---------- + c : 1-D array_like + 1-D array of coefficients. + + Returns + ------- + out : ndarray + Array of the roots of the series. If all the roots are real, + then `out` is also real, otherwise it is complex. + + See Also + -------- + numpy.polynomial.polynomial.polyroots + numpy.polynomial.legendre.legroots + numpy.polynomial.laguerre.lagroots + numpy.polynomial.hermite.hermroots + numpy.polynomial.hermite_e.hermeroots + + Notes + ----- + The root estimates are obtained as the eigenvalues of the companion + matrix, Roots far from the origin of the complex plane may have large + errors due to the numerical instability of the series for such + values. Roots with multiplicity greater than 1 will also show larger + errors as the value of the series near such points is relatively + insensitive to errors in the roots. Isolated roots near the origin can + be improved by a few iterations of Newton's method. + + The Chebyshev series basis polynomials aren't powers of `x` so the + results of this function may seem unintuitive. + + Examples + -------- + >>> import numpy.polynomial.chebyshev as cheb + >>> cheb.chebroots((-1, 1,-1, 1)) # T3 - T2 + T1 - T0 has real roots + array([ -5.00000000e-01, 2.60860684e-17, 1.00000000e+00]) # may vary + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + if len(c) < 2: + return np.array([], dtype=c.dtype) + if len(c) == 2: + return np.array([-c[0]/c[1]]) + + # rotated companion matrix reduces error + m = chebcompanion(c)[::-1,::-1] + r = la.eigvals(m) + r.sort() + return r + + +def chebinterpolate(func, deg, args=()): + """Interpolate a function at the Chebyshev points of the first kind. + + Returns the Chebyshev series that interpolates `func` at the Chebyshev + points of the first kind in the interval [-1, 1]. The interpolating + series tends to a minmax approximation to `func` with increasing `deg` + if the function is continuous in the interval. + + .. versionadded:: 1.14.0 + + Parameters + ---------- + func : function + The function to be approximated. It must be a function of a single + variable of the form ``f(x, a, b, c...)``, where ``a, b, c...`` are + extra arguments passed in the `args` parameter. + deg : int + Degree of the interpolating polynomial + args : tuple, optional + Extra arguments to be used in the function call. Default is no extra + arguments. + + Returns + ------- + coef : ndarray, shape (deg + 1,) + Chebyshev coefficients of the interpolating series ordered from low to + high. + + Examples + -------- + >>> import numpy.polynomial.chebyshev as C + >>> C.chebfromfunction(lambda x: np.tanh(x) + 0.5, 8) + array([ 5.00000000e-01, 8.11675684e-01, -9.86864911e-17, + -5.42457905e-02, -2.71387850e-16, 4.51658839e-03, + 2.46716228e-17, -3.79694221e-04, -3.26899002e-16]) + + Notes + ----- + + The Chebyshev polynomials used in the interpolation are orthogonal when + sampled at the Chebyshev points of the first kind. If it is desired to + constrain some of the coefficients they can simply be set to the desired + value after the interpolation, no new interpolation or fit is needed. This + is especially useful if it is known apriori that some of coefficients are + zero. For instance, if the function is even then the coefficients of the + terms of odd degree in the result can be set to zero. + + """ + deg = np.asarray(deg) + + # check arguments. + if deg.ndim > 0 or deg.dtype.kind not in 'iu' or deg.size == 0: + raise TypeError("deg must be an int") + if deg < 0: + raise ValueError("expected deg >= 0") + + order = deg + 1 + xcheb = chebpts1(order) + yfunc = func(xcheb, *args) + m = chebvander(xcheb, deg) + c = np.dot(m.T, yfunc) + c[0] /= order + c[1:] /= 0.5*order + + return c + + +def chebgauss(deg): + """ + Gauss-Chebyshev quadrature. + + Computes the sample points and weights for Gauss-Chebyshev quadrature. + These sample points and weights will correctly integrate polynomials of + degree :math:`2*deg - 1` or less over the interval :math:`[-1, 1]` with + the weight function :math:`f(x) = 1/\\sqrt{1 - x^2}`. + + Parameters + ---------- + deg : int + Number of sample points and weights. It must be >= 1. + + Returns + ------- + x : ndarray + 1-D ndarray containing the sample points. + y : ndarray + 1-D ndarray containing the weights. + + Notes + ----- + + .. versionadded:: 1.7.0 + + The results have only been tested up to degree 100, higher degrees may + be problematic. For Gauss-Chebyshev there are closed form solutions for + the sample points and weights. If n = `deg`, then + + .. math:: x_i = \\cos(\\pi (2 i - 1) / (2 n)) + + .. math:: w_i = \\pi / n + + """ + ideg = pu._deprecate_as_int(deg, "deg") + if ideg <= 0: + raise ValueError("deg must be a positive integer") + + x = np.cos(np.pi * np.arange(1, 2*ideg, 2) / (2.0*ideg)) + w = np.ones(ideg)*(np.pi/ideg) + + return x, w + + +def chebweight(x): + """ + The weight function of the Chebyshev polynomials. + + The weight function is :math:`1/\\sqrt{1 - x^2}` and the interval of + integration is :math:`[-1, 1]`. The Chebyshev polynomials are + orthogonal, but not normalized, with respect to this weight function. + + Parameters + ---------- + x : array_like + Values at which the weight function will be computed. + + Returns + ------- + w : ndarray + The weight function at `x`. + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + w = 1./(np.sqrt(1. + x) * np.sqrt(1. - x)) + return w + + +def chebpts1(npts): + """ + Chebyshev points of the first kind. + + The Chebyshev points of the first kind are the points ``cos(x)``, + where ``x = [pi*(k + .5)/npts for k in range(npts)]``. + + Parameters + ---------- + npts : int + Number of sample points desired. + + Returns + ------- + pts : ndarray + The Chebyshev points of the first kind. + + See Also + -------- + chebpts2 + + Notes + ----- + + .. versionadded:: 1.5.0 + + """ + _npts = int(npts) + if _npts != npts: + raise ValueError("npts must be integer") + if _npts < 1: + raise ValueError("npts must be >= 1") + + x = 0.5 * np.pi / _npts * np.arange(-_npts+1, _npts+1, 2) + return np.sin(x) + + +def chebpts2(npts): + """ + Chebyshev points of the second kind. + + The Chebyshev points of the second kind are the points ``cos(x)``, + where ``x = [pi*k/(npts - 1) for k in range(npts)]`` sorted in ascending + order. + + Parameters + ---------- + npts : int + Number of sample points desired. + + Returns + ------- + pts : ndarray + The Chebyshev points of the second kind. + + Notes + ----- + + .. versionadded:: 1.5.0 + + """ + _npts = int(npts) + if _npts != npts: + raise ValueError("npts must be integer") + if _npts < 2: + raise ValueError("npts must be >= 2") + + x = np.linspace(-np.pi, 0, _npts) + return np.cos(x) + + +# +# Chebyshev series class +# + +class Chebyshev(ABCPolyBase): + """A Chebyshev series class. + + The Chebyshev class provides the standard Python numerical methods + '+', '-', '*', '//', '%', 'divmod', '**', and '()' as well as the + methods listed below. + + Parameters + ---------- + coef : array_like + Chebyshev coefficients in order of increasing degree, i.e., + ``(1, 2, 3)`` gives ``1*T_0(x) + 2*T_1(x) + 3*T_2(x)``. + domain : (2,) array_like, optional + Domain to use. The interval ``[domain[0], domain[1]]`` is mapped + to the interval ``[window[0], window[1]]`` by shifting and scaling. + The default value is [-1, 1]. + window : (2,) array_like, optional + Window, see `domain` for its use. The default value is [-1, 1]. + + .. versionadded:: 1.6.0 + symbol : str, optional + Symbol used to represent the independent variable in string + representations of the polynomial expression, e.g. for printing. + The symbol must be a valid Python identifier. Default value is 'x'. + + .. versionadded:: 1.24 + + """ + # Virtual Functions + _add = staticmethod(chebadd) + _sub = staticmethod(chebsub) + _mul = staticmethod(chebmul) + _div = staticmethod(chebdiv) + _pow = staticmethod(chebpow) + _val = staticmethod(chebval) + _int = staticmethod(chebint) + _der = staticmethod(chebder) + _fit = staticmethod(chebfit) + _line = staticmethod(chebline) + _roots = staticmethod(chebroots) + _fromroots = staticmethod(chebfromroots) + + @classmethod + def interpolate(cls, func, deg, domain=None, args=()): + """Interpolate a function at the Chebyshev points of the first kind. + + Returns the series that interpolates `func` at the Chebyshev points of + the first kind scaled and shifted to the `domain`. The resulting series + tends to a minmax approximation of `func` when the function is + continuous in the domain. + + .. versionadded:: 1.14.0 + + Parameters + ---------- + func : function + The function to be interpolated. It must be a function of a single + variable of the form ``f(x, a, b, c...)``, where ``a, b, c...`` are + extra arguments passed in the `args` parameter. + deg : int + Degree of the interpolating polynomial. + domain : {None, [beg, end]}, optional + Domain over which `func` is interpolated. The default is None, in + which case the domain is [-1, 1]. + args : tuple, optional + Extra arguments to be used in the function call. Default is no + extra arguments. + + Returns + ------- + polynomial : Chebyshev instance + Interpolating Chebyshev instance. + + Notes + ----- + See `numpy.polynomial.chebfromfunction` for more details. + + """ + if domain is None: + domain = cls.domain + xfunc = lambda x: func(pu.mapdomain(x, cls.window, domain), *args) + coef = chebinterpolate(xfunc, deg) + return cls(coef, domain=domain) + + # Virtual properties + domain = np.array(chebdomain) + window = np.array(chebdomain) + basis_name = 'T' diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/hermite.pyi b/.venv/lib/python3.11/site-packages/numpy/polynomial/hermite.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0d3556d696410689b4614138ad4cf1f6c2283a9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/hermite.pyi @@ -0,0 +1,46 @@ +from typing import Any + +from numpy import ndarray, dtype, int_, float_ +from numpy.polynomial._polybase import ABCPolyBase +from numpy.polynomial.polyutils import trimcoef + +__all__: list[str] + +hermtrim = trimcoef + +def poly2herm(pol): ... +def herm2poly(c): ... + +hermdomain: ndarray[Any, dtype[int_]] +hermzero: ndarray[Any, dtype[int_]] +hermone: ndarray[Any, dtype[int_]] +hermx: ndarray[Any, dtype[float_]] + +def hermline(off, scl): ... +def hermfromroots(roots): ... +def hermadd(c1, c2): ... +def hermsub(c1, c2): ... +def hermmulx(c): ... +def hermmul(c1, c2): ... +def hermdiv(c1, c2): ... +def hermpow(c, pow, maxpower=...): ... +def hermder(c, m=..., scl=..., axis=...): ... +def hermint(c, m=..., k = ..., lbnd=..., scl=..., axis=...): ... +def hermval(x, c, tensor=...): ... +def hermval2d(x, y, c): ... +def hermgrid2d(x, y, c): ... +def hermval3d(x, y, z, c): ... +def hermgrid3d(x, y, z, c): ... +def hermvander(x, deg): ... +def hermvander2d(x, y, deg): ... +def hermvander3d(x, y, z, deg): ... +def hermfit(x, y, deg, rcond=..., full=..., w=...): ... +def hermcompanion(c): ... +def hermroots(c): ... +def hermgauss(deg): ... +def hermweight(x): ... + +class Hermite(ABCPolyBase): + domain: Any + window: Any + basis_name: Any diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/polynomial.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/polynomial.py new file mode 100644 index 0000000000000000000000000000000000000000..ceadff0bf4ed32f8bbbb9f208bf4d84946efe195 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/polynomial.py @@ -0,0 +1,1542 @@ +""" +================================================= +Power Series (:mod:`numpy.polynomial.polynomial`) +================================================= + +This module provides a number of objects (mostly functions) useful for +dealing with polynomials, including a `Polynomial` class that +encapsulates the usual arithmetic operations. (General information +on how this module represents and works with polynomial objects is in +the docstring for its "parent" sub-package, `numpy.polynomial`). + +Classes +------- +.. autosummary:: + :toctree: generated/ + + Polynomial + +Constants +--------- +.. autosummary:: + :toctree: generated/ + + polydomain + polyzero + polyone + polyx + +Arithmetic +---------- +.. autosummary:: + :toctree: generated/ + + polyadd + polysub + polymulx + polymul + polydiv + polypow + polyval + polyval2d + polyval3d + polygrid2d + polygrid3d + +Calculus +-------- +.. autosummary:: + :toctree: generated/ + + polyder + polyint + +Misc Functions +-------------- +.. autosummary:: + :toctree: generated/ + + polyfromroots + polyroots + polyvalfromroots + polyvander + polyvander2d + polyvander3d + polycompanion + polyfit + polytrim + polyline + +See Also +-------- +`numpy.polynomial` + +""" +__all__ = [ + 'polyzero', 'polyone', 'polyx', 'polydomain', 'polyline', 'polyadd', + 'polysub', 'polymulx', 'polymul', 'polydiv', 'polypow', 'polyval', + 'polyvalfromroots', 'polyder', 'polyint', 'polyfromroots', 'polyvander', + 'polyfit', 'polytrim', 'polyroots', 'Polynomial', 'polyval2d', 'polyval3d', + 'polygrid2d', 'polygrid3d', 'polyvander2d', 'polyvander3d'] + +import numpy as np +import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index + +from . import polyutils as pu +from ._polybase import ABCPolyBase + +polytrim = pu.trimcoef + +# +# These are constant arrays are of integer type so as to be compatible +# with the widest range of other types, such as Decimal. +# + +# Polynomial default domain. +polydomain = np.array([-1, 1]) + +# Polynomial coefficients representing zero. +polyzero = np.array([0]) + +# Polynomial coefficients representing one. +polyone = np.array([1]) + +# Polynomial coefficients representing the identity x. +polyx = np.array([0, 1]) + +# +# Polynomial series functions +# + + +def polyline(off, scl): + """ + Returns an array representing a linear polynomial. + + Parameters + ---------- + off, scl : scalars + The "y-intercept" and "slope" of the line, respectively. + + Returns + ------- + y : ndarray + This module's representation of the linear polynomial ``off + + scl*x``. + + See Also + -------- + numpy.polynomial.chebyshev.chebline + numpy.polynomial.legendre.legline + numpy.polynomial.laguerre.lagline + numpy.polynomial.hermite.hermline + numpy.polynomial.hermite_e.hermeline + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> P.polyline(1,-1) + array([ 1, -1]) + >>> P.polyval(1, P.polyline(1,-1)) # should be 0 + 0.0 + + """ + if scl != 0: + return np.array([off, scl]) + else: + return np.array([off]) + + +def polyfromroots(roots): + """ + Generate a monic polynomial with given roots. + + Return the coefficients of the polynomial + + .. math:: p(x) = (x - r_0) * (x - r_1) * ... * (x - r_n), + + where the ``r_n`` are the roots specified in `roots`. If a zero has + multiplicity n, then it must appear in `roots` n times. For instance, + if 2 is a root of multiplicity three and 3 is a root of multiplicity 2, + then `roots` looks something like [2, 2, 2, 3, 3]. The roots can appear + in any order. + + If the returned coefficients are `c`, then + + .. math:: p(x) = c_0 + c_1 * x + ... + x^n + + The coefficient of the last term is 1 for monic polynomials in this + form. + + Parameters + ---------- + roots : array_like + Sequence containing the roots. + + Returns + ------- + out : ndarray + 1-D array of the polynomial's coefficients If all the roots are + real, then `out` is also real, otherwise it is complex. (see + Examples below). + + See Also + -------- + numpy.polynomial.chebyshev.chebfromroots + numpy.polynomial.legendre.legfromroots + numpy.polynomial.laguerre.lagfromroots + numpy.polynomial.hermite.hermfromroots + numpy.polynomial.hermite_e.hermefromroots + + Notes + ----- + The coefficients are determined by multiplying together linear factors + of the form ``(x - r_i)``, i.e. + + .. math:: p(x) = (x - r_0) (x - r_1) ... (x - r_n) + + where ``n == len(roots) - 1``; note that this implies that ``1`` is always + returned for :math:`a_n`. + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> P.polyfromroots((-1,0,1)) # x(x - 1)(x + 1) = x^3 - x + array([ 0., -1., 0., 1.]) + >>> j = complex(0,1) + >>> P.polyfromroots((-j,j)) # complex returned, though values are real + array([1.+0.j, 0.+0.j, 1.+0.j]) + + """ + return pu._fromroots(polyline, polymul, roots) + + +def polyadd(c1, c2): + """ + Add one polynomial to another. + + Returns the sum of two polynomials `c1` + `c2`. The arguments are + sequences of coefficients from lowest order term to highest, i.e., + [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of polynomial coefficients ordered from low to high. + + Returns + ------- + out : ndarray + The coefficient array representing their sum. + + See Also + -------- + polysub, polymulx, polymul, polydiv, polypow + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> sum = P.polyadd(c1,c2); sum + array([4., 4., 4.]) + >>> P.polyval(2, sum) # 4 + 4(2) + 4(2**2) + 28.0 + + """ + return pu._add(c1, c2) + + +def polysub(c1, c2): + """ + Subtract one polynomial from another. + + Returns the difference of two polynomials `c1` - `c2`. The arguments + are sequences of coefficients from lowest order term to highest, i.e., + [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of polynomial coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Of coefficients representing their difference. + + See Also + -------- + polyadd, polymulx, polymul, polydiv, polypow + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> P.polysub(c1,c2) + array([-2., 0., 2.]) + >>> P.polysub(c2,c1) # -P.polysub(c1,c2) + array([ 2., 0., -2.]) + + """ + return pu._sub(c1, c2) + + +def polymulx(c): + """Multiply a polynomial by x. + + Multiply the polynomial `c` by x, where x is the independent + variable. + + + Parameters + ---------- + c : array_like + 1-D array of polynomial coefficients ordered from low to + high. + + Returns + ------- + out : ndarray + Array representing the result of the multiplication. + + See Also + -------- + polyadd, polysub, polymul, polydiv, polypow + + Notes + ----- + + .. versionadded:: 1.5.0 + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + # The zero series needs special treatment + if len(c) == 1 and c[0] == 0: + return c + + prd = np.empty(len(c) + 1, dtype=c.dtype) + prd[0] = c[0]*0 + prd[1:] = c + return prd + + +def polymul(c1, c2): + """ + Multiply one polynomial by another. + + Returns the product of two polynomials `c1` * `c2`. The arguments are + sequences of coefficients, from lowest order term to highest, e.g., + [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2.`` + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of coefficients representing a polynomial, relative to the + "standard" basis, and ordered from lowest order term to highest. + + Returns + ------- + out : ndarray + Of the coefficients of their product. + + See Also + -------- + polyadd, polysub, polymulx, polydiv, polypow + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> P.polymul(c1,c2) + array([ 3., 8., 14., 8., 3.]) + + """ + # c1, c2 are trimmed copies + [c1, c2] = pu.as_series([c1, c2]) + ret = np.convolve(c1, c2) + return pu.trimseq(ret) + + +def polydiv(c1, c2): + """ + Divide one polynomial by another. + + Returns the quotient-with-remainder of two polynomials `c1` / `c2`. + The arguments are sequences of coefficients, from lowest order term + to highest, e.g., [1,2,3] represents ``1 + 2*x + 3*x**2``. + + Parameters + ---------- + c1, c2 : array_like + 1-D arrays of polynomial coefficients ordered from low to high. + + Returns + ------- + [quo, rem] : ndarrays + Of coefficient series representing the quotient and remainder. + + See Also + -------- + polyadd, polysub, polymulx, polymul, polypow + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c1 = (1,2,3) + >>> c2 = (3,2,1) + >>> P.polydiv(c1,c2) + (array([3.]), array([-8., -4.])) + >>> P.polydiv(c2,c1) + (array([ 0.33333333]), array([ 2.66666667, 1.33333333])) # may vary + + """ + # c1, c2 are trimmed copies + [c1, c2] = pu.as_series([c1, c2]) + if c2[-1] == 0: + raise ZeroDivisionError() + + # note: this is more efficient than `pu._div(polymul, c1, c2)` + lc1 = len(c1) + lc2 = len(c2) + if lc1 < lc2: + return c1[:1]*0, c1 + elif lc2 == 1: + return c1/c2[-1], c1[:1]*0 + else: + dlen = lc1 - lc2 + scl = c2[-1] + c2 = c2[:-1]/scl + i = dlen + j = lc1 - 1 + while i >= 0: + c1[i:j] -= c2*c1[j] + i -= 1 + j -= 1 + return c1[j+1:]/scl, pu.trimseq(c1[:j+1]) + + +def polypow(c, pow, maxpower=None): + """Raise a polynomial to a power. + + Returns the polynomial `c` raised to the power `pow`. The argument + `c` is a sequence of coefficients ordered from low to high. i.e., + [1,2,3] is the series ``1 + 2*x + 3*x**2.`` + + Parameters + ---------- + c : array_like + 1-D array of array of series coefficients ordered from low to + high degree. + pow : integer + Power to which the series will be raised + maxpower : integer, optional + Maximum power allowed. This is mainly to limit growth of the series + to unmanageable size. Default is 16 + + Returns + ------- + coef : ndarray + Power series of power. + + See Also + -------- + polyadd, polysub, polymulx, polymul, polydiv + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> P.polypow([1,2,3], 2) + array([ 1., 4., 10., 12., 9.]) + + """ + # note: this is more efficient than `pu._pow(polymul, c1, c2)`, as it + # avoids calling `as_series` repeatedly + return pu._pow(np.convolve, c, pow, maxpower) + + +def polyder(c, m=1, scl=1, axis=0): + """ + Differentiate a polynomial. + + Returns the polynomial coefficients `c` differentiated `m` times along + `axis`. At each iteration the result is multiplied by `scl` (the + scaling factor is for use in a linear change of variable). The + argument `c` is an array of coefficients from low to high degree along + each axis, e.g., [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2`` + while [[1,2],[1,2]] represents ``1 + 1*x + 2*y + 2*x*y`` if axis=0 is + ``x`` and axis=1 is ``y``. + + Parameters + ---------- + c : array_like + Array of polynomial coefficients. If c is multidimensional the + different axis correspond to different variables with the degree + in each axis given by the corresponding index. + m : int, optional + Number of derivatives taken, must be non-negative. (Default: 1) + scl : scalar, optional + Each differentiation is multiplied by `scl`. The end result is + multiplication by ``scl**m``. This is for use in a linear change + of variable. (Default: 1) + axis : int, optional + Axis over which the derivative is taken. (Default: 0). + + .. versionadded:: 1.7.0 + + Returns + ------- + der : ndarray + Polynomial coefficients of the derivative. + + See Also + -------- + polyint + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c = (1,2,3,4) # 1 + 2x + 3x**2 + 4x**3 + >>> P.polyder(c) # (d/dx)(c) = 2 + 6x + 12x**2 + array([ 2., 6., 12.]) + >>> P.polyder(c,3) # (d**3/dx**3)(c) = 24 + array([24.]) + >>> P.polyder(c,scl=-1) # (d/d(-x))(c) = -2 - 6x - 12x**2 + array([ -2., -6., -12.]) + >>> P.polyder(c,2,-1) # (d**2/d(-x)**2)(c) = 6 + 24x + array([ 6., 24.]) + + """ + c = np.array(c, ndmin=1, copy=True) + if c.dtype.char in '?bBhHiIlLqQpP': + # astype fails with NA + c = c + 0.0 + cdt = c.dtype + cnt = pu._deprecate_as_int(m, "the order of derivation") + iaxis = pu._deprecate_as_int(axis, "the axis") + if cnt < 0: + raise ValueError("The order of derivation must be non-negative") + iaxis = normalize_axis_index(iaxis, c.ndim) + + if cnt == 0: + return c + + c = np.moveaxis(c, iaxis, 0) + n = len(c) + if cnt >= n: + c = c[:1]*0 + else: + for i in range(cnt): + n = n - 1 + c *= scl + der = np.empty((n,) + c.shape[1:], dtype=cdt) + for j in range(n, 0, -1): + der[j - 1] = j*c[j] + c = der + c = np.moveaxis(c, 0, iaxis) + return c + + +def polyint(c, m=1, k=[], lbnd=0, scl=1, axis=0): + """ + Integrate a polynomial. + + Returns the polynomial coefficients `c` integrated `m` times from + `lbnd` along `axis`. At each iteration the resulting series is + **multiplied** by `scl` and an integration constant, `k`, is added. + The scaling factor is for use in a linear change of variable. ("Buyer + beware": note that, depending on what one is doing, one may want `scl` + to be the reciprocal of what one might expect; for more information, + see the Notes section below.) The argument `c` is an array of + coefficients, from low to high degree along each axis, e.g., [1,2,3] + represents the polynomial ``1 + 2*x + 3*x**2`` while [[1,2],[1,2]] + represents ``1 + 1*x + 2*y + 2*x*y`` if axis=0 is ``x`` and axis=1 is + ``y``. + + Parameters + ---------- + c : array_like + 1-D array of polynomial coefficients, ordered from low to high. + m : int, optional + Order of integration, must be positive. (Default: 1) + k : {[], list, scalar}, optional + Integration constant(s). The value of the first integral at zero + is the first value in the list, the value of the second integral + at zero is the second value, etc. If ``k == []`` (the default), + all constants are set to zero. If ``m == 1``, a single scalar can + be given instead of a list. + lbnd : scalar, optional + The lower bound of the integral. (Default: 0) + scl : scalar, optional + Following each integration the result is *multiplied* by `scl` + before the integration constant is added. (Default: 1) + axis : int, optional + Axis over which the integral is taken. (Default: 0). + + .. versionadded:: 1.7.0 + + Returns + ------- + S : ndarray + Coefficient array of the integral. + + Raises + ------ + ValueError + If ``m < 1``, ``len(k) > m``, ``np.ndim(lbnd) != 0``, or + ``np.ndim(scl) != 0``. + + See Also + -------- + polyder + + Notes + ----- + Note that the result of each integration is *multiplied* by `scl`. Why + is this important to note? Say one is making a linear change of + variable :math:`u = ax + b` in an integral relative to `x`. Then + :math:`dx = du/a`, so one will need to set `scl` equal to + :math:`1/a` - perhaps not what one would have first thought. + + Examples + -------- + >>> from numpy.polynomial import polynomial as P + >>> c = (1,2,3) + >>> P.polyint(c) # should return array([0, 1, 1, 1]) + array([0., 1., 1., 1.]) + >>> P.polyint(c,3) # should return array([0, 0, 0, 1/6, 1/12, 1/20]) + array([ 0. , 0. , 0. , 0.16666667, 0.08333333, # may vary + 0.05 ]) + >>> P.polyint(c,k=3) # should return array([3, 1, 1, 1]) + array([3., 1., 1., 1.]) + >>> P.polyint(c,lbnd=-2) # should return array([6, 1, 1, 1]) + array([6., 1., 1., 1.]) + >>> P.polyint(c,scl=-2) # should return array([0, -2, -2, -2]) + array([ 0., -2., -2., -2.]) + + """ + c = np.array(c, ndmin=1, copy=True) + if c.dtype.char in '?bBhHiIlLqQpP': + # astype doesn't preserve mask attribute. + c = c + 0.0 + cdt = c.dtype + if not np.iterable(k): + k = [k] + cnt = pu._deprecate_as_int(m, "the order of integration") + iaxis = pu._deprecate_as_int(axis, "the axis") + if cnt < 0: + raise ValueError("The order of integration must be non-negative") + if len(k) > cnt: + raise ValueError("Too many integration constants") + if np.ndim(lbnd) != 0: + raise ValueError("lbnd must be a scalar.") + if np.ndim(scl) != 0: + raise ValueError("scl must be a scalar.") + iaxis = normalize_axis_index(iaxis, c.ndim) + + if cnt == 0: + return c + + k = list(k) + [0]*(cnt - len(k)) + c = np.moveaxis(c, iaxis, 0) + for i in range(cnt): + n = len(c) + c *= scl + if n == 1 and np.all(c[0] == 0): + c[0] += k[i] + else: + tmp = np.empty((n + 1,) + c.shape[1:], dtype=cdt) + tmp[0] = c[0]*0 + tmp[1] = c[0] + for j in range(1, n): + tmp[j + 1] = c[j]/(j + 1) + tmp[0] += k[i] - polyval(lbnd, tmp) + c = tmp + c = np.moveaxis(c, 0, iaxis) + return c + + +def polyval(x, c, tensor=True): + """ + Evaluate a polynomial at points x. + + If `c` is of length `n + 1`, this function returns the value + + .. math:: p(x) = c_0 + c_1 * x + ... + c_n * x^n + + The parameter `x` is converted to an array only if it is a tuple or a + list, otherwise it is treated as a scalar. In either case, either `x` + or its elements must support multiplication and addition both with + themselves and with the elements of `c`. + + If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If + `c` is multidimensional, then the shape of the result depends on the + value of `tensor`. If `tensor` is true the shape will be c.shape[1:] + + x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that + scalars have shape (,). + + Trailing zeros in the coefficients will be used in the evaluation, so + they should be avoided if efficiency is a concern. + + Parameters + ---------- + x : array_like, compatible object + If `x` is a list or tuple, it is converted to an ndarray, otherwise + it is left unchanged and treated as a scalar. In either case, `x` + or its elements must support addition and multiplication with + with themselves and with the elements of `c`. + c : array_like + Array of coefficients ordered so that the coefficients for terms of + degree n are contained in c[n]. If `c` is multidimensional the + remaining indices enumerate multiple polynomials. In the two + dimensional case the coefficients may be thought of as stored in + the columns of `c`. + tensor : boolean, optional + If True, the shape of the coefficient array is extended with ones + on the right, one for each dimension of `x`. Scalars have dimension 0 + for this action. The result is that every column of coefficients in + `c` is evaluated for every element of `x`. If False, `x` is broadcast + over the columns of `c` for the evaluation. This keyword is useful + when `c` is multidimensional. The default value is True. + + .. versionadded:: 1.7.0 + + Returns + ------- + values : ndarray, compatible object + The shape of the returned array is described above. + + See Also + -------- + polyval2d, polygrid2d, polyval3d, polygrid3d + + Notes + ----- + The evaluation uses Horner's method. + + Examples + -------- + >>> from numpy.polynomial.polynomial import polyval + >>> polyval(1, [1,2,3]) + 6.0 + >>> a = np.arange(4).reshape(2,2) + >>> a + array([[0, 1], + [2, 3]]) + >>> polyval(a, [1,2,3]) + array([[ 1., 6.], + [17., 34.]]) + >>> coef = np.arange(4).reshape(2,2) # multidimensional coefficients + >>> coef + array([[0, 1], + [2, 3]]) + >>> polyval([1,2], coef, tensor=True) + array([[2., 4.], + [4., 7.]]) + >>> polyval([1,2], coef, tensor=False) + array([2., 7.]) + + """ + c = np.array(c, ndmin=1, copy=False) + if c.dtype.char in '?bBhHiIlLqQpP': + # astype fails with NA + c = c + 0.0 + if isinstance(x, (tuple, list)): + x = np.asarray(x) + if isinstance(x, np.ndarray) and tensor: + c = c.reshape(c.shape + (1,)*x.ndim) + + c0 = c[-1] + x*0 + for i in range(2, len(c) + 1): + c0 = c[-i] + c0*x + return c0 + + +def polyvalfromroots(x, r, tensor=True): + """ + Evaluate a polynomial specified by its roots at points x. + + If `r` is of length `N`, this function returns the value + + .. math:: p(x) = \\prod_{n=1}^{N} (x - r_n) + + The parameter `x` is converted to an array only if it is a tuple or a + list, otherwise it is treated as a scalar. In either case, either `x` + or its elements must support multiplication and addition both with + themselves and with the elements of `r`. + + If `r` is a 1-D array, then `p(x)` will have the same shape as `x`. If `r` + is multidimensional, then the shape of the result depends on the value of + `tensor`. If `tensor` is ``True`` the shape will be r.shape[1:] + x.shape; + that is, each polynomial is evaluated at every value of `x`. If `tensor` is + ``False``, the shape will be r.shape[1:]; that is, each polynomial is + evaluated only for the corresponding broadcast value of `x`. Note that + scalars have shape (,). + + .. versionadded:: 1.12 + + Parameters + ---------- + x : array_like, compatible object + If `x` is a list or tuple, it is converted to an ndarray, otherwise + it is left unchanged and treated as a scalar. In either case, `x` + or its elements must support addition and multiplication with + with themselves and with the elements of `r`. + r : array_like + Array of roots. If `r` is multidimensional the first index is the + root index, while the remaining indices enumerate multiple + polynomials. For instance, in the two dimensional case the roots + of each polynomial may be thought of as stored in the columns of `r`. + tensor : boolean, optional + If True, the shape of the roots array is extended with ones on the + right, one for each dimension of `x`. Scalars have dimension 0 for this + action. The result is that every column of coefficients in `r` is + evaluated for every element of `x`. If False, `x` is broadcast over the + columns of `r` for the evaluation. This keyword is useful when `r` is + multidimensional. The default value is True. + + Returns + ------- + values : ndarray, compatible object + The shape of the returned array is described above. + + See Also + -------- + polyroots, polyfromroots, polyval + + Examples + -------- + >>> from numpy.polynomial.polynomial import polyvalfromroots + >>> polyvalfromroots(1, [1,2,3]) + 0.0 + >>> a = np.arange(4).reshape(2,2) + >>> a + array([[0, 1], + [2, 3]]) + >>> polyvalfromroots(a, [-1, 0, 1]) + array([[-0., 0.], + [ 6., 24.]]) + >>> r = np.arange(-2, 2).reshape(2,2) # multidimensional coefficients + >>> r # each column of r defines one polynomial + array([[-2, -1], + [ 0, 1]]) + >>> b = [-2, 1] + >>> polyvalfromroots(b, r, tensor=True) + array([[-0., 3.], + [ 3., 0.]]) + >>> polyvalfromroots(b, r, tensor=False) + array([-0., 0.]) + """ + r = np.array(r, ndmin=1, copy=False) + if r.dtype.char in '?bBhHiIlLqQpP': + r = r.astype(np.double) + if isinstance(x, (tuple, list)): + x = np.asarray(x) + if isinstance(x, np.ndarray): + if tensor: + r = r.reshape(r.shape + (1,)*x.ndim) + elif x.ndim >= r.ndim: + raise ValueError("x.ndim must be < r.ndim when tensor == False") + return np.prod(x - r, axis=0) + + +def polyval2d(x, y, c): + """ + Evaluate a 2-D polynomial at points (x, y). + + This function returns the value + + .. math:: p(x,y) = \\sum_{i,j} c_{i,j} * x^i * y^j + + The parameters `x` and `y` are converted to arrays only if they are + tuples or a lists, otherwise they are treated as a scalars and they + must have the same shape after conversion. In either case, either `x` + and `y` or their elements must support multiplication and addition both + with themselves and with the elements of `c`. + + If `c` has fewer than two dimensions, ones are implicitly appended to + its shape to make it 2-D. The shape of the result will be c.shape[2:] + + x.shape. + + Parameters + ---------- + x, y : array_like, compatible objects + The two dimensional series is evaluated at the points `(x, y)`, + where `x` and `y` must have the same shape. If `x` or `y` is a list + or tuple, it is first converted to an ndarray, otherwise it is left + unchanged and, if it isn't an ndarray, it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficient of the term + of multi-degree i,j is contained in `c[i,j]`. If `c` has + dimension greater than two the remaining indices enumerate multiple + sets of coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional polynomial at points formed with + pairs of corresponding values from `x` and `y`. + + See Also + -------- + polyval, polygrid2d, polyval3d, polygrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._valnd(polyval, c, x, y) + + +def polygrid2d(x, y, c): + """ + Evaluate a 2-D polynomial on the Cartesian product of x and y. + + This function returns the values: + + .. math:: p(a,b) = \\sum_{i,j} c_{i,j} * a^i * b^j + + where the points `(a, b)` consist of all pairs formed by taking + `a` from `x` and `b` from `y`. The resulting points form a grid with + `x` in the first dimension and `y` in the second. + + The parameters `x` and `y` are converted to arrays only if they are + tuples or a lists, otherwise they are treated as a scalars. In either + case, either `x` and `y` or their elements must support multiplication + and addition both with themselves and with the elements of `c`. + + If `c` has fewer than two dimensions, ones are implicitly appended to + its shape to make it 2-D. The shape of the result will be c.shape[2:] + + x.shape + y.shape. + + Parameters + ---------- + x, y : array_like, compatible objects + The two dimensional series is evaluated at the points in the + Cartesian product of `x` and `y`. If `x` or `y` is a list or + tuple, it is first converted to an ndarray, otherwise it is left + unchanged and, if it isn't an ndarray, it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficients for terms of + degree i,j are contained in ``c[i,j]``. If `c` has dimension + greater than two the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional polynomial at points in the Cartesian + product of `x` and `y`. + + See Also + -------- + polyval, polyval2d, polyval3d, polygrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._gridnd(polyval, c, x, y) + + +def polyval3d(x, y, z, c): + """ + Evaluate a 3-D polynomial at points (x, y, z). + + This function returns the values: + + .. math:: p(x,y,z) = \\sum_{i,j,k} c_{i,j,k} * x^i * y^j * z^k + + The parameters `x`, `y`, and `z` are converted to arrays only if + they are tuples or a lists, otherwise they are treated as a scalars and + they must have the same shape after conversion. In either case, either + `x`, `y`, and `z` or their elements must support multiplication and + addition both with themselves and with the elements of `c`. + + If `c` has fewer than 3 dimensions, ones are implicitly appended to its + shape to make it 3-D. The shape of the result will be c.shape[3:] + + x.shape. + + Parameters + ---------- + x, y, z : array_like, compatible object + The three dimensional series is evaluated at the points + `(x, y, z)`, where `x`, `y`, and `z` must have the same shape. If + any of `x`, `y`, or `z` is a list or tuple, it is first converted + to an ndarray, otherwise it is left unchanged and if it isn't an + ndarray it is treated as a scalar. + c : array_like + Array of coefficients ordered so that the coefficient of the term of + multi-degree i,j,k is contained in ``c[i,j,k]``. If `c` has dimension + greater than 3 the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the multidimensional polynomial on points formed with + triples of corresponding values from `x`, `y`, and `z`. + + See Also + -------- + polyval, polyval2d, polygrid2d, polygrid3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._valnd(polyval, c, x, y, z) + + +def polygrid3d(x, y, z, c): + """ + Evaluate a 3-D polynomial on the Cartesian product of x, y and z. + + This function returns the values: + + .. math:: p(a,b,c) = \\sum_{i,j,k} c_{i,j,k} * a^i * b^j * c^k + + where the points `(a, b, c)` consist of all triples formed by taking + `a` from `x`, `b` from `y`, and `c` from `z`. The resulting points form + a grid with `x` in the first dimension, `y` in the second, and `z` in + the third. + + The parameters `x`, `y`, and `z` are converted to arrays only if they + are tuples or a lists, otherwise they are treated as a scalars. In + either case, either `x`, `y`, and `z` or their elements must support + multiplication and addition both with themselves and with the elements + of `c`. + + If `c` has fewer than three dimensions, ones are implicitly appended to + its shape to make it 3-D. The shape of the result will be c.shape[3:] + + x.shape + y.shape + z.shape. + + Parameters + ---------- + x, y, z : array_like, compatible objects + The three dimensional series is evaluated at the points in the + Cartesian product of `x`, `y`, and `z`. If `x`,`y`, or `z` is a + list or tuple, it is first converted to an ndarray, otherwise it is + left unchanged and, if it isn't an ndarray, it is treated as a + scalar. + c : array_like + Array of coefficients ordered so that the coefficients for terms of + degree i,j are contained in ``c[i,j]``. If `c` has dimension + greater than two the remaining indices enumerate multiple sets of + coefficients. + + Returns + ------- + values : ndarray, compatible object + The values of the two dimensional polynomial at points in the Cartesian + product of `x` and `y`. + + See Also + -------- + polyval, polyval2d, polygrid2d, polyval3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._gridnd(polyval, c, x, y, z) + + +def polyvander(x, deg): + """Vandermonde matrix of given degree. + + Returns the Vandermonde matrix of degree `deg` and sample points + `x`. The Vandermonde matrix is defined by + + .. math:: V[..., i] = x^i, + + where `0 <= i <= deg`. The leading indices of `V` index the elements of + `x` and the last index is the power of `x`. + + If `c` is a 1-D array of coefficients of length `n + 1` and `V` is the + matrix ``V = polyvander(x, n)``, then ``np.dot(V, c)`` and + ``polyval(x, c)`` are the same up to roundoff. This equivalence is + useful both for least squares fitting and for the evaluation of a large + number of polynomials of the same degree and sample points. + + Parameters + ---------- + x : array_like + Array of points. The dtype is converted to float64 or complex128 + depending on whether any of the elements are complex. If `x` is + scalar it is converted to a 1-D array. + deg : int + Degree of the resulting matrix. + + Returns + ------- + vander : ndarray. + The Vandermonde matrix. The shape of the returned matrix is + ``x.shape + (deg + 1,)``, where the last index is the power of `x`. + The dtype will be the same as the converted `x`. + + See Also + -------- + polyvander2d, polyvander3d + + """ + ideg = pu._deprecate_as_int(deg, "deg") + if ideg < 0: + raise ValueError("deg must be non-negative") + + x = np.array(x, copy=False, ndmin=1) + 0.0 + dims = (ideg + 1,) + x.shape + dtyp = x.dtype + v = np.empty(dims, dtype=dtyp) + v[0] = x*0 + 1 + if ideg > 0: + v[1] = x + for i in range(2, ideg + 1): + v[i] = v[i-1]*x + return np.moveaxis(v, 0, -1) + + +def polyvander2d(x, y, deg): + """Pseudo-Vandermonde matrix of given degrees. + + Returns the pseudo-Vandermonde matrix of degrees `deg` and sample + points `(x, y)`. The pseudo-Vandermonde matrix is defined by + + .. math:: V[..., (deg[1] + 1)*i + j] = x^i * y^j, + + where `0 <= i <= deg[0]` and `0 <= j <= deg[1]`. The leading indices of + `V` index the points `(x, y)` and the last index encodes the powers of + `x` and `y`. + + If ``V = polyvander2d(x, y, [xdeg, ydeg])``, then the columns of `V` + correspond to the elements of a 2-D coefficient array `c` of shape + (xdeg + 1, ydeg + 1) in the order + + .. math:: c_{00}, c_{01}, c_{02} ... , c_{10}, c_{11}, c_{12} ... + + and ``np.dot(V, c.flat)`` and ``polyval2d(x, y, c)`` will be the same + up to roundoff. This equivalence is useful both for least squares + fitting and for the evaluation of a large number of 2-D polynomials + of the same degrees and sample points. + + Parameters + ---------- + x, y : array_like + Arrays of point coordinates, all of the same shape. The dtypes + will be converted to either float64 or complex128 depending on + whether any of the elements are complex. Scalars are converted to + 1-D arrays. + deg : list of ints + List of maximum degrees of the form [x_deg, y_deg]. + + Returns + ------- + vander2d : ndarray + The shape of the returned matrix is ``x.shape + (order,)``, where + :math:`order = (deg[0]+1)*(deg([1]+1)`. The dtype will be the same + as the converted `x` and `y`. + + See Also + -------- + polyvander, polyvander3d, polyval2d, polyval3d + + """ + return pu._vander_nd_flat((polyvander, polyvander), (x, y), deg) + + +def polyvander3d(x, y, z, deg): + """Pseudo-Vandermonde matrix of given degrees. + + Returns the pseudo-Vandermonde matrix of degrees `deg` and sample + points `(x, y, z)`. If `l, m, n` are the given degrees in `x, y, z`, + then The pseudo-Vandermonde matrix is defined by + + .. math:: V[..., (m+1)(n+1)i + (n+1)j + k] = x^i * y^j * z^k, + + where `0 <= i <= l`, `0 <= j <= m`, and `0 <= j <= n`. The leading + indices of `V` index the points `(x, y, z)` and the last index encodes + the powers of `x`, `y`, and `z`. + + If ``V = polyvander3d(x, y, z, [xdeg, ydeg, zdeg])``, then the columns + of `V` correspond to the elements of a 3-D coefficient array `c` of + shape (xdeg + 1, ydeg + 1, zdeg + 1) in the order + + .. math:: c_{000}, c_{001}, c_{002},... , c_{010}, c_{011}, c_{012},... + + and ``np.dot(V, c.flat)`` and ``polyval3d(x, y, z, c)`` will be the + same up to roundoff. This equivalence is useful both for least squares + fitting and for the evaluation of a large number of 3-D polynomials + of the same degrees and sample points. + + Parameters + ---------- + x, y, z : array_like + Arrays of point coordinates, all of the same shape. The dtypes will + be converted to either float64 or complex128 depending on whether + any of the elements are complex. Scalars are converted to 1-D + arrays. + deg : list of ints + List of maximum degrees of the form [x_deg, y_deg, z_deg]. + + Returns + ------- + vander3d : ndarray + The shape of the returned matrix is ``x.shape + (order,)``, where + :math:`order = (deg[0]+1)*(deg([1]+1)*(deg[2]+1)`. The dtype will + be the same as the converted `x`, `y`, and `z`. + + See Also + -------- + polyvander, polyvander3d, polyval2d, polyval3d + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + return pu._vander_nd_flat((polyvander, polyvander, polyvander), (x, y, z), deg) + + +def polyfit(x, y, deg, rcond=None, full=False, w=None): + """ + Least-squares fit of a polynomial to data. + + Return the coefficients of a polynomial of degree `deg` that is the + least squares fit to the data values `y` given at points `x`. If `y` is + 1-D the returned coefficients will also be 1-D. If `y` is 2-D multiple + fits are done, one for each column of `y`, and the resulting + coefficients are stored in the corresponding columns of a 2-D return. + The fitted polynomial(s) are in the form + + .. math:: p(x) = c_0 + c_1 * x + ... + c_n * x^n, + + where `n` is `deg`. + + Parameters + ---------- + x : array_like, shape (`M`,) + x-coordinates of the `M` sample (data) points ``(x[i], y[i])``. + y : array_like, shape (`M`,) or (`M`, `K`) + y-coordinates of the sample points. Several sets of sample points + sharing the same x-coordinates can be (independently) fit with one + call to `polyfit` by passing in for `y` a 2-D array that contains + one data set per column. + deg : int or 1-D array_like + Degree(s) of the fitting polynomials. If `deg` is a single integer + all terms up to and including the `deg`'th term are included in the + fit. For NumPy versions >= 1.11.0 a list of integers specifying the + degrees of the terms to include may be used instead. + rcond : float, optional + Relative condition number of the fit. Singular values smaller + than `rcond`, relative to the largest singular value, will be + ignored. The default value is ``len(x)*eps``, where `eps` is the + relative precision of the platform's float type, about 2e-16 in + most cases. + full : bool, optional + Switch determining the nature of the return value. When ``False`` + (the default) just the coefficients are returned; when ``True``, + diagnostic information from the singular value decomposition (used + to solve the fit's matrix equation) is also returned. + w : array_like, shape (`M`,), optional + Weights. If not None, the weight ``w[i]`` applies to the unsquared + residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are + chosen so that the errors of the products ``w[i]*y[i]`` all have the + same variance. When using inverse-variance weighting, use + ``w[i] = 1/sigma(y[i])``. The default value is None. + + .. versionadded:: 1.5.0 + + Returns + ------- + coef : ndarray, shape (`deg` + 1,) or (`deg` + 1, `K`) + Polynomial coefficients ordered from low to high. If `y` was 2-D, + the coefficients in column `k` of `coef` represent the polynomial + fit to the data in `y`'s `k`-th column. + + [residuals, rank, singular_values, rcond] : list + These values are only returned if ``full == True`` + + - residuals -- sum of squared residuals of the least squares fit + - rank -- the numerical rank of the scaled Vandermonde matrix + - singular_values -- singular values of the scaled Vandermonde matrix + - rcond -- value of `rcond`. + + For more details, see `numpy.linalg.lstsq`. + + Raises + ------ + RankWarning + Raised if the matrix in the least-squares fit is rank deficient. + The warning is only raised if ``full == False``. The warnings can + be turned off by: + + >>> import warnings + >>> warnings.simplefilter('ignore', np.RankWarning) + + See Also + -------- + numpy.polynomial.chebyshev.chebfit + numpy.polynomial.legendre.legfit + numpy.polynomial.laguerre.lagfit + numpy.polynomial.hermite.hermfit + numpy.polynomial.hermite_e.hermefit + polyval : Evaluates a polynomial. + polyvander : Vandermonde matrix for powers. + numpy.linalg.lstsq : Computes a least-squares fit from the matrix. + scipy.interpolate.UnivariateSpline : Computes spline fits. + + Notes + ----- + The solution is the coefficients of the polynomial `p` that minimizes + the sum of the weighted squared errors + + .. math:: E = \\sum_j w_j^2 * |y_j - p(x_j)|^2, + + where the :math:`w_j` are the weights. This problem is solved by + setting up the (typically) over-determined matrix equation: + + .. math:: V(x) * c = w * y, + + where `V` is the weighted pseudo Vandermonde matrix of `x`, `c` are the + coefficients to be solved for, `w` are the weights, and `y` are the + observed values. This equation is then solved using the singular value + decomposition of `V`. + + If some of the singular values of `V` are so small that they are + neglected (and `full` == ``False``), a `RankWarning` will be raised. + This means that the coefficient values may be poorly determined. + Fitting to a lower order polynomial will usually get rid of the warning + (but may not be what you want, of course; if you have independent + reason(s) for choosing the degree which isn't working, you may have to: + a) reconsider those reasons, and/or b) reconsider the quality of your + data). The `rcond` parameter can also be set to a value smaller than + its default, but the resulting fit may be spurious and have large + contributions from roundoff error. + + Polynomial fits using double precision tend to "fail" at about + (polynomial) degree 20. Fits using Chebyshev or Legendre series are + generally better conditioned, but much can still depend on the + distribution of the sample points and the smoothness of the data. If + the quality of the fit is inadequate, splines may be a good + alternative. + + Examples + -------- + >>> np.random.seed(123) + >>> from numpy.polynomial import polynomial as P + >>> x = np.linspace(-1,1,51) # x "data": [-1, -0.96, ..., 0.96, 1] + >>> y = x**3 - x + np.random.randn(len(x)) # x^3 - x + Gaussian noise + >>> c, stats = P.polyfit(x,y,3,full=True) + >>> np.random.seed(123) + >>> c # c[0], c[2] should be approx. 0, c[1] approx. -1, c[3] approx. 1 + array([ 0.01909725, -1.30598256, -0.00577963, 1.02644286]) # may vary + >>> stats # note the large SSR, explaining the rather poor results + [array([ 38.06116253]), 4, array([ 1.38446749, 1.32119158, 0.50443316, # may vary + 0.28853036]), 1.1324274851176597e-014] + + Same thing without the added noise + + >>> y = x**3 - x + >>> c, stats = P.polyfit(x,y,3,full=True) + >>> c # c[0], c[2] should be "very close to 0", c[1] ~= -1, c[3] ~= 1 + array([-6.36925336e-18, -1.00000000e+00, -4.08053781e-16, 1.00000000e+00]) + >>> stats # note the minuscule SSR + [array([ 7.46346754e-31]), 4, array([ 1.38446749, 1.32119158, # may vary + 0.50443316, 0.28853036]), 1.1324274851176597e-014] + + """ + return pu._fit(polyvander, x, y, deg, rcond, full, w) + + +def polycompanion(c): + """ + Return the companion matrix of c. + + The companion matrix for power series cannot be made symmetric by + scaling the basis, so this function differs from those for the + orthogonal polynomials. + + Parameters + ---------- + c : array_like + 1-D array of polynomial coefficients ordered from low to high + degree. + + Returns + ------- + mat : ndarray + Companion matrix of dimensions (deg, deg). + + Notes + ----- + + .. versionadded:: 1.7.0 + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + if len(c) < 2: + raise ValueError('Series must have maximum degree of at least 1.') + if len(c) == 2: + return np.array([[-c[0]/c[1]]]) + + n = len(c) - 1 + mat = np.zeros((n, n), dtype=c.dtype) + bot = mat.reshape(-1)[n::n+1] + bot[...] = 1 + mat[:, -1] -= c[:-1]/c[-1] + return mat + + +def polyroots(c): + """ + Compute the roots of a polynomial. + + Return the roots (a.k.a. "zeros") of the polynomial + + .. math:: p(x) = \\sum_i c[i] * x^i. + + Parameters + ---------- + c : 1-D array_like + 1-D array of polynomial coefficients. + + Returns + ------- + out : ndarray + Array of the roots of the polynomial. If all the roots are real, + then `out` is also real, otherwise it is complex. + + See Also + -------- + numpy.polynomial.chebyshev.chebroots + numpy.polynomial.legendre.legroots + numpy.polynomial.laguerre.lagroots + numpy.polynomial.hermite.hermroots + numpy.polynomial.hermite_e.hermeroots + + Notes + ----- + The root estimates are obtained as the eigenvalues of the companion + matrix, Roots far from the origin of the complex plane may have large + errors due to the numerical instability of the power series for such + values. Roots with multiplicity greater than 1 will also show larger + errors as the value of the series near such points is relatively + insensitive to errors in the roots. Isolated roots near the origin can + be improved by a few iterations of Newton's method. + + Examples + -------- + >>> import numpy.polynomial.polynomial as poly + >>> poly.polyroots(poly.polyfromroots((-1,0,1))) + array([-1., 0., 1.]) + >>> poly.polyroots(poly.polyfromroots((-1,0,1))).dtype + dtype('float64') + >>> j = complex(0,1) + >>> poly.polyroots(poly.polyfromroots((-j,0,j))) + array([ 0.00000000e+00+0.j, 0.00000000e+00+1.j, 2.77555756e-17-1.j]) # may vary + + """ + # c is a trimmed copy + [c] = pu.as_series([c]) + if len(c) < 2: + return np.array([], dtype=c.dtype) + if len(c) == 2: + return np.array([-c[0]/c[1]]) + + # rotated companion matrix reduces error + m = polycompanion(c)[::-1,::-1] + r = la.eigvals(m) + r.sort() + return r + + +# +# polynomial class +# + +class Polynomial(ABCPolyBase): + """A power series class. + + The Polynomial class provides the standard Python numerical methods + '+', '-', '*', '//', '%', 'divmod', '**', and '()' as well as the + attributes and methods listed in the `ABCPolyBase` documentation. + + Parameters + ---------- + coef : array_like + Polynomial coefficients in order of increasing degree, i.e., + ``(1, 2, 3)`` give ``1 + 2*x + 3*x**2``. + domain : (2,) array_like, optional + Domain to use. The interval ``[domain[0], domain[1]]`` is mapped + to the interval ``[window[0], window[1]]`` by shifting and scaling. + The default value is [-1, 1]. + window : (2,) array_like, optional + Window, see `domain` for its use. The default value is [-1, 1]. + + .. versionadded:: 1.6.0 + symbol : str, optional + Symbol used to represent the independent variable in string + representations of the polynomial expression, e.g. for printing. + The symbol must be a valid Python identifier. Default value is 'x'. + + .. versionadded:: 1.24 + + """ + # Virtual Functions + _add = staticmethod(polyadd) + _sub = staticmethod(polysub) + _mul = staticmethod(polymul) + _div = staticmethod(polydiv) + _pow = staticmethod(polypow) + _val = staticmethod(polyval) + _int = staticmethod(polyint) + _der = staticmethod(polyder) + _fit = staticmethod(polyfit) + _line = staticmethod(polyline) + _roots = staticmethod(polyroots) + _fromroots = staticmethod(polyfromroots) + + # Virtual properties + domain = np.array(polydomain) + window = np.array(polydomain) + basis_name = None + + @classmethod + def _str_term_unicode(cls, i, arg_str): + if i == '1': + return f"·{arg_str}" + else: + return f"·{arg_str}{i.translate(cls._superscript_mapping)}" + + @staticmethod + def _str_term_ascii(i, arg_str): + if i == '1': + return f" {arg_str}" + else: + return f" {arg_str}**{i}" + + @staticmethod + def _repr_latex_term(i, arg_str, needs_parens): + if needs_parens: + arg_str = rf"\left({arg_str}\right)" + if i == 0: + return '1' + elif i == 1: + return arg_str + else: + return f"{arg_str}^{{{i}}}" diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/polyutils.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/polyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..4829138920169efc5b18b20be4a7d7c9509ba7fb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/polyutils.py @@ -0,0 +1,789 @@ +""" +Utility classes and functions for the polynomial modules. + +This module provides: error and warning objects; a polynomial base class; +and some routines used in both the `polynomial` and `chebyshev` modules. + +Warning objects +--------------- + +.. autosummary:: + :toctree: generated/ + + RankWarning raised in least-squares fit for rank-deficient matrix. + +Functions +--------- + +.. autosummary:: + :toctree: generated/ + + as_series convert list of array_likes into 1-D arrays of common type. + trimseq remove trailing zeros. + trimcoef remove small trailing coefficients. + getdomain return the domain appropriate for a given set of abscissae. + mapdomain maps points between domains. + mapparms parameters of the linear map between domains. + +""" +import operator +import functools +import warnings + +import numpy as np + +from numpy.core.multiarray import dragon4_positional, dragon4_scientific +from numpy.core.umath import absolute + +__all__ = [ + 'RankWarning', 'as_series', 'trimseq', + 'trimcoef', 'getdomain', 'mapdomain', 'mapparms', + 'format_float'] + +# +# Warnings and Exceptions +# + +class RankWarning(UserWarning): + """Issued by chebfit when the design matrix is rank deficient.""" + pass + +# +# Helper functions to convert inputs to 1-D arrays +# +def trimseq(seq): + """Remove small Poly series coefficients. + + Parameters + ---------- + seq : sequence + Sequence of Poly series coefficients. This routine fails for + empty sequences. + + Returns + ------- + series : sequence + Subsequence with trailing zeros removed. If the resulting sequence + would be empty, return the first element. The returned sequence may + or may not be a view. + + Notes + ----- + Do not lose the type info if the sequence contains unknown objects. + + """ + if len(seq) == 0: + return seq + else: + for i in range(len(seq) - 1, -1, -1): + if seq[i] != 0: + break + return seq[:i+1] + + +def as_series(alist, trim=True): + """ + Return argument as a list of 1-d arrays. + + The returned list contains array(s) of dtype double, complex double, or + object. A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of + size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays + of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array + raises a Value Error if it is not first reshaped into either a 1-d or 2-d + array. + + Parameters + ---------- + alist : array_like + A 1- or 2-d array_like + trim : boolean, optional + When True, trailing zeros are removed from the inputs. + When False, the inputs are passed through intact. + + Returns + ------- + [a1, a2,...] : list of 1-D arrays + A copy of the input data as a list of 1-d arrays. + + Raises + ------ + ValueError + Raised when `as_series` cannot convert its input to 1-d arrays, or at + least one of the resulting arrays is empty. + + Examples + -------- + >>> from numpy.polynomial import polyutils as pu + >>> a = np.arange(4) + >>> pu.as_series(a) + [array([0.]), array([1.]), array([2.]), array([3.])] + >>> b = np.arange(6).reshape((2,3)) + >>> pu.as_series(b) + [array([0., 1., 2.]), array([3., 4., 5.])] + + >>> pu.as_series((1, np.arange(3), np.arange(2, dtype=np.float16))) + [array([1.]), array([0., 1., 2.]), array([0., 1.])] + + >>> pu.as_series([2, [1.1, 0.]]) + [array([2.]), array([1.1])] + + >>> pu.as_series([2, [1.1, 0.]], trim=False) + [array([2.]), array([1.1, 0. ])] + + """ + arrays = [np.array(a, ndmin=1, copy=False) for a in alist] + if min([a.size for a in arrays]) == 0: + raise ValueError("Coefficient array is empty") + if any(a.ndim != 1 for a in arrays): + raise ValueError("Coefficient array is not 1-d") + if trim: + arrays = [trimseq(a) for a in arrays] + + if any(a.dtype == np.dtype(object) for a in arrays): + ret = [] + for a in arrays: + if a.dtype != np.dtype(object): + tmp = np.empty(len(a), dtype=np.dtype(object)) + tmp[:] = a[:] + ret.append(tmp) + else: + ret.append(a.copy()) + else: + try: + dtype = np.common_type(*arrays) + except Exception as e: + raise ValueError("Coefficient arrays have no common type") from e + ret = [np.array(a, copy=True, dtype=dtype) for a in arrays] + return ret + + +def trimcoef(c, tol=0): + """ + Remove "small" "trailing" coefficients from a polynomial. + + "Small" means "small in absolute value" and is controlled by the + parameter `tol`; "trailing" means highest order coefficient(s), e.g., in + ``[0, 1, 1, 0, 0]`` (which represents ``0 + x + x**2 + 0*x**3 + 0*x**4``) + both the 3-rd and 4-th order coefficients would be "trimmed." + + Parameters + ---------- + c : array_like + 1-d array of coefficients, ordered from lowest order to highest. + tol : number, optional + Trailing (i.e., highest order) elements with absolute value less + than or equal to `tol` (default value is zero) are removed. + + Returns + ------- + trimmed : ndarray + 1-d array with trailing zeros removed. If the resulting series + would be empty, a series containing a single zero is returned. + + Raises + ------ + ValueError + If `tol` < 0 + + See Also + -------- + trimseq + + Examples + -------- + >>> from numpy.polynomial import polyutils as pu + >>> pu.trimcoef((0,0,3,0,5,0,0)) + array([0., 0., 3., 0., 5.]) + >>> pu.trimcoef((0,0,1e-3,0,1e-5,0,0),1e-3) # item == tol is trimmed + array([0.]) + >>> i = complex(0,1) # works for complex + >>> pu.trimcoef((3e-4,1e-3*(1-i),5e-4,2e-5*(1+i)), 1e-3) + array([0.0003+0.j , 0.001 -0.001j]) + + """ + if tol < 0: + raise ValueError("tol must be non-negative") + + [c] = as_series([c]) + [ind] = np.nonzero(np.abs(c) > tol) + if len(ind) == 0: + return c[:1]*0 + else: + return c[:ind[-1] + 1].copy() + +def getdomain(x): + """ + Return a domain suitable for given abscissae. + + Find a domain suitable for a polynomial or Chebyshev series + defined at the values supplied. + + Parameters + ---------- + x : array_like + 1-d array of abscissae whose domain will be determined. + + Returns + ------- + domain : ndarray + 1-d array containing two values. If the inputs are complex, then + the two returned points are the lower left and upper right corners + of the smallest rectangle (aligned with the axes) in the complex + plane containing the points `x`. If the inputs are real, then the + two points are the ends of the smallest interval containing the + points `x`. + + See Also + -------- + mapparms, mapdomain + + Examples + -------- + >>> from numpy.polynomial import polyutils as pu + >>> points = np.arange(4)**2 - 5; points + array([-5, -4, -1, 4]) + >>> pu.getdomain(points) + array([-5., 4.]) + >>> c = np.exp(complex(0,1)*np.pi*np.arange(12)/6) # unit circle + >>> pu.getdomain(c) + array([-1.-1.j, 1.+1.j]) + + """ + [x] = as_series([x], trim=False) + if x.dtype.char in np.typecodes['Complex']: + rmin, rmax = x.real.min(), x.real.max() + imin, imax = x.imag.min(), x.imag.max() + return np.array((complex(rmin, imin), complex(rmax, imax))) + else: + return np.array((x.min(), x.max())) + +def mapparms(old, new): + """ + Linear map parameters between domains. + + Return the parameters of the linear map ``offset + scale*x`` that maps + `old` to `new` such that ``old[i] -> new[i]``, ``i = 0, 1``. + + Parameters + ---------- + old, new : array_like + Domains. Each domain must (successfully) convert to a 1-d array + containing precisely two values. + + Returns + ------- + offset, scale : scalars + The map ``L(x) = offset + scale*x`` maps the first domain to the + second. + + See Also + -------- + getdomain, mapdomain + + Notes + ----- + Also works for complex numbers, and thus can be used to calculate the + parameters required to map any line in the complex plane to any other + line therein. + + Examples + -------- + >>> from numpy.polynomial import polyutils as pu + >>> pu.mapparms((-1,1),(-1,1)) + (0.0, 1.0) + >>> pu.mapparms((1,-1),(-1,1)) + (-0.0, -1.0) + >>> i = complex(0,1) + >>> pu.mapparms((-i,-1),(1,i)) + ((1+1j), (1-0j)) + + """ + oldlen = old[1] - old[0] + newlen = new[1] - new[0] + off = (old[1]*new[0] - old[0]*new[1])/oldlen + scl = newlen/oldlen + return off, scl + +def mapdomain(x, old, new): + """ + Apply linear map to input points. + + The linear map ``offset + scale*x`` that maps the domain `old` to + the domain `new` is applied to the points `x`. + + Parameters + ---------- + x : array_like + Points to be mapped. If `x` is a subtype of ndarray the subtype + will be preserved. + old, new : array_like + The two domains that determine the map. Each must (successfully) + convert to 1-d arrays containing precisely two values. + + Returns + ------- + x_out : ndarray + Array of points of the same shape as `x`, after application of the + linear map between the two domains. + + See Also + -------- + getdomain, mapparms + + Notes + ----- + Effectively, this implements: + + .. math:: + x\\_out = new[0] + m(x - old[0]) + + where + + .. math:: + m = \\frac{new[1]-new[0]}{old[1]-old[0]} + + Examples + -------- + >>> from numpy.polynomial import polyutils as pu + >>> old_domain = (-1,1) + >>> new_domain = (0,2*np.pi) + >>> x = np.linspace(-1,1,6); x + array([-1. , -0.6, -0.2, 0.2, 0.6, 1. ]) + >>> x_out = pu.mapdomain(x, old_domain, new_domain); x_out + array([ 0. , 1.25663706, 2.51327412, 3.76991118, 5.02654825, # may vary + 6.28318531]) + >>> x - pu.mapdomain(x_out, new_domain, old_domain) + array([0., 0., 0., 0., 0., 0.]) + + Also works for complex numbers (and thus can be used to map any line in + the complex plane to any other line therein). + + >>> i = complex(0,1) + >>> old = (-1 - i, 1 + i) + >>> new = (-1 + i, 1 - i) + >>> z = np.linspace(old[0], old[1], 6); z + array([-1. -1.j , -0.6-0.6j, -0.2-0.2j, 0.2+0.2j, 0.6+0.6j, 1. +1.j ]) + >>> new_z = pu.mapdomain(z, old, new); new_z + array([-1.0+1.j , -0.6+0.6j, -0.2+0.2j, 0.2-0.2j, 0.6-0.6j, 1.0-1.j ]) # may vary + + """ + x = np.asanyarray(x) + off, scl = mapparms(old, new) + return off + scl*x + + +def _nth_slice(i, ndim): + sl = [np.newaxis] * ndim + sl[i] = slice(None) + return tuple(sl) + + +def _vander_nd(vander_fs, points, degrees): + r""" + A generalization of the Vandermonde matrix for N dimensions + + The result is built by combining the results of 1d Vandermonde matrices, + + .. math:: + W[i_0, \ldots, i_M, j_0, \ldots, j_N] = \prod_{k=0}^N{V_k(x_k)[i_0, \ldots, i_M, j_k]} + + where + + .. math:: + N &= \texttt{len(points)} = \texttt{len(degrees)} = \texttt{len(vander\_fs)} \\ + M &= \texttt{points[k].ndim} \\ + V_k &= \texttt{vander\_fs[k]} \\ + x_k &= \texttt{points[k]} \\ + 0 \le j_k &\le \texttt{degrees[k]} + + Expanding the one-dimensional :math:`V_k` functions gives: + + .. math:: + W[i_0, \ldots, i_M, j_0, \ldots, j_N] = \prod_{k=0}^N{B_{k, j_k}(x_k[i_0, \ldots, i_M])} + + where :math:`B_{k,m}` is the m'th basis of the polynomial construction used along + dimension :math:`k`. For a regular polynomial, :math:`B_{k, m}(x) = P_m(x) = x^m`. + + Parameters + ---------- + vander_fs : Sequence[function(array_like, int) -> ndarray] + The 1d vander function to use for each axis, such as ``polyvander`` + points : Sequence[array_like] + Arrays of point coordinates, all of the same shape. The dtypes + will be converted to either float64 or complex128 depending on + whether any of the elements are complex. Scalars are converted to + 1-D arrays. + This must be the same length as `vander_fs`. + degrees : Sequence[int] + The maximum degree (inclusive) to use for each axis. + This must be the same length as `vander_fs`. + + Returns + ------- + vander_nd : ndarray + An array of shape ``points[0].shape + tuple(d + 1 for d in degrees)``. + """ + n_dims = len(vander_fs) + if n_dims != len(points): + raise ValueError( + f"Expected {n_dims} dimensions of sample points, got {len(points)}") + if n_dims != len(degrees): + raise ValueError( + f"Expected {n_dims} dimensions of degrees, got {len(degrees)}") + if n_dims == 0: + raise ValueError("Unable to guess a dtype or shape when no points are given") + + # convert to the same shape and type + points = tuple(np.array(tuple(points), copy=False) + 0.0) + + # produce the vandermonde matrix for each dimension, placing the last + # axis of each in an independent trailing axis of the output + vander_arrays = ( + vander_fs[i](points[i], degrees[i])[(...,) + _nth_slice(i, n_dims)] + for i in range(n_dims) + ) + + # we checked this wasn't empty already, so no `initial` needed + return functools.reduce(operator.mul, vander_arrays) + + +def _vander_nd_flat(vander_fs, points, degrees): + """ + Like `_vander_nd`, but flattens the last ``len(degrees)`` axes into a single axis + + Used to implement the public ``vanderd`` functions. + """ + v = _vander_nd(vander_fs, points, degrees) + return v.reshape(v.shape[:-len(degrees)] + (-1,)) + + +def _fromroots(line_f, mul_f, roots): + """ + Helper function used to implement the ``fromroots`` functions. + + Parameters + ---------- + line_f : function(float, float) -> ndarray + The ``line`` function, such as ``polyline`` + mul_f : function(array_like, array_like) -> ndarray + The ``mul`` function, such as ``polymul`` + roots + See the ``fromroots`` functions for more detail + """ + if len(roots) == 0: + return np.ones(1) + else: + [roots] = as_series([roots], trim=False) + roots.sort() + p = [line_f(-r, 1) for r in roots] + n = len(p) + while n > 1: + m, r = divmod(n, 2) + tmp = [mul_f(p[i], p[i+m]) for i in range(m)] + if r: + tmp[0] = mul_f(tmp[0], p[-1]) + p = tmp + n = m + return p[0] + + +def _valnd(val_f, c, *args): + """ + Helper function used to implement the ``vald`` functions. + + Parameters + ---------- + val_f : function(array_like, array_like, tensor: bool) -> array_like + The ``val`` function, such as ``polyval`` + c, args + See the ``vald`` functions for more detail + """ + args = [np.asanyarray(a) for a in args] + shape0 = args[0].shape + if not all((a.shape == shape0 for a in args[1:])): + if len(args) == 3: + raise ValueError('x, y, z are incompatible') + elif len(args) == 2: + raise ValueError('x, y are incompatible') + else: + raise ValueError('ordinates are incompatible') + it = iter(args) + x0 = next(it) + + # use tensor on only the first + c = val_f(x0, c) + for xi in it: + c = val_f(xi, c, tensor=False) + return c + + +def _gridnd(val_f, c, *args): + """ + Helper function used to implement the ``gridd`` functions. + + Parameters + ---------- + val_f : function(array_like, array_like, tensor: bool) -> array_like + The ``val`` function, such as ``polyval`` + c, args + See the ``gridd`` functions for more detail + """ + for xi in args: + c = val_f(xi, c) + return c + + +def _div(mul_f, c1, c2): + """ + Helper function used to implement the ``div`` functions. + + Implementation uses repeated subtraction of c2 multiplied by the nth basis. + For some polynomial types, a more efficient approach may be possible. + + Parameters + ---------- + mul_f : function(array_like, array_like) -> array_like + The ``mul`` function, such as ``polymul`` + c1, c2 + See the ``div`` functions for more detail + """ + # c1, c2 are trimmed copies + [c1, c2] = as_series([c1, c2]) + if c2[-1] == 0: + raise ZeroDivisionError() + + lc1 = len(c1) + lc2 = len(c2) + if lc1 < lc2: + return c1[:1]*0, c1 + elif lc2 == 1: + return c1/c2[-1], c1[:1]*0 + else: + quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype) + rem = c1 + for i in range(lc1 - lc2, - 1, -1): + p = mul_f([0]*i + [1], c2) + q = rem[-1]/p[-1] + rem = rem[:-1] - q*p[:-1] + quo[i] = q + return quo, trimseq(rem) + + +def _add(c1, c2): + """ Helper function used to implement the ``add`` functions. """ + # c1, c2 are trimmed copies + [c1, c2] = as_series([c1, c2]) + if len(c1) > len(c2): + c1[:c2.size] += c2 + ret = c1 + else: + c2[:c1.size] += c1 + ret = c2 + return trimseq(ret) + + +def _sub(c1, c2): + """ Helper function used to implement the ``sub`` functions. """ + # c1, c2 are trimmed copies + [c1, c2] = as_series([c1, c2]) + if len(c1) > len(c2): + c1[:c2.size] -= c2 + ret = c1 + else: + c2 = -c2 + c2[:c1.size] += c1 + ret = c2 + return trimseq(ret) + + +def _fit(vander_f, x, y, deg, rcond=None, full=False, w=None): + """ + Helper function used to implement the ``fit`` functions. + + Parameters + ---------- + vander_f : function(array_like, int) -> ndarray + The 1d vander function, such as ``polyvander`` + c1, c2 + See the ``fit`` functions for more detail + """ + x = np.asarray(x) + 0.0 + y = np.asarray(y) + 0.0 + deg = np.asarray(deg) + + # check arguments. + if deg.ndim > 1 or deg.dtype.kind not in 'iu' or deg.size == 0: + raise TypeError("deg must be an int or non-empty 1-D array of int") + if deg.min() < 0: + raise ValueError("expected deg >= 0") + if x.ndim != 1: + raise TypeError("expected 1D vector for x") + if x.size == 0: + raise TypeError("expected non-empty vector for x") + if y.ndim < 1 or y.ndim > 2: + raise TypeError("expected 1D or 2D array for y") + if len(x) != len(y): + raise TypeError("expected x and y to have same length") + + if deg.ndim == 0: + lmax = deg + order = lmax + 1 + van = vander_f(x, lmax) + else: + deg = np.sort(deg) + lmax = deg[-1] + order = len(deg) + van = vander_f(x, lmax)[:, deg] + + # set up the least squares matrices in transposed form + lhs = van.T + rhs = y.T + if w is not None: + w = np.asarray(w) + 0.0 + if w.ndim != 1: + raise TypeError("expected 1D vector for w") + if len(x) != len(w): + raise TypeError("expected x and w to have same length") + # apply weights. Don't use inplace operations as they + # can cause problems with NA. + lhs = lhs * w + rhs = rhs * w + + # set rcond + if rcond is None: + rcond = len(x)*np.finfo(x.dtype).eps + + # Determine the norms of the design matrix columns. + if issubclass(lhs.dtype.type, np.complexfloating): + scl = np.sqrt((np.square(lhs.real) + np.square(lhs.imag)).sum(1)) + else: + scl = np.sqrt(np.square(lhs).sum(1)) + scl[scl == 0] = 1 + + # Solve the least squares problem. + c, resids, rank, s = np.linalg.lstsq(lhs.T/scl, rhs.T, rcond) + c = (c.T/scl).T + + # Expand c to include non-fitted coefficients which are set to zero + if deg.ndim > 0: + if c.ndim == 2: + cc = np.zeros((lmax+1, c.shape[1]), dtype=c.dtype) + else: + cc = np.zeros(lmax+1, dtype=c.dtype) + cc[deg] = c + c = cc + + # warn on rank reduction + if rank != order and not full: + msg = "The fit may be poorly conditioned" + warnings.warn(msg, RankWarning, stacklevel=2) + + if full: + return c, [resids, rank, s, rcond] + else: + return c + + +def _pow(mul_f, c, pow, maxpower): + """ + Helper function used to implement the ``pow`` functions. + + Parameters + ---------- + mul_f : function(array_like, array_like) -> ndarray + The ``mul`` function, such as ``polymul`` + c : array_like + 1-D array of array of series coefficients + pow, maxpower + See the ``pow`` functions for more detail + """ + # c is a trimmed copy + [c] = as_series([c]) + power = int(pow) + if power != pow or power < 0: + raise ValueError("Power must be a non-negative integer.") + elif maxpower is not None and power > maxpower: + raise ValueError("Power is too large") + elif power == 0: + return np.array([1], dtype=c.dtype) + elif power == 1: + return c + else: + # This can be made more efficient by using powers of two + # in the usual way. + prd = c + for i in range(2, power + 1): + prd = mul_f(prd, c) + return prd + + +def _deprecate_as_int(x, desc): + """ + Like `operator.index`, but emits a deprecation warning when passed a float + + Parameters + ---------- + x : int-like, or float with integral value + Value to interpret as an integer + desc : str + description to include in any error message + + Raises + ------ + TypeError : if x is a non-integral float or non-numeric + DeprecationWarning : if x is an integral float + """ + try: + return operator.index(x) + except TypeError as e: + # Numpy 1.17.0, 2019-03-11 + try: + ix = int(x) + except TypeError: + pass + else: + if ix == x: + warnings.warn( + f"In future, this will raise TypeError, as {desc} will " + "need to be an integer not just an integral float.", + DeprecationWarning, + stacklevel=3 + ) + return ix + + raise TypeError(f"{desc} must be an integer") from e + + +def format_float(x, parens=False): + if not np.issubdtype(type(x), np.floating): + return str(x) + + opts = np.get_printoptions() + + if np.isnan(x): + return opts['nanstr'] + elif np.isinf(x): + return opts['infstr'] + + exp_format = False + if x != 0: + a = absolute(x) + if a >= 1.e8 or a < 10**min(0, -(opts['precision']-1)//2): + exp_format = True + + trim, unique = '0', True + if opts['floatmode'] == 'fixed': + trim, unique = 'k', False + + if exp_format: + s = dragon4_scientific(x, precision=opts['precision'], + unique=unique, trim=trim, + sign=opts['sign'] == '+') + if parens: + s = '(' + s + ')' + else: + s = dragon4_positional(x, precision=opts['precision'], + fractional=True, + unique=unique, trim=trim, + sign=opts['sign'] == '+') + return s diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_chebyshev.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_chebyshev.py new file mode 100644 index 0000000000000000000000000000000000000000..2f54bebfdb27d54f436378e4ab6d6c8f2426dd90 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_chebyshev.py @@ -0,0 +1,619 @@ +"""Tests for chebyshev module. + +""" +from functools import reduce + +import numpy as np +import numpy.polynomial.chebyshev as cheb +from numpy.polynomial.polynomial import polyval +from numpy.testing import ( + assert_almost_equal, assert_raises, assert_equal, assert_, + ) + + +def trim(x): + return cheb.chebtrim(x, tol=1e-6) + +T0 = [1] +T1 = [0, 1] +T2 = [-1, 0, 2] +T3 = [0, -3, 0, 4] +T4 = [1, 0, -8, 0, 8] +T5 = [0, 5, 0, -20, 0, 16] +T6 = [-1, 0, 18, 0, -48, 0, 32] +T7 = [0, -7, 0, 56, 0, -112, 0, 64] +T8 = [1, 0, -32, 0, 160, 0, -256, 0, 128] +T9 = [0, 9, 0, -120, 0, 432, 0, -576, 0, 256] + +Tlist = [T0, T1, T2, T3, T4, T5, T6, T7, T8, T9] + + +class TestPrivate: + + def test__cseries_to_zseries(self): + for i in range(5): + inp = np.array([2] + [1]*i, np.double) + tgt = np.array([.5]*i + [2] + [.5]*i, np.double) + res = cheb._cseries_to_zseries(inp) + assert_equal(res, tgt) + + def test__zseries_to_cseries(self): + for i in range(5): + inp = np.array([.5]*i + [2] + [.5]*i, np.double) + tgt = np.array([2] + [1]*i, np.double) + res = cheb._zseries_to_cseries(inp) + assert_equal(res, tgt) + + +class TestConstants: + + def test_chebdomain(self): + assert_equal(cheb.chebdomain, [-1, 1]) + + def test_chebzero(self): + assert_equal(cheb.chebzero, [0]) + + def test_chebone(self): + assert_equal(cheb.chebone, [1]) + + def test_chebx(self): + assert_equal(cheb.chebx, [0, 1]) + + +class TestArithmetic: + + def test_chebadd(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + tgt = np.zeros(max(i, j) + 1) + tgt[i] += 1 + tgt[j] += 1 + res = cheb.chebadd([0]*i + [1], [0]*j + [1]) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_chebsub(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + tgt = np.zeros(max(i, j) + 1) + tgt[i] += 1 + tgt[j] -= 1 + res = cheb.chebsub([0]*i + [1], [0]*j + [1]) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_chebmulx(self): + assert_equal(cheb.chebmulx([0]), [0]) + assert_equal(cheb.chebmulx([1]), [0, 1]) + for i in range(1, 5): + ser = [0]*i + [1] + tgt = [0]*(i - 1) + [.5, 0, .5] + assert_equal(cheb.chebmulx(ser), tgt) + + def test_chebmul(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + tgt = np.zeros(i + j + 1) + tgt[i + j] += .5 + tgt[abs(i - j)] += .5 + res = cheb.chebmul([0]*i + [1], [0]*j + [1]) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_chebdiv(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + ci = [0]*i + [1] + cj = [0]*j + [1] + tgt = cheb.chebadd(ci, cj) + quo, rem = cheb.chebdiv(tgt, ci) + res = cheb.chebadd(cheb.chebmul(quo, ci), rem) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_chebpow(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + c = np.arange(i + 1) + tgt = reduce(cheb.chebmul, [c]*j, np.array([1])) + res = cheb.chebpow(c, j) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + +class TestEvaluation: + # coefficients of 1 + 2*x + 3*x**2 + c1d = np.array([2.5, 2., 1.5]) + c2d = np.einsum('i,j->ij', c1d, c1d) + c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d) + + # some random values in [-1, 1) + x = np.random.random((3, 5))*2 - 1 + y = polyval(x, [1., 2., 3.]) + + def test_chebval(self): + #check empty input + assert_equal(cheb.chebval([], [1]).size, 0) + + #check normal input) + x = np.linspace(-1, 1) + y = [polyval(x, c) for c in Tlist] + for i in range(10): + msg = f"At i={i}" + tgt = y[i] + res = cheb.chebval(x, [0]*i + [1]) + assert_almost_equal(res, tgt, err_msg=msg) + + #check that shape is preserved + for i in range(3): + dims = [2]*i + x = np.zeros(dims) + assert_equal(cheb.chebval(x, [1]).shape, dims) + assert_equal(cheb.chebval(x, [1, 0]).shape, dims) + assert_equal(cheb.chebval(x, [1, 0, 0]).shape, dims) + + def test_chebval2d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test exceptions + assert_raises(ValueError, cheb.chebval2d, x1, x2[:2], self.c2d) + + #test values + tgt = y1*y2 + res = cheb.chebval2d(x1, x2, self.c2d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = cheb.chebval2d(z, z, self.c2d) + assert_(res.shape == (2, 3)) + + def test_chebval3d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test exceptions + assert_raises(ValueError, cheb.chebval3d, x1, x2, x3[:2], self.c3d) + + #test values + tgt = y1*y2*y3 + res = cheb.chebval3d(x1, x2, x3, self.c3d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = cheb.chebval3d(z, z, z, self.c3d) + assert_(res.shape == (2, 3)) + + def test_chebgrid2d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test values + tgt = np.einsum('i,j->ij', y1, y2) + res = cheb.chebgrid2d(x1, x2, self.c2d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = cheb.chebgrid2d(z, z, self.c2d) + assert_(res.shape == (2, 3)*2) + + def test_chebgrid3d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test values + tgt = np.einsum('i,j,k->ijk', y1, y2, y3) + res = cheb.chebgrid3d(x1, x2, x3, self.c3d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = cheb.chebgrid3d(z, z, z, self.c3d) + assert_(res.shape == (2, 3)*3) + + +class TestIntegral: + + def test_chebint(self): + # check exceptions + assert_raises(TypeError, cheb.chebint, [0], .5) + assert_raises(ValueError, cheb.chebint, [0], -1) + assert_raises(ValueError, cheb.chebint, [0], 1, [0, 0]) + assert_raises(ValueError, cheb.chebint, [0], lbnd=[0]) + assert_raises(ValueError, cheb.chebint, [0], scl=[0]) + assert_raises(TypeError, cheb.chebint, [0], axis=.5) + + # test integration of zero polynomial + for i in range(2, 5): + k = [0]*(i - 2) + [1] + res = cheb.chebint([0], m=i, k=k) + assert_almost_equal(res, [0, 1]) + + # check single integration with integration constant + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + tgt = [i] + [0]*i + [1/scl] + chebpol = cheb.poly2cheb(pol) + chebint = cheb.chebint(chebpol, m=1, k=[i]) + res = cheb.cheb2poly(chebint) + assert_almost_equal(trim(res), trim(tgt)) + + # check single integration with integration constant and lbnd + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + chebpol = cheb.poly2cheb(pol) + chebint = cheb.chebint(chebpol, m=1, k=[i], lbnd=-1) + assert_almost_equal(cheb.chebval(-1, chebint), i) + + # check single integration with integration constant and scaling + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + tgt = [i] + [0]*i + [2/scl] + chebpol = cheb.poly2cheb(pol) + chebint = cheb.chebint(chebpol, m=1, k=[i], scl=2) + res = cheb.cheb2poly(chebint) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with default k + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = cheb.chebint(tgt, m=1) + res = cheb.chebint(pol, m=j) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with defined k + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = cheb.chebint(tgt, m=1, k=[k]) + res = cheb.chebint(pol, m=j, k=list(range(j))) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with lbnd + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = cheb.chebint(tgt, m=1, k=[k], lbnd=-1) + res = cheb.chebint(pol, m=j, k=list(range(j)), lbnd=-1) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with scaling + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = cheb.chebint(tgt, m=1, k=[k], scl=2) + res = cheb.chebint(pol, m=j, k=list(range(j)), scl=2) + assert_almost_equal(trim(res), trim(tgt)) + + def test_chebint_axis(self): + # check that axis keyword works + c2d = np.random.random((3, 4)) + + tgt = np.vstack([cheb.chebint(c) for c in c2d.T]).T + res = cheb.chebint(c2d, axis=0) + assert_almost_equal(res, tgt) + + tgt = np.vstack([cheb.chebint(c) for c in c2d]) + res = cheb.chebint(c2d, axis=1) + assert_almost_equal(res, tgt) + + tgt = np.vstack([cheb.chebint(c, k=3) for c in c2d]) + res = cheb.chebint(c2d, k=3, axis=1) + assert_almost_equal(res, tgt) + + +class TestDerivative: + + def test_chebder(self): + # check exceptions + assert_raises(TypeError, cheb.chebder, [0], .5) + assert_raises(ValueError, cheb.chebder, [0], -1) + + # check that zeroth derivative does nothing + for i in range(5): + tgt = [0]*i + [1] + res = cheb.chebder(tgt, m=0) + assert_equal(trim(res), trim(tgt)) + + # check that derivation is the inverse of integration + for i in range(5): + for j in range(2, 5): + tgt = [0]*i + [1] + res = cheb.chebder(cheb.chebint(tgt, m=j), m=j) + assert_almost_equal(trim(res), trim(tgt)) + + # check derivation with scaling + for i in range(5): + for j in range(2, 5): + tgt = [0]*i + [1] + res = cheb.chebder(cheb.chebint(tgt, m=j, scl=2), m=j, scl=.5) + assert_almost_equal(trim(res), trim(tgt)) + + def test_chebder_axis(self): + # check that axis keyword works + c2d = np.random.random((3, 4)) + + tgt = np.vstack([cheb.chebder(c) for c in c2d.T]).T + res = cheb.chebder(c2d, axis=0) + assert_almost_equal(res, tgt) + + tgt = np.vstack([cheb.chebder(c) for c in c2d]) + res = cheb.chebder(c2d, axis=1) + assert_almost_equal(res, tgt) + + +class TestVander: + # some random values in [-1, 1) + x = np.random.random((3, 5))*2 - 1 + + def test_chebvander(self): + # check for 1d x + x = np.arange(3) + v = cheb.chebvander(x, 3) + assert_(v.shape == (3, 4)) + for i in range(4): + coef = [0]*i + [1] + assert_almost_equal(v[..., i], cheb.chebval(x, coef)) + + # check for 2d x + x = np.array([[1, 2], [3, 4], [5, 6]]) + v = cheb.chebvander(x, 3) + assert_(v.shape == (3, 2, 4)) + for i in range(4): + coef = [0]*i + [1] + assert_almost_equal(v[..., i], cheb.chebval(x, coef)) + + def test_chebvander2d(self): + # also tests chebval2d for non-square coefficient array + x1, x2, x3 = self.x + c = np.random.random((2, 3)) + van = cheb.chebvander2d(x1, x2, [1, 2]) + tgt = cheb.chebval2d(x1, x2, c) + res = np.dot(van, c.flat) + assert_almost_equal(res, tgt) + + # check shape + van = cheb.chebvander2d([x1], [x2], [1, 2]) + assert_(van.shape == (1, 5, 6)) + + def test_chebvander3d(self): + # also tests chebval3d for non-square coefficient array + x1, x2, x3 = self.x + c = np.random.random((2, 3, 4)) + van = cheb.chebvander3d(x1, x2, x3, [1, 2, 3]) + tgt = cheb.chebval3d(x1, x2, x3, c) + res = np.dot(van, c.flat) + assert_almost_equal(res, tgt) + + # check shape + van = cheb.chebvander3d([x1], [x2], [x3], [1, 2, 3]) + assert_(van.shape == (1, 5, 24)) + + +class TestFitting: + + def test_chebfit(self): + def f(x): + return x*(x - 1)*(x - 2) + + def f2(x): + return x**4 + x**2 + 1 + + # Test exceptions + assert_raises(ValueError, cheb.chebfit, [1], [1], -1) + assert_raises(TypeError, cheb.chebfit, [[1]], [1], 0) + assert_raises(TypeError, cheb.chebfit, [], [1], 0) + assert_raises(TypeError, cheb.chebfit, [1], [[[1]]], 0) + assert_raises(TypeError, cheb.chebfit, [1, 2], [1], 0) + assert_raises(TypeError, cheb.chebfit, [1], [1, 2], 0) + assert_raises(TypeError, cheb.chebfit, [1], [1], 0, w=[[1]]) + assert_raises(TypeError, cheb.chebfit, [1], [1], 0, w=[1, 1]) + assert_raises(ValueError, cheb.chebfit, [1], [1], [-1,]) + assert_raises(ValueError, cheb.chebfit, [1], [1], [2, -1, 6]) + assert_raises(TypeError, cheb.chebfit, [1], [1], []) + + # Test fit + x = np.linspace(0, 2) + y = f(x) + # + coef3 = cheb.chebfit(x, y, 3) + assert_equal(len(coef3), 4) + assert_almost_equal(cheb.chebval(x, coef3), y) + coef3 = cheb.chebfit(x, y, [0, 1, 2, 3]) + assert_equal(len(coef3), 4) + assert_almost_equal(cheb.chebval(x, coef3), y) + # + coef4 = cheb.chebfit(x, y, 4) + assert_equal(len(coef4), 5) + assert_almost_equal(cheb.chebval(x, coef4), y) + coef4 = cheb.chebfit(x, y, [0, 1, 2, 3, 4]) + assert_equal(len(coef4), 5) + assert_almost_equal(cheb.chebval(x, coef4), y) + # check things still work if deg is not in strict increasing + coef4 = cheb.chebfit(x, y, [2, 3, 4, 1, 0]) + assert_equal(len(coef4), 5) + assert_almost_equal(cheb.chebval(x, coef4), y) + # + coef2d = cheb.chebfit(x, np.array([y, y]).T, 3) + assert_almost_equal(coef2d, np.array([coef3, coef3]).T) + coef2d = cheb.chebfit(x, np.array([y, y]).T, [0, 1, 2, 3]) + assert_almost_equal(coef2d, np.array([coef3, coef3]).T) + # test weighting + w = np.zeros_like(x) + yw = y.copy() + w[1::2] = 1 + y[0::2] = 0 + wcoef3 = cheb.chebfit(x, yw, 3, w=w) + assert_almost_equal(wcoef3, coef3) + wcoef3 = cheb.chebfit(x, yw, [0, 1, 2, 3], w=w) + assert_almost_equal(wcoef3, coef3) + # + wcoef2d = cheb.chebfit(x, np.array([yw, yw]).T, 3, w=w) + assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T) + wcoef2d = cheb.chebfit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w) + assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T) + # test scaling with complex values x points whose square + # is zero when summed. + x = [1, 1j, -1, -1j] + assert_almost_equal(cheb.chebfit(x, x, 1), [0, 1]) + assert_almost_equal(cheb.chebfit(x, x, [0, 1]), [0, 1]) + # test fitting only even polynomials + x = np.linspace(-1, 1) + y = f2(x) + coef1 = cheb.chebfit(x, y, 4) + assert_almost_equal(cheb.chebval(x, coef1), y) + coef2 = cheb.chebfit(x, y, [0, 2, 4]) + assert_almost_equal(cheb.chebval(x, coef2), y) + assert_almost_equal(coef1, coef2) + + +class TestInterpolate: + + def f(self, x): + return x * (x - 1) * (x - 2) + + def test_raises(self): + assert_raises(ValueError, cheb.chebinterpolate, self.f, -1) + assert_raises(TypeError, cheb.chebinterpolate, self.f, 10.) + + def test_dimensions(self): + for deg in range(1, 5): + assert_(cheb.chebinterpolate(self.f, deg).shape == (deg + 1,)) + + def test_approximation(self): + + def powx(x, p): + return x**p + + x = np.linspace(-1, 1, 10) + for deg in range(0, 10): + for p in range(0, deg + 1): + c = cheb.chebinterpolate(powx, deg, (p,)) + assert_almost_equal(cheb.chebval(x, c), powx(x, p), decimal=12) + + +class TestCompanion: + + def test_raises(self): + assert_raises(ValueError, cheb.chebcompanion, []) + assert_raises(ValueError, cheb.chebcompanion, [1]) + + def test_dimensions(self): + for i in range(1, 5): + coef = [0]*i + [1] + assert_(cheb.chebcompanion(coef).shape == (i, i)) + + def test_linear_root(self): + assert_(cheb.chebcompanion([1, 2])[0, 0] == -.5) + + +class TestGauss: + + def test_100(self): + x, w = cheb.chebgauss(100) + + # test orthogonality. Note that the results need to be normalized, + # otherwise the huge values that can arise from fast growing + # functions like Laguerre can be very confusing. + v = cheb.chebvander(x, 99) + vv = np.dot(v.T * w, v) + vd = 1/np.sqrt(vv.diagonal()) + vv = vd[:, None] * vv * vd + assert_almost_equal(vv, np.eye(100)) + + # check that the integral of 1 is correct + tgt = np.pi + assert_almost_equal(w.sum(), tgt) + + +class TestMisc: + + def test_chebfromroots(self): + res = cheb.chebfromroots([]) + assert_almost_equal(trim(res), [1]) + for i in range(1, 5): + roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2]) + tgt = [0]*i + [1] + res = cheb.chebfromroots(roots)*2**(i-1) + assert_almost_equal(trim(res), trim(tgt)) + + def test_chebroots(self): + assert_almost_equal(cheb.chebroots([1]), []) + assert_almost_equal(cheb.chebroots([1, 2]), [-.5]) + for i in range(2, 5): + tgt = np.linspace(-1, 1, i) + res = cheb.chebroots(cheb.chebfromroots(tgt)) + assert_almost_equal(trim(res), trim(tgt)) + + def test_chebtrim(self): + coef = [2, -1, 1, 0] + + # Test exceptions + assert_raises(ValueError, cheb.chebtrim, coef, -1) + + # Test results + assert_equal(cheb.chebtrim(coef), coef[:-1]) + assert_equal(cheb.chebtrim(coef, 1), coef[:-3]) + assert_equal(cheb.chebtrim(coef, 2), [0]) + + def test_chebline(self): + assert_equal(cheb.chebline(3, 4), [3, 4]) + + def test_cheb2poly(self): + for i in range(10): + assert_almost_equal(cheb.cheb2poly([0]*i + [1]), Tlist[i]) + + def test_poly2cheb(self): + for i in range(10): + assert_almost_equal(cheb.poly2cheb(Tlist[i]), [0]*i + [1]) + + def test_weight(self): + x = np.linspace(-1, 1, 11)[1:-1] + tgt = 1./(np.sqrt(1 + x) * np.sqrt(1 - x)) + res = cheb.chebweight(x) + assert_almost_equal(res, tgt) + + def test_chebpts1(self): + #test exceptions + assert_raises(ValueError, cheb.chebpts1, 1.5) + assert_raises(ValueError, cheb.chebpts1, 0) + + #test points + tgt = [0] + assert_almost_equal(cheb.chebpts1(1), tgt) + tgt = [-0.70710678118654746, 0.70710678118654746] + assert_almost_equal(cheb.chebpts1(2), tgt) + tgt = [-0.86602540378443871, 0, 0.86602540378443871] + assert_almost_equal(cheb.chebpts1(3), tgt) + tgt = [-0.9238795325, -0.3826834323, 0.3826834323, 0.9238795325] + assert_almost_equal(cheb.chebpts1(4), tgt) + + def test_chebpts2(self): + #test exceptions + assert_raises(ValueError, cheb.chebpts2, 1.5) + assert_raises(ValueError, cheb.chebpts2, 1) + + #test points + tgt = [-1, 1] + assert_almost_equal(cheb.chebpts2(2), tgt) + tgt = [-1, 0, 1] + assert_almost_equal(cheb.chebpts2(3), tgt) + tgt = [-1, -0.5, .5, 1] + assert_almost_equal(cheb.chebpts2(4), tgt) + tgt = [-1.0, -0.707106781187, 0, 0.707106781187, 1.0] + assert_almost_equal(cheb.chebpts2(5), tgt) diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_hermite_e.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_hermite_e.py new file mode 100644 index 0000000000000000000000000000000000000000..2d262a3306222bd79f682b09763b0bd2b90ba8fe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_hermite_e.py @@ -0,0 +1,556 @@ +"""Tests for hermite_e module. + +""" +from functools import reduce + +import numpy as np +import numpy.polynomial.hermite_e as herme +from numpy.polynomial.polynomial import polyval +from numpy.testing import ( + assert_almost_equal, assert_raises, assert_equal, assert_, + ) + +He0 = np.array([1]) +He1 = np.array([0, 1]) +He2 = np.array([-1, 0, 1]) +He3 = np.array([0, -3, 0, 1]) +He4 = np.array([3, 0, -6, 0, 1]) +He5 = np.array([0, 15, 0, -10, 0, 1]) +He6 = np.array([-15, 0, 45, 0, -15, 0, 1]) +He7 = np.array([0, -105, 0, 105, 0, -21, 0, 1]) +He8 = np.array([105, 0, -420, 0, 210, 0, -28, 0, 1]) +He9 = np.array([0, 945, 0, -1260, 0, 378, 0, -36, 0, 1]) + +Helist = [He0, He1, He2, He3, He4, He5, He6, He7, He8, He9] + + +def trim(x): + return herme.hermetrim(x, tol=1e-6) + + +class TestConstants: + + def test_hermedomain(self): + assert_equal(herme.hermedomain, [-1, 1]) + + def test_hermezero(self): + assert_equal(herme.hermezero, [0]) + + def test_hermeone(self): + assert_equal(herme.hermeone, [1]) + + def test_hermex(self): + assert_equal(herme.hermex, [0, 1]) + + +class TestArithmetic: + x = np.linspace(-3, 3, 100) + + def test_hermeadd(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + tgt = np.zeros(max(i, j) + 1) + tgt[i] += 1 + tgt[j] += 1 + res = herme.hermeadd([0]*i + [1], [0]*j + [1]) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_hermesub(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + tgt = np.zeros(max(i, j) + 1) + tgt[i] += 1 + tgt[j] -= 1 + res = herme.hermesub([0]*i + [1], [0]*j + [1]) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_hermemulx(self): + assert_equal(herme.hermemulx([0]), [0]) + assert_equal(herme.hermemulx([1]), [0, 1]) + for i in range(1, 5): + ser = [0]*i + [1] + tgt = [0]*(i - 1) + [i, 0, 1] + assert_equal(herme.hermemulx(ser), tgt) + + def test_hermemul(self): + # check values of result + for i in range(5): + pol1 = [0]*i + [1] + val1 = herme.hermeval(self.x, pol1) + for j in range(5): + msg = f"At i={i}, j={j}" + pol2 = [0]*j + [1] + val2 = herme.hermeval(self.x, pol2) + pol3 = herme.hermemul(pol1, pol2) + val3 = herme.hermeval(self.x, pol3) + assert_(len(pol3) == i + j + 1, msg) + assert_almost_equal(val3, val1*val2, err_msg=msg) + + def test_hermediv(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + ci = [0]*i + [1] + cj = [0]*j + [1] + tgt = herme.hermeadd(ci, cj) + quo, rem = herme.hermediv(tgt, ci) + res = herme.hermeadd(herme.hermemul(quo, ci), rem) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + def test_hermepow(self): + for i in range(5): + for j in range(5): + msg = f"At i={i}, j={j}" + c = np.arange(i + 1) + tgt = reduce(herme.hermemul, [c]*j, np.array([1])) + res = herme.hermepow(c, j) + assert_equal(trim(res), trim(tgt), err_msg=msg) + + +class TestEvaluation: + # coefficients of 1 + 2*x + 3*x**2 + c1d = np.array([4., 2., 3.]) + c2d = np.einsum('i,j->ij', c1d, c1d) + c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d) + + # some random values in [-1, 1) + x = np.random.random((3, 5))*2 - 1 + y = polyval(x, [1., 2., 3.]) + + def test_hermeval(self): + #check empty input + assert_equal(herme.hermeval([], [1]).size, 0) + + #check normal input) + x = np.linspace(-1, 1) + y = [polyval(x, c) for c in Helist] + for i in range(10): + msg = f"At i={i}" + tgt = y[i] + res = herme.hermeval(x, [0]*i + [1]) + assert_almost_equal(res, tgt, err_msg=msg) + + #check that shape is preserved + for i in range(3): + dims = [2]*i + x = np.zeros(dims) + assert_equal(herme.hermeval(x, [1]).shape, dims) + assert_equal(herme.hermeval(x, [1, 0]).shape, dims) + assert_equal(herme.hermeval(x, [1, 0, 0]).shape, dims) + + def test_hermeval2d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test exceptions + assert_raises(ValueError, herme.hermeval2d, x1, x2[:2], self.c2d) + + #test values + tgt = y1*y2 + res = herme.hermeval2d(x1, x2, self.c2d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = herme.hermeval2d(z, z, self.c2d) + assert_(res.shape == (2, 3)) + + def test_hermeval3d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test exceptions + assert_raises(ValueError, herme.hermeval3d, x1, x2, x3[:2], self.c3d) + + #test values + tgt = y1*y2*y3 + res = herme.hermeval3d(x1, x2, x3, self.c3d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = herme.hermeval3d(z, z, z, self.c3d) + assert_(res.shape == (2, 3)) + + def test_hermegrid2d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test values + tgt = np.einsum('i,j->ij', y1, y2) + res = herme.hermegrid2d(x1, x2, self.c2d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = herme.hermegrid2d(z, z, self.c2d) + assert_(res.shape == (2, 3)*2) + + def test_hermegrid3d(self): + x1, x2, x3 = self.x + y1, y2, y3 = self.y + + #test values + tgt = np.einsum('i,j,k->ijk', y1, y2, y3) + res = herme.hermegrid3d(x1, x2, x3, self.c3d) + assert_almost_equal(res, tgt) + + #test shape + z = np.ones((2, 3)) + res = herme.hermegrid3d(z, z, z, self.c3d) + assert_(res.shape == (2, 3)*3) + + +class TestIntegral: + + def test_hermeint(self): + # check exceptions + assert_raises(TypeError, herme.hermeint, [0], .5) + assert_raises(ValueError, herme.hermeint, [0], -1) + assert_raises(ValueError, herme.hermeint, [0], 1, [0, 0]) + assert_raises(ValueError, herme.hermeint, [0], lbnd=[0]) + assert_raises(ValueError, herme.hermeint, [0], scl=[0]) + assert_raises(TypeError, herme.hermeint, [0], axis=.5) + + # test integration of zero polynomial + for i in range(2, 5): + k = [0]*(i - 2) + [1] + res = herme.hermeint([0], m=i, k=k) + assert_almost_equal(res, [0, 1]) + + # check single integration with integration constant + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + tgt = [i] + [0]*i + [1/scl] + hermepol = herme.poly2herme(pol) + hermeint = herme.hermeint(hermepol, m=1, k=[i]) + res = herme.herme2poly(hermeint) + assert_almost_equal(trim(res), trim(tgt)) + + # check single integration with integration constant and lbnd + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + hermepol = herme.poly2herme(pol) + hermeint = herme.hermeint(hermepol, m=1, k=[i], lbnd=-1) + assert_almost_equal(herme.hermeval(-1, hermeint), i) + + # check single integration with integration constant and scaling + for i in range(5): + scl = i + 1 + pol = [0]*i + [1] + tgt = [i] + [0]*i + [2/scl] + hermepol = herme.poly2herme(pol) + hermeint = herme.hermeint(hermepol, m=1, k=[i], scl=2) + res = herme.herme2poly(hermeint) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with default k + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = herme.hermeint(tgt, m=1) + res = herme.hermeint(pol, m=j) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with defined k + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = herme.hermeint(tgt, m=1, k=[k]) + res = herme.hermeint(pol, m=j, k=list(range(j))) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with lbnd + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = herme.hermeint(tgt, m=1, k=[k], lbnd=-1) + res = herme.hermeint(pol, m=j, k=list(range(j)), lbnd=-1) + assert_almost_equal(trim(res), trim(tgt)) + + # check multiple integrations with scaling + for i in range(5): + for j in range(2, 5): + pol = [0]*i + [1] + tgt = pol[:] + for k in range(j): + tgt = herme.hermeint(tgt, m=1, k=[k], scl=2) + res = herme.hermeint(pol, m=j, k=list(range(j)), scl=2) + assert_almost_equal(trim(res), trim(tgt)) + + def test_hermeint_axis(self): + # check that axis keyword works + c2d = np.random.random((3, 4)) + + tgt = np.vstack([herme.hermeint(c) for c in c2d.T]).T + res = herme.hermeint(c2d, axis=0) + assert_almost_equal(res, tgt) + + tgt = np.vstack([herme.hermeint(c) for c in c2d]) + res = herme.hermeint(c2d, axis=1) + assert_almost_equal(res, tgt) + + tgt = np.vstack([herme.hermeint(c, k=3) for c in c2d]) + res = herme.hermeint(c2d, k=3, axis=1) + assert_almost_equal(res, tgt) + + +class TestDerivative: + + def test_hermeder(self): + # check exceptions + assert_raises(TypeError, herme.hermeder, [0], .5) + assert_raises(ValueError, herme.hermeder, [0], -1) + + # check that zeroth derivative does nothing + for i in range(5): + tgt = [0]*i + [1] + res = herme.hermeder(tgt, m=0) + assert_equal(trim(res), trim(tgt)) + + # check that derivation is the inverse of integration + for i in range(5): + for j in range(2, 5): + tgt = [0]*i + [1] + res = herme.hermeder(herme.hermeint(tgt, m=j), m=j) + assert_almost_equal(trim(res), trim(tgt)) + + # check derivation with scaling + for i in range(5): + for j in range(2, 5): + tgt = [0]*i + [1] + res = herme.hermeder( + herme.hermeint(tgt, m=j, scl=2), m=j, scl=.5) + assert_almost_equal(trim(res), trim(tgt)) + + def test_hermeder_axis(self): + # check that axis keyword works + c2d = np.random.random((3, 4)) + + tgt = np.vstack([herme.hermeder(c) for c in c2d.T]).T + res = herme.hermeder(c2d, axis=0) + assert_almost_equal(res, tgt) + + tgt = np.vstack([herme.hermeder(c) for c in c2d]) + res = herme.hermeder(c2d, axis=1) + assert_almost_equal(res, tgt) + + +class TestVander: + # some random values in [-1, 1) + x = np.random.random((3, 5))*2 - 1 + + def test_hermevander(self): + # check for 1d x + x = np.arange(3) + v = herme.hermevander(x, 3) + assert_(v.shape == (3, 4)) + for i in range(4): + coef = [0]*i + [1] + assert_almost_equal(v[..., i], herme.hermeval(x, coef)) + + # check for 2d x + x = np.array([[1, 2], [3, 4], [5, 6]]) + v = herme.hermevander(x, 3) + assert_(v.shape == (3, 2, 4)) + for i in range(4): + coef = [0]*i + [1] + assert_almost_equal(v[..., i], herme.hermeval(x, coef)) + + def test_hermevander2d(self): + # also tests hermeval2d for non-square coefficient array + x1, x2, x3 = self.x + c = np.random.random((2, 3)) + van = herme.hermevander2d(x1, x2, [1, 2]) + tgt = herme.hermeval2d(x1, x2, c) + res = np.dot(van, c.flat) + assert_almost_equal(res, tgt) + + # check shape + van = herme.hermevander2d([x1], [x2], [1, 2]) + assert_(van.shape == (1, 5, 6)) + + def test_hermevander3d(self): + # also tests hermeval3d for non-square coefficient array + x1, x2, x3 = self.x + c = np.random.random((2, 3, 4)) + van = herme.hermevander3d(x1, x2, x3, [1, 2, 3]) + tgt = herme.hermeval3d(x1, x2, x3, c) + res = np.dot(van, c.flat) + assert_almost_equal(res, tgt) + + # check shape + van = herme.hermevander3d([x1], [x2], [x3], [1, 2, 3]) + assert_(van.shape == (1, 5, 24)) + + +class TestFitting: + + def test_hermefit(self): + def f(x): + return x*(x - 1)*(x - 2) + + def f2(x): + return x**4 + x**2 + 1 + + # Test exceptions + assert_raises(ValueError, herme.hermefit, [1], [1], -1) + assert_raises(TypeError, herme.hermefit, [[1]], [1], 0) + assert_raises(TypeError, herme.hermefit, [], [1], 0) + assert_raises(TypeError, herme.hermefit, [1], [[[1]]], 0) + assert_raises(TypeError, herme.hermefit, [1, 2], [1], 0) + assert_raises(TypeError, herme.hermefit, [1], [1, 2], 0) + assert_raises(TypeError, herme.hermefit, [1], [1], 0, w=[[1]]) + assert_raises(TypeError, herme.hermefit, [1], [1], 0, w=[1, 1]) + assert_raises(ValueError, herme.hermefit, [1], [1], [-1,]) + assert_raises(ValueError, herme.hermefit, [1], [1], [2, -1, 6]) + assert_raises(TypeError, herme.hermefit, [1], [1], []) + + # Test fit + x = np.linspace(0, 2) + y = f(x) + # + coef3 = herme.hermefit(x, y, 3) + assert_equal(len(coef3), 4) + assert_almost_equal(herme.hermeval(x, coef3), y) + coef3 = herme.hermefit(x, y, [0, 1, 2, 3]) + assert_equal(len(coef3), 4) + assert_almost_equal(herme.hermeval(x, coef3), y) + # + coef4 = herme.hermefit(x, y, 4) + assert_equal(len(coef4), 5) + assert_almost_equal(herme.hermeval(x, coef4), y) + coef4 = herme.hermefit(x, y, [0, 1, 2, 3, 4]) + assert_equal(len(coef4), 5) + assert_almost_equal(herme.hermeval(x, coef4), y) + # check things still work if deg is not in strict increasing + coef4 = herme.hermefit(x, y, [2, 3, 4, 1, 0]) + assert_equal(len(coef4), 5) + assert_almost_equal(herme.hermeval(x, coef4), y) + # + coef2d = herme.hermefit(x, np.array([y, y]).T, 3) + assert_almost_equal(coef2d, np.array([coef3, coef3]).T) + coef2d = herme.hermefit(x, np.array([y, y]).T, [0, 1, 2, 3]) + assert_almost_equal(coef2d, np.array([coef3, coef3]).T) + # test weighting + w = np.zeros_like(x) + yw = y.copy() + w[1::2] = 1 + y[0::2] = 0 + wcoef3 = herme.hermefit(x, yw, 3, w=w) + assert_almost_equal(wcoef3, coef3) + wcoef3 = herme.hermefit(x, yw, [0, 1, 2, 3], w=w) + assert_almost_equal(wcoef3, coef3) + # + wcoef2d = herme.hermefit(x, np.array([yw, yw]).T, 3, w=w) + assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T) + wcoef2d = herme.hermefit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w) + assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T) + # test scaling with complex values x points whose square + # is zero when summed. + x = [1, 1j, -1, -1j] + assert_almost_equal(herme.hermefit(x, x, 1), [0, 1]) + assert_almost_equal(herme.hermefit(x, x, [0, 1]), [0, 1]) + # test fitting only even Legendre polynomials + x = np.linspace(-1, 1) + y = f2(x) + coef1 = herme.hermefit(x, y, 4) + assert_almost_equal(herme.hermeval(x, coef1), y) + coef2 = herme.hermefit(x, y, [0, 2, 4]) + assert_almost_equal(herme.hermeval(x, coef2), y) + assert_almost_equal(coef1, coef2) + + +class TestCompanion: + + def test_raises(self): + assert_raises(ValueError, herme.hermecompanion, []) + assert_raises(ValueError, herme.hermecompanion, [1]) + + def test_dimensions(self): + for i in range(1, 5): + coef = [0]*i + [1] + assert_(herme.hermecompanion(coef).shape == (i, i)) + + def test_linear_root(self): + assert_(herme.hermecompanion([1, 2])[0, 0] == -.5) + + +class TestGauss: + + def test_100(self): + x, w = herme.hermegauss(100) + + # test orthogonality. Note that the results need to be normalized, + # otherwise the huge values that can arise from fast growing + # functions like Laguerre can be very confusing. + v = herme.hermevander(x, 99) + vv = np.dot(v.T * w, v) + vd = 1/np.sqrt(vv.diagonal()) + vv = vd[:, None] * vv * vd + assert_almost_equal(vv, np.eye(100)) + + # check that the integral of 1 is correct + tgt = np.sqrt(2*np.pi) + assert_almost_equal(w.sum(), tgt) + + +class TestMisc: + + def test_hermefromroots(self): + res = herme.hermefromroots([]) + assert_almost_equal(trim(res), [1]) + for i in range(1, 5): + roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2]) + pol = herme.hermefromroots(roots) + res = herme.hermeval(roots, pol) + tgt = 0 + assert_(len(pol) == i + 1) + assert_almost_equal(herme.herme2poly(pol)[-1], 1) + assert_almost_equal(res, tgt) + + def test_hermeroots(self): + assert_almost_equal(herme.hermeroots([1]), []) + assert_almost_equal(herme.hermeroots([1, 1]), [-1]) + for i in range(2, 5): + tgt = np.linspace(-1, 1, i) + res = herme.hermeroots(herme.hermefromroots(tgt)) + assert_almost_equal(trim(res), trim(tgt)) + + def test_hermetrim(self): + coef = [2, -1, 1, 0] + + # Test exceptions + assert_raises(ValueError, herme.hermetrim, coef, -1) + + # Test results + assert_equal(herme.hermetrim(coef), coef[:-1]) + assert_equal(herme.hermetrim(coef, 1), coef[:-3]) + assert_equal(herme.hermetrim(coef, 2), [0]) + + def test_hermeline(self): + assert_equal(herme.hermeline(3, 4), [3, 4]) + + def test_herme2poly(self): + for i in range(10): + assert_almost_equal(herme.herme2poly([0]*i + [1]), Helist[i]) + + def test_poly2herme(self): + for i in range(10): + assert_almost_equal(herme.poly2herme(Helist[i]), [0]*i + [1]) + + def test_weight(self): + x = np.linspace(-5, 5, 11) + tgt = np.exp(-.5*x**2) + res = herme.hermeweight(x) + assert_almost_equal(res, tgt) diff --git a/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_printing.py b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2a5092d7225c797b60fd8f2602f2f9276cdd74 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_printing.py @@ -0,0 +1,530 @@ +from math import nan, inf +import pytest +from numpy.core import array, arange, printoptions +import numpy.polynomial as poly +from numpy.testing import assert_equal, assert_ + +# For testing polynomial printing with object arrays +from fractions import Fraction +from decimal import Decimal + + +class TestStrUnicodeSuperSubscripts: + + @pytest.fixture(scope='class', autouse=True) + def use_unicode(self): + poly.set_default_printstyle('unicode') + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·x + 3.0·x²"), + ([-1, 0, 3, -1], "-1.0 + 0.0·x + 3.0·x² - 1.0·x³"), + (arange(12), ("0.0 + 1.0·x + 2.0·x² + 3.0·x³ + 4.0·x⁴ + 5.0·x⁵ + " + "6.0·x⁶ + 7.0·x⁷ +\n8.0·x⁸ + 9.0·x⁹ + 10.0·x¹⁰ + " + "11.0·x¹¹")), + )) + def test_polynomial_str(self, inp, tgt): + res = str(poly.Polynomial(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·T₁(x) + 3.0·T₂(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0·T₁(x) + 3.0·T₂(x) - 1.0·T₃(x)"), + (arange(12), ("0.0 + 1.0·T₁(x) + 2.0·T₂(x) + 3.0·T₃(x) + 4.0·T₄(x) + " + "5.0·T₅(x) +\n6.0·T₆(x) + 7.0·T₇(x) + 8.0·T₈(x) + " + "9.0·T₉(x) + 10.0·T₁₀(x) + 11.0·T₁₁(x)")), + )) + def test_chebyshev_str(self, inp, tgt): + res = str(poly.Chebyshev(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·P₁(x) + 3.0·P₂(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0·P₁(x) + 3.0·P₂(x) - 1.0·P₃(x)"), + (arange(12), ("0.0 + 1.0·P₁(x) + 2.0·P₂(x) + 3.0·P₃(x) + 4.0·P₄(x) + " + "5.0·P₅(x) +\n6.0·P₆(x) + 7.0·P₇(x) + 8.0·P₈(x) + " + "9.0·P₉(x) + 10.0·P₁₀(x) + 11.0·P₁₁(x)")), + )) + def test_legendre_str(self, inp, tgt): + res = str(poly.Legendre(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·H₁(x) + 3.0·H₂(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0·H₁(x) + 3.0·H₂(x) - 1.0·H₃(x)"), + (arange(12), ("0.0 + 1.0·H₁(x) + 2.0·H₂(x) + 3.0·H₃(x) + 4.0·H₄(x) + " + "5.0·H₅(x) +\n6.0·H₆(x) + 7.0·H₇(x) + 8.0·H₈(x) + " + "9.0·H₉(x) + 10.0·H₁₀(x) + 11.0·H₁₁(x)")), + )) + def test_hermite_str(self, inp, tgt): + res = str(poly.Hermite(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·He₁(x) + 3.0·He₂(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0·He₁(x) + 3.0·He₂(x) - 1.0·He₃(x)"), + (arange(12), ("0.0 + 1.0·He₁(x) + 2.0·He₂(x) + 3.0·He₃(x) + " + "4.0·He₄(x) + 5.0·He₅(x) +\n6.0·He₆(x) + 7.0·He₇(x) + " + "8.0·He₈(x) + 9.0·He₉(x) + 10.0·He₁₀(x) +\n" + "11.0·He₁₁(x)")), + )) + def test_hermiteE_str(self, inp, tgt): + res = str(poly.HermiteE(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0·L₁(x) + 3.0·L₂(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0·L₁(x) + 3.0·L₂(x) - 1.0·L₃(x)"), + (arange(12), ("0.0 + 1.0·L₁(x) + 2.0·L₂(x) + 3.0·L₃(x) + 4.0·L₄(x) + " + "5.0·L₅(x) +\n6.0·L₆(x) + 7.0·L₇(x) + 8.0·L₈(x) + " + "9.0·L₉(x) + 10.0·L₁₀(x) + 11.0·L₁₁(x)")), + )) + def test_laguerre_str(self, inp, tgt): + res = str(poly.Laguerre(inp)) + assert_equal(res, tgt) + + +class TestStrAscii: + + @pytest.fixture(scope='class', autouse=True) + def use_ascii(self): + poly.set_default_printstyle('ascii') + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 x + 3.0 x**2"), + ([-1, 0, 3, -1], "-1.0 + 0.0 x + 3.0 x**2 - 1.0 x**3"), + (arange(12), ("0.0 + 1.0 x + 2.0 x**2 + 3.0 x**3 + 4.0 x**4 + " + "5.0 x**5 + 6.0 x**6 +\n7.0 x**7 + 8.0 x**8 + " + "9.0 x**9 + 10.0 x**10 + 11.0 x**11")), + )) + def test_polynomial_str(self, inp, tgt): + res = str(poly.Polynomial(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 T_1(x) + 3.0 T_2(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0 T_1(x) + 3.0 T_2(x) - 1.0 T_3(x)"), + (arange(12), ("0.0 + 1.0 T_1(x) + 2.0 T_2(x) + 3.0 T_3(x) + " + "4.0 T_4(x) + 5.0 T_5(x) +\n6.0 T_6(x) + 7.0 T_7(x) + " + "8.0 T_8(x) + 9.0 T_9(x) + 10.0 T_10(x) +\n" + "11.0 T_11(x)")), + )) + def test_chebyshev_str(self, inp, tgt): + res = str(poly.Chebyshev(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 P_1(x) + 3.0 P_2(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0 P_1(x) + 3.0 P_2(x) - 1.0 P_3(x)"), + (arange(12), ("0.0 + 1.0 P_1(x) + 2.0 P_2(x) + 3.0 P_3(x) + " + "4.0 P_4(x) + 5.0 P_5(x) +\n6.0 P_6(x) + 7.0 P_7(x) + " + "8.0 P_8(x) + 9.0 P_9(x) + 10.0 P_10(x) +\n" + "11.0 P_11(x)")), + )) + def test_legendre_str(self, inp, tgt): + res = str(poly.Legendre(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 H_1(x) + 3.0 H_2(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0 H_1(x) + 3.0 H_2(x) - 1.0 H_3(x)"), + (arange(12), ("0.0 + 1.0 H_1(x) + 2.0 H_2(x) + 3.0 H_3(x) + " + "4.0 H_4(x) + 5.0 H_5(x) +\n6.0 H_6(x) + 7.0 H_7(x) + " + "8.0 H_8(x) + 9.0 H_9(x) + 10.0 H_10(x) +\n" + "11.0 H_11(x)")), + )) + def test_hermite_str(self, inp, tgt): + res = str(poly.Hermite(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 He_1(x) + 3.0 He_2(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0 He_1(x) + 3.0 He_2(x) - 1.0 He_3(x)"), + (arange(12), ("0.0 + 1.0 He_1(x) + 2.0 He_2(x) + 3.0 He_3(x) + " + "4.0 He_4(x) +\n5.0 He_5(x) + 6.0 He_6(x) + " + "7.0 He_7(x) + 8.0 He_8(x) + 9.0 He_9(x) +\n" + "10.0 He_10(x) + 11.0 He_11(x)")), + )) + def test_hermiteE_str(self, inp, tgt): + res = str(poly.HermiteE(inp)) + assert_equal(res, tgt) + + @pytest.mark.parametrize(('inp', 'tgt'), ( + ([1, 2, 3], "1.0 + 2.0 L_1(x) + 3.0 L_2(x)"), + ([-1, 0, 3, -1], "-1.0 + 0.0 L_1(x) + 3.0 L_2(x) - 1.0 L_3(x)"), + (arange(12), ("0.0 + 1.0 L_1(x) + 2.0 L_2(x) + 3.0 L_3(x) + " + "4.0 L_4(x) + 5.0 L_5(x) +\n6.0 L_6(x) + 7.0 L_7(x) + " + "8.0 L_8(x) + 9.0 L_9(x) + 10.0 L_10(x) +\n" + "11.0 L_11(x)")), + )) + def test_laguerre_str(self, inp, tgt): + res = str(poly.Laguerre(inp)) + assert_equal(res, tgt) + + +class TestLinebreaking: + + @pytest.fixture(scope='class', autouse=True) + def use_ascii(self): + poly.set_default_printstyle('ascii') + + def test_single_line_one_less(self): + # With 'ascii' style, len(str(p)) is default linewidth - 1 (i.e. 74) + p = poly.Polynomial([12345678, 12345678, 12345678, 12345678, 123]) + assert_equal(len(str(p)), 74) + assert_equal(str(p), ( + '12345678.0 + 12345678.0 x + 12345678.0 x**2 + ' + '12345678.0 x**3 + 123.0 x**4' + )) + + def test_num_chars_is_linewidth(self): + # len(str(p)) == default linewidth == 75 + p = poly.Polynomial([12345678, 12345678, 12345678, 12345678, 1234]) + assert_equal(len(str(p)), 75) + assert_equal(str(p), ( + '12345678.0 + 12345678.0 x + 12345678.0 x**2 + ' + '12345678.0 x**3 +\n1234.0 x**4' + )) + + def test_first_linebreak_multiline_one_less_than_linewidth(self): + # Multiline str where len(first_line) + len(next_term) == lw - 1 == 74 + p = poly.Polynomial( + [12345678, 12345678, 12345678, 12345678, 1, 12345678] + ) + assert_equal(len(str(p).split('\n')[0]), 74) + assert_equal(str(p), ( + '12345678.0 + 12345678.0 x + 12345678.0 x**2 + ' + '12345678.0 x**3 + 1.0 x**4 +\n12345678.0 x**5' + )) + + def test_first_linebreak_multiline_on_linewidth(self): + # First line is one character longer than previous test + p = poly.Polynomial( + [12345678, 12345678, 12345678, 12345678.12, 1, 12345678] + ) + assert_equal(str(p), ( + '12345678.0 + 12345678.0 x + 12345678.0 x**2 + ' + '12345678.12 x**3 +\n1.0 x**4 + 12345678.0 x**5' + )) + + @pytest.mark.parametrize(('lw', 'tgt'), ( + (75, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 + ' + '500000.0 x**5 +\n600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + ' + '900.0 x**9')), + (45, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 +\n40000.0 x**4 + ' + '500000.0 x**5 +\n600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 +\n' + '900.0 x**9')), + (132, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 + ' + '500000.0 x**5 + 600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + ' + '900.0 x**9')), + )) + def test_linewidth_printoption(self, lw, tgt): + p = poly.Polynomial( + [0, 10, 200, 3000, 40000, 500000, 600000, 70000, 8000, 900] + ) + with printoptions(linewidth=lw): + assert_equal(str(p), tgt) + for line in str(p).split('\n'): + assert_(len(line) < lw) + + +def test_set_default_printoptions(): + p = poly.Polynomial([1, 2, 3]) + c = poly.Chebyshev([1, 2, 3]) + poly.set_default_printstyle('ascii') + assert_equal(str(p), "1.0 + 2.0 x + 3.0 x**2") + assert_equal(str(c), "1.0 + 2.0 T_1(x) + 3.0 T_2(x)") + poly.set_default_printstyle('unicode') + assert_equal(str(p), "1.0 + 2.0·x + 3.0·x²") + assert_equal(str(c), "1.0 + 2.0·T₁(x) + 3.0·T₂(x)") + with pytest.raises(ValueError): + poly.set_default_printstyle('invalid_input') + + +def test_complex_coefficients(): + """Test both numpy and built-in complex.""" + coefs = [0+1j, 1+1j, -2+2j, 3+0j] + # numpy complex + p1 = poly.Polynomial(coefs) + # Python complex + p2 = poly.Polynomial(array(coefs, dtype=object)) + poly.set_default_printstyle('unicode') + assert_equal(str(p1), "1j + (1+1j)·x - (2-2j)·x² + (3+0j)·x³") + assert_equal(str(p2), "1j + (1+1j)·x + (-2+2j)·x² + (3+0j)·x³") + poly.set_default_printstyle('ascii') + assert_equal(str(p1), "1j + (1+1j) x - (2-2j) x**2 + (3+0j) x**3") + assert_equal(str(p2), "1j + (1+1j) x + (-2+2j) x**2 + (3+0j) x**3") + + +@pytest.mark.parametrize(('coefs', 'tgt'), ( + (array([Fraction(1, 2), Fraction(3, 4)], dtype=object), ( + "1/2 + 3/4·x" + )), + (array([1, 2, Fraction(5, 7)], dtype=object), ( + "1 + 2·x + 5/7·x²" + )), + (array([Decimal('1.00'), Decimal('2.2'), 3], dtype=object), ( + "1.00 + 2.2·x + 3·x²" + )), +)) +def test_numeric_object_coefficients(coefs, tgt): + p = poly.Polynomial(coefs) + poly.set_default_printstyle('unicode') + assert_equal(str(p), tgt) + + +@pytest.mark.parametrize(('coefs', 'tgt'), ( + (array([1, 2, 'f'], dtype=object), '1 + 2·x + f·x²'), + (array([1, 2, [3, 4]], dtype=object), '1 + 2·x + [3, 4]·x²'), +)) +def test_nonnumeric_object_coefficients(coefs, tgt): + """ + Test coef fallback for object arrays of non-numeric coefficients. + """ + p = poly.Polynomial(coefs) + poly.set_default_printstyle('unicode') + assert_equal(str(p), tgt) + + +class TestFormat: + def test_format_unicode(self): + poly.set_default_printstyle('ascii') + p = poly.Polynomial([1, 2, 0, -1]) + assert_equal(format(p, 'unicode'), "1.0 + 2.0·x + 0.0·x² - 1.0·x³") + + def test_format_ascii(self): + poly.set_default_printstyle('unicode') + p = poly.Polynomial([1, 2, 0, -1]) + assert_equal( + format(p, 'ascii'), "1.0 + 2.0 x + 0.0 x**2 - 1.0 x**3" + ) + + def test_empty_formatstr(self): + poly.set_default_printstyle('ascii') + p = poly.Polynomial([1, 2, 3]) + assert_equal(format(p), "1.0 + 2.0 x + 3.0 x**2") + assert_equal(f"{p}", "1.0 + 2.0 x + 3.0 x**2") + + def test_bad_formatstr(self): + p = poly.Polynomial([1, 2, 0, -1]) + with pytest.raises(ValueError): + format(p, '.2f') + + +@pytest.mark.parametrize(('poly', 'tgt'), ( + (poly.Polynomial, '1.0 + 2.0·z + 3.0·z²'), + (poly.Chebyshev, '1.0 + 2.0·T₁(z) + 3.0·T₂(z)'), + (poly.Hermite, '1.0 + 2.0·H₁(z) + 3.0·H₂(z)'), + (poly.HermiteE, '1.0 + 2.0·He₁(z) + 3.0·He₂(z)'), + (poly.Laguerre, '1.0 + 2.0·L₁(z) + 3.0·L₂(z)'), + (poly.Legendre, '1.0 + 2.0·P₁(z) + 3.0·P₂(z)'), +)) +def test_symbol(poly, tgt): + p = poly([1, 2, 3], symbol='z') + assert_equal(f"{p:unicode}", tgt) + + +class TestRepr: + def test_polynomial_str(self): + res = repr(poly.Polynomial([0, 1])) + tgt = ( + "Polynomial([0., 1.], domain=[-1, 1], window=[-1, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + def test_chebyshev_str(self): + res = repr(poly.Chebyshev([0, 1])) + tgt = ( + "Chebyshev([0., 1.], domain=[-1, 1], window=[-1, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + def test_legendre_repr(self): + res = repr(poly.Legendre([0, 1])) + tgt = ( + "Legendre([0., 1.], domain=[-1, 1], window=[-1, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + def test_hermite_repr(self): + res = repr(poly.Hermite([0, 1])) + tgt = ( + "Hermite([0., 1.], domain=[-1, 1], window=[-1, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + def test_hermiteE_repr(self): + res = repr(poly.HermiteE([0, 1])) + tgt = ( + "HermiteE([0., 1.], domain=[-1, 1], window=[-1, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + def test_laguerre_repr(self): + res = repr(poly.Laguerre([0, 1])) + tgt = ( + "Laguerre([0., 1.], domain=[0, 1], window=[0, 1], " + "symbol='x')" + ) + assert_equal(res, tgt) + + +class TestLatexRepr: + """Test the latex repr used by Jupyter""" + + def as_latex(self, obj): + # right now we ignore the formatting of scalars in our tests, since + # it makes them too verbose. Ideally, the formatting of scalars will + # be fixed such that tests below continue to pass + obj._repr_latex_scalar = lambda x, parens=False: str(x) + try: + return obj._repr_latex_() + finally: + del obj._repr_latex_scalar + + def test_simple_polynomial(self): + # default input + p = poly.Polynomial([1, 2, 3]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0 + 2.0\,x + 3.0\,x^{2}$') + + # translated input + p = poly.Polynomial([1, 2, 3], domain=[-2, 0]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0 + 2.0\,\left(1.0 + x\right) + 3.0\,\left(1.0 + x\right)^{2}$') + + # scaled input + p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0 + 2.0\,\left(2.0x\right) + 3.0\,\left(2.0x\right)^{2}$') + + # affine input + p = poly.Polynomial([1, 2, 3], domain=[-1, 0]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0 + 2.0\,\left(1.0 + 2.0x\right) + 3.0\,\left(1.0 + 2.0x\right)^{2}$') + + def test_basis_func(self): + p = poly.Chebyshev([1, 2, 3]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0\,{T}_{0}(x) + 2.0\,{T}_{1}(x) + 3.0\,{T}_{2}(x)$') + # affine input - check no surplus parens are added + p = poly.Chebyshev([1, 2, 3], domain=[-1, 0]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0\,{T}_{0}(1.0 + 2.0x) + 2.0\,{T}_{1}(1.0 + 2.0x) + 3.0\,{T}_{2}(1.0 + 2.0x)$') + + def test_multichar_basis_func(self): + p = poly.HermiteE([1, 2, 3]) + assert_equal(self.as_latex(p), + r'$x \mapsto 1.0\,{He}_{0}(x) + 2.0\,{He}_{1}(x) + 3.0\,{He}_{2}(x)$') + + def test_symbol_basic(self): + # default input + p = poly.Polynomial([1, 2, 3], symbol='z') + assert_equal(self.as_latex(p), + r'$z \mapsto 1.0 + 2.0\,z + 3.0\,z^{2}$') + + # translated input + p = poly.Polynomial([1, 2, 3], domain=[-2, 0], symbol='z') + assert_equal( + self.as_latex(p), + ( + r'$z \mapsto 1.0 + 2.0\,\left(1.0 + z\right) + 3.0\,' + r'\left(1.0 + z\right)^{2}$' + ), + ) + + # scaled input + p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5], symbol='z') + assert_equal( + self.as_latex(p), + ( + r'$z \mapsto 1.0 + 2.0\,\left(2.0z\right) + 3.0\,' + r'\left(2.0z\right)^{2}$' + ), + ) + + # affine input + p = poly.Polynomial([1, 2, 3], domain=[-1, 0], symbol='z') + assert_equal( + self.as_latex(p), + ( + r'$z \mapsto 1.0 + 2.0\,\left(1.0 + 2.0z\right) + 3.0\,' + r'\left(1.0 + 2.0z\right)^{2}$' + ), + ) + + +SWITCH_TO_EXP = ( + '1.0 + (1.0e-01) x + (1.0e-02) x**2', + '1.2 + (1.2e-01) x + (1.2e-02) x**2', + '1.23 + 0.12 x + (1.23e-02) x**2 + (1.23e-03) x**3', + '1.235 + 0.123 x + (1.235e-02) x**2 + (1.235e-03) x**3', + '1.2346 + 0.1235 x + 0.0123 x**2 + (1.2346e-03) x**3 + (1.2346e-04) x**4', + '1.23457 + 0.12346 x + 0.01235 x**2 + (1.23457e-03) x**3 + ' + '(1.23457e-04) x**4', + '1.234568 + 0.123457 x + 0.012346 x**2 + 0.001235 x**3 + ' + '(1.234568e-04) x**4 + (1.234568e-05) x**5', + '1.2345679 + 0.1234568 x + 0.0123457 x**2 + 0.0012346 x**3 + ' + '(1.2345679e-04) x**4 + (1.2345679e-05) x**5') + +class TestPrintOptions: + """ + Test the output is properly configured via printoptions. + The exponential notation is enabled automatically when the values + are too small or too large. + """ + + @pytest.fixture(scope='class', autouse=True) + def use_ascii(self): + poly.set_default_printstyle('ascii') + + def test_str(self): + p = poly.Polynomial([1/2, 1/7, 1/7*10**8, 1/7*10**9]) + assert_equal(str(p), '0.5 + 0.14285714 x + 14285714.28571429 x**2 ' + '+ (1.42857143e+08) x**3') + + with printoptions(precision=3): + assert_equal(str(p), '0.5 + 0.143 x + 14285714.286 x**2 ' + '+ (1.429e+08) x**3') + + def test_latex(self): + p = poly.Polynomial([1/2, 1/7, 1/7*10**8, 1/7*10**9]) + assert_equal(p._repr_latex_(), + r'$x \mapsto \text{0.5} + \text{0.14285714}\,x + ' + r'\text{14285714.28571429}\,x^{2} + ' + r'\text{(1.42857143e+08)}\,x^{3}$') + + with printoptions(precision=3): + assert_equal(p._repr_latex_(), + r'$x \mapsto \text{0.5} + \text{0.143}\,x + ' + r'\text{14285714.286}\,x^{2} + \text{(1.429e+08)}\,x^{3}$') + + def test_fixed(self): + p = poly.Polynomial([1/2]) + assert_equal(str(p), '0.5') + + with printoptions(floatmode='fixed'): + assert_equal(str(p), '0.50000000') + + with printoptions(floatmode='fixed', precision=4): + assert_equal(str(p), '0.5000') + + def test_switch_to_exp(self): + for i, s in enumerate(SWITCH_TO_EXP): + with printoptions(precision=i): + p = poly.Polynomial([1.23456789*10**-i + for i in range(i//2+3)]) + assert str(p).replace('\n', ' ') == s + + def test_non_finite(self): + p = poly.Polynomial([nan, inf]) + assert str(p) == 'nan + inf x' + assert p._repr_latex_() == r'$x \mapsto \text{nan} + \text{inf}\,x$' + with printoptions(nanstr='NAN', infstr='INF'): + assert str(p) == 'NAN + INF x' + assert p._repr_latex_() == \ + r'$x \mapsto \text{NAN} + \text{INF}\,x$' diff --git a/.venv/lib/python3.11/site-packages/torchgen/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d5dbf0667a022caa07ec30bb10db5b4f83159dd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/__init__.py @@ -0,0 +1,10 @@ +"""torchgen + +This module contains codegeneration utilities for PyTorch. It is used to +build PyTorch from source, but may also be used for out-of-tree projects +that extend PyTorch. + +Note well that we provide no BC guarantees for torchgen. If you're interested +in using torchgen and want the PyTorch team to be aware, please reach out +on GitHub. +""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be0e3586e91ead00ba1e406677ab05d83c297282 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..018ec245db74861bf6368c8d5dadc202f0445f48 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e2cd21074e164c51eb16e71f9b723ba93532453 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de0312862fe889bffc1151e1df9af9841b8e5c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..184de843439663533bbe3f7711d91d44c48f6e91 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85e1255fc96facd485703c067384140ba70c5bdc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c98bd33a42099a310c6ea0a80bdff9fa60d55760 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4d4e7d14988cbb6b3e4331685a353cf2ba0fe00 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8259cb98ba38ccd55cf3264d36d7c7be763e29f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecd9f81469ca947b3fb00741daa792e2f399f6f3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9f4a7b8b0c1613344339484cb82b8ece4d77e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..868dd2d9972a6cf0276d5e09be09e673b6bfeae0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79ceb03ec6fcbd862d95bee2c865d6ee1a197b14 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0701dd47918b9981e85886cd28a1739ed4ea2665 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/aoti/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/aoti/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0ed6f40a7ff4deb3e7d1737e641631ce6dc337c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..923aa0c82061b02f82aecebd8922a00307aaca82 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/aoti/fallback_ops.py b/.venv/lib/python3.11/site-packages/torchgen/aoti/fallback_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..aa88214b3672f199b2858eeb18ec2917ba3c2d0b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/aoti/fallback_ops.py @@ -0,0 +1,149 @@ +# Be extra careful when you edit this file, because it affects AOTInductor ABI compatbility. See +# https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 +# for details. +# +# The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py. +# Generally speaking, it is ok to add a new op to the list, but you need to run +# `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files. +# But it is NOT ok to remove an existing fallback op from the list, since that will break +# some existing AOTInductor-compiled models. +inductor_fallback_ops = { + "aten._adaptive_avg_pool2d_backward.default", + "aten._adaptive_avg_pool2d.default", + "aten._adaptive_avg_pool3d.default", + "aten._adaptive_avg_pool3d_backward.default", + "aten.adaptive_max_pool2d_backward.default", + "aten.adaptive_max_pool2d.default", + "aten.adaptive_max_pool3d.default", + "aten.adaptive_max_pool3d_backward.default", + "aten.addbmm.default", + "aten._addmm_activation.default", + "aten.addmm.out", + "aten.addmv.default", + "aten.angle.default", + "aten.avg_pool2d_backward.default", + "aten.avg_pool2d.default", + "aten.avg_pool3d_backward.default", + "aten.avg_pool3d.default", + "aten.bernoulli_.float", + "aten.bernoulli_.Tensor", + "aten.bmm.out", + "aten.bucketize.Tensor", + "aten.cat.default", + "aten._cdist_backward.default", + "aten._cdist_forward.default", + "aten.cholesky_inverse.default", + "aten.cholesky_solve.default", + "aten.convolution_backward.default", + "aten._cudnn_rnn.default", + "aten._cudnn_rnn_backward.default", + "aten.convolution.default", + "aten.cummax.default", + "aten.cummin.default", + "aten.cumprod.default", + "aten.cumsum.default", + "aten._efficient_attention_backward.default", + "aten._efficient_attention_forward.default", + "aten._efficientzerotensor.default", + "aten._embedding_bag.default", + "aten._embedding_bag_dense_backward.default", + "aten._embedding_bag_forward_only.default", + "aten._embedding_bag_per_sample_weights_backward.default", + "aten.exponential.default", + "aten._fft_c2c.default", + "aten._fft_r2c.default", + "aten._flash_attention_backward.default", + "aten._flash_attention_forward.default", + "aten.fractional_max_pool2d_backward.default", + "aten.fractional_max_pool2d.default", + "aten.fractional_max_pool3d.default", + "aten.fractional_max_pool3d_backward.default", + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.gcd.default", + "aten.geqrf.default", + "aten.grid_sampler_2d_backward.default", + "aten.histc.default", + "aten.histogram.bin_ct", + "aten._histogramdd_bin_edges.default", + "aten._histogramdd_from_bin_cts.default", + "aten.index_put.default", + "aten.index_reduce.default", + "aten.index.Tensor", + "aten.kthvalue.default", + "aten.logcumsumexp.default", + "aten.lu_unpack.default", + "aten.masked_scatter.default", + "aten.masked_scatter_backward.default", + "aten.max_pool2d_with_indices_backward.default", + "aten.max_pool2d_with_indices.default", + "aten.max_pool3d_with_indices.default", + "aten.max_pool3d_with_indices_backward.default", + "aten.max_unpool2d.default", + "aten.max_unpool3d.default", + "aten.median.default", + "aten.mm.out", + "aten.mode.default", + "aten.mul.Scalar", + "aten.mul.Tensor", + "aten.nanmedian.default", + "aten.native_dropout.default", + "aten.normal_functional.default", + "aten.nonzero.default", + "aten.ormqr.default", + "aten._pdist_backward.default", + "aten._pdist_forward.default", + "aten.polar.default", + "aten.pow.Scalar", + "aten.pow.Tensor_Scalar", + "aten.pow.Tensor_Tensor", + "aten.rand.default", + "aten.rand.generator", + "aten.randint.default", + "aten.randint.generator", + "aten.randint.low", + "aten.randint.low_out", + "aten.randn.default", + "aten.randn.generator", + "aten.randperm.default", + "aten.repeat_interleave.Tensor", + "aten.replication_pad1d_backward.default", + "aten.replication_pad2d_backward.default", + "aten.reshape.default", + "aten.resize_.default", + "aten.resize_as_.default", + "aten._scaled_dot_product_efficient_attention_backward.default", + "aten._scaled_dot_product_efficient_attention.default", + "aten._scaled_dot_product_flash_attention_backward.default", + "aten._scaled_dot_product_flash_attention.default", + "aten._scaled_dot_product_cudnn_attention_backward.default", + "aten._scaled_dot_product_cudnn_attention.default", + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu.default", + "aten._scaled_mm.default", + "aten.scatter_reduce.two_out", + "aten.scatter.src_out", + "aten.scatter.value_out", + "aten.searchsorted.default", + "aten._segment_reduce_backward.default", + "aten.segment_reduce.default", + "aten.slice.Tensor", + "aten.soft_margin_loss_backward.default", + "aten.sort.default", + "aten.sort.stable", + "aten._sparse_coo_tensor_with_dims_and_tensors.default", + "aten._thnn_fused_lstm_cell.default", + "aten.topk.default", + "aten._to_sparse.default", + "aten.to_sparse.default", + "aten.triangular_solve.default", + "aten._trilinear.default", + "aten.uniform.default", + "aten.upsample_bicubic2d_backward.default", + "aten.upsample_linear1d_backward.default", + "aten.upsample_trilinear3d_backward.default", + "aten.view_as_complex.default", + "aten.view_as_real.default", + "aten.view.dtype", + "aten.zeros.names", +} diff --git a/.venv/lib/python3.11/site-packages/torchgen/code_template.py b/.venv/lib/python3.11/site-packages/torchgen/code_template.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb86a48064248298e1481c072ad9c3d90c19242 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/code_template.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import re +from typing import Mapping, Sequence + + +# match $identifier or ${identifier} and replace with value in env +# If this identifier is at the beginning of whitespace on a line +# and its value is a list then it is treated as +# block substitution by indenting to that depth and putting each element +# of the list on its own line +# if the identifier is on a line starting with non-whitespace and a list +# then it is comma separated ${,foo} will insert a comma before the list +# if this list is not empty and ${foo,} will insert one after. + + +class CodeTemplate: + substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})" + substitution = re.compile(substitution_str, re.MULTILINE) + + pattern: str + filename: str + + @staticmethod + def from_file(filename: str) -> CodeTemplate: + with open(filename) as f: + return CodeTemplate(f.read(), filename) + + def __init__(self, pattern: str, filename: str = "") -> None: + self.pattern = pattern + self.filename = filename + + def substitute( + self, env: Mapping[str, object] | None = None, **kwargs: object + ) -> str: + if env is None: + env = {} + + def lookup(v: str) -> object: + assert env is not None + return kwargs[v] if v in kwargs else env[v] + + def indent_lines(indent: str, v: Sequence[object]) -> str: + return "".join( + [indent + l + "\n" for e in v for l in str(e).splitlines()] + ).rstrip() + + def replace(match: re.Match[str]) -> str: + indent = match.group(1) + key = match.group(2) + comma_before = "" + comma_after = "" + if key[0] == "{": + key = key[1:-1] + if key[0] == ",": + comma_before = ", " + key = key[1:] + if key[-1] == ",": + comma_after = ", " + key = key[:-1] + v = lookup(key) + if indent is not None: + if not isinstance(v, list): + v = [v] + return indent_lines(indent, v) + elif isinstance(v, list): + middle = ", ".join([str(x) for x in v]) + if len(v) == 0: + return middle + return comma_before + middle + comma_after + else: + return str(v) + + return self.substitution.sub(replace, self.pattern) + + +if __name__ == "__main__": + c = CodeTemplate( + """\ + int foo($args) { + + $bar + $bar + $a+$b + } + int commatest(int a${,stuff}) + int notest(int a${,empty,}) + """ + ) + print( + c.substitute( + args=["hi", 8], + bar=["what", 7], + a=3, + b=4, + stuff=["things...", "others"], + empty=[], + ) + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/context.py b/.venv/lib/python3.11/site-packages/torchgen/context.py new file mode 100644 index 0000000000000000000000000000000000000000..a20310498164b5930adde76be5d825cc6b36778c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/context.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import contextlib +import functools +from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union + +import torchgen.local as local +from torchgen.model import ( + BackendIndex, + DispatchKey, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, +) +from torchgen.utils import context, S, T + + +# Helper functions for defining generators on things in the model + +F = TypeVar( + "F", + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + Union[NativeFunction, NativeFunctionsGroup], + Union[NativeFunction, NativeFunctionsViewGroup], +) + +F2 = TypeVar( + "F2", + NativeFunction, + NativeFunctionsGroup, + Optional[NativeFunction], + bool, + str, +) + +F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction]) + + +@contextlib.contextmanager +def native_function_manager( + g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction, +) -> Iterator[None]: + if isinstance(g, NativeFunctionsGroup): + # By default, we associate all errors with structured native functions + # with the out variant. In some cases, it might be better to have + # a more specific place to hang things; if so, use + # native_function_manager again on the inside + f = g.out + elif isinstance(g, NativeFunctionsViewGroup): + # We associate errors with the view operator + f = g.view + else: + f = g + with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): + with local.parametrize( + use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=f.part_of_structured_group, + ): + yield + + +# Given a function that operates on NativeFunction, wrap it into a new function +# that sets some appropriate context managers for that native function. +# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound +# (you will get an error if we try to access the local variables without having +# set them). +def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: + @functools.wraps(func) + def wrapper(f: F) -> T: + with native_function_manager(f): + return func(f) + + return wrapper + + +def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: + @functools.wraps(func) + def wrapper(f: F, f2: F2) -> T: + # The first native_function is assumed to be the one with the appropriate context. + with native_function_manager(f): + return func(f, f2) + + return wrapper + + +def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: + @functools.wraps(func) + def wrapper(slf: S, f: F) -> T: + with native_function_manager(f): + return func(slf, f) + + return wrapper + + +def method_with_nested_native_function( + func: Callable[[S, F3], T] +) -> Callable[[S, F3], T]: + @functools.wraps(func) + def wrapper(slf: S, f: F3) -> T: + with native_function_manager(f[0]): + return func(slf, f) + + return wrapper + + +# Convenience decorator for functions that explicitly take in a BackendIndex, +# instead of indirectly taking one in as a closure +def with_native_function_and_index( + func: Callable[[F, BackendIndex], T] +) -> Callable[[F, BackendIndex], T]: + @functools.wraps(func) + def wrapper(f: F, backend_index: BackendIndex) -> T: + with native_function_manager(f): + return func(f, backend_index) + + return wrapper + + +# Convenience decorator for functions that explicitly take in a Dict of BackendIndices +def with_native_function_and_indices( + func: Callable[[F, dict[DispatchKey, BackendIndex]], T] +) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: + @functools.wraps(func) + def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: + with native_function_manager(f): + return func(f, backend_indices) + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py b/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..70161216d8e7c95e194b0d89b345e0da886ef989 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py @@ -0,0 +1,48 @@ +from torchgen.api.lazy import LazyArgument, LazyIrSchema +from torchgen.api.types import OptionalCType + + +def ts_lowering_body(schema: LazyIrSchema) -> str: + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + emplace_arguments = [] + + def get_value(arg: LazyArgument) -> str: + if isinstance(arg.lazy_type, OptionalCType): + return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" + return "loctx->GetOutputOp(operand(i++))" + + for arg in schema.positional_args: + if arg.is_lazy_value: + emplace_arguments.append(get_value(arg)) + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments] + ) + emplace_kwarg_values = [ + f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values + ] + emplace_kwarg_scalars = [ + f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars + ] + emplace_kwarguments = "\n ".join( + [ + f"kwarguments.emplace_back({a});" + for a in emplace_kwarg_values + emplace_kwarg_scalars + ] + ) + return f"""\ + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + + return {schema.aten_name}_out; +""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen.py b/.venv/lib/python3.11/site-packages/torchgen/gen.py new file mode 100644 index 0000000000000000000000000000000000000000..e5870a24fc668401b7e0eea155ac1cb5057d73ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen.py @@ -0,0 +1,2986 @@ +from __future__ import annotations + +import argparse +import functools +import json +import os +from collections import defaultdict, namedtuple, OrderedDict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Literal, Sequence, TypeVar + +import yaml + +import torchgen.api.dispatcher as dispatcher +import torchgen.api.meta as meta +import torchgen.api.native as native +import torchgen.api.structured as structured +import torchgen.dest as dest +from torchgen.aoti.fallback_ops import inductor_fallback_ops +from torchgen.api import cpp +from torchgen.api.translate import translate +from torchgen.api.types import ( + Binding, + CppSignature, + CppSignatureGroup, + DispatcherSignature, + NamedCType, + NativeSignature, + SpecialArgName, +) +from torchgen.context import ( + method_with_native_function, + native_function_manager, + with_native_function, + with_native_function_and_indices, +) +from torchgen.gen_aoti_c_shim import ( + gen_aoti_c_shim, + gen_static_dispatch_backend_call_signature, + get_fallback_op_name, + get_header_for_aoti, +) +from torchgen.gen_functionalization_type import ( + gen_functionalization_definition, + gen_functionalization_registration, + gen_functionalization_view_inverse_declaration, + GenCompositeViewCopyKernel, +) +from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseOperatorName, + DEFAULT_KERNEL_NAMESPACE, + DispatchKey, + FRAGMENT_NAMESPACES, + FunctionSchema, + is_cuda_dispatch_key, + is_generic_dispatch_key, + is_ufunc_dispatch_key, + is_xpu_dispatch_key, + Location, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OperatorName, + OptionalType, + SchemaKind, + SelfArgument, + STRUCTURED_DISPATCH_KEYS, + TensorOptionsArguments, + Type, + Variant, + ViewSchemaKind, +) +from torchgen.native_function_generation import ( + add_generated_native_functions, + gen_composite_functional_kernel, + gen_composite_out_kernel, + pre_group_native_functions, +) +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import ( + assert_never, + concatMap, + context, + FileManager, + make_file_manager, + mapMaybe, + NamespaceHelper, + Target, +) +from torchgen.yaml_utils import YamlDumper, YamlLoader + + +T = TypeVar("T") + +# Welcome to the ATen code generator v2! The ATen code generator is +# responsible for parsing native_functions.yaml and then generating +# various generated files (e.g., TypeDefault.cpp) based on the operators +# defined in this file. This means that the code generator knows how to +# parse function schema, and then translate this into various C++ types +# and boilerplate code. +# +# Some things to know about this file when you modify it: +# +# - This file has STRICT mypy typechecking. Typecheck it with +# `mypy --config mypy-strict.ini` in the root source directory +# +# - Most of the heavy lifting lives in external modules: +# - 'model' has the data model for native_functions.yaml. The classes +# in those file represent what you see when you look at +# a native_functions.yaml +# - 'api' has conversions for how to translate JIT schema into +# the various C++ APIs that the codegen interacts with. There +# are in fact THREE different C++ APIs: the public C++ API, +# the dispatcher API, and the legacy dispatcher API. See each +# of these respective files for more information + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# HELPER FUNCTIONS +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# A custom loader for YAML to let us also keep track of line numbers +# of each entry in the YAML file +class LineLoader(YamlLoader): + def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] + mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] + # Add 1 so line numbering starts at 1 + mapping["__line__"] = node.start_mark.line + 1 + return mapping + + +# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. +ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) + + +_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {} +_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {} + + +def parse_native_yaml_struct( + es: object, + valid_tags: set[str], + ignore_keys: set[DispatchKey] | None = None, + path: str = "", + skip_native_fns_gen: bool = False, +) -> ParsedYaml: + assert isinstance(es, list) + rs: list[NativeFunction] = [] + bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict) + for e in es: + assert isinstance(e, dict), f"expected to be dict: {e}" + assert isinstance(e.get("__line__"), int), e + loc = Location(path, e["__line__"]) + funcs = e.get("func") + assert funcs is not None, f"missed 'func' in {e}" + with context(lambda: f"in {loc}:\n {funcs}"): + func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) + rs.append(func) + BackendIndex.grow_index(bs, m) + error_check_native_functions(rs) + # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. + indices: dict[DispatchKey, BackendIndex] = defaultdict( + lambda: BackendIndex( + dispatch_key=DispatchKey.Undefined, + use_out_as_primary=True, + external=False, + device_guard=False, + # I'm actually not sure about this; undefined could be hit on + # empty TensorList, hypothetically that could have sizes in it + index={}, + ) + ) + if not skip_native_fns_gen: + add_generated_native_functions(rs, bs) + for k, v in bs.items(): + # All structured in-tree operators are implemented in terms of their out operator. + indices[k] = BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + # Only cuda-like devices in tree require device guards + device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), + index=v, + ) + return ParsedYaml(rs, indices) + + +def parse_tags_yaml_struct(es: object, path: str = "") -> set[str]: + assert isinstance(es, list) + rs: set[str] = set() + for e in es: + assert isinstance(e.get("__line__"), int), e + loc = Location(path, e["__line__"]) + tags = e.get("tag") + with context(lambda: f"in {loc}:\n {tags}"): + e_i = e.copy() + name = e_i.pop("tag") + desc = e_i.pop("desc", "") + # ensure that each tag has a non-empty description + assert desc != "" + rs.add(name) + return rs + + +@functools.lru_cache(maxsize=None) +def parse_tags_yaml(path: str) -> set[str]: + global _GLOBAL_PARSE_TAGS_YAML_CACHE + if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path) + + return _GLOBAL_PARSE_TAGS_YAML_CACHE[path] + + +def parse_native_yaml( + path: str, + tags_yaml_path: str, + ignore_keys: set[DispatchKey] | None = None, + *, + skip_native_fns_gen: bool = False, + loaded_yaml: object | None = None, +) -> ParsedYaml: + global _GLOBAL_PARSE_NATIVE_YAML_CACHE + if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: + valid_tags = parse_tags_yaml(tags_yaml_path) + + # if a loaded yaml is provided, use that instead of reading from path + if loaded_yaml is None: + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + else: + es = loaded_yaml + + _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct( + es, + valid_tags, + ignore_keys, + path=path, + skip_native_fns_gen=skip_native_fns_gen, + ) + + return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] + + +# Some assertions are already performed during parsing, but those are only within a single NativeFunction. +# Assertions here are meant to be performed across NativeFunctions. +def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: + func_map: dict[OperatorName, NativeFunction] = {} + base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) + for f in funcs: + func_map[f.func.name] = f + base_func_map[f.func.name.name].append(f) + for f in funcs: + if f.structured_delegate is not None: + delegate_func = func_map.get(f.structured_delegate) + assert delegate_func is not None, ( + f"{f.func.name} is marked as a structured_delegate pointing to " + f"{f.structured_delegate}, but {f.structured_delegate} is missing." + ) + assert delegate_func.structured, ( + f"{f.func.name} is marked as a structured_delegate pointing to " + f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " + f"Consider adding 'structured=True' to the delegated operator" + ) + # See Note [resize_ in Functionalization] + # resize_() is technically an inplace view op (and therefore needs the tag), + # but it would be overkill to add a true "view" variant of resize. + # Instead, resize_() gets special treatment in functionalization, + # and we have a resize() op that is non-aliasing + functional. + if ( + "inplace_view" in f.tags + and str(f.func.name) != "resize_" + and str(f.func.name) != "resize_as_" + and str(f.func.name.name) != "set_" + ): + base_name = f.func.name.name + assert base_name.inplace, ( + f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming " + "convention for inplace ops - the codegen expects the base name to have a trailing underscore. " + ) + out_of_place_base_name = BaseOperatorName( + base_name.base, False, base_name.dunder_method + ) + assert len(base_func_map[out_of_place_base_name]) > 0, ( + f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding " + f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. " + ) + + +def cpp_string(s: str) -> str: + """Convert a python string into a c++ string literal""" + s = s.replace("\\", "\\\\") + s = s.replace('"', '\\"') + s = s.replace("\a", "\\a") + s = s.replace("\b", "\\b") + s = s.replace("\f", "\\f") + s = s.replace("\n", "\\n") + s = s.replace("\v", "\\v") + s = s.replace("\t", "\\t") + return f'"{s}"' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ CODE GENERATION +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# Most functions in this section are curried: they consist of a function +# that takes some parameters (e.g., what is to be generated) which itself +# returns a function that actually maps NativeFunction to the code +# to be generated. This pattern makes it convenient to use map, concatMap +# and similar functional combinators. + + +def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]: + if len(backends) == 0: + return [] + else: + return [backend.dispatch_key for backend in backends] + [ + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ] + + +def get_static_dispatch_backend( + f: NativeFunction, backend_index: BackendIndex +) -> DispatchKey | None: + if f.structured_delegate is not None or backend_index.has_kernel(f): + # TODO: for ops with structured_delegate it should check the dispatch table of + # the out variant instead. For now, these structured ops all have CPU/CUDA kernels + # so we always dispatch to the `backend`, but this could be wrong when we + # migrate math/default_backend ops to use structured delegate. + return backend_index.dispatch_key + elif f.has_composite_explicit_autograd_kernel: + return DispatchKey.CompositeExplicitAutograd + elif f.has_composite_explicit_autograd_non_functional_kernel: + return DispatchKey.CompositeExplicitAutogradNonFunctional + elif f.has_composite_implicit_autograd_kernel: + return DispatchKey.CompositeImplicitAutograd + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return DispatchKey.CompositeImplicitAutogradNestedTensor + return None + + +def static_dispatch_ops_header( + f: NativeFunction, backend_index: list[BackendIndex] +) -> str | None: + if backend_index is None or f.manual_kernel_registration: + return None + + output = [] + for index in backend_index: + dispatch_key = get_static_dispatch_backend(f, index) + if dispatch_key is not None: + output.append( + f"#include " + ) + return "\n".join(output) + + +def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]: + return [ + f"#include " + for dispatch_key in static_dispatch_keys(backends) + ] + + +# Translates arguments of `sig` to CppSignature bindings. +# Note that we have a special case for `memory_format` argument and this case is not covered by +# tools.codegen.api.translate() yet as its application is limited to static dispatch. +def translate_args( + sig: CppSignature | DispatcherSignature, + cpp_sig: CppSignature, +) -> str: + # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings + def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]: + output_bindings: list[Binding] = [] + for binding in input_bindings: + if binding.name == "memory_format": + spl_mem_format_binding = Binding( + nctype=NamedCType( + SpecialArgName.possibly_redundant_memory_format, + binding.nctype.type, + ), + name=binding.name, + default=binding.default, + argument=binding.argument, + ) + output_bindings.append(spl_mem_format_binding) + else: + output_bindings.append(binding) + return output_bindings + + src_bindings = list(sig.arguments()) + goal_bindings = list(cpp_sig.arguments()) + # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, + # get memory_format bindings of dispatcher signature to have the same NCType as well + for arg in goal_bindings: + if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: + src_bindings = add_spl_memory_format_binding(src_bindings) + break + exprs = translate(src_bindings, goal_bindings) + return ", ".join(a.expr for a in exprs) + + +def generate_static_dispatch_backend_call( + sig: CppSignature | DispatcherSignature, + f: NativeFunction, + backend_index: BackendIndex, +) -> str: + cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) + name = cpp_sig.name() + exprs = translate_args(sig, cpp_sig) + backend_metadata = backend_index.get_kernel(f) + kernel_ns = ( + backend_metadata.cpp_namespace + if backend_metadata and backend_metadata.cpp_namespace + else DEFAULT_KERNEL_NAMESPACE + ) + ns = kernel_ns.replace("::native", "") + return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});" + + +def generate_static_dispatch_fallback_call( + sig: CppSignature | DispatcherSignature, + f: NativeFunction, + backend_indices: list[BackendIndex], +) -> str: + cpp_sigs = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + if sig.symint and f.func.has_symint(): + cpp_sig = cpp_sigs.symint_signature + else: + cpp_sig = cpp_sigs.signature + assert cpp_sig is not None + name = cpp_sig.name() + exprs = translate_args(sig, cpp_sig) + ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") + if f.has_composite_explicit_autograd_kernel: + return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_explicit_autograd_non_functional_kernel: + return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" + elif f.has_composite_implicit_autograd_kernel: + return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" + else: + return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ +{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" + + +def static_dispatch( + sig: CppSignature | DispatcherSignature, + f: NativeFunction, + backend_indices: list[BackendIndex], +) -> str: + """ + For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one + backends exsit, fallback to static dispatch by determining dispatch key from inputs. + Arguments: + sig: A CppSignature or DispatcherSignature for this native function we want to use. + f: NativeFunction to generate static dispatch. + backend_indices: All available backends. + Return: + C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);" + """ + if len(backend_indices) == 0 or f.manual_kernel_registration: + return "" + + keys = [ + b + for b in backend_indices + if b.has_kernel(f) + or ( + f.structured_delegate is not None + and b.dispatch_key in STRUCTURED_DISPATCH_KEYS + ) + ] + if len(keys) == 1: + return generate_static_dispatch_backend_call(sig, f, keys[0]) + elif len(keys) == 0: + return generate_static_dispatch_fallback_call(sig, f, backend_indices) + + native_tensor_args = [ + a.name + for a in sig.arguments() + if isinstance(a.argument, SelfArgument) + or isinstance(a.argument, Argument) + and a.argument.type.is_tensor_like() + ] + tensor_args = ", ".join(native_tensor_args) + tensor_opts = f.func.arguments.tensor_options + + stmts = [] + subexprs: list[str] = [] + if tensor_opts is not None: + subexprs.append( + "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" + ) + if tensor_args != "": + subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})") + stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") + stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);") + + dispatch_code = [] + for index in keys: + dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") + dispatch_code.append( + f"""\t{generate_static_dispatch_backend_call(sig, f, index)};""" + ) + + fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices) + connector = "\n\t\t" + + return f""" + {connector.join(stmts)} + switch (_dk) {{ + {connector.join(dispatch_code)} + default: + {fallback} + }} + """ + + +# Generates RegisterSchema.cpp. Depending on the selector, either +# all schemas are registered, or only some are (in the case of +# selective build) +@dataclass(frozen=True) +class RegisterSchema: + selector: SelectiveBuilder + known_tags: dict[str, int] = field(default_factory=dict) + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + if not self.selector.is_native_function_selected(f): + return None + tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" + if tags == "{}": + return f"m.def({cpp_string(str(f.func))}, {{}});\n" + maybe_tags = "" + if tags not in self.known_tags: + idx = len(self.known_tags) + self.known_tags[tags] = idx + maybe_tags = f"const std::vector tags_{idx} = {tags};\n" + return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n" + + +# Generates Operators.h and Operators.cpp. +# These provide macros that, given an operator and overload name, allow users +# to access an "un-overloaded" function version of the operator. This +# is useful for extension writers who want to (1) want to decltype the operator +# and (2) don't want to worry about method-only operators. +@dataclass(frozen=True) +class ComputeOperators: + target: Literal[Target.DECLARATION, Target.DEFINITION] + static_dispatch_backend_indices: list[BackendIndex] + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str: + sig = DispatcherSignature.from_schema(f.func) + name = f.func.name.unambiguous_name() + + if self.target is Target.DECLARATION: + # Note [The ATen Operators API] + # The ATen Operators API lives in the at::_ops namespace, and contains compile-time + # metadata about each operator + entry points into the Dispatcher. + # The C++ function, method, and redispatch API's are all implemented as wrappers + # into various bits of the structs defined here. + # + # Important characteristics about the Operators API: + # (1) It follows the Dispatcher API. + # This is kind of necessary to avoid overhead. + # For example: if it followed the C++ API, then all of the faithful C++ factory functions + # would need to wrap their arguments into TensorOptions only to unwrap them again. + # (2) Overload names are disambiguated. + # This is helpful for pytorch extenders who would like to decltype() an aten operator, + # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) + # (3) No argument defaulting is allowed. + # This is more of an implementation detail to avoid #include cycles, + # since TensorBody.h (which defines the Tensor class) needs to include this file. + # (4) manual_cpp_bindings and faithful names are not included in the API. + # This applies to stuff like __dispatch__is_complex(), and add_outf(). + # These aren't "real aten ops", they're just additional functions provided by the C++ API. + # They're implemented as wrappers in Functions.h that call into the actual operators + # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). + # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. + return f""" +struct TORCH_API {name} {{ + using schema = {sig.type()}; + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) + static {sig.defn(name="call", is_redispatching_fn=False)}; + static {sig.defn(name="redispatch", is_redispatching_fn=True)}; +}};""" + + elif self.target is Target.DEFINITION: + defns = f""" +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}") +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) + +// aten::{f.func} +static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ + return c10::Dispatcher::singleton() + .findSchemaOrThrow({name}::name, {name}::overload_name) + .typed<{name}::schema>(); +}} +""" + for is_redispatching_fn in [False, True]: + if is_redispatching_fn: + dispatcher_exprs_str = ", ".join( + ["dispatchKeySet"] + [a.name for a in sig.arguments()] + ) + method_base = "redispatch" + else: + dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) + method_base = "call" + + dispatcher_call = method_base + method_name = f"{name}::{method_base}" + + fn_body = f""" + static auto op = create_{name}_typed_handle(); + return op.{dispatcher_call}({dispatcher_exprs_str});""" + + if ( + not is_redispatching_fn + and len(self.static_dispatch_backend_indices) > 0 + ): + # call() should go through static dispatch + fn_body = static_dispatch( + sig, f, backend_indices=self.static_dispatch_backend_indices + ) + defns += f""" +// aten::{f.func} +{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ + {fn_body} +}} +""" + return defns + else: + assert_never(self.target) + + +# Generates Functions.h, which provides the functional public C++ API, +# and the scaffolding to call into the dispatcher from these functions. +@dataclass(frozen=True) +class ComputeFunction: + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ) + has_symint = f.func.has_symint() + + result = "" + for sig in sig_group.signatures(): + # See Note [The ATen Operators API] + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments()) + exprs_str = ", ".join([e.expr for e in exprs]) + + if sig.symint: + intlike_t = "c10::SymInt" + else: + intlike_t = "int64_t" + + if Variant.function in f.variants: + result += f""" +// aten::{f.func} +inline {sig.decl()} {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); +}}""" + + # The template function can be used from template situations + # where you want to switch between the symint or not version + # depending on a template argument + # + # NB: we ALWAYS generate this even for methods. But we put it in + # this header so it can take advantage of per-op headers + if has_symint: + result += f""" +namespace symint {{ + template ::value>> + {sig.decl(suppress_symint_suffix=True)} {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); + }} +}} +""" + return result + + +# Generates TensorBody.h. This file provides the object-oriented (method-based) +# public C++ API, and the scaffolding to call into the dispatcher from these functions. +@dataclass(frozen=True) +class ComputeTensorMethod: + target: Literal[Target.DECLARATION, Target.DEFINITION] + static_dispatch_backend_indices: list[BackendIndex] + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + if Variant.method not in f.variants: + return None + + assert not f.func.is_out_fn() + assert f.func.arguments.self_arg is not None + + sig_group = CppSignatureGroup.from_native_function( + f, method=True, fallback_binding=f.manual_cpp_binding + ) + + if self.target is Target.DECLARATION: + result = "" + for sig in sig_group.signatures(): + result += f"{sig.decl()} const;\n" + return result + + if self.target is not Target.DEFINITION: + assert_never(self.target) + + result = "" + + for sig in sig_group.signatures(): + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments(), method=True) + exprs_str = ", ".join([e.expr for e in exprs]) + + result += f""" +// aten::{f.func} +inline {sig.defn(prefix="Tensor::")} const {{ + return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); +}} +""" + + return result + + +# Generates RedispatchFunctions.h. +# This is similar to the C++ API defined in Functions.h, but provides access +# to the dispatcher's redispatch API. +@dataclass(frozen=True) +class ComputeRedispatchFunction: + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + # We unconditionally generate function variants of the redispatch API. + # This is mainly because we can namespace functions separately, but not methods, + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ) + + result = "" + for sig in sig_group.signatures(): + target_sig = DispatcherSignature.from_schema(f.func) + exprs = translate(sig.arguments(), target_sig.arguments()) + exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs]) + + result += f""" +// aten::{f.func} +inline {sig.decl(is_redispatching_fn=True)} {{ + return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); +}} +""" + + return result + + +# Generates ATenOpList.cpp, a runtime accessible list of all aten +# operators. +# TODO: This was historically used to help some JIT interop code +# figure out whether or not to treat aten namespace'd operators +# one way or another, we should reevaluate if this is actually needed. +@with_native_function +def compute_aten_op(f: NativeFunction) -> str: + return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},' + + +# Generates MetaFunctions.h +def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None: + if not g.structured: + return None + with native_function_manager(g.out): + name = meta.name(g) + args = structured.meta_arguments(g) + args_str = ", ".join(a.decl() for a in args) + parent_class = g.out.structured_inherits + if parent_class is None: + parent_class = "at::impl::MetaBase" + meta_return = "void" + precomputed = g.out.precomputed if g.structured else None + + if precomputed: + # Generate the template declaration with one bool parameter for each + # precomputed element. Each parameter is true if the corresponding (in + # terms of position) precomputed element has been set. + precomputed_values = [*precomputed.replace.values(), precomputed.add] + precomputed_elements = [ + elem for replace_list in precomputed_values for elem in replace_list + ] + precomputed_template_parameters = [ + elem.name.upper() for elem in precomputed_elements + ] + precomputed_template_params_str = ", ".join( + f"bool {param} = false" for param in precomputed_template_parameters + ) + precompute_template_decl = f"template <{precomputed_template_params_str}>" + + # Generate a string containing declarations of all precomputed elements. + precomputed_elements_with_cpp_types = [ + structured.argument_type(elem, binds=elem.name) + for elem in precomputed_elements + ] + + precomputed_elements_decl = ";\n".join( + f"{elem.cpp_type(strip_ref=True)} {elem.name}" + for elem in precomputed_elements_with_cpp_types + ) + + # Generate "setter" methods for each precomputed element. Each method will return + # a new instance of precompute_out with the template parameter that corresponds to + # the member set by the method to true (to indicate that it has been set). + setter_methods = [] + for i, elem in enumerate(precomputed_elements): + # Generate the signature. The return type will be the same + # as the type of `this` but with the template parameter + # corresponding to the element set by this method set to true. + # The assert generated below will ensure that this template + # parameter is false on the type of `this`. + return_ty_templates = ", ".join( + precomputed_template_parameters[:i] + + ["true"] + + precomputed_template_parameters[i + 1 :] + ) + return_ty = f"precompute_out<{return_ty_templates}>" + elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type( + strip_ref=True + ) + signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)" + + # Generate an assert which checks that the + # template parameter corresponding to the precomputed + # element that is set by this method is false on the + # class corresponding to the object that `this` points to. + # This ensures that each element can be set only once. + assert_msg = f'"{elem.name} already set"' + assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});" + + # Generate the new object construction block. All state + # except the element that this method sets is copied from the + # object that `this` points to. The value for the element that + # the method sets is taken from a method parameter. + construction_stmts = [] + construction_stmts.append(f"{return_ty} ret;") + + for j, elem in enumerate(precomputed_elements): + if i == j: + construction_stmts.append(f"ret.{elem.name} = value;") + else: + construction_stmts.append( + f"ret.{elem.name} = this->{elem.name};" + ) + + construction_stmts.append("return ret;") + construction_block = "\n".join(construction_stmts) + + setter_methods.append( + f""" + {signature} {{ + {assert_stmt} + {construction_block} + }} + """ + ) + setter_methods_decl = "\n".join(setter_methods) + + # Meta should return an instance of the struct containing the precomputed elements. + meta_return_template_params = ", ".join( + ["true"] * len(precomputed_template_parameters) + ) + # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return + # type (which has a variable number of template parameters). + meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;" + meta_return = "meta_return_ty" + precomputed_decl = f""" + {precompute_template_decl} + struct TORCH_API precompute_out {{ + {setter_methods_decl} + {precomputed_elements_decl}; + }};""" + else: + meta_return_typedef = "" + precomputed_decl = "" + + return f"""\ +struct TORCH_API structured_{name} : public {parent_class} {{ + {precomputed_decl} + {meta_return_typedef} + {meta_return} meta({args_str}); +}}; +""" + + +def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool: + name = str(f.func.name.name) + if name.endswith("_like") or name.startswith("new_"): + return False + if f.func.arguments.tensor_options is None: + return False + return selector.is_native_function_selected(f) + + +# Generates RegisterBackendSelect.cpp, a series of kernels which provide +# specialized computation of dispatch key for operator signatures which cannot +# be easily done automatically using templating. +@dataclass(frozen=True) +class ComputeBackendSelect: + target: Literal[Target.DEFINITION, Target.REGISTRATION] + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + if not needs_backend_select(f, self.selector): + return None + + name = native.name(f.func) + # BackendSelect can go to Meta, so it must preserve symints + native_sig = NativeSignature(f.func, symint=True) + + native_tensor_args = [ + a + for a in native_sig.arguments() + if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() + ] + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + + sig: NativeSignature | DispatcherSignature + sig = dispatcher_sig + dispatcher_exprs = dispatcher_sig.exprs() + dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" + + if self.target is Target.DEFINITION: + # I don't think there's actually a good reason to generate + # these two cases differently + # The first case could probably be improved though- it calls computeDispatchKeySet(), + # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. + if native_tensor_args: + assert f.func.arguments.has_tensor_arg() + tensor_args = ", ".join(a.name for a in native_tensor_args) + compute_dk = f"""\ +DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); +DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); +DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" + else: + assert not f.func.arguments.has_tensor_arg() + compute_dk = ( + f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" + ) + return f"""\ +// aten::{f.func} +C10_ALWAYS_INLINE +{sig.defn(name)} {{ + {compute_dk} + return at::_ops::{f.func.name.unambiguous_name()}::redispatch( + _dk, {', '.join(a.expr for a in dispatcher_exprs)}); +}} +""" + elif self.target is Target.REGISTRATION: + return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" + else: + assert_never(self.target) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# YAML CODE GENERATION +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def format_yaml(data: object) -> str: + # Ignore alias in Dumper + YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment] + + # Support serializing OrderedDict + def dict_representer(dumper: Any, data: Any) -> Any: + return dumper.represent_dict(data.items()) + + YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call] + # Some yaml parsers (e.g. Haskell's) don't understand line breaks. + # width=1e9 turns off optional line breaks and improves + # the portability of the outputted yaml. + return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload] + + +# For some reason, some defaults we write to YAML are written as native +# YAML objects, rather than doing them uniformly as strings. This +# function detects those cases and converts them into native Python +# objects. +def pythonify_default(s: str) -> object: + if s == "true": + return True + elif s == "false": + return False + + try: + return int(s) + except ValueError: + try: + return float(s) + except ValueError: + return s + + +# What is a dynamic type? Over time, the semantic meaning of +# dynamic type has degraded to meaninglessness (in the old days, +# it captured dtype-ness of types, but that has gone away with +# the removal of TH). These days, it's mostly the same thing as +# the C++ API argument type, except that Tensor and Tensor? +# arguments simply present as Tensor. +# +# TODO: Get rid of dynamic_type, after getting tools/autograd +# to use the new codegen framework +def dynamic_type(t: Type) -> str: + if isinstance(t, OptionalType): + return dynamic_type(t.elem) + # Note we don't use t.is_tensor_like() here because it would + # also include Tensor[] + if str(t) == "Tensor": + return "at::Tensor" + # This is a legacy concept, so never report SymInt + return cpp.argumenttype_type( + t, mutable=False, binds="__placeholder__", symint=False + ).cpp_type() + + +def compute_method_of_yaml(variants: set[Variant]) -> list[str]: + # This is written out explicitly to ensure that Tensor and + # namespace are put into the list in the right order + method_of = ["Type"] + if Variant.method in variants: + method_of.append("Tensor") + if Variant.function in variants: + method_of.append("namespace") + return method_of + + +def compute_returns_yaml( + f: NativeFunction, +) -> tuple[list[dict[str, str]], dict[str, str]]: + # Note [name and field_name] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~ + # To understand name_to_field_name, we must first talk about this + # schema: + # + # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) + # + # There is something very odd about this schema: it is an out + # variant of the function (that is to say, it will convert into + # at::lstsq_out() in the C++ API), but the names of the output + # return arguments don't match the keyword argument names of + # the inputs. It TURNS OUT that in this situation, the historical + # Declarations.yaml we want to output is this (abbreviated to + # only show relevant fields): + # + # arguments: + # ... + # - field_name: solution + # name: X + # - field_name: QR + # name: qr + # ... + # + # returns: + # - field_name: solution + # name: X + # - field_name: QR + # name: qr + # + # The name of the return fields is stored in 'field_name', and the + # name of the arguments is stored in 'name'. So when we process + # arguments, we need a way to get at the corresponding return. At + # the moment, this is most conveniently done by constructing a + # mapping from name (the argument concept) to field_name (the + # return concept) while processing return arguments, since we don't + # directly maintain this correspondence in the modeling of function + # schema itself. + # + # See also https://github.com/pytorch/pytorch/issues/43114 + name_to_field_name: dict[str, str] = {} + + # Compute the returns field of the YAML entry + names = cpp.return_names(f) + returns = [] + for i, (r, name) in enumerate(zip(f.func.returns, names)): + ret = { + "dynamic_type": dynamic_type(r.type), + "name": name, + # legacy, report ints + "type": cpp.return_type(r, symint=False).cpp_type(), + } + + if r.name: + # See Note [name and field_name] + ret["field_name"] = r.name + if f.func.is_out_fn(): + name_to_field_name[f.func.arguments.out[i].name] = r.name + + returns.append(ret) + + return returns, name_to_field_name + + +# arguments in yaml roughly corresponds to the public C++ API +def compute_cpp_argument_yaml( + cpp_a: Binding, + *, + schema_order: bool, + kwarg_only_set: set[str], + out_arg_set: set[str], + name_to_field_name: dict[str, str], +) -> object: + if isinstance(cpp_a.argument, TensorOptionsArguments): + arg: dict[str, object] = { + "annotation": None, + "dynamic_type": "at::TensorOptions", + "is_nullable": False, + "name": cpp_a.name, + "type": cpp_a.type, + "kwarg_only": True, + } + if cpp_a.default is not None: + arg["default"] = cpp_a.default + return arg + elif isinstance(cpp_a.argument, SelfArgument): + raise AssertionError + elif isinstance(cpp_a.argument, Argument): + return compute_argument_yaml( + cpp_a.argument, + schema_order=schema_order, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + + +def compute_argument_yaml( + a: Argument, + *, + schema_order: bool, + kwarg_only_set: set[str], + out_arg_set: set[str], + name_to_field_name: dict[str, str], +) -> object: + arg: dict[str, object] = { + "annotation": str(a.annotation) if a.annotation else None, + "dynamic_type": dynamic_type(a.type), + "is_nullable": a.type.is_nullable(), + "name": a.name, + # legacy, report ints + "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(), + } + if a.default is not None: + arg["default"] = pythonify_default( + cpp.default_expr(a.default, a.type, symint=False) + ) + if a.name in kwarg_only_set: + arg["kwarg_only"] = True + if a.name in out_arg_set: + arg["output"] = True + arg["allocate"] = True + # See Note [name and field_name] + if a.name in name_to_field_name: + arg["field_name"] = name_to_field_name[a.name] + # Historically, booleans don't get their size recorded, because it + # is already built into the cpp type (e.g., std::array) + l = a.type.is_list_like() + if l is not None and l.size is not None and str(l.elem) != "bool": + arg["size"] = l.size + return arg + + +@with_native_function +def compute_declaration_yaml(f: NativeFunction) -> object: + returns, name_to_field_name = compute_returns_yaml(f) + + # These sets are used to conveniently test if an argument is a + # kwarg-only or out argument + kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in f.func.arguments.out} + + sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + cpp_args = sig_group.signature.arguments() + arguments = [ + compute_cpp_argument_yaml( + cpp_a, + schema_order=False, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + for cpp_a in cpp_args + ] + + schema_order_jit_arguments = list(f.func.schema_order_arguments()) + + schema_order_arguments = [ + compute_argument_yaml( + a, + schema_order=True, + kwarg_only_set=kwarg_only_set, + out_arg_set=out_arg_set, + name_to_field_name=name_to_field_name, + ) + for a in schema_order_jit_arguments + ] + + cpp_schema_order_types = [ + # NB: method here doesn't matter + r.type + for a in schema_order_jit_arguments + for r in cpp.argument( + a, + method=False, + cpp_no_default_args=set(), + faithful=False, + symint=False, + has_tensor_options=False, + ) + ] + + # legacy, report ints + cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type() + schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" + + is_factory_method = ( + any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) + and Variant.method not in f.variants + ) + + return OrderedDict( + [ + ("name", cpp.name(f.func)), + ("operator_name", str(f.func.name.name)), + ("overload_name", str(f.func.name.overload_name)), + ("manual_kernel_registration", f.manual_kernel_registration), + ( + "category_override", + f.category_override if f.category_override is not None else "", + ), + ("schema_string", f"aten::{f.func}"), + ("arguments", arguments), + ("schema_order_cpp_signature", schema_order_cpp_signature), + ("schema_order_arguments", schema_order_arguments), + ("method_of", compute_method_of_yaml(f.variants)), + ("mode", "native"), + ("python_module", "" if f.python_module is None else f.python_module), + ("returns", returns), + ("inplace", f.func.name.name.inplace), + ("is_factory_method", is_factory_method), + ("abstract", f.is_abstract), + ("device_guard", f.device_guard), + ("with_gil", False), + ("deprecated", False), + ("has_math_kernel", f.has_composite_implicit_autograd_kernel), + ] + ) + + +# See Note [Auto generated composite kernels] +def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: + return (f.structured or f.structured_delegate is not None) and ( + f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace + ) + + +@with_native_function_and_indices +def compute_registration_declarations( + f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex] +) -> str: + name = dispatcher.name(f.func) + returns_type = dispatcher.returns_type( + f.func.returns + ).cpp_type_registration_declarations() + args = dispatcher.arguments(f.func) + args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) + comment_data: dict[str, str] = { + "schema": f"aten::{f.func}", + # TODO: What exactly is the semantics of the 'dispatch' field? + "dispatch": str( + {k for k, v in backend_indices.items() if v.has_kernel(f)} + != {DispatchKey.CompositeImplicitAutograd} + and {k for k, v in backend_indices.items() if v.has_kernel(f)} + != { + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } + ), + "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)), + } + return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} +""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# RUN IT ALL +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def get_custom_build_selector( + provided_op_registration_allowlist: list[str] | None, + op_selection_yaml_path: str | None, +) -> SelectiveBuilder: + assert not ( + provided_op_registration_allowlist is not None + and op_selection_yaml_path is not None + ), ( + "Both provided_op_registration_allowlist and " + + "op_selection_yaml_path can NOT be provided at the " + + "same time." + ) + + op_registration_allowlist: set[str] | None = None + if provided_op_registration_allowlist is not None: + op_registration_allowlist = set(provided_op_registration_allowlist) + + if op_registration_allowlist is not None: + selector = SelectiveBuilder.from_legacy_op_registration_allow_list( + op_registration_allowlist, + True, + False, + ) + elif op_selection_yaml_path is not None: + selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path) + else: + selector = SelectiveBuilder.get_nop_selector() + + return selector + + +def get_grouped_by_view_native_functions( + native_functions: Sequence[NativeFunction], +) -> Sequence[NativeFunction | NativeFunctionsViewGroup]: + def maybe_create_view_group( + d: dict[ViewSchemaKind | SchemaKind, NativeFunction] + ) -> list[NativeFunction | NativeFunctionsViewGroup]: + funcs: list[NativeFunction | NativeFunctionsViewGroup] = [] + if ViewSchemaKind.aliasing in d: + view = d.pop(ViewSchemaKind.aliasing) + view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) + view_copy = d.pop(SchemaKind.functional, None) + + funcs.append( + NativeFunctionsViewGroup( + view=view, + view_copy=view_copy, + view_inplace=view_inplace, + ) + ) + # Take the remaining functions that weren't part of the view group + # and emit them separately + funcs.extend(d.values()) + return funcs + + grouped_by_views: dict[ + FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction] + ] = defaultdict(dict) + for f in native_functions: + schema = f.func.view_signature() + view_kind: ViewSchemaKind = f.view_schema_kind + # We need to group up ops relevant to the same "view", consisting of: + # view op (ViewSchemaKind.aliasing) + # view_inplace op (ViewSchemaKind.aliasing_inplace) + # view_copy op (SchemaKind.functional) + if view_kind == ViewSchemaKind.non_aliasing: + kind = f.func.kind() + assert kind not in grouped_by_views[schema] + grouped_by_views[schema][kind] = f + else: + assert ( + view_kind not in grouped_by_views[schema] + ), f"{view_kind} already in {grouped_by_views[schema].keys()}" + grouped_by_views[schema][view_kind] = f + + return list(concatMap(maybe_create_view_group, grouped_by_views.values())) + + +def get_grouped_native_functions( + native_functions: Sequence[NativeFunction], +) -> Sequence[NativeFunction | NativeFunctionsGroup]: + def flatten_pre_group( + d: dict[SchemaKind, NativeFunction] + ) -> Sequence[NativeFunction | NativeFunctionsGroup]: + r = NativeFunctionsGroup.from_dict(d) + if r is None: + # Invariant: any NativeFunctions that are code-generated + # should have been grouped into NativeFunctionsGroup objects + assert not any("generated" in f.tags for f in d.values()) + return list(d.values()) + else: + return [r] + + # TODO: how come ValuesView isn't a Sequence lol + pre_grouped_native_functions = pre_group_native_functions(native_functions) + return list( + concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())) + ) + + +def get_ns_grouped_kernels( + *, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_indices: dict[DispatchKey, BackendIndex], + native_function_decl_gen: Callable[ + [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] + ] = dest.compute_native_function_declaration, +) -> dict[str, list[str]]: + ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) + for f in grouped_native_functions: + native_function_namespaces = set() + dispatch_keys = set() + for dispatch_key, backend_idx in backend_indices.items(): + backend_metadata = backend_idx.get_kernel(f) + if backend_metadata: + namespace = backend_metadata.cpp_namespace + dispatch_keys.add(dispatch_key) + native_function_namespaces.add(namespace) + else: + namespace = DEFAULT_KERNEL_NAMESPACE + assert ( + len(native_function_namespaces) <= 1 + ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" + ns_grouped_kernels[namespace].extend( + native_function_decl_gen(f, backend_idx) + ) + return ns_grouped_kernels + + +def get_native_function_declarations_from_ns_grouped_kernels( + *, + ns_grouped_kernels: dict[str, list[str]], +) -> list[str]: + declarations: list[str] = [] + newline = "\n" + for namespace, kernels in ns_grouped_kernels.items(): + ns_helper = NamespaceHelper( + namespace_str=namespace, + entity_name="", + max_level=4, + ) + # Convert to a set first to remove duplicate kernel names. Backends are + # allowed to repeat kernel names; only generate the declaration once! + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + +# Return native function declarations grouped by their namespaces. +def get_native_function_declarations( + *, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_indices: dict[DispatchKey, BackendIndex], + native_function_decl_gen: Callable[ + [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] + ] = dest.compute_native_function_declaration, +) -> list[str]: + """ + Generate kernel declarations, in `NativeFunction(s).h`. + :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. + :param backend_indices: kernel collections grouped by dispatch key. + :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`. + :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline. + """ + + ns_grouped_kernels = get_ns_grouped_kernels( + grouped_native_functions=grouped_native_functions, + backend_indices=backend_indices, + native_function_decl_gen=native_function_decl_gen, + ) + return get_native_function_declarations_from_ns_grouped_kernels( + ns_grouped_kernels=ns_grouped_kernels + ) + + +def get_kernel_namespace( + *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex +) -> str: + backend_metadata = backend_idx.get_kernel(f) + assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( + f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} " + f"with dispatch key {backend_idx.dispatch_key}" + f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'." + ) + return ( + backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE + ) + + +# Return native function definitions grouped by dispatch key and custom namespace. +# Used in RegisterDispatchKey.cpp and etc. +def get_native_function_definitions( + *, + fm: FileManager, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, + symint: bool, + skip_dispatcher_op_registration: bool, + gen_dispatch_helpers: bool, +) -> list[str]: + definitions: list[str] = [] + ns_definitions: dict[str, list[str]] = defaultdict(list) + anonymous_definitions: dict[str, list[str]] = defaultdict(list) + registrations: dict[str, dict[str, list[str]]] = defaultdict(dict) + newline = "\n" + ns_gen = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DEFINITION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + anonymous_gen = dest.RegisterDispatchKey( + backend_idx, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + reg_gen = dest.RegisterDispatchKey( + backend_idx, + Target.REGISTRATION, + selector, + rocm=rocm, + symint=symint, + class_method_name=None, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + ) + for f in grouped_native_functions: + kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "::native", "" + ) + + ns_definitions[kernel_namespace].extend( + ns_gen(f), + ) + anonymous_definitions[kernel_namespace].extend( + anonymous_gen(f), + ) + namespace = ( + f.namespace if isinstance(f, NativeFunction) else f.functional.namespace + ) + if namespace not in registrations[kernel_namespace]: + registrations[kernel_namespace] = defaultdict(list) + registrations[kernel_namespace][namespace].extend( + reg_gen(f), + ) + + for kernel_namespace in ns_definitions: + if len(ns_definitions[kernel_namespace]) == 0: + continue + ns_helper = NamespaceHelper(namespace_str=kernel_namespace) + registration_body = "" + for namespace in registrations[kernel_namespace]: + if not registrations[kernel_namespace][namespace]: + continue + registration_body += f""" +TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ + {newline.join(registrations[kernel_namespace][namespace])} +}};""" + definitions.extend( + fm.substitute_with_template( + "RegisterDispatchDefinitions.ini", + lambda: { + "ns_prologue": ns_helper.prologue, + "ns_epilogue": ns_helper.epilogue, + "dispatch_helpers": dest.gen_registration_helpers(backend_idx) + if gen_dispatch_helpers + else [], + "dispatch_anonymous_definitions": anonymous_definitions[ + kernel_namespace + ], + "static_init_dispatch_registrations": "" + if skip_dispatcher_op_registration + else registration_body, + "deferred_dispatch_registrations": "", + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": ns_definitions[kernel_namespace], + }, + ).split(newline) + ) + + return definitions + + +# Return native function declarations grouped by dispatch key and custom namespace. +# Used in CPUFunctions_inl.h and etc. +def get_namespaced_declaration( + *, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_idx: BackendIndex, + selector: SelectiveBuilder, + rocm: bool, + symint: bool, +) -> list[str]: + declarations: list[str] = [] + ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) + newline = "\n" + func = dest.RegisterDispatchKey( + backend_idx, + Target.NAMESPACED_DECLARATION, + selector, + rocm=rocm, + class_method_name=None, + skip_dispatcher_op_registration=False, + symint=symint, + ) + for f in grouped_native_functions: + namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( + "native", dispatch_key.lower() + ) + + ns_grouped_kernels[namespace].extend( + func(f), + ) + + for namespace, kernels in ns_grouped_kernels.items(): + if len(kernels) == 0: + continue + ns_helper = NamespaceHelper( + namespace_str=namespace, entity_name="", max_level=3 + ) + ordered_kernels = list(OrderedDict.fromkeys(kernels)) + declarations.extend( + f""" +{ns_helper.prologue} +{newline.join(ordered_kernels)} +{ns_helper.epilogue} + """.split( + newline + ) + ) + return declarations + + +# Return native function schema registration code for aten and other namespaces. +def get_native_function_schema_registrations( + *, + native_functions: Sequence[NativeFunction], + schema_selector: SelectiveBuilder, +) -> tuple[list[str], str]: + ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) + for native_function in native_functions: + ns_native_functions[native_function.namespace].append(native_function) + schema_registrations = "" + aten_schema_registrations = [] + custom_namespace = None + for namespace, funcs in ns_native_functions.items(): + schema_registrations_body = list( + mapMaybe(RegisterSchema(schema_selector), funcs) + ) + # NB: we have to separate aten namespace registration from other namespaces, + # because in the template we hardcoded an operator for ATen already. + if namespace == "aten": + aten_schema_registrations = schema_registrations_body + else: + custom_namespace = namespace + tab = "\t" + # if the namespace is predefined, we should use define a library fragment + # instead of a new library + torch_library_macro = ( + "TORCH_LIBRARY_FRAGMENT" + if namespace in FRAGMENT_NAMESPACES + else "TORCH_LIBRARY" + ) + schema_registrations += f""" +{torch_library_macro}({custom_namespace}, m) {{ + {tab.join(schema_registrations_body)} +}};""" + return (aten_schema_registrations, schema_registrations) + + +def gen_aggregated_headers( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + structured_native_functions: Sequence[NativeFunctionsGroup], + static_dispatch_idx: list[BackendIndex], + selector: SelectiveBuilder, + backend_indices: dict[DispatchKey, BackendIndex], + cpu_fm: FileManager, + cuda_fm: FileManager, + functions_keys: set[DispatchKey], + dispatch_keys: Sequence[DispatchKey], + rocm: bool, +) -> None: + # Buck doesn't support dynamic output files, so we aggregate all operator + # headers into a single file + cpu_fm.write( + "NativeMetaFunctions.h", + lambda: { + "NativeMetaFunctions_includes": [], + "NativeMetaFunctions_declarations": list( + mapMaybe(compute_meta_function_declaration, structured_native_functions) + ), + }, + ) + method_native_functions = [ + fn for fn in native_functions if Variant.method in fn.variants + ] + non_method_native_functions = [ + fn for fn in native_functions if fn not in method_native_functions + ] + cpu_fm.write( + "MethodOperators.h", + lambda: { + "MethodOperators_includes": [], + "MethodOperators_declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + method_native_functions, + ) + ), + }, + ) + cpu_fm.write( + "Operators.h", + lambda: { + "Operators_includes": ["#include "], + "Operators_declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + non_method_native_functions, + ) + ), + }, + ) + cpu_fm.write( + "Functions.h", + lambda: { + "static_dispatch_extra_headers": static_dispatch_extra_headers( + static_dispatch_idx + ), + "Functions_includes": ["#include "], + "Functions_declarations": list( + mapMaybe( + ComputeFunction(), + native_functions, + ) + ), + }, + ) + declarations = get_native_function_declarations( + grouped_native_functions=grouped_native_functions, + backend_indices=backend_indices, + ) + cpu_fm.write( + "NativeFunctions.h", + lambda: { + "NativeFunctions_includes": ["#include "], + "NativeFunctions_declarations": declarations, + }, + ) + + for dispatch_key in dispatch_keys: + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + if dispatch_key in functions_keys: + inl_headers = f"#include " + + fm.write_with_template( + f"{dispatch_key}Functions.h", + "DispatchKeyFunctions.h", + lambda: { + "dispatch_key": str(dispatch_key), + "inline_headers": inl_headers, + }, + ) + fm.write_with_template( + f"{dispatch_key}Functions_inl.h", + "DispatchKeyFunctions_inl.h", + lambda: { + "DispatchKeyFunctions_inl_includes": [], + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_declarations": get_namespaced_declaration( + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_indices[dispatch_key], + selector=selector, + rocm=rocm, + symint=True, + ), + }, + ) + + del fm + + +def gen_per_operator_headers( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + static_dispatch_idx: list[BackendIndex], + selector: SelectiveBuilder, + backend_indices: dict[DispatchKey, BackendIndex], + cpu_fm: FileManager, + cuda_fm: FileManager, + ops_fm: FileManager, + functions_keys: set[DispatchKey], + dispatch_keys: Sequence[DispatchKey], + rocm: bool, +) -> None: + # For CMake builds, split operator declarations into separate headers in + # the ATen/ops folder to split up header dependencies + functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list) + for fn in native_functions: + functions_by_root_name[fn.root_name].append(fn) + + grouped_functions_by_root_name: dict[ + str, list[NativeFunction | NativeFunctionsGroup] + ] = defaultdict(list) + for group in grouped_native_functions: + name = group.root_name + grouped_functions_by_root_name[name].append(group) + + for name, functions in functions_by_root_name.items(): + ops_fm.write_with_template( + f"{name}_ops.h", + "Operator.h", + lambda: { + "declarations": list( + mapMaybe( + ComputeOperators( + Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + functions, + ) + ), + }, + ) + + ops_fm.write_with_template( + f"{name}.h", + "Function.h", + lambda: { + "static_dispatch_ops_headers": list( + mapMaybe( + lambda fn: static_dispatch_ops_header( + fn, backend_index=static_dispatch_idx + ), + functions, + ) + ), + "operator_includes": f"#include ", + "function_definitions": list( + mapMaybe( + ComputeFunction(), + functions, + ) + ), + }, + ) + + grouped_functions = grouped_functions_by_root_name.get(name, []) + structured_functions = [ + fn + for fn in grouped_functions + if isinstance(fn, NativeFunctionsGroup) and fn.structured + ] + is_structured = len(structured_functions) > 0 + + if is_structured: + ops_fm.write_with_template( + f"{name}_meta.h", + "NativeMetaFunction.h", + lambda: { + "meta_function_declarations": list( + mapMaybe( + compute_meta_function_declaration, structured_functions + ) + ), + }, + ) + declarations = get_native_function_declarations( + grouped_native_functions=grouped_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ) + ops_fm.write_with_template( + f"{name}_native.h", + "NativeFunction.h", + lambda: { + "extra_includes": ( + f"#include " if is_structured else [] + ), + "native_function_declarations": declarations, + }, + ) + + for category, suffix in [ + ("Functions", ""), + ("Operators", "_ops"), + ("NativeMetaFunctions", "_meta"), + ("NativeFunctions", "_native"), + ]: + cpu_fm.write( + f"{category}.h", + lambda: { + f"{category}_includes": [ + f"#include " + for name in sorted(functions_by_root_name.keys()) + ], + f"{category}_declarations": [], + }, + ) + + for dispatch_key in dispatch_keys: + if dispatch_key not in functions_keys: + continue + + dispatch_namespace = dispatch_key.lower() + dispatch_names = [] + + for name, functions in functions_by_root_name.items(): + grouped_functions = grouped_functions_by_root_name.get(name, []) + declarations = list( + concatMap( + dest.RegisterDispatchKey( + backend_indices[dispatch_key], + Target.NAMESPACED_DECLARATION, + selector, + rocm=rocm, + symint=True, + class_method_name=None, + skip_dispatcher_op_registration=False, + ), + grouped_functions, + ) + ) + + if len(declarations) == 0: + continue + + dispatch_names.append(name) + ops_fm.write_with_template( + f"{name}_{dispatch_namespace}_dispatch.h", + "DispatchKeyFunction.h", + lambda: { + "dispatch_namespace": dispatch_namespace, + "dispatch_namespaced_declarations": declarations, + }, + ) + + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + inl_headers = f"#include " + + fm.write_with_template( + f"{dispatch_key}Functions.h", + "DispatchKeyFunctions.h", + lambda: { + "dispatch_key": str(dispatch_key), + "inline_headers": inl_headers, + }, + ) + fm.write_with_template( + f"{dispatch_key}Functions_inl.h", + "DispatchKeyFunctions_inl.h", + lambda: { + "dispatch_namespace": dispatch_namespace, + "DispatchKeyFunctions_inl_includes": [ + f"#include " + for name in sorted(dispatch_names) + ], + "dispatch_namespaced_declarations": [], + }, + ) + del fm + + cpu_fm.write( + "MethodOperators.h", + lambda: { + "MethodOperators_includes": sorted( + f"#include " + for name, functions in functions_by_root_name.items() + if any(Variant.method in fn.variants for fn in functions) + ), + "MethodOperators_declarations": [], + }, + ) + + +def gen_headers( + *, + native_functions: Sequence[NativeFunction], + valid_tags: set[str], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + structured_native_functions: Sequence[NativeFunctionsGroup], + static_dispatch_idx: list[BackendIndex], + selector: SelectiveBuilder, + backend_indices: dict[DispatchKey, BackendIndex], + core_fm: FileManager, + cpu_fm: FileManager, + cuda_fm: FileManager, + ops_fm: FileManager, + dispatch_keys: Sequence[DispatchKey], + functions_keys: set[DispatchKey], + rocm: bool, + per_operator_headers: bool, +) -> None: + if per_operator_headers: + gen_per_operator_headers( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + ops_fm=ops_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=rocm, + ) + else: + gen_aggregated_headers( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=rocm, + ) + + core_fm.write( + "TensorBody.h", + lambda: { + "tensor_method_declarations": list( + mapMaybe( + ComputeTensorMethod( + target=Target.DECLARATION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + native_functions, + ) + ), + "tensor_method_definitions": list( + mapMaybe( + ComputeTensorMethod( + target=Target.DEFINITION, + static_dispatch_backend_indices=static_dispatch_idx, + ), + native_functions, + ) + ), + }, + ) + + cpu_fm.write( + "RedispatchFunctions.h", + lambda: { + "function_redispatch_definitions": list( + mapMaybe(ComputeRedispatchFunction(), native_functions) + ), + }, + ) + + cpu_fm.write( + "RegistrationDeclarations.h", + lambda: { + "registration_declarations": [ + compute_registration_declarations(f, backend_indices) + for f in native_functions + ], + }, + ) + + cpu_fm.write( + "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) + ) + + def gen_aten_interned_strings() -> dict[str, str]: + attrs: set[str] = set() # All function argument names + names = set() # All ATen function names + for func in native_functions: + names.add(str(func.func.name.name)) + # Some operators don't have a functional variant but we still create a + # symbol without the underscore + names.add(func.func.name.name.base) + + attrs.update(arg.name for arg in func.func.schema_order_arguments()) + + # These are keywords in C++, so aren't valid symbol names + # https://en.cppreference.com/w/cpp/language/operator_alternative + names -= { + "and", + "and_eq", + "bitand", + "bitor", + "compl", + "not", + "not_eq", + "or", + "or_eq", + "xor", + "xor_eq", + } + + return { + "aten_symbols": " \\\n".join( + [f"_(aten, {name})" for name in sorted(names)] + ), + "attr_symbols": " \\\n".join( + [f"_(attr, {name})" for name in sorted(attrs)] + ), + } + + core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) + + def gen_tags_enum() -> dict[str, str]: + return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} + + core_fm.write("enum_tag.h", gen_tags_enum) + + +def gen_source_files( + *, + native_functions: Sequence[NativeFunction], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + structured_native_functions: Sequence[NativeFunctionsGroup], + view_groups: Sequence[NativeFunctionsViewGroup], + selector: SelectiveBuilder, + static_dispatch_idx: list[BackendIndex], + backend_indices: dict[DispatchKey, BackendIndex], + aoti_fm: FileManager, + core_fm: FileManager, + cpu_fm: FileManager, + cpu_vec_fm: FileManager, + cuda_fm: FileManager, + dispatch_keys: Sequence[DispatchKey], + functions_keys: set[DispatchKey], + rocm: bool, + force_schema_registration: bool, + per_operator_headers: bool, + skip_dispatcher_op_registration: bool, + update_aoti_c_shim: bool, +) -> None: + extra_cuda_headers = """\ +#include +#include +#include +#include """ + if rocm: + extra_cuda_headers = """\ +#include +#include +#include +#include """ + + for dispatch_key in dispatch_keys: + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + + if per_operator_headers: + + def operator_headers() -> list[str]: + headers = [] + for g in grouped_native_functions: + is_registered = False + if backend_index.has_kernel(g): + is_registered = True + # The above has_kernel test on a group will only test for + # the existence of out dispatch, because that's how + # structured kernels work. But sometimes functions can be + # grouped but not be structured, and then you need to check + # each individual piece, as they may have manual dispatch + # entries. + elif isinstance(g, NativeFunctionsGroup) and any( + backend_index.has_kernel(fn) for fn in g.functions() + ): + is_registered = True + # TODO: this condition is a bit questionable + # (It has to do with the fact that structured kernels get generated kernels + # to the Meta + CompositeExplicitAutogradNonFunctional keys). + elif g.structured and dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ): + is_registered = True + if not is_registered: + continue + + headers.append(f"#include ") + if ( + dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + headers.append(f"#include ") + if dispatch_key in functions_keys: + headers.append( + f"#include " + ) + + return sorted(set(headers)) + + else: + + def operator_headers() -> list[str]: + headers = ["#include "] + if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + headers.append("#include ") + if dispatch_key in functions_keys: + headers.append(f"#include ") + return headers + + backend_index = backend_indices[dispatch_key] + ns_grouped_native_functions = defaultdict(list) + for grouped_native_function in grouped_native_functions: + namespace = ( + grouped_native_function.namespace + if isinstance(grouped_native_function, NativeFunction) + else grouped_native_function.functional.namespace + ) + ns_grouped_native_functions[namespace].append(grouped_native_function) + + dispatch_namespace = str(dispatch_key).lower() + + # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated + # compilation will fail when `-Werror=unused-function` flag is set + gen_dispatch_helpers: bool = ( + dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor + ) + + dispatch_definitions = get_native_function_definitions( + fm=fm, + grouped_native_functions=grouped_native_functions, + dispatch_key=dispatch_key, + backend_idx=backend_index, + selector=selector, + rocm=rocm, + symint=True, + skip_dispatcher_op_registration=skip_dispatcher_op_registration, + gen_dispatch_helpers=gen_dispatch_helpers, + ) + fm.write_with_template( + f"Register{dispatch_key}.cpp", + "RegisterDispatchKey.cpp", + lambda: { + "extra_cuda_headers": extra_cuda_headers + if is_cuda_dispatch_key(dispatch_key) + else "", + "external_backend_headers": "", + "dispatch_headers": dest.gen_registration_headers( + backend_index, per_operator_headers, rocm + ), + "ops_headers": operator_headers(), + "dispatch_helpers": "", + "dispatch_definitions": dispatch_definitions, + }, + ) + + for g in structured_native_functions: + if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): + continue + name = g.functional.func.name.name + if dispatch_key is DispatchKey.CPU: + assert fm is cpu_fm + fm.write_with_template( + f"UfuncCPU_{name}.cpp", + "UfuncCPU.cpp", + lambda: { + "meta_declaration": compute_meta_function_declaration(g), + "native_declaration": dest.compute_native_function_declaration( + g, backend_indices[dispatch_key] + ), + "native_definitions": dest.compute_ufunc_cpu(g), + }, + ) + cpu_vec_fm.write_with_template( + f"UfuncCPUKernel_{name}.cpp", + "UfuncCPUKernel.cpp", + lambda: { + "name": name, + "native_definitions": dest.compute_ufunc_cpu_kernel(g), + }, + ) + elif dispatch_key is DispatchKey.CUDA: + cuda_headers = "#include " + if rocm: + cuda_headers = "#include " + fm.write_with_template( + f"UfuncCUDA_{name}.cu", + "UfuncCUDA.cu", + lambda: { + "name": name, + "cuda_headers": cuda_headers, + "meta_declaration": compute_meta_function_declaration(g), + "native_declaration": dest.compute_native_function_declaration( + g, backend_indices[dispatch_key] + ), + "native_definitions": dest.compute_ufunc_cuda(g), + }, + ) + else: + raise AssertionError(f"unrecognized {dispatch_key} for ufunc") + + structured_func_group_dict = {} + for func_group in structured_native_functions: + for func in func_group.functions(): + if func.structured_delegate is not None: + structured_func_group_dict[func.structured_delegate] = func_group + break + + if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA): + fallbacks = {} + for func in native_functions: + op_name = get_fallback_op_name(func) + if op_name in inductor_fallback_ops: + fallbacks[op_name] = func + fallback_native_functions = tuple( + value for _, value in sorted(fallbacks.items()) + ) + + # header files were checked in for ABI-compatiblilty checking + header_file_name = f"c_shim_{dispatch_key.lower()}.h" + new_header = gen_aoti_c_shim( + fallback_native_functions, + structured_func_group_dict, + dispatch_key, + backend_indices, + header=True, + includes="", + ) + if update_aoti_c_shim: + aoti_fm.write( + header_file_name, + lambda: new_header, + ) + else: + try: + with open( + os.path.join(aoti_fm.install_dir, header_file_name) + ) as old_file: + old_header = old_file.read() + assert ( + old_header == new_header + ), """ + +WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This +indicates an AOTInductor fallback operator ABI backward compatibility breakage!!! +Only in a limited number of situations, this is allowed: + +1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py. +If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing +C shim header files. + +2. You added a new default argument to an existing fallback op. This is clearly a BC breaking +change in the AOTInductor land. In this case, you need to keep a manual copy of that existing +fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version +number of that fallback op in the newly generated C shim files, and update the cpp wrapper +codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance. + + """ + except FileNotFoundError: + print( + f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found" + ) + + # cpp files are always generated on-the-fly + def headers_for_aoti() -> str: + headers = [] + for func in fallback_native_functions: + header = get_header_for_aoti( + func, structured_func_group_dict, dispatch_key, backend_indices + ) + if header is not None: + headers.append(header) + return "\n".join(sorted(set(headers))) + + extra_headers = ( + extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" + ) + + aoti_fm.write( + f"c_shim_{dispatch_key.lower()}.cpp", + lambda: gen_aoti_c_shim( + fallback_native_functions, + structured_func_group_dict, + dispatch_key, + backend_indices, + header=False, + includes=headers_for_aoti() + "\n" + extra_headers, + ), + ) + + del fm + + # BackendSelect is generated specially + def gen_backend_select() -> dict[str, list[str]]: + relevant_fns = [ + fn for fn in native_functions if needs_backend_select(fn, selector) + ] + return { + "ops_headers": [ + f"#include " for fn in relevant_fns + ], + "backend_select_method_definitions": list( + mapMaybe( + ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns + ) + ), + "backend_select_function_registrations": list( + mapMaybe( + ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns + ) + ), + } + + cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select) + + schema_selector = selector + if force_schema_registration: + schema_selector = SelectiveBuilder.get_nop_selector() + + ( + aten_schema_registrations, + schema_registrations, + ) = get_native_function_schema_registrations( + native_functions=native_functions, schema_selector=schema_selector + ) + cpu_fm.write( + "RegisterSchema.cpp", + lambda: { + "aten_schema_registrations": [] + if skip_dispatcher_op_registration + else aten_schema_registrations, + "schema_registrations": [] + if skip_dispatcher_op_registration + else schema_registrations, + }, + ) + + def key_func( + fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> str: + return fn.root_name + + cpu_fm.write_sharded( + "Operators.cpp", + native_functions, + key_fn=key_func, + env_callable=lambda fn: { + "operator_headers": [f"#include "], + "definitions": [ + ComputeOperators( + Target.DEFINITION, + static_dispatch_backend_indices=static_dispatch_idx, + )(fn) + ], + }, + base_env={ + "static_dispatch_extra_headers": static_dispatch_extra_headers( + static_dispatch_idx + ), + }, + num_shards=5, + sharded_keys={ + "operator_headers", + "definitions", + "static_dispatch_extra_headers", + }, + ) + + cpu_fm.write("Functions.cpp", dict) + + core_fm.write("TensorMethods.cpp", dict) + + core_fm.write( + "ATenOpList.cpp", + lambda: { + "aten_ops": list(mapMaybe(compute_aten_op, native_functions)), + }, + ) + + def functionalization_env_callable( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> dict[str, list[str]]: + def gen_op_headers( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> list[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + + return { + "ops_headers": gen_op_headers(g), + "func_definitions": gen_functionalization_definition( + selector, + g, + ), + "func_registrations": gen_functionalization_registration( + selector, + g, + backend_indices[DispatchKey.CompositeImplicitAutograd], + ), + } + + all_groups: list[ + NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup + ] = list(structured_native_functions) + list( + view_groups # type: ignore[assignment, arg-type, operator] + ) + # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly. + # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because: + # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) + # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. + # Although this could go away long-term if we add a dedicated dispatch key for decompositions. + structured_map: dict[OperatorName, NativeFunction] = { + f.func.name: f + for f in concatMap(lambda g: list(g.functions()), structured_native_functions) + } + view_map: dict[OperatorName, NativeFunction] = { + f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) + } + for f in native_functions: + if f.func.name not in structured_map and f.func.name not in view_map: + all_groups.append(f) + + cpu_fm.write_sharded( + "RegisterFunctionalization.cpp", + all_groups, + key_fn=key_func, + env_callable=functionalization_env_callable, + num_shards=4, + sharded_keys={ + "ops_headers", + "func_definitions", + "func_registrations", + "func_add_back_views_definitions", + "func_add_back_views_registrations", + }, + ) + + cpu_fm.write( + "FunctionalInverses.h", + lambda: { + "view_inverse_declarations": list( + mapMaybe( + lambda g: gen_functionalization_view_inverse_declaration( + selector, g + ), + view_groups, + ) + ) + }, + ) + + # Note [view_copy NativeFunctions] + # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd + # needs to have a corresponding non-aliasing {view}_copy variant. + # Backends that use functionalization and don't know how to handle aliasing ops + # are expected to implement kernels for these {view}_copy kernels instead. + # The code for {view}_copy operators in core is pretty boilerplate-heavy however, + # so we codegen the following: + # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator. + # These are never explicitly invoked by the functionalization pass, + # but they could theoretically be called from user code (I added these kernels for completeness, + # since the ops are part of the public API). + # (2) A derivative formula for every {view}_copy operator + # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, + # so rather than stamping all of the entries out in derivatives.yaml, + # we codegen them in. + # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. + cpu_fm.write( + "CompositeViewCopyKernels.cpp", + lambda: { + "ops_headers": [ + "\n".join( + f"#include \n" + # NB: this include is important as it ensures we + # set the visibility on generated view_copy kernels + # correctly + f"#include " + for f in ( + [g.view] if g.view_copy is None else [g.view, g.view_copy] + ) + ) + for g in view_groups + ] + + [ + "\n".join( + f"#include \n" + # NB: this include is also important for correct visibility + f"#include " + for f in [g.inplace, g.mutable, g.functional] + if f is not None and "generated" not in f.tags + ) + for g in structured_native_functions + ], + "CompositeViewCopyKernel_Definitions": list( + mapMaybe( + GenCompositeViewCopyKernel( + backend_indices[ + DispatchKey.CompositeExplicitAutogradNonFunctional + ] + ), + view_groups, + ) + ), + "GeneratedCompositeFunctional_Definitions": list( + mapMaybe( + gen_composite_functional_kernel, + structured_native_functions, + ) + ), + "GeneratedCompositeOut_Definitions": list( + mapMaybe( + gen_composite_out_kernel, + structured_native_functions, + ) + ), + }, + ) + + +def gen_declarations_yaml( + cpu_fm: FileManager, native_functions: Sequence[NativeFunction] +) -> None: + cpu_fm.write( + "Declarations.yaml", + lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]), + ) + + +def get_torchgen_root() -> Path: + """ + If you're depending on torchgen out-of-tree, you can use the root to figure + out the path to native_functions.yaml + """ + return Path(__file__).parent.resolve() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate ATen source files") + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for ATen", + default="aten/src/ATen", + ) + parser.add_argument( + "-o", + "--output-dependencies", + help="output a list of dependencies into the given file and exit", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="run without writing any files (still updates outputs)", + ) + parser.add_argument( + "--per-operator-headers", + action="store_true", + help="generate separate headers per operator in ATen/ops", + ) + parser.add_argument( + "-d", + "--install-dir", + "--install_dir", + help="output directory", + default="build/aten/src/ATen", + ) + parser.add_argument( + "--aoti-install-dir", + "--aoti_install_dir", + help="output directory for AOTInductor shim", + default="torch/csrc/inductor/aoti_torch/generated", + ) + parser.add_argument( + "--rocm", + action="store_true", + help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", + ) + parser.add_argument( + "--mps", + action="store_true", + help="Generate MPS registration code when set", + ) + # TODO: --op-registration-whitelist will be removed when all call-sites + # for gen.py are moved over to using the operator YAML file for mobile + # custom build. + parser.add_argument( + "--op-registration-whitelist", + "--op_registration_whitelist", + nargs="*", + help="filter op registrations by the whitelist (if set); " + "each item is `namespace`::`operator name` without overload name; " + "e.g.: aten::empty aten::conv2d ...", + ) + parser.add_argument( + "--op-selection-yaml-path", + "--op_selection_yaml_path", + help="Provide a path to the operator selection (for custom build) YAML " + "that contains the information about the set of selected operators " + "and their categories (training, ...). Each operator is either a " + "full operator name with overload or just a bare operator name. " + "The operator names also contain the namespace prefix (e.g. aten::)", + ) + parser.add_argument( + "--backend-whitelist", + "--backend_whitelist", + nargs="*", + help="filter dispatch backend by the whitelist (if set), " + "e.g.: CPU CUDA QuantizedCPU ...", + ) + parser.add_argument( + "--static-dispatch-backend", + "--static_dispatch_backend", + nargs="*", + help="generate static dispatch code for the specific backend (if set)", + ) + parser.add_argument( + "--skip-dispatcher-op-registration", + "--skip_dispatcher_op_registration", + action="store_true", + help="Avoid registering operators into the dispatcher.", + ) + parser.add_argument( + "--force-schema-registration", + "--force_schema_registration", + action="store_true", + help="force it to generate schema-only registrations for all ops, including" + "those that are not listed on --op-registration-whitelist", + ) + parser.add_argument( + "--generate", + type=str, + nargs="*", + choices=["headers", "sources", "declarations_yaml"], + default=["headers", "sources", "declarations_yaml"], + help="Generate only a subset of files", + ) + parser.add_argument( + "--update-aoti-c-shim", + action="store_true", + help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. " + "WARNING: Do not use this unless you are sure what you are doing!!!", + ) + + options = parser.parse_args() + + selector = get_custom_build_selector( + options.op_registration_whitelist, + options.op_selection_yaml_path, + ) + + native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") + + from torchgen.model import dispatch_keys + + # TODO: stop generating CUDA kernels for non-CUDA builds + ignore_keys = set() + if not options.mps: + ignore_keys.add(DispatchKey.MPS) + + if DispatchKey.MPS in dispatch_keys: + del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] + + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) + valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + + grouped_native_functions = get_grouped_native_functions(native_functions) + + structured_native_functions = [ + g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) + ] + native_functions_with_view_groups = get_grouped_by_view_native_functions( + native_functions + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + # NB: It is mandatory to NOT use os.path.join here, as the install directory + # will eventually be ingested by cmake, which does not respect Windows style + # path slashes. If you switch this to use os.path.join, you'll get an error + # like: + # + # Syntax error in cmake code when parsing string + # + # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h + # + # Invalid character escape '\c'. + core_install_dir = f"{options.install_dir}/core" + Path(core_install_dir).mkdir(parents=True, exist_ok=True) + ops_install_dir = f"{options.install_dir}/ops" + Path(ops_install_dir).mkdir(parents=True, exist_ok=True) + aoti_install_dir = f"{options.aoti_install_dir}" + Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) + + core_fm = make_file_manager(options=options, install_dir=core_install_dir) + cpu_fm = make_file_manager(options=options) + cpu_vec_fm = make_file_manager(options=options) + cuda_fm = make_file_manager(options=options) + ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) + aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) + + # Only a limited set of dispatch keys get CPUFunctions.h headers generated + # for them; this is the set + functions_keys = { + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.Meta, + } + if options.mps: + functions_keys.add(DispatchKey.MPS) + + if options.backend_whitelist: + dispatch_keys = [ + k + for k in dispatch_keys + if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist + ] + + static_dispatch_idx: list[BackendIndex] = [] + if options.static_dispatch_backend: + static_dispatch_idx = [ + backend_indices[DispatchKey.parse(key)] + for key in options.static_dispatch_backend + ] + for key in options.static_dispatch_backend: + dp_key = DispatchKey.parse(key) + if dp_key not in functions_keys: + functions_keys.add(dp_key) + + if "sources" in options.generate: + gen_source_files( + native_functions=native_functions, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + view_groups=view_groups, + selector=selector, + static_dispatch_idx=static_dispatch_idx, + backend_indices=backend_indices, + aoti_fm=aoti_fm, + core_fm=core_fm, + cpu_fm=cpu_fm, + cpu_vec_fm=cpu_vec_fm, + cuda_fm=cuda_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=options.rocm, + force_schema_registration=options.force_schema_registration, + per_operator_headers=options.per_operator_headers, + skip_dispatcher_op_registration=options.skip_dispatcher_op_registration, + update_aoti_c_shim=options.update_aoti_c_shim, + ) + + if "headers" in options.generate: + gen_headers( + native_functions=native_functions, + valid_tags=valid_tags, + grouped_native_functions=grouped_native_functions, + structured_native_functions=structured_native_functions, + static_dispatch_idx=static_dispatch_idx, + selector=selector, + backend_indices=backend_indices, + core_fm=core_fm, + cpu_fm=cpu_fm, + cuda_fm=cuda_fm, + ops_fm=ops_fm, + dispatch_keys=dispatch_keys, + functions_keys=functions_keys, + rocm=options.rocm, + per_operator_headers=options.per_operator_headers, + ) + + if "declarations_yaml" in options.generate: + gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) + + if options.output_dependencies: + depfile_path = Path(options.output_dependencies).resolve() + depfile_name = depfile_path.name + depfile_stem = depfile_path.stem + + for fm, prefix in [ + (cpu_fm, ""), + (cpu_vec_fm, "cpu_vec_"), + (core_fm, "core_"), + (cuda_fm, "cuda_"), + (ops_fm, "ops_"), + ]: + varname = prefix + depfile_stem + path = depfile_path.parent / (prefix + depfile_name) + fm.write_outputs(varname, str(path)) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_aoti_c_shim.py b/.venv/lib/python3.11/site-packages/torchgen/gen_aoti_c_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba12f88bdd9d0620edd6e405186f417aafdbe90 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_aoti_c_shim.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +import textwrap +from dataclasses import dataclass +from typing import Sequence + +from torchgen.api.types import DispatcherSignature +from torchgen.api.types.signatures import CppSignature, CppSignatureGroup +from torchgen.context import method_with_native_function +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + DispatchKey, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, + OperatorName, + OptionalType, + Type, +) +from torchgen.utils import mapMaybe + + +base_type_to_c_type = { + BaseTy.Tensor: "AtenTensorHandle", + BaseTy.bool: "int32_t", # Use int to pass bool + BaseTy.int: "int64_t", + BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt + BaseTy.Scalar: "double", # Use double to pass both integer and floating point + BaseTy.float: "double", # TODO: how about other floating point types? + BaseTy.str: "const char*", + BaseTy.DeviceIndex: "int32_t", + BaseTy.Layout: "int32_t", # Represent enum as int + BaseTy.MemoryFormat: "int32_t", # Represent enum as int + BaseTy.ScalarType: "int32_t", # Represent enum as int + BaseTy.Generator: "AtenGeneratorHandle", +} + +base_type_to_aten_type = { + BaseTy.Tensor: "at::Tensor", + BaseTy.bool: "bool", + BaseTy.int: "int64_t", + BaseTy.SymInt: "c10::SymInt", + BaseTy.Scalar: "c10::Scalar", + BaseTy.float: "double", + BaseTy.str: "c10::string_view", + BaseTy.DeviceIndex: "c10::DeviceIndex", + BaseTy.Layout: "c10::Layout", + BaseTy.MemoryFormat: "c10::MemoryFormat", + BaseTy.ScalarType: "c10::ScalarType", + BaseTy.Generator: "at::Generator", +} + +base_type_to_callsite_expr = { + BaseTy.Tensor: "*tensor_handle_to_tensor_pointer", + BaseTy.bool: "", + BaseTy.int: "", + BaseTy.SymInt: "", + BaseTy.Scalar: "", + BaseTy.float: "", + BaseTy.str: "", + BaseTy.DeviceIndex: "static_cast", + BaseTy.Layout: "static_cast", + BaseTy.MemoryFormat: "static_cast", + BaseTy.ScalarType: "static_cast", + BaseTy.Generator: "*generator_handle_to_generator_pointer", +} + + +# convert args to C types, names in declarations, and expressions in function bodies +def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return] + if isinstance(typ, BaseType): + if typ.name in base_type_to_c_type: + return ( + [base_type_to_c_type[typ.name]], + [name], + [base_type_to_aten_type[typ.name]], + [ + f"{base_type_to_callsite_expr[typ.name]}({name})" + if base_type_to_callsite_expr[typ.name] + else name + ], + ) + elif typ.name == BaseTy.Device: + return ( + ["int32_t", "int32_t"], + [name, name + "_index_"], + ["c10::Device"], + [ + f"c10::Device(static_cast({name}), static_cast({name}_index_))" + ], + ) + else: + # TODO: BaseTy.Dimname, etc. + raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}") + elif isinstance(typ, OptionalType): + c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name( + typ.elem, name + ) + j = 0 # index for names + new_aten_types = [] + new_callsite_exprs = [] + for aten_type in aten_types: + # Use pointer to denote optional type + c_types[j] = c_types[j] + "*" + if aten_type.startswith("c10::ArrayRef<"): + # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument + new_aten_types.append(f"::std::optional<{aten_type}>") + base_type = aten_type[len("c10::ArrayRef<") : -1] + new_callsite_exprs.append( + f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})" + ) + j += 2 + elif aten_type == "c10::Device": + # Device is passed as device_type + device_index + new_aten_types.append("::std::optional") + new_callsite_exprs.append( + f"pointer_to_optional_device({names[j]}, {names[j+1]})" + ) + j += 2 + else: + new_aten_types.append(f"::std::optional<{aten_type}>") + new_callsite_exprs.append( + f"pointer_to_optional<{aten_type}>({names[j]})" + ) + j += 1 + + return ( + c_types, + names, + new_aten_types, + new_callsite_exprs, + ) + elif isinstance(typ, ListType): + # Need to explictly pass the list as pointer + length + c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name) + assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ) + + # The list content should never be modified + c_types[0] = f"const {c_types[0]}*" + c_types.append("int64_t") + name = names[0] + names.append(name + "_len_") + + atype = aten_types[0] + callsite_exprs = [] + if atype == "bool": + # no converter from std::vector to c10::ArrayRef + # construct std::array instead + assert typ.size is not None + callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})") + elif atype == "::std::optional": + # convert from std::vector<::std::optional> to c10::List<::std::optional> + callsite_exprs.append( + f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))" + ) + else: + callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)") + + aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types] + return ( + c_types, + names, + aten_types, + callsite_exprs, + ) + + +def zip_type_and_name(types: list[str], names: list[str]) -> list[str]: + return [typ + " " + name for typ, name in zip(types, names)] + + +# Generate argument declarations and callsite expressions +def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]: + types = [] + new_names = [] + callsite_exprs = [] + for arg in flat_arguments: + new_types, names, _, new_callsite_exprs = convert_arg_type_and_name( + arg.type, arg.name + ) + types.extend(new_types) + new_names.extend(names) + callsite_exprs.extend(new_callsite_exprs) + return zip_type_and_name(types, new_names), callsite_exprs + + +# Return values are passed out as pointer arguments because all the C shim functions +# are expected to return AOTITorchError. +# Generate returns as declarations and callsite expressions +def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: + types = [] + names = [] + for idx, ret in enumerate(schema.returns): + names.append(f"ret{idx}") + if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type: + types.append(base_type_to_c_type[ret.type.name] + "*") + else: + raise NotImplementedError( + f"TODO: add support for return type {repr(ret.type)}" + ) + + def convert_return(typ: BaseType, val: str) -> str: + if typ.name == BaseTy.Tensor: + return f"new_tensor_handle(std::move({val}));" + elif typ.name == BaseTy.SymInt: + return f"{val}.expect_int()" + elif typ.name == BaseTy.Scalar: + return f"{val}.toDouble()" + else: + return val + + ret_pointer_can_be_null = False + unambiguous_name = schema.name.unambiguous_name() + for name in [ + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_cudnn_attention", + "convolution_backward", + ]: + if name in unambiguous_name: + ret_pointer_can_be_null = True + break + + callsite_exprs: list[str] = [] + for idx, ret in enumerate(schema.returns): + tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)" + assert isinstance(ret.type, BaseType) + rval = convert_return(ret.type, tmp) + if ret_pointer_can_be_null: + callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}") + else: + callsite_exprs.append(f"*{names[idx]} = {rval};") + + return zip_type_and_name(types, names), callsite_exprs + + +# gen.py generates header first and then src, so caching the result here to avoid duplicate work +declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {} + + +def gen_declaration_and_definition( + schema: FunctionSchema, device: str, backend_call: str +) -> tuple[str, str]: + func_name = schema.name.unambiguous_name() + + global declaration_definition_cache + if (func_name, device, backend_call) in declaration_definition_cache: + return declaration_definition_cache[(func_name, device, backend_call)] + + if schema.is_out_fn(): + # out_variant has out arguments in the front, and it's ok to ignore return values + # because C shim functions only return AOTITorchError + args, callsite_exprs = gen_arguments( + [*schema.arguments.out, *schema.arguments.flat_non_out] + ) + ret_assignments: list[str] = [] + else: + args, callsite_exprs = gen_arguments(schema.arguments.flat_all) + # ignore return values for inplace ops + ret_declarations, ret_assignments = ( + ([], []) if schema.name.name.inplace else gen_returns(schema) + ) + args.extend(ret_declarations) + + declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" + + tmp_result = "auto tmp_result = " if ret_assignments else "" + ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else "" + definition = f""" +{declaration} {{ + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ + {tmp_result}{backend_call}( +{textwrap.indent(', '.join(callsite_exprs), " ")} + );{textwrap.indent(ret_assignments_str, " ")} + }}); +}} +""" + declaration_definition_cache[(func_name, device, backend_call)] = ( + declaration, + definition, + ) + return declaration, definition + + +def gen_static_dispatch_backend_call_signature( + sig: CppSignature | DispatcherSignature, + f: NativeFunction, +) -> CppSignature: + sig = DispatcherSignature.from_schema(f.func) + cpp_sigs = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + if sig.symint and f.func.has_symint(): + cpp_sig = cpp_sigs.symint_signature + else: + cpp_sig = cpp_sigs.signature + assert cpp_sig is not None + return cpp_sig + + +def gen_static_dispatch_backend_call( + f: NativeFunction, + backend_index: BackendIndex, +) -> str: + sig = DispatcherSignature.from_schema(f.func) + cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) + return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" + + +def get_backend_index_for_aoti( + func: NativeFunction, + func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_indices: dict[DispatchKey, BackendIndex], +) -> BackendIndex | None: + backend_index = None + if backend_indices[dispatch_key].has_kernel(func) or ( + func.structured_delegate is not None + and func.structured_delegate in func_group_mapping + and backend_indices[dispatch_key].has_kernel( + func_group_mapping[func.structured_delegate] + ) + ): + backend_index = backend_indices[dispatch_key] + elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func): + # We need to create C shim wrappers for CompositeExplicitAutograd kernels + backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd] + elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel( + func + ): + # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels + backend_index = backend_indices[ + DispatchKey.CompositeExplicitAutogradNonFunctional + ] + elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func): + backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd] + + return backend_index + + +def get_header_for_aoti( + func: NativeFunction, + func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_indices: dict[DispatchKey, BackendIndex], +) -> str | None: + backend_index = get_backend_index_for_aoti( + func, func_group_mapping, dispatch_key, backend_indices + ) + return ( + None + if backend_index is None + else f"#include " + ) + + +def get_fallback_op_name(func: NativeFunction) -> str: + return ( + f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}" + if func.func.name.overload_name + else f"{func.namespace}.{func.func.name.name}.default" + ) + + +def gen_c_shim( + func: NativeFunction, + func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_indices: dict[DispatchKey, BackendIndex], + header: bool, +) -> str | None: + backend_index = get_backend_index_for_aoti( + func, func_group_mapping, dispatch_key, backend_indices + ) + if backend_index is None: + return None + + schema = func.func + device = dispatch_key.lower() + backend_call = gen_static_dispatch_backend_call( + func, + backend_index, + ) + + try: + if header: + declaration, _ = gen_declaration_and_definition( + schema, device, backend_call + ) + return f"AOTI_TORCH_EXPORT {declaration};" + else: + _, definition = gen_declaration_and_definition(schema, device, backend_call) + return definition + + except NotImplementedError: + return None + + +@dataclass(frozen=True) +class ShimGenerator: + func_group_mapping: dict[OperatorName, NativeFunctionsGroup] + dispatch_key: DispatchKey + backend_indices: dict[DispatchKey, BackendIndex] + header: bool # True to generate .h and False to generate .cpp + + @method_with_native_function + def __call__( + self, + func: NativeFunction, + ) -> str | None: + result = gen_c_shim( + func, + self.func_group_mapping, + self.dispatch_key, + self.backend_indices, + self.header, + ) + return result + + +def gen_aoti_c_shim( + native_functions: Sequence[NativeFunction], + func_group_mapping: dict[OperatorName, NativeFunctionsGroup], + dispatch_key: DispatchKey, + backend_indices: dict[DispatchKey, BackendIndex], + header: bool, + includes: str = "", +) -> str: + body = "\n".join( + list( + mapMaybe( + ShimGenerator( + func_group_mapping, dispatch_key, backend_indices, header + ), + native_functions, + ) + ) + ) + device = dispatch_key.lower() + + warning = """ +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details""" + + if header: + return f""" +{warning} + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" {{ +#endif + +{body} + +#ifdef __cplusplus +}} // extern "C" +#endif +""" + + else: + return f""" +{warning} + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#include +#else +{includes} +#endif + +using namespace torch::aot_inductor; + +{body}""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_backend_stubs.py b/.venv/lib/python3.11/site-packages/torchgen/gen_backend_stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..92a897a330f90377ae014e09a633cbc999cd9474 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_backend_stubs.py @@ -0,0 +1,611 @@ +from __future__ import annotations + +import argparse +import os +import re +from collections import Counter, defaultdict, namedtuple +from pathlib import Path +from typing import Sequence + +import yaml + +import torchgen.api.dispatcher as dispatcher +import torchgen.dest as dest +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import native_function_manager +from torchgen.gen import get_grouped_native_functions, parse_native_yaml +from torchgen.model import ( + BackendIndex, + BackendMetadata, + DispatchKey, + NativeFunction, + NativeFunctionsGroup, + OperatorName, +) +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target +from torchgen.yaml_utils import YamlLoader + + +# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. +# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping) +ParsedExternalYaml = namedtuple( + "ParsedExternalYaml", + ["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"], +) + + +def parse_backend_yaml( + backend_yaml_path: str, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_indices: dict[DispatchKey, BackendIndex], +) -> ParsedExternalYaml: + native_functions_map: dict[OperatorName, NativeFunction] = { + f.func.name: f + for f in concatMap( + lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), + grouped_native_functions, + ) + } + + with open(backend_yaml_path) as f: + yaml_values = yaml.load(f, Loader=YamlLoader) + assert isinstance(yaml_values, dict) + + valid_keys = [ + "backend", + "class_name", + "cpp_namespace", + "extra_headers", + "supported", + "autograd", + "full_codegen", + "non_native", + "ir_gen", + "symint", + ] + + backend = yaml_values.pop("backend", None) + assert backend is not None, 'You must provide a value for "backend"' + + class_name = yaml_values.pop("class_name", None) + + cpp_namespace = yaml_values.pop("cpp_namespace", None) + assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"' + + # Mostly just defaulting to false to stick with LazyTensor convention. + use_out_as_primary = yaml_values.pop("use_out_as_primary", False) + assert isinstance( + use_out_as_primary, bool + ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}" + + use_device_guard = yaml_values.pop("device_guard", False) + assert isinstance( + use_device_guard, bool + ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}" + + supported = yaml_values.pop("supported", []) + if supported is None: + supported = [] # Allow an empty list of supported ops + assert isinstance( + supported, list + ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' + + symint = yaml_values.pop("symint", []) + if symint is None: + symint = [] # Allow an empty list of symint ops + assert isinstance( + symint, list + ), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})' + symint_set = set(symint) + + supported_autograd = yaml_values.pop("autograd", []) + assert isinstance( + supported_autograd, list + ), f'expected "autograd" to be a list, but got: {supported_autograd}' + + # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py + full_codegen = yaml_values.pop("full_codegen", []) + supported.extend(full_codegen) + + # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py + yaml_values.pop("non_native", {}) + + # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py + yaml_values.pop("ir_gen", {}) + + assert ( + len(yaml_values.keys()) == 0 + ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ +Only the following keys are supported: {", ".join(valid_keys)}' + + def create_backend_index( + backend_ops: list[str], + symint_ops: set[str], + dispatch_key: DispatchKey, + *, + use_out_as_primary: bool, + use_device_guard: bool, + ) -> BackendIndex: + metadata: dict[OperatorName, BackendMetadata] = {} + for op in backend_ops: + op_name = OperatorName.parse(op) + assert ( + op_name in native_functions_map + ), f"Found an invalid operator name: {op_name}" + # See Note [External Backends Follow Dispatcher API] + kernel_name = dispatcher.name(native_functions_map[op_name].func) + if op in symint_ops: + kernel_name += "_symint" + # TODO: allow structured external backends later. + m = BackendMetadata( + kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace + ) + metadata[op_name] = m + return BackendIndex( + dispatch_key=dispatch_key, + use_out_as_primary=use_out_as_primary, + external=True, + device_guard=use_device_guard, + index=metadata, + ) + + backend_key: DispatchKey | None = None + if len(supported) > 0: + with context( + lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.' + ): + backend_key = DispatchKey.parse(backend) + + backend_idx = create_backend_index( + supported, + symint_set, + backend_key, + use_out_as_primary=use_out_as_primary, + use_device_guard=use_device_guard, + ) + assert backend_key not in backend_indices + backend_indices[backend_key] = backend_idx + + autograd_key: DispatchKey | None = None + if len(supported_autograd) > 0: + with context( + lambda: f'The "autograd" key was specified, which indicates that you would like to override \ +the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.' + ): + autograd_key = DispatchKey.parse(f"Autograd{backend}") + + autograd_idx = create_backend_index( + supported_autograd, + symint_set, + autograd_key, + use_out_as_primary=use_out_as_primary, + use_device_guard=use_device_guard, + ) + assert autograd_key not in backend_indices + backend_indices[autograd_key] = autograd_idx + + for g in grouped_native_functions: + if isinstance(g, NativeFunction): + forward_kernels = ( + [] + if backend_key is None + else [ + m + for m in [backend_indices[backend_key].get_kernel(g)] + if m is not None + ] + ) + backward_kernels = ( + [] + if autograd_key is None + else [ + m + for m in [backend_indices[autograd_key].get_kernel(g)] + if m is not None + ] + ) + else: + forward_kernels = ( + [] + if backend_key is None + else [ + m + for m in [ + backend_indices[backend_key].get_kernel(f) + for f in g.functions() + ] + if m is not None + ] + ) + backward_kernels = ( + [] + if autograd_key is None + else [ + m + for m in [ + backend_indices[autograd_key].get_kernel(f) + for f in g.functions() + ] + if m is not None + ] + ) + + forward_kernels = [f for f in forward_kernels if f is not None] + backward_kernels = [f for f in backward_kernels if f is not None] + assert ( + len(forward_kernels) == 0 or len(backward_kernels) == 0 + ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \ +autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \ +{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".' + + return ParsedExternalYaml( + backend_key, autograd_key, class_name, cpp_namespace, backend_indices + ) + + +def error_on_missing_kernels( + native_functions: Sequence[NativeFunction], + backend_indices: dict[DispatchKey, BackendIndex], + backend_key: DispatchKey, + autograd_key: DispatchKey | None, + class_name: str, + kernel_defn_file_path: str, + full_codegen: list[OperatorName] | None = None, +) -> None: + try: + with open(kernel_defn_file_path) as f: + backend_defns = f.read() + except OSError as e: + raise AssertionError( + f"Unable to read from the specified impl_path file: {kernel_defn_file_path}" + ) from e + + if full_codegen is None: + full_codegen = [] + + indices = [backend_indices[backend_key].index] + ( + [] if autograd_key is None else [backend_indices[autograd_key].index] + ) + # Quick mapping from each OperatorName used by the external backend + # to its backend kernel name + expected_backend_op_names: dict[OperatorName, str] = dict( + list( + concatMap( + lambda index: [ + (op_name, metadata.kernel) for op_name, metadata in index.items() + ], + indices, + ) + ) + ) + expected_backend_native_funcs: list[NativeFunction] = [ + f + for f in native_functions + if f.func.name in expected_backend_op_names.keys() + and f.func.name not in full_codegen + ] + expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( + list + ) + for native_f in expected_backend_native_funcs: + expected_backend_kernel_name_counts[ + expected_backend_op_names[native_f.func.name] + ].append(native_f) + + # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented. + # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel + # here, then we get a nicer error message. If we miss it, you get a linker error. + kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\(" + actual_backend_kernel_name_counts = Counter( + # A bit unwieldy (this could probably be moved into regex), + # but we don't want to include kernel names that come from function calls, + # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)". + # Easy check is to ignore any lines with colons before the class name. + [ + y + for (x, y) in re.findall(kernel_defn_regex, backend_defns) + if not x.endswith(":") + ] + ) + + missing_kernels_err_msg = "" + for expected_name, funcs in expected_backend_kernel_name_counts.items(): + expected_overload_count = len(funcs) + actual_overload_count = actual_backend_kernel_name_counts[expected_name] + if expected_overload_count != actual_overload_count: + + def create_decl(f: NativeFunction) -> str: + with native_function_manager(f): + return DispatcherSignature.from_schema(f.func).decl() + + expected_schemas_str = "\n".join([create_decl(f) for f in funcs]) + missing_kernels_err_msg += f""" +{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name, +but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are: +{expected_schemas_str} + +""" + assert missing_kernels_err_msg == "", missing_kernels_err_msg + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate backend stub files") + parser.add_argument( + "-s", + "--source-yaml", + "--source_yaml", + help="path to source yaml file containing operator external definitions", + ) + parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory") + parser.add_argument( + "--dry-run", "--dry_run", type=bool, default=False, help="output directory" + ) + parser.add_argument( + "--impl-path", + "--impl_path", + type=str, + default=None, + help="path to the source C++ file containing kernel definitions", + ) + options = parser.parse_args() + + run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path) + + +def gen_dispatchkey_nativefunc_headers( + fm: FileManager, + class_name: str, + cpp_namespace: str, + backend_indices: dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_dispatch_key: DispatchKey, + autograd_dispatch_key: DispatchKey | None, + backend_name: str = "", +) -> None: + assert class_name is not None + generated_comment = ( + "Autogenerated file by gen_backend_stubs.py. Do not edit directly!" + ) + + # Convert to a set first to remove duplicate kernel names. + # Backends are allowed to repeat kernel names; only generate the declaration once! + # Sort for deterministic output. + backend_declarations = sorted( + set( + concatMap( + lambda f: dest.compute_native_function_declaration( + f, backend_indices[backend_dispatch_key] + ), + grouped_native_functions, + ) + ) + ) + autograd_declarations = sorted( + set( + concatMap( + lambda f: [] + if autograd_dispatch_key is None + else dest.compute_native_function_declaration( + f, backend_indices[autograd_dispatch_key] + ), + grouped_native_functions, + ) + ) + ) + + ns_helper = NamespaceHelper(cpp_namespace) + fm.write_with_template( + f"{backend_dispatch_key}NativeFunctions.h", + "DispatchKeyNativeFunctions.h", + lambda: { + "generated_comment": generated_comment, + "namespace_prologue": ns_helper.prologue, + "class_name": class_name, + "namespace_epilogue": ns_helper.epilogue, + "dispatch_declarations": backend_declarations + autograd_declarations, + "BackendName": backend_name, + "DispatchKey": backend_dispatch_key, + }, + ) + + +def gen_dispatcher_registrations( + fm: FileManager, + output_dir: str, + class_name: str, + backend_indices: dict[DispatchKey, BackendIndex], + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], + backend_dispatch_key: DispatchKey, + dispatch_key: DispatchKey, + selector: SelectiveBuilder, + # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends + build_in_tree: bool = False, + per_operator_headers: bool = False, + backend_name: str = "", + eager_registration: bool = True, +) -> None: + headers = [ + f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", + ] + if build_in_tree: + external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers) + else: + external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) + + assert class_name is not None + backend_index = backend_indices[dispatch_key] + + dispatch_registrations_body = list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.REGISTRATION, + selector, + rocm=False, + symint=True, + class_method_name=f"{class_name}", + skip_dispatcher_op_registration=False, + ), + grouped_native_functions, + ) + ) + newline = "\n" + ns_helper = NamespaceHelper(namespace_str="at") + deferred_dispatch_registrations = "" + static_init_dispatch_registrations = "" + if eager_registration: + static_template = CodeTemplate( + """\ +TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { + $dispatch_registrations_body +};""" + ) + static_init_dispatch_registrations = static_template.substitute( + dispatch_key=dispatch_key, + dispatch_registrations_body=dispatch_registrations_body, + ) + else: + deferred_template = CodeTemplate( + """\ +TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions(); +TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { + static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); + $dispatch_registrations_body +}""" + ) + deferred_dispatch_registrations = deferred_template.substitute( + backend_name=backend_name, + dispatch_key=dispatch_key, + dispatch_registrations_body=dispatch_registrations_body, + ) + + fm.write_with_template( + f"Register{dispatch_key}.cpp", + "RegisterDispatchKey.cpp", + lambda: { + "extra_cuda_headers": "", + "external_backend_headers": external_backend_headers_str, + "ops_headers": "#include " + if not per_operator_headers + else "", + "DispatchKey": dispatch_key, + "dispatch_namespace": dispatch_key.lower(), + "dispatch_headers": dest.gen_registration_headers( + backend_index, per_operator_headers=per_operator_headers, rocm=False + ), + "dispatch_definitions": fm.substitute_with_template( + "RegisterDispatchDefinitions.ini", + lambda: { + "ns_prologue": ns_helper.prologue, + "ns_epilogue": ns_helper.epilogue, + "static_init_dispatch_registrations": static_init_dispatch_registrations, + "deferred_dispatch_registrations": deferred_dispatch_registrations, + "dispatch_helpers": dest.gen_registration_helpers(backend_index), + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": "", + "dispatch_anonymous_definitions": list( + concatMap( + dest.RegisterDispatchKey( + backend_index, + Target.ANONYMOUS_DEFINITION, + selector, + rocm=False, + symint=True, + class_method_name=f"{class_name}", + skip_dispatcher_op_registration=False, + ), + grouped_native_functions, + ) + ), + }, + ).split(newline), + }, + ) + + +def run( + source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None +) -> None: + # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py + pytorch_root = Path(__file__).parent.parent.absolute() + template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") + + def make_file_manager(install_dir: str) -> FileManager: + return FileManager( + install_dir=install_dir, template_dir=template_dir, dry_run=dry_run + ) + + fm = make_file_manager(output_dir) + + native_yaml_path = os.path.join( + pytorch_root, "aten/src/ATen/native/native_functions.yaml" + ) + tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml") + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + grouped_native_functions = get_grouped_native_functions(native_functions) + parsed_backend_yaml = parse_backend_yaml( + source_yaml, grouped_native_functions, backend_indices + ) + backend_key = parsed_backend_yaml.backend_key + autograd_key = parsed_backend_yaml.autograd_key + cpp_namespace = parsed_backend_yaml.cpp_namespace + class_name = parsed_backend_yaml.class_name + backend_indices = parsed_backend_yaml.backend_indices + + selector = SelectiveBuilder.get_nop_selector() + + if backend_key is None: + # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet. + return + + if class_name is None: + # class_name is an optional argument to backend yaml file. + # if specified it allows an external backend to override + # the name of the class that all generated kernel definitions live under. + # if not specified, its value is given as native_function_class_name. + class_name = backend_indices[backend_key].native_function_class_name() + assert class_name is not None + + if impl_path is not None: + error_on_missing_kernels( + native_functions, + backend_indices, + backend_key, + autograd_key, + class_name, + impl_path, + ) + + gen_dispatchkey_nativefunc_headers( + fm, + class_name, + cpp_namespace, + backend_indices, + grouped_native_functions, + backend_key, + autograd_key, + ) + + for dispatch_key in ( + [backend_key] if autograd_key is None else [backend_key, autograd_key] + ): + gen_dispatcher_registrations( + fm, + output_dir, + class_name, + backend_indices, + grouped_native_functions, + backend_key, + dispatch_key, + selector, + ) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_executorch.py b/.venv/lib/python3.11/site-packages/torchgen/gen_executorch.py new file mode 100644 index 0000000000000000000000000000000000000000..353302c7cd4a1638dc5b2f665fc61475e0479c6b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_executorch.py @@ -0,0 +1,998 @@ +from __future__ import annotations + +import argparse +import os +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING + +import yaml + +# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. +from torchgen import dest +from torchgen.api import cpp as aten_cpp +from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType +from torchgen.context import ( + method_with_native_function, + method_with_nested_native_function, + with_native_function_and_index, +) +from torchgen.executorch.api import et_cpp +from torchgen.executorch.api.custom_ops import ( + ComputeNativeFunctionStub, + gen_custom_ops_registration, +) +from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature +from torchgen.executorch.api.unboxing import Unboxing +from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml +from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct +from torchgen.gen import ( + get_custom_build_selector, + get_native_function_declarations, + get_native_function_declarations_from_ns_grouped_kernels, + get_native_function_schema_registrations, + LineLoader, + parse_native_yaml, +) +from torchgen.model import ( + BackendIndex, + BackendMetadata, + DEFAULT_KERNEL_NAMESPACE, + DispatchKey, + FunctionSchema, + Location, + NativeFunction, + NativeFunctionsGroup, + OperatorName, + Variant, +) +from torchgen.utils import ( + context, + FileManager, + make_file_manager, + mapMaybe, + NamespaceHelper, +) + + +if TYPE_CHECKING: + from torchgen.selective_build.selector import SelectiveBuilder + + +def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: + """ + A wrapper function to basically get `sig.decl(include_context=True)`. + For ATen kernel, the codegen has no idea about ET contextArg, so we + use this wrapper to add it. + """ + if isinstance(sig, ExecutorchCppSignature): + return sig.decl() + + returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type() + cpp_args = [a.decl() for a in sig.arguments()] + cpp_args_str = ", ".join([contextArg.decl()] + cpp_args) + sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})" + return sig_decl + + +def static_dispatch( + sig: CppSignature | ExecutorchCppSignature, + f: NativeFunction, + backend_indices: list[BackendIndex], +) -> str: + """ + For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one + native function exists, error out. A simplified version of register_dispatch_key.py + Arguments: + sig: A CppSignature for this native function we want to use. + f: NativeFunction to generate static dispatch. + backend_indices: All available backends. + Return: + C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);" + """ + if len(backend_indices) == 0 or f.manual_kernel_registration: + return "" + + backends = [b for b in backend_indices if b.has_kernel(f)] + static_block = None + if len(backends) == 1: + backend_metadata = backends[0].get_kernel(f) + if backend_metadata: + args = ", ".join(a.name for a in sig.arguments()) + # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch. + static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});" + else: + static_block = f""" +ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}."); + """ + return f""" +// {f.namespace}::{f.func} +TORCH_API inline {_sig_decl_wrapper(sig)} {{ + {static_block} +}} +""" + + +# Generates Functions.h, which provides the functional public C++ API, +# and the scaffolding to call into the dispatcher from these functions. +@dataclass(frozen=True) +class ComputeFunction: + static_dispatch_backend_indices: list[BackendIndex] + + selector: SelectiveBuilder + + use_aten_lib: bool + + is_custom_op: Callable[[NativeFunction], bool] + + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + is_method_variant = False + if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): + return None + + if Variant.function not in f.variants and Variant.method in f.variants: + is_method_variant = True + + # only valid remaining case is only function is in f.variants + elif not (Variant.function in f.variants and Variant.method not in f.variants): + raise Exception( # noqa: TRY002 + f"Can't handle native function {f.func} with the following variant specification {f.variants}." + ) + + sig: CppSignature | ExecutorchCppSignature = ( + CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ).most_faithful_signature() + if self.use_aten_lib + else ExecutorchCppSignature.from_native_function(f) + ) + if self.use_aten_lib and not self.is_custom_op(f): + comma = ", " + + if is_method_variant: + return f""" +// {f.namespace}::{f.func} +TORCH_API inline {_sig_decl_wrapper(sig)} {{ + return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])}); +}} +""" + else: + return f""" +// {f.namespace}::{f.func} +TORCH_API inline {_sig_decl_wrapper(sig)} {{ + return at::{sig.name()}({comma.join(e.name for e in sig.arguments())}); +}} +""" + + else: + return static_dispatch( + sig, + f, + backend_indices=self.static_dispatch_backend_indices, + ) + + +# Generates RegisterCodegenUnboxedKernels.cpp. +@dataclass(frozen=True) +class ComputeCodegenUnboxedKernels: + selector: SelectiveBuilder + + use_aten_lib: bool + + @method_with_nested_native_function + def __call__( + self, + unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], + ) -> str: + f: NativeFunction = unbox_kernel_entry[0] + kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0] + kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] + + op_name = f"{f.namespace}::{f.func.name}" + if not self.selector.is_root_operator(op_name): + return "" + + if not isinstance(kernel_key, list): + kernel_key = [kernel_key] + used_kernel_keys = self.selector.et_get_selected_kernels( + op_name, [k.to_native_string() for k in kernel_key] + ) + if not used_kernel_keys: + return "" + sig: CppSignature | ExecutorchCppSignature + argument_type_gen: Callable[..., NamedCType] + return_type_gen: Callable[..., CType] + if self.use_aten_lib: + sig = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=f.manual_cpp_binding + ).most_faithful_signature() + argument_type_gen = aten_cpp.argumenttype_type + return_type_gen = aten_cpp.returns_type + arguments = sig.arguments() + kernel_call = f"torch::executor::{f.namespace}::{sig.name()}" + else: + sig = ExecutorchCppSignature.from_native_function(f) + argument_type_gen = et_cpp.argumenttype_type + return_type_gen = et_cpp.returns_type + arguments = sig.arguments(include_context=False) + kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}" + # parse arguments into C++ code + binding_list, code_list = Unboxing( + argument_type_gen=argument_type_gen + ).convert_arguments(arguments) + + # for each C++ argument, generate the conversion code + code_connector = "\n\t" + arg_connector = ", " + + args_str = f"{arg_connector.join(e.name for e in binding_list)}" + event_tracer_output_logging = "" + output_ids = [] + + if len(f.func.returns) == 0: + if len(f.func.arguments.out) == 0: + raise Exception( # noqa: TRY002 + f"Can't handle native function {f.func} with no returns and no out yet." + ) + out = f.func.arguments.out[0] + return_assignment = f"""stack[{len(binding_list)}] = &{out.name};""" + ret_prefix = "" + output_ids = [len(binding_list)] + else: + if len(f.func.arguments.out) == 0: + return_assignment = ( + f"""*stack[{len(binding_list)}] = EValue(result_);""" + ) + ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = " + output_ids = [len(binding_list)] + else: + return_assignment = "" + ret_prefix = "" + output_ids = [ + len(binding_list) - (i + 1) + for i in reversed(range(len(f.func.arguments.out))) + ] + + for output_id in output_ids: + event_tracer_output_logging += ( + f"internal::event_tracer_log_evalue(" + f"context.internal_event_tracer(), " + f"*stack[{output_id}]);\n" + ) + + newline = "\n " + return "\n".join( + [ + f""" +Kernel( + "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''} + []({contextArg.defn()}, EValue** stack) {{ + {code_connector.join(code_list)} + + internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_{f.func.name}"); + EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); + {ret_prefix}{kernel_call}(context, {args_str}); + {event_tracer_output_logging} + {return_assignment} + }} +), +""" + for k in used_kernel_keys + ] + ) + + +def gen_unboxing( + *, + native_functions: Sequence[NativeFunction], + cpu_fm: FileManager, + selector: SelectiveBuilder, + use_aten_lib: bool, + kernel_index: ETKernelIndex, + manual_registration: bool, +) -> None: + # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) + def key_func( + item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]] + ) -> str: + return item[0].root_name + ":" + item[1][0].to_native_string() + + items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [ + (native_function, (kernel_key, metadata)) + for native_function in native_functions + for kernel_key, metadata in kernel_index.get_kernels(native_function).items() + ] + + header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"] + filename = ( + "RegisterKernels.cpp" + if manual_registration + else "RegisterCodegenUnboxedKernels.cpp" + ) + cpu_fm.write_sharded( + filename, + items, + key_fn=key_func, + env_callable=lambda unbox_kernel_entry: { + "unboxed_kernels": [ + ComputeCodegenUnboxedKernels(selector, use_aten_lib)(unbox_kernel_entry) + ], + "fn_header": header + if unbox_kernel_entry == items[0] + else [], # Only write header once + }, + num_shards=1, + sharded_keys={"unboxed_kernels", "fn_header"}, + ) + + +@with_native_function_and_index # type: ignore[arg-type] +def compute_native_function_declaration( + g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex +) -> list[str]: + assert isinstance(g, NativeFunction) + sig = ExecutorchCppSignature.from_native_function(f=g) + metadata_list = kernel_index.get_kernels(g).values() + if metadata_list is None: + return [] + + # for kernels in lean mode, we declare two versions, one with context and one without. + # In the end we will cleanup the unused one. + def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: + return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" + + return [ + gen_decl(metadata, include_context) + for include_context in [False, True] + for metadata in metadata_list + ] + + +def gen_functions_declarations( + *, + native_functions: Sequence[NativeFunction], + kernel_index: ETKernelIndex, + selector: SelectiveBuilder, + use_aten_lib: bool, + custom_ops_native_functions: Sequence[NativeFunction] | None = None, +) -> str: + """ + Generates namespace separated C++ function API inline declaration/definitions. + Native functions are grouped by namespaces and the generated code is wrapped inside + namespace blocks. + + E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol + in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when + the other `custom_2::foo.out` is available. + """ + + # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. + # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. + + backend_index = kernel_index._to_backend_index() + + ns_grouped_functions = defaultdict(list) + for native_function in native_functions: + ns_grouped_functions[native_function.namespace].append(native_function) + functions_declarations = "" + newline = "\n" + for namespace in ns_grouped_functions: + ns_helper = NamespaceHelper( + namespace_str=namespace, + entity_name="", + max_level=3, + ) + declarations = list( + mapMaybe( + ComputeFunction( + static_dispatch_backend_indices=[backend_index], + selector=selector, + use_aten_lib=use_aten_lib, + is_custom_op=lambda f: custom_ops_native_functions is not None + and f in custom_ops_native_functions, + ), + ns_grouped_functions[namespace], + ) + ) + functions_declarations += f""" +{ns_helper.prologue} +{newline.join(declarations)} +{ns_helper.epilogue} + """ + return functions_declarations + + +def get_ns_grouped_kernels( + *, + native_functions: Sequence[NativeFunction], + kernel_index: ETKernelIndex, + native_function_decl_gen: Callable[ + [ + NativeFunctionsGroup | NativeFunction, + ETKernelIndex, + ], + list[str], + ], +) -> dict[str, list[str]]: + ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) + for f in native_functions: + native_function_namespaces = set() + op_kernels = kernel_index.get_kernels(f) + for backend_metadata in op_kernels.values(): + if backend_metadata: + namespace = backend_metadata.cpp_namespace + native_function_namespaces.add(namespace) + else: + namespace = DEFAULT_KERNEL_NAMESPACE + assert ( + len(native_function_namespaces) <= 1 + ), f"Codegen only supports one namespace per operator, got {native_function_namespaces}" + ns_grouped_kernels[namespace].extend( + native_function_decl_gen(f, kernel_index) + ) + return ns_grouped_kernels + + +def gen_headers( + *, + native_functions: Sequence[NativeFunction], + gen_custom_ops_header: bool, + custom_ops_native_functions: Sequence[NativeFunction], + selector: SelectiveBuilder, + kernel_index: ETKernelIndex, + cpu_fm: FileManager, + use_aten_lib: bool, +) -> None: + """Generate headers. + + Args: + native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops. + gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h + custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops. + kernel_index (ETKernelIndex): kernel collection + cpu_fm (FileManager): file manager manages output stream + use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types. + """ + aten_headers = ["#include "] + backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()} + if gen_custom_ops_header: + cpu_fm.write_with_template( + "CustomOpsNativeFunctions.h", + "NativeFunctions.h", + lambda: { + "nativeFunctions_declarations": get_native_function_declarations( + grouped_native_functions=custom_ops_native_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ), + "headers": [ + "#include ", + "#include ", + ], + }, + ) + aten_headers.append('#include "CustomOpsNativeFunctions.h"') + cpu_fm.write( + "Functions.h", + lambda: { + "static_dispatch_extra_headers": aten_headers + if use_aten_lib + else ['#include "NativeFunctions.h"'], + "Functions_declarations": gen_functions_declarations( + native_functions=native_functions, + kernel_index=kernel_index, + selector=selector, + use_aten_lib=use_aten_lib, + custom_ops_native_functions=custom_ops_native_functions, + ), + }, + ) + cpu_fm.write( + "RegisterKernels.h", + lambda: { + "generated_comment": "@" + "generated by torchgen/gen_executorch.py", + }, + ) + headers = { + "headers": [ + "#include // at::Tensor etc.", + "#include ", + ], + } + if use_aten_lib: + headers["headers"].append("#include // TORCH_API") + cpu_fm.write( + "NativeFunctions.h", + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations( + grouped_native_functions=native_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ), + }, + **headers, + ), + ) + else: + ns_grouped_kernels = get_ns_grouped_kernels( + native_functions=native_functions, + kernel_index=kernel_index, + native_function_decl_gen=compute_native_function_declaration, # type: ignore[arg-type] + ) + cpu_fm.write( + "NativeFunctions.h", + lambda: dict( + { + "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( + ns_grouped_kernels=ns_grouped_kernels, + ), + }, + **headers, + ), + ) + + +def gen_custom_ops( + *, + native_functions: Sequence[NativeFunction], + selector: SelectiveBuilder, + kernel_index: ETKernelIndex, + cpu_fm: FileManager, + rocm: bool, +) -> None: + dispatch_key = DispatchKey.CPU + ( + anonymous_definition, + static_init_dispatch_registrations, + ) = gen_custom_ops_registration( + native_functions=native_functions, + selector=selector, + kernel_index=kernel_index, + rocm=rocm, + ) + cpu_fm.write_with_template( + f"Register{dispatch_key}CustomOps.cpp", + "RegisterDispatchKeyCustomOps.cpp", + lambda: { + "ops_headers": '#include "CustomOpsNativeFunctions.h"', + "DispatchKey": dispatch_key, + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": "", + "dispatch_anonymous_definitions": anonymous_definition, + "static_init_dispatch_registrations": static_init_dispatch_registrations, + }, + ) + cpu_fm.write_with_template( + f"Register{dispatch_key}Stub.cpp", + "RegisterDispatchKeyCustomOps.cpp", + lambda: { + "ops_headers": "", + "DispatchKey": dispatch_key, + "dispatch_namespace": dispatch_key.lower(), + "dispatch_namespaced_definitions": "", + "dispatch_anonymous_definitions": list( + mapMaybe(ComputeNativeFunctionStub(), native_functions) + ), + "static_init_dispatch_registrations": static_init_dispatch_registrations, + }, + ) + + ( + aten_schema_registrations, + schema_registrations, + ) = get_native_function_schema_registrations( + native_functions=native_functions, + schema_selector=selector, + ) + cpu_fm.write( + "RegisterSchema.cpp", + lambda: { + "schema_registrations": schema_registrations, + "aten_schema_registrations": aten_schema_registrations, + }, + ) + + +def translate_native_yaml( + tags_yaml_path: str, + aten_yaml_path: str, + native_yaml_path: str | None, + use_aten_lib: bool, + out_file: TextIO, +) -> None: + """Translates Executorch DSL dialect to use the same syntax as + native_functions.yaml. The major difference is that Executorch DSL dialect + supports "op" key, where it refers to the operator name in native_functions.yaml. + + For example, a functions.yaml may have the following entry: + + - op: add.out + ... + + It needs to be translated to the following: + + - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + ... + + We go in aten_yaml_path and find the operator schema for "add.out" and add it + to the original functions.yaml. We also add required field "variants", where for + Executorch it will always be "function". + + For ATen mode we don't have to do the translation because native_yaml_path is + the same as native_functions.yaml. + + Args: + tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. + It is not optional. + aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. + native_yaml_path: Path to a functions.yaml file to parse. + If the path does not exist in the filesystem, it is treated as an + empty file. If `custom_ops_yaml_path` exists, the contents of that + file are appended to the yaml input to be parsed. + use_aten_lib: We use this flag to determine if we want to generate native + functions. In ATen mode we should generate out= variants. + out_file: The IO object that we are writing into. + Returns: + None + """ + if use_aten_lib: + with open(aten_yaml_path) as aten_yaml: + out_file.writelines(aten_yaml.readlines()) + return + + native_functions, persisted_fields = parse_et_yaml( + aten_yaml_path, + tags_yaml_path, + None, + skip_native_fns_gen=False, + ) + + func_to_scoped_name: dict[FunctionSchema, str] = { + f.func: f"{f.namespace}::{f.func.name}" for f in native_functions + } + op_to_scoped_name: dict[OperatorName, str] = { + func.name: name for func, name in func_to_scoped_name.items() + } + + schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} + kernel_persist_dict: dict[str, dict[str, Any]] = { + op_to_scoped_name[op]: v for op, v in persisted_fields.items() + } + + if ( + not native_yaml_path + or not os.path.exists(native_yaml_path) + or os.stat(native_yaml_path).st_size == 0 + ): + return + with open(native_yaml_path) as native_yaml: + native_es = yaml.load(native_yaml, Loader=LineLoader) + if not native_es: + return + for e in native_es: + assert isinstance(e.get("__line__"), int), e + loc = Location(native_yaml_path, e.pop("__line__")) + with context(lambda: f"in {loc}:\n "): + if "variants" not in e: + e["variants"] = "function" + if "func" in e: + continue + assert isinstance(e.get("op"), str), e + opname = e.pop("op") + if "::" not in opname: + opname = "aten::" + opname + assert opname in schema_dict + e["func"] = schema_dict.get(opname) + + # Write out persisted kernel information + if opname in kernel_persist_dict: + for k, v in kernel_persist_dict[opname].items(): + e[k] = v + + yaml.dump(native_es, out_file, width=1000) + + +def parse_yaml( + path: str | None, + tags_yaml_path: str, + function_filter: Callable[[NativeFunction], bool], + skip_native_fns_gen: bool = False, +) -> tuple[ + list[NativeFunction], + dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex, +]: + if path and os.path.exists(path) and os.stat(path).st_size > 0: + with open(path) as f: + es = yaml.load(f, Loader=LineLoader) + + # Check for kernel index structure + kernel_index = ( + parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None + ) + + # Remove ET specific fields from entries for BC compatibility + for entry in es: + for field in ET_FIELDS: + entry.pop(field, None) + + parsed_yaml = parse_native_yaml( + path, + tags_yaml_path, + None, + skip_native_fns_gen=skip_native_fns_gen, + loaded_yaml=es, + ) + native_functions = list(filter(function_filter, parsed_yaml.native_functions)) + op_names = [f.func.name for f in native_functions] + + # (1) Return ETKernelIndex if kernel index is present + if kernel_index is not None: + filtered_index = { + op_name: kernel_mapping + for op_name, kernel_mapping in kernel_index.index.items() + if op_name in op_names + } + return native_functions, ETKernelIndex(index=filtered_index) + + # (2) Return BackendIndices if kernel index is absent + def map_index( + m: dict[OperatorName, BackendMetadata] + ) -> dict[OperatorName, BackendMetadata]: + return {op: m[op] for op in m if op in op_names} + + backend_indices = { + k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items() + } + + return native_functions, backend_indices + else: + return [], {} + + +def parse_yaml_files( + tags_yaml_path: str, + aten_yaml_path: str, + native_yaml_path: str | None, + custom_ops_yaml_path: str | None, + selector: SelectiveBuilder, + use_aten_lib: bool, +) -> tuple[ETParsedYaml, ETParsedYaml | None]: + """Parses functions.yaml and custom_ops.yaml files. + + Args: + tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. + It is not optional. + aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. + native_yaml_path: Path to a functions.yaml file to parse. + If the path does not exist in the filesystem, it is treated as an + empty file. If `custom_ops_yaml_path` exists, the contents of that + file are appended to the yaml input to be parsed. + custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If + the path does not exist in the filesystem, it is ignored. + selector: For selective build. + use_aten_lib: We use this flag to determine if we want to generate native + functions. In ATen mode we should generate out= variants. + Returns: + A tuple with two elements: + [0]: The parsed results of concatenating the contents of + `native_yaml_path` and `custom_ops_yaml_path`. + [1]: The parsed results of the contents of `custom_ops_yaml_path`, if + present. If not present, None. + """ + import tempfile + + # only include selected ops, this is because we want to avoid + def function_filter(f: NativeFunction) -> bool: + return selector.is_native_function_selected(f) + + with tempfile.TemporaryDirectory() as tmpdirname: + translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") + with open(translated_yaml_path, "w") as translated: + translate_native_yaml( + tags_yaml_path, + aten_yaml_path, + native_yaml_path, + use_aten_lib, + translated, + ) + + translated_functions, translated_indices = parse_yaml( + translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib + ) + custom_ops_functions, custom_ops_indices = parse_yaml( + custom_ops_yaml_path, tags_yaml_path, function_filter, True + ) + + # Convert BackendIndices to ETKernelIndex + if not isinstance(translated_indices, ETKernelIndex): + translated_indices = ETKernelIndex.from_backend_indices(translated_indices) + if not isinstance(custom_ops_indices, ETKernelIndex): + custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices) + + combined_functions = translated_functions + custom_ops_functions + combined_kernel_index = ETKernelIndex.merge_indices( + translated_indices, custom_ops_indices + ) + combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index) + custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices) + + return combined_yaml, custom_ops_parsed_yaml + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate operator source files") + # Although we don't refer to --source-path directly, make_file_manager() + # expects it to point to a directory that contains a templates/ subdirectory + # containing the file templates. + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for kernel templates", + ) + parser.add_argument( + "--functions-yaml-path", + "--functions_yaml_path", + help="path to the functions.yaml file to use. Optional, but at least " + "one of --functions-yaml-path and --custom-ops-yaml-path must be " + "specified.", + ) + parser.add_argument( + "--custom-ops-yaml-path", + "--custom_ops_yaml_path", + help="path to the custom_ops.yaml file to use. Optional, but at least " + "one of --functions-yaml-path and --custom-ops-yaml-path must be " + "specified.", + ) + parser.add_argument( + "--aten-yaml-path", + "--aten_yaml_path", + help="path to native_functions.yaml file.", + ) + # Note that make_file_manager() also looks at --install-dir. + parser.add_argument( + "-d", + "--install-dir", + "--install_dir", + help="output directory", + default="build/generated", + ) + parser.add_argument( + "-o", + "--output-dependencies", + help="output a list of dependencies into the given file and exit", + ) + # Although we don't refer to --dry-run directly, make_file_manager() looks + # for it. + parser.add_argument( + "--dry-run", + action="store_true", + help="run without writing any files (still updates outputs)", + ) + parser.add_argument( + "--static-dispatch-backend", + "--static_dispatch_backend", + nargs="*", + help="generate static dispatch code for the specific backend (if set)", + ) + parser.add_argument( + "--op-registration-whitelist", + "--op_registration_whitelist", + nargs="*", + help="filter op registrations by the whitelist (if set); " + "each item is `namespace`::`operator name` without overload name; " + "e.g.: aten::empty aten::conv2d ...", + ) + parser.add_argument( + "--op-selection-yaml-path", + "--op_selection_yaml_path", + help="Provide a path to the operator selection (for custom build) YAML " + "that contains the information about the set of selected operators " + "and their categories (training, ...). Each operator is either a " + "full operator name with overload or just a bare operator name. " + "The operator names also contain the namespace prefix (e.g. aten::)", + ) + parser.add_argument( + "--tags-path", + help="Path to tags.yaml. Required by yaml parsing in codegen system.", + ) + parser.add_argument( + "--rocm", + action="store_true", + help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", + ) + parser.add_argument( + "--use-aten-lib", + "--use_aten_lib", + action="store_true", + help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per " + "operator", + ) + parser.add_argument( + "--manual_registration", + "--manual-registration", + action="store_true", + help="a boolean flag to indicate whether we want to manually call" + "register_kernels() or rely on static init. ", + ) + parser.add_argument( + "--generate", + type=str, + nargs="*", + choices=["headers", "sources"], + default=["headers", "sources"], + help="Generate only a subset of files", + ) + options = parser.parse_args() + assert options.tags_path, "tags.yaml is required by codegen yaml parsing." + + selector = get_custom_build_selector( + options.op_registration_whitelist, + options.op_selection_yaml_path, + ) + + parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( + aten_yaml_path=options.aten_yaml_path, + tags_yaml_path=options.tags_path, + native_yaml_path=options.functions_yaml_path, + custom_ops_yaml_path=options.custom_ops_yaml_path, + selector=selector, + use_aten_lib=options.use_aten_lib, + ) + native_functions, kernel_index = ( + parsed_yaml.native_functions, + parsed_yaml.kernel_index, + ) + custom_ops_native_functions = ( + custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else [] + ) + + cpu_fm = make_file_manager(options=options) + + if "headers" in options.generate: + # generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system. + gen_headers( + native_functions=native_functions, + gen_custom_ops_header=options.custom_ops_yaml_path, + custom_ops_native_functions=custom_ops_native_functions, + selector=selector, + kernel_index=kernel_index, + cpu_fm=cpu_fm, + use_aten_lib=options.use_aten_lib, + ) + + if "sources" in options.generate: + gen_unboxing( + native_functions=native_functions, + cpu_fm=cpu_fm, + selector=selector, + use_aten_lib=options.use_aten_lib, + kernel_index=kernel_index, + manual_registration=options.manual_registration, + ) + if custom_ops_native_functions: + gen_custom_ops( + native_functions=custom_ops_native_functions, + selector=selector, + kernel_index=kernel_index, + cpu_fm=cpu_fm, + rocm=options.rocm, + ) + + if options.output_dependencies: + depfile_path = Path(options.output_dependencies).resolve() + depfile_name = depfile_path.name + depfile_stem = depfile_path.stem + + for fm, prefix in [ + (cpu_fm, ""), + ]: + varname = prefix + depfile_stem + path = depfile_path.parent / (prefix + depfile_name) + fm.write_outputs(varname, str(path)) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_functionalization_type.py b/.venv/lib/python3.11/site-packages/torchgen/gen_functionalization_type.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc9459eb5e6499ad74fd754be1cfde137c5ea65 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_functionalization_type.py @@ -0,0 +1,882 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, TYPE_CHECKING + +from torchgen.api import cpp, dispatcher +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + CType, + DispatcherSignature, + FunctionalizationLambda, + iTensorListRefT, + NativeSignature, + OptionalCType, + optionalSymIntArrayRefT, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorT, + VectorCType, + ViewInverseSignature, +) +from torchgen.context import ( + method_with_native_function, + native_function_manager, + with_native_function, + with_native_function_and, +) +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + Return, + SchemaKind, + SelfArgument, + TensorOptionsArguments, +) +from torchgen.native_function_generation import ( + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY, + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, +) +from torchgen.utils import dataclass_repr + + +if TYPE_CHECKING: + from torchgen.selective_build.selector import SelectiveBuilder + + +# Note: [Mutable Ops Not Using Functionalization] +# Ops in this list currently do not work with functionalization and should be fixed. +MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = ( + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY + + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT + + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY + + [ + # It will be BC-breaking, but we should fix their schemas. + # should be inplace? + "record_stream", + # See Note [resize_ in Functionalization] + "resize_", + "resize_as_", + # This function is used as for testing purposes only. + "_fill_mem_eff_dropout_mask_", + ] +) + +# This file contains codegen that relates to the functionalization pass. +# It includes: +# - gen_functionalization_definition +# Generates dispatcher kernel definitions for the functionalization pass. +# - gen_functionalization_registration +# Generates dispatcher kernel registrations for the functionalization pass. +# - gen_functionalization_view_inverse_declaration +# Generates a declaration for an "inverse view", for every view op +# that is needed in functionalization. We manually implement their definitions. +# - gen_composite_view_copy_kernel +# Generates view_copy() composite kernels for all view_copy operators. + + +# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction +# See Note [view_copy NativeFunctions] +@dataclass(frozen=True) +class GenCompositeViewCopyKernel: + backend_index: BackendIndex + + @method_with_native_function + def __call__(self, g: NativeFunctionsViewGroup) -> str | None: + if g.view_copy is None: + return None + elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy": + # If the view_copy doesn't match the standard naming scheme of _copy, + # assume it already exists and doesn't need to be generated. + # Example: slice_inverse() with the copy variant named slice_scatter() + # instead of slice_inverse_copy() + return None + + metadata = self.backend_index.get_kernel(g.view_copy) + assert metadata is not None + + # We can make view_copy work in more cases by using reshape() + # when a normal view call would ordinarily fail. + # This also makes LTC more efficient, because they don't need to include + # clone() calls in their graph (which is normally needed by reshape). + if str(g.view_copy.func.name) == "view_copy": + assert metadata.kernel == "view_copy_symint" + return """\ +at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) { + c10::SymDimVector shape = infer_size_dv(size, self.sym_numel()); + if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) { + return self.reshape_symint(size); + } else { + auto output = at::_ops::view::call(self, size); + return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous); + } +} +""" + # view_copy is a native signature, since we're generating an at::native:: kernel + # Functionalization always operates on symints though + view_copy_sig = NativeSignature( + g.view_copy.func, symint=metadata.supports_symint() + ) + + # view is a dispatcher signature, since we're calling into the at::_ops API + view_sig = DispatcherSignature(g.view.func) + + view_api_name = g.view.func.name.unambiguous_name() + exprs = ", ".join( + [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())] + ) + + # view ops today always return either a Tensor or a list of Tensors + assert len(g.view.func.returns) == 1 + assert g.view.func.returns[0].type == BaseType( + BaseTy.Tensor + ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) + + if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): + return_cloned_output = """\ + return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);""" + else: + # If the return type is a list, we need to clone each tensor in the list. + return_cloned_output = f"""\ + {view_copy_sig.returns_type().cpp_type()} out_clone; + for (const auto i : c10::irange(output.size())) {{ + out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous)); + }} + return out_clone;""" + + # The default generated composite kernel for {view}_copy() operators just clones + # the input tensor, and runs the underlying view on the clone. + return f""" +{view_copy_sig.defn(name=metadata.kernel)} {{ + auto output = at::_ops::{view_api_name}::call({exprs}); + {return_cloned_output} +}} +""" + + +def return_str(rets: tuple[Return, ...], names: list[str]) -> str: + assert len(rets) == len(names) + if len(rets) == 0: + return "" + elif len(rets) == 1: + return f"return {names[0]};" + else: + return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});" + + +def modifies_arguments(f: NativeFunction) -> bool: + return any( + a.annotation is not None and a.annotation.is_write + for a in f.func.arguments.flat_all + ) + + +def wrapper_name(func: FunctionSchema) -> str: + if func.name.overload_name: + return f"{cpp.name(func)}_{func.name.overload_name}" + else: + return cpp.name(func) + + +def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool: + return isinstance(a, SelfArgument) or ( + isinstance(a, Argument) and a.type.is_tensor_like() + ) + + +# We need to wrap / unwrap various arguments from the op in the functionalization kernels. +# Some op schemas include non-owning types though (like TensorList), +# and when we unwrap them we expect to get out an owning type!. +# We also return a lambda that tells you how to conver the non-owning type argument into the owning type. +def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]: + if t == BaseCType(tensorListT): + return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" + if t == BaseCType(iTensorListRefT): + return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}" + # There are technically other non-owning types out there (like IntArrayRef), + # but functionalization only actually cares about the ones involving tensors. + return t, lambda x: x + + +# unwraps all tensor-like arguments, returning: +# (1) a string containing all of the logic that does the unwrapping +# (2) a context, to be used by translate(), with all of the relevant bindings. +def unwrap_tensor_args( + sig: DispatcherSignature, *, is_view_op: bool +) -> tuple[str, list[Binding]]: + context: list[Binding] = [] + unwrapped_tensor_args: list[str] = [] + for arg in sig.arguments(): + if is_tensor_like(arg.argument): + # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. + unwrapped_name = f"{arg.name}_" + # For most ops, the functionalization needs to sync any pending updates on the input tensors + # before calling the operator, since otherwise the operator will act on stale data. + # For view ops though, we can continue to defer syncing until the tensor is used by + # a non-view operator. + maybe_sync_input = ( + "" if is_view_op else f"at::functionalization::impl::sync({arg.name});" + ) + unwrapped_type, conversion_fn = get_owning_type( + arg.nctype.remove_const_ref().type + ) + unwrapped_tensor_args.append( + f""" + {unwrapped_type.cpp_type()} {unwrapped_name}; + if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{ + {maybe_sync_input} + {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name}); + }} else {{ + {unwrapped_name} = {conversion_fn(arg.name)}; + }}""" + ) + context.append(arg.with_name(unwrapped_name)) + else: + # for non-tensor inputs, we want to pass them directly into the redispatch calls. + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + +# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns: +# (1) a string containing all of the logic that does the conversions. +# (2) a context, to be used by translate(), with all of the relevant bindings. +def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: + context: list[Binding] = [] + unwrapped_tensor_args: list[str] = [] + for arg in sig.arguments(): + if is_tensor_like(arg.argument): + # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. + a_ = arg.name + unwrapped_name = f"{arg.name}_meta" + unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});") + context.append(arg.with_name(unwrapped_name)) + else: + # for non-tensor inputs, we want to pass them directly into the redispatch calls. + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + +# The functionalization codegen currently expects view op schemas to have this form: +# foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose) +# foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_) +def assert_view_op_properties(func: FunctionSchema) -> None: + def is_alias(a: Argument) -> bool: + return a.annotation is not None + + args = func.arguments.flat_non_out + # The first argument is a tensor with an alias semantics (annotations) + assert len(args) > 0 and args[0].type == BaseType( + BaseTy.Tensor + ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor, +but found an argument of type {str(args[0].type)} for operator: {str(func.name)}.""" + # No other arguments have aliasing semantics + assert is_alias(args[0]) and not any( + is_alias(a) for a in args[1:] + ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output. +View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint""" + + +# One-liner expression for checking if an expression expr of type type has any +# symbolic values. +def emit_expr_has_symbolic_values(expr: str, type: CType) -> str: + if type == BaseCType(SymIntT): + return f"{expr}.is_symbolic()" + + if isinstance(type, OptionalCType): + innerexpr = f"(*{expr})" + return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false" + + if type == BaseCType(optionalSymIntArrayRefT): + return emit_expr_has_symbolic_values( + expr, OptionalCType(BaseCType(symIntArrayRefT)) + ) + + if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))): + argname = "arg" + lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT)) + return ( + "std::any_of(" + f"{expr}.begin(), {expr}.end(), " + f"[=](auto& {argname}) {{ return {lambda_check}; }})" + ) + + raise ValueError( + "unsupported type for has_symbolic_values check. " + "It should be a SymInt or a collection of those. " + f"Got: {type.cpp_type()}" + ) + + +# Detects whether any of the SymInt arguments are, in fact, symbolic values. +# This is used in the constructor of ViewMeta. +def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]: + name = "has_symbolic_inputs" + statements = [ + f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});" + for binding in sig.arguments() + if ( + isinstance(binding.argument, Argument) + and binding.argument.type.is_symint_like() + ) + ] + body = "\n ".join(statements) + return ( + name, + f""" + bool {name} = false; + {body}""", + ) + + +# Generates the Functionalization kernel for: +# - ops that create aliases (e.g. transpose()) +# - ops that are views AND mutations (e.g. transpose_()) +def emit_view_functionalization_body( + g: NativeFunctionsViewGroup, *, view_inplace: bool +) -> str: + if view_inplace: + # This op is both an inplace op AND a view op. + # See Note [Functionalization Pass - Inplace View Ops] for details. + # I currently have the view meta call into the out-of-place variant of the view, to avoid + # having to define an extra ~20 inplace {view}_inverse_ functions. + # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. + # I'm assuming that every inplace-view op has a corresponding out-of-place view op, + # with the same name but the trailing underscore removed. + # This is currently asserted at parse time in gen.py (see error_check_native_functions). + assert g.view_inplace is not None + f = g.view_inplace + else: + f = g.view + + assert g.view_copy is not None + with native_function_manager(f): + call_sig = DispatcherSignature.from_schema(g.view_copy.func) + + # the "view_copy" op name that the functionalization kernels need to call + api_name = g.view_copy.func.name.unambiguous_name() + # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) + # "no-op"ing in this context is just redispatching to the original op. + noop_api_name = f.func.name.unambiguous_name() + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + assert_view_op_properties(f.func) + view_tensor_name = dispatcher_sig.arguments()[0].name + + return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() + + unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( + dispatcher_sig, is_view_op=True + ) + view_redispatch_args = [ + e.expr + for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) + ] + + forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) + reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) + + # The meta API call should use the same arguments, but convert all tensors to meta tensors first. + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + meta_call_args = [ + e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) + ] + + ( + symbolic_inputs_varname, + symbolic_inputs_check, + ) = emit_has_symbolic_inputs(call_sig) + + if "inplace_view" in f.tags: + # See Note [Functionalization Pass - Inplace View Ops] for more details + return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ + if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ + // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. + {unwrap_tensor_args_str} + at::AutoDispatchSkipFunctionalize guard; + return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); + }} + auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); + auto inverse_return_mode = ( + reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse + : at::functionalization::InverseReturnMode::NeverView + ); + {symbolic_inputs_check} + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname} + ); + auto compute_reference_meta = + {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || + {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); + {return_type} reference_tensor_output; + if (compute_reference_meta) {{ + {meta_conversion_str} + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); + reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); + }} + // This function adds the above view meta to the current tensor and replays them off the base, + // mutating the size/stride info of the current FunctionalTensorWrapper. + // Because of this, we need to make sure to run the reference shape function above, + // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) + at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); + // See Note [Propagating strides in the functionalization pass] + // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely + // on a reference implementation here (instead of relying on the output from the forward lambda + // having the correct stride info) + if (compute_reference_meta) {{ + at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); + }} + return {view_tensor_name}; + }} +""" + + else: + is_multi_output_view = isinstance(f.func.returns[0].type, ListType) + return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ + {unwrap_tensor_args_str} + if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ + // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. + at::AutoDispatchSkipFunctionalize guard; + return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); + }} + auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); + auto inverse_return_mode = ( + reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse + : at::functionalization::InverseReturnMode::NeverView + ); + auto compute_reference_meta = + {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || + {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); + {return_type} reference_tensor_output; + if (compute_reference_meta) {{ + {meta_conversion_str} + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); + reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); + }} + {return_type} tmp_output; + {{ + at::AutoDispatchSkipFunctionalize guard; + if (reapply_views) {{ + tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); + }} else {{ + tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)}); + }} + }} + {symbolic_inputs_check} + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname}, + /*is_multi_output=*/{str(is_multi_output_view).lower()}, + /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} + ); + auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); + // See Note [Propagating strides in the functionalization pass] + if (compute_reference_meta) {{ + at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output); + }} + return out; + }} +""" + + +def maybe_create_output(f: NativeFunction, var_name: str) -> str: + if len(f.func.returns) == 0: + return "" + return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type() + return f"{return_type} {var_name} = " + + +# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function, +# this returns two lists of names, consisting of: +# - the names of returns corresponding to the original (mutable) inputs of the outer function +# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function +def get_mutable_redispatch_return_names( + f: NativeFunction, inner_return_var: str +) -> tuple[list[str], list[str]]: + aliased_returns = [] + non_aliased_returns = [] + for i, name in enumerate(f.func.aliased_return_names()): + if name is not None: + aliased_returns.append(name) + else: + non_aliased_returns.append( + inner_return_var + if len(f.func.returns) == 1 + else f"std::get<{i}>({inner_return_var})" + ) + return aliased_returns, non_aliased_returns + + +# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that: +# - For fresh outputs, we return the result of the redispatch (without wrapping outputs) +# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped) +def return_from_mutable_noop_redispatch( + f: NativeFunction, inner_return_var: str +) -> str: + aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var) + # Just get all of the return names, and immediately return them + return return_str(f.func.returns, aliased + non_aliased) + + +def wrap_propagate_mutations_and_return( + f: NativeFunction, functional_op: NativeFunction, inner_return_var: str +) -> str: + mutable_arg_names = f.func.arguments.mutable_arg_names() + ( + aliased_outer_rets, + non_aliased_outer_rets, + ) = get_mutable_redispatch_return_names(f, inner_return_var) + _, non_aliased_inner_rets = get_mutable_redispatch_return_names( + functional_op, inner_return_var + ) + # The outer function may have a mix of aliased and non-aliased outputs, + # But the inner functional op that we're transforming to should only have non-aliased outputs + assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len( + non_aliased_inner_rets + ) + + # First, take all of the newly created outputs from the inner call and wrap them into functional tensors + updates = [] + non_aliased_wrapped_ret_names = [] + for i, inner_ret in enumerate( + non_aliased_inner_rets[: len(non_aliased_outer_rets)] + ): + ret_name = f"output_{i}" + updates.append( + f"""\ + auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});""" + ) + non_aliased_wrapped_ret_names.append(ret_name) + + # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs, + # and propagate the mutations + for outer_arg, inner_ret in zip( + mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :] + ): + updates.append( + f"""\ + auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg}); + at::functionalization::impl::replace_({outer_arg}, {inner_ret}); + at::functionalization::impl::commit_update({outer_arg}); + at::functionalization::impl::sync({outer_arg}); + auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg}); + at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);""" + ) + + # Finally, we return: + # - Any mutable arguments that also returns + # - Any immutable returns that were created wrapping the output from the inner call + returns_str = return_str( + f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names + ) + updates_str = "\n".join(updates) + return f"""\ +{updates_str} + {returns_str}""" + + +# Generates the Functionalization kernel for: +# - mutation ops (inplace and out= ops) +@with_native_function_and +def emit_inplace_functionalization_body( + f: NativeFunction, g: NativeFunctionsGroup +) -> str: + # mutation case + assert modifies_arguments(f) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + + unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( + dispatcher_sig, is_view_op=False + ) + + mutated_names = [ + a.name + for a in f.func.arguments.flat_all + if a.type.is_tensor_like() and a.annotation is not None + ] + non_mutated_names = [ + a.name + for a in f.func.arguments.flat_all + if a.type.is_tensor_like() and a.annotation is None + ] + non_mutated_tensor_names = [ + a.name + for a in f.func.arguments.flat_all + if a.type == BaseType(BaseTy.Tensor) and a.annotation is None + ] + # all mutable inputs must be functional tensors in order to participate in functionalization + check_all_mutated_args_are_functional = " && ".join( + ["true"] + + [ + f"at::functionalization::impl::isFunctionalTensor({a})" + for a in mutated_names + ] + ) + check_any_non_mutated_args_are_functional = " || ".join( + ["false"] + + [ + f"at::functionalization::impl::isFunctionalTensor({a})" + for a in non_mutated_names + ] + ) + + check_any_non_mutated_tensors_are_xla = " || ".join( + ["false"] + + [ + f"{a}.device().type() == c10::DeviceType::XLA" + for a in non_mutated_tensor_names + ] + ) + # These are used in the cases where we don't functionalize and redispatch to the inplace op + # case 1: we hit an inplace op that doesn't have an out-of-place equivalent + # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) + inplace_exprs = [ + e.expr + for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) + ] + + # call the out-of-place variant of the op + return_type = ( + dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type() + ) + functional_sig = DispatcherSignature.from_schema(g.functional.func) + functional_exprs = [ + e.expr + for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) + ] + + if f.func.is_out_fn(): + mutable_input_post_processing = "\n".join( + [ + f""" + at::functionalization::impl::replace_( + {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); + at::functionalization::impl::commit_update({a.name});""" + for (i, a) in enumerate(f.func.arguments.out) + if a.annotation and a.annotation.is_write and a.type.is_tensor_like() + ] + ) + else: + mutable_input_post_processing = "\n".join( + [ + f""" + at::functionalization::impl::replace_({a.name}, tmp_output); + at::functionalization::impl::commit_update({a.name});""" + for a in f.func.arguments.flat_all + if a.annotation and a.annotation.is_write and a.type.is_tensor_like() + ] + ) + + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + # We don't want to run the inplace meta func for ops like .set_(), because: + # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(), + # where broadcasting will work for the out-of-place case but should fail on the inplace call + # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument + # into a meta storage + any_storage_args = any( + a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all + ) + + return f""" + {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ + if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{ + // Before converting the mutable op to its functional variant, run meta tensors through the original op. + // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants. + // (We can only do this for inplace ops today though, because they technically all support meta tensors). + {meta_conversion_str} + at::AutoDispatchSkipFunctionalize func_guard; + c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); + at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)}); + }} + {unwrap_tensor_args_str} + if (!({check_all_mutated_args_are_functional})) {{ + // We want to disable this check if there are any XLA tensors. + // cpu_tensor.copy_(xla_tensor) is valid code. + if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{ + // case 1: trying to mutate a non functional tensor with a functional tensor is an error + TORCH_INTERNAL_ASSERT(false, + "mutating a non-functional tensor with a functional tensor is not allowed.", + " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); + }} else {{ + // case 2: arguments are not functional tensors, so we no-op and redispatch. + at::AutoDispatchSkipFunctionalize guard; + {maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); + {return_from_mutable_noop_redispatch(f, 'tmp_output')} + }} + }} else {{ + {return_type} tmp_output; + {{ + at::AutoDispatchSkipFunctionalize guard; + tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)}); + }} + {wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')} + }} + }}""" + + +# The below functions generate RegisterFunctionalization.cpp +# These files provide the kernels that run the functionalization pass, which can be opted into +# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch). + + +# See Note [Functionalization Pass: View Inverses]. +def gen_functionalization_view_inverse_declaration( + selector: SelectiveBuilder, g: NativeFunctionsViewGroup +) -> str | None: + # For every (non-composite) view op, we need a corresponding "inverse view" function. + # This generates the declarations so we get a good compiler error when someone adds a new view. + @with_native_function + def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None: + if g.view.has_composite_implicit_autograd_kernel: + return None + view_inverse_sig = ViewInverseSignature(g) + return view_inverse_sig.decl() + + return emit_decl_helper(g) + + +def gen_functionalization_registration( + selector: SelectiveBuilder, + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + composite_implicit_autograd_index: BackendIndex, +) -> list[str]: + @with_native_function + def emit_registration_helper(f: NativeFunction) -> str: + assert not f.has_composite_implicit_autograd_kernel + registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" + return f'm.impl("{f.func.name}", {registration_str});' + + # Don't generate kernels in mobile build + if not selector.include_all_operators: + return [] + + if isinstance(g, NativeFunctionsViewGroup): + # functionalization needs to register kernels for view + view_inplace ops + # See Note [Functionalization <> torch.Tensor constructor] + if str(g.view.func.name) == "lift_fresh": + return [] + view_str = [] + if not g.view.has_composite_implicit_autograd_kernel: + view_str.append(emit_registration_helper(g.view)) + if ( + g.view_inplace is not None + and not g.view_inplace.has_composite_implicit_autograd_kernel + ): + assert g.view_inplace.is_view_op + view_str.append(emit_registration_helper(g.view_inplace)) + return view_str + + elif isinstance(g, NativeFunctionsGroup): + # Gets a hand-written functionalization kernel + if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor": + fns = [] + else: + fns = list(g.functions()) + else: + if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION: + return [] + fns = [g] + + registrations = [] + for f in fns: + if f.has_composite_implicit_autograd_kernel: + continue + if str(f.func.name) == "lift": + # See Note [Functionalization <> torch.Tensor constructor] + return [] + if str(f.func.name) == "resize_": + # See Note [resize_ in Functionalization] + return [] + if str(f.func.name.name) != "set_": + assert not f.is_view_op + # functionalization needs to generate and register kernels for inplace ops. + # We *also* need to directly register CompositeImplicitAUtograd kernels + # so that they decompose properly before functioanlization. + if modifies_arguments(f): + registrations.append(emit_registration_helper(f)) + return registrations + + +def gen_functionalization_definition( + selector: SelectiveBuilder, + # Note: Ideally this code should never have to look at NativeFunction + # (and instead only need to operate on grouped NativeFunctions). + # The only reason currently is because we need to emit direct dispatch registrations + # For CompositeImplicitAutograd operators, which are potentially ungrouped. + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, +) -> list[str]: + # Don't generate kernels in mobile build + if not selector.include_all_operators: + return [] + + if isinstance(g, NativeFunctionsViewGroup): + # Case 1: emit view -> view_copy kernels for the functionalization pass + view_defs = [] + if not g.composite: + # invariant: NativeFunctionsViewGroup's always have a view_copy operator + # if the view is not composite (implicit autograd) + assert g.view_copy is not None, dataclass_repr(g, indent=1) + view_defs.append(emit_view_functionalization_body(g, view_inplace=False)) + if g.view_inplace is not None: + view_defs.append(emit_view_functionalization_body(g, view_inplace=True)) + return view_defs + elif isinstance(g, NativeFunction): + # Invariant: all mutable operators that we need to handle in functionalization + # should have been properly grouped up. + # TODO: The below ops all have "problematic" schemas that prevent them from + # getting functionalized. Instead of bending over backwards to get things to work, + # I think we should either: + # (1) fix their schemas (BC-breaking) + # (2) hand-write their functionalization kernels + if ( + str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION + and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION + ): + assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g) + return [] + else: + # Case 2: emit inplace -> out-of-place kernels for the functionalization pass + mutation_defs = [] + mutation_defs.append(emit_inplace_functionalization_body(g.out, g)) + if g.inplace is not None: + mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g)) + if g.mutable is not None: + mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g)) + return mutation_defs + return [] diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_lazy_tensor.py b/.venv/lib/python3.11/site-packages/torchgen/gen_lazy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..884f645cc4b5b13e9a9c05b6e0e7dcc19bb59ec9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_lazy_tensor.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import argparse +import os +from collections import namedtuple +from pathlib import Path +from typing import Any, Callable, Iterable, Iterator, Sequence + +import yaml + +import torchgen.dest as dest +from torchgen.api.lazy import setValueT +from torchgen.api.types import BaseCppType +from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR +from torchgen.gen import get_grouped_native_functions, parse_native_yaml +from torchgen.gen_backend_stubs import ( + error_on_missing_kernels, + gen_dispatcher_registrations, + gen_dispatchkey_nativefunc_headers, + parse_backend_yaml, +) +from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import FileManager, NamespaceHelper +from torchgen.yaml_utils import YamlLoader + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Lazy Tensor Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# Overview +# ~~~~~~~~ +# +# This codegen script builds on existing data models and helpers used +# by all ATen backends, and adds new functionality specific to lazy +# tensor backends. +# +# Inputs: +# - _native_functions.yaml: controls which operators are +# supported by the backend. +# +# Outputs: +# (for all backends) +# Ir.h defines Lazy IR classes to be constructed during tracing +# - opt-in: also generate 'lowering' methods for the TorchScript backend only +# NativeFunctions.cpp defines implementations of native functions which perform lazy tracing +# - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations +# NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen' +# ops +# +# Register.cpp registers all op implementations with the dispatcher +# RegisterAutograd.cpp registers all autograd implementations with the dispatcher +# +# Validation Helpers: +# - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or +# implementations in torch/csrc/lazy/core/shape_inference.* +# - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend +# (non-codegen) implementation file +# +# +# About the Data Model +# ~~~~~~~~~~~~~~~~~~~~ +# +# Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators +# we care about. In this case, the _native_functions yaml defines a subset of the core operators +# (defined in more detail in the main native_functions.yaml), which will be supported by your backend. +# Backends can list ops in two categories: +# - `supported` ops require hand-implementations but still get codegenned declarations and registrations +# - `full_codegen` ops get implementations (and IR classes) generated too +# +# Each native function is modeled as an object with a schema, and each schema has objects representing their +# arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor +# backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference +# types (stringref) with actual string objects, and this is done by manipulating the data model objects. +# - see api/lazy.py for the lazy data model +# +# Once the data model is set up, the rest of this script processes a number of templates for output CPP file +# and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These +# helpers mostly iterate over functions and their arguments, outputting different c++ snippets. +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. +# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen) +ParsedExternalYaml = namedtuple( + "ParsedExternalYaml", + ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"], +) + + +def parse_native_functions_keys( + backend_yaml_path: str, + grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], +) -> tuple[list[OperatorName], list[Any], list[OperatorName]]: + with open(backend_yaml_path) as f: + yaml_values = yaml.load(f, Loader=YamlLoader) + assert isinstance(yaml_values, dict) + + full_codegen = yaml_values.pop("full_codegen", []) + non_native = yaml_values.pop("non_native", []) + ir_gen = yaml_values.pop("ir_gen", []) + assert isinstance(full_codegen, list) + assert isinstance(non_native, list) + assert isinstance(ir_gen, list) + full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen] + ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen] + return full_codegen_opnames, non_native, ir_gen_opnames + + +def validate_shape_inference_header( + shape_inference_hdr: str, expected_shape_infr_decls: list[str] +) -> None: + try: + with open(shape_inference_hdr) as f: + shape_infr_decls = f.read() + shape_infr_decl_lines = set(shape_infr_decls.split("\n")) + except OSError as e: + raise AssertionError( + f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}" + ) from e + + # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired. + + missing_decls = [ + decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines + ] + if missing_decls: + raise Exception( # noqa: TRY002 + f"""Missing shape inference function.\n +Please add declare this function in {shape_inference_hdr}:\n +and implement it in the corresponding shape_inference.cpp file.\n +{os.linesep.join(missing_decls)}""" + ) + + +# Some helper functions for the codegen. +def get_ltc_helper_fns() -> str: + return """\ +at::Tensor to_meta(const at::Tensor& tensor) { + // undefined tensors can't be converted to the meta device, since they don't have sizes/strides + if (!tensor.defined()) return tensor; + auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \ +/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \ +/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt); + // needs to handle wrapped numbers, so dtype promotion works properly. + if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + out.unsafeGetTensorImpl()->set_wrapped_number(true); + } + return out; +} +std::optional to_meta(const std::optional& tensor) { + if (tensor.has_value()) { + return to_meta(*tensor); + } + return std::nullopt; +} + +std::vector to_meta(at::ITensorListRef t_list) { + std::vector outs; + outs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outs.push_back(to_meta(tensor)); + } + return outs; +} +""" + + +class default_args: + node_base: str = "Node" + node_base_hdr: str | None = None + shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h" + tensor_class: str = "torch::lazy::LazyTensor" + tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h" + lazy_ir_generator: type[GenLazyIR] = GenLazyIR + native_func_definition_generator: type[ + GenLazyNativeFuncDefinition + ] = GenLazyNativeFuncDefinition + backend_name: str = "TorchScript" + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files") + parser.add_argument( + "-s", + "--source-yaml", + "--source_yaml", + help="path to source yaml file containing operator external definitions", + ) + parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory") + parser.add_argument( + "--dry-run", "--dry_run", type=bool, default=False, help="output directory" + ) + parser.add_argument( + "--impl-path", + "--impl_path", + type=str, + default=None, + help="path to the source C++ file containing kernel definitions", + ) + parser.add_argument( + "--gen-ts-lowerings", + "--gen_ts_lowerings", + action="store_true", + help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions", + ) + parser.add_argument( + "--node-base", + "--node_base", + type=str, + default=default_args.node_base, + help="Name of backend specific custom Lazy IR Node base class", + ) + parser.add_argument( + "--node-base-hdr", + "--node_base_hdr", + type=str, + default=default_args.node_base_hdr, + help="Path to header file defining custom Lazy IR Node base class", + ) + parser.add_argument( + "--shape-inference-hdr", + "--shape_inference_hdr", + type=str, + default=default_args.shape_inference_hdr, + help="Path to header file defining custom Lazy shape inference functions", + ) + parser.add_argument( + "--tensor-class", + "--tensor_class", + type=str, + default=default_args.tensor_class, + help="Name of backend specific custom Lazy Tensor class", + ) + parser.add_argument( + "--tensor-class-hdr", + "--tensor_class_hdr", + type=str, + default=default_args.tensor_class_hdr, + help="Path to header file defining custom Lazy Tensor class", + ) + parser.add_argument( + "--backend-name", + "--backend_name", + type=str, + default=default_args.backend_name, + help="Name of the backend to generate", + ) + options = parser.parse_args() + + # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py + torch_root = Path(__file__).parent.parent.parent.absolute() + aten_path = str(torch_root / "aten" / "src" / "ATen") + lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator + if options.gen_ts_lowerings: + lazy_ir_generator = GenTSLazyIR + native_func_definition_generator: type[ + GenLazyNativeFuncDefinition + ] = default_args.native_func_definition_generator + + run_gen_lazy_tensor( + aten_path, + options.source_yaml, + options.output_dir, + options.dry_run, + options.impl_path, + options.node_base, + options.node_base_hdr, + options.tensor_class, + options.tensor_class_hdr, + options.shape_inference_hdr, + lazy_ir_generator, + native_func_definition_generator, + options.backend_name, + ) + + +def run_gen_lazy_tensor( + aten_path: str, + source_yaml: str, + output_dir: str, + dry_run: bool, + impl_path: str | None, + node_base: str = default_args.node_base, + node_base_hdr: str | None = default_args.node_base_hdr, + tensor_class: str = default_args.tensor_class, + tensor_class_hdr: str = default_args.tensor_class_hdr, + shape_inference_hdr: str = default_args.shape_inference_hdr, + lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator, + native_func_definition_generator: type[ + GenLazyNativeFuncDefinition + ] = default_args.native_func_definition_generator, + # build_in_tree is true for TS backend and affects include paths + build_in_tree: bool = False, + # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used + # it must match how ATen was built + per_operator_headers: bool = False, + backend_name: str = default_args.backend_name, + gen_forced_fallback_code: bool = False, + use_lazy_shape: bool = True, + # the following arguments are temporary customization points for xla backend migration. + # do not rely on them otherwise, they should be removed once migration is complete + backend_namespace: str = "torch::lazy", + get_tensorlist: str = "GetTensorList", + get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber", + try_get_tensor: str = "TryGetLtcTensor", + metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")', + create_tensor: str = "LazyTensor::Create", + create_from_first_tensor: bool = False, + create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor", + tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors", + lazy_value_class: str = "torch::lazy::Value", + lazy_tensor_ptr: str = "LazyTensorPtr", + get_device_fn: str = "torch::lazy::GetBackendDevice", +) -> None: + lv_tokens = lazy_value_class.split("::") + lv_class = lv_tokens[-1] + lv_ns = "::".join(lv_tokens[:-1]) + setValueT(BaseCppType(lv_ns, lv_class)) + template_dir = os.path.join(aten_path, "templates") + + def make_file_manager(install_dir: str) -> FileManager: + return FileManager( + install_dir=install_dir, template_dir=template_dir, dry_run=dry_run + ) + + fm = make_file_manager(output_dir) + + native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(aten_path, "native/tags.yaml") + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + grouped_native_functions = get_grouped_native_functions(native_functions) + + def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str: + """ + We sort the native function because of the note in concat_map_codegen. + TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly. + """ + func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func + return str(func.name.name) + + grouped_native_functions = sorted( + grouped_native_functions, key=sort_native_function + ) + + parsed_backend_yaml = parse_backend_yaml( + source_yaml, grouped_native_functions, backend_indices + ) + backend_key = parsed_backend_yaml.backend_key + autograd_key = parsed_backend_yaml.autograd_key + cpp_namespace = parsed_backend_yaml.cpp_namespace + backend_indices = parsed_backend_yaml.backend_indices + # the following 3 keys are all processed differently + # for full_codegen, we generate IR, kernels, etc + # for ir_gen, we generate only IR + # non_native is used to register kernels not declared in + # native_functions.yaml + full_codegen, non_native, ir_gen = parse_native_functions_keys( + source_yaml, grouped_native_functions + ) + + def concat_map_codegen( + func: Callable[[NativeFunction], Sequence[str]], + xs: Iterable[NativeFunctionsGroup | NativeFunction], + ops_list: list[OperatorName] = full_codegen, + ) -> Iterator[str]: + """ + We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we + only code-gen additional entries for the inplace variant for the native functions. + """ + + for x in xs: + fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x] + for f in fs: + if f.func.name in ops_list: + yield from func(f) + + selector = SelectiveBuilder.get_nop_selector() + + assert backend_key is not None + class_name = backend_indices[backend_key].native_function_class_name() + + if impl_path is not None: + error_on_missing_kernels( + native_functions, + backend_indices, + backend_key, + autograd_key, + class_name, + impl_path, + full_codegen, + ) + + """ Validate Shape Inference Definitions + + Generated lazy native functions all perform shape inference, by first using a meta:: kernel + if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator + knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature, + so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev + to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides + the expected signature which can be copy-pasted into shape_inference.h. + + compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported + to structured kernels. + + See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information. + """ + if shape_inference_hdr is not None: + expected_shape_infr_decls = list( + concat_map_codegen( + dest.GenLazyShapeInferenceDefinition( + backend_indices[backend_key], tensor_class + ), + grouped_native_functions, + ) + ) + + validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls) + assert class_name is not None + + # Generate nativefunction declarations + # Note, eager registrations is set to False for the lazy TS backend as another LTC backend + # may want to register their own lazy kernels instead of registering the TS ones. + # The registration will lazily happen when init_ts_backend is called. + gen_dispatchkey_nativefunc_headers( + fm, + class_name, + cpp_namespace, + backend_indices, + grouped_native_functions, + backend_key, + autograd_key, + backend_name, + ) + + # Generate Dispatcher registrations which hook up the nativefunctions + for dispatch_key in ( + [backend_key] if autograd_key is None else [backend_key, autograd_key] + ): + gen_dispatcher_registrations( + fm, + output_dir, + class_name, + backend_indices, + grouped_native_functions, + backend_key, + dispatch_key, + selector, + build_in_tree=build_in_tree, + per_operator_headers=per_operator_headers, + backend_name=backend_name, + eager_registration=False, + ) + + # Generate native function impls that build IR nodes + ns_helper = NamespaceHelper(cpp_namespace) + fm.write_with_template( + f"{backend_key}NativeFunctions.cpp", + "DispatchKeyNativeFunctions.cpp", + lambda: { + "includes": [ + f"#include <{path}>" + for path in [ + tensor_class_hdr, + shape_inference_hdr, + "ATen/Functions.h", + "ATen/native/TensorConversions.h", + "ATen/NativeFunctions.h", + "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h", + "ATen/MetaFunctions.h", + "ATen/Operators.h", + "ATen/native/CPUFallback.h", + "torch/csrc/lazy/core/ir_builder.h", + "torch/csrc/lazy/core/lazy_graph_executor.h", + "torch/csrc/lazy/core/metrics.h", + "torch/csrc/lazy/core/shape.h", + f"{output_dir}/{backend_key}NativeFunctions.h", + f"{output_dir}/LazyIr.h", + ] + + ( + ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"] + if gen_forced_fallback_code + else [] + ) + ], + "helper_fns": get_ltc_helper_fns(), + "native_functions_include": "", + "namespace_prologue": ns_helper.prologue, + "namespace_epilogue": ns_helper.epilogue, + "native_function_definitions": list( + concat_map_codegen( + native_func_definition_generator( + f"{backend_key}NativeFunctions", + backend_indices[backend_key], + tensor_class, + gen_forced_fallback_code, + backend_namespace, + get_tensorlist, + get_tensor_or_wrap_number, + try_get_tensor, + metrics_counter, + create_tensor, + create_from_first_tensor, + create_aten_from_ltc_tensor, + tuple_aten_from_ltc_tensors, + lazy_tensor_ptr, + get_device_fn, + ), + grouped_native_functions, + ) + ), + }, + ) + # Generate IR node classes + lazy_ir_obj = lazy_ir_generator( + backend_indices[backend_key], backend_name, node_base, use_lazy_shape + ) + + fm.write_with_template( + "LazyIr.h", + "LazyIr.h", + lambda: { + "lazy_ir_sysinc": [ + f"#include <{path}>" + for path in [ + "ATen/core/Formatting.h", + "c10/core/ScalarType.h", + "torch/csrc/lazy/core/hash.h", + "torch/csrc/lazy/core/ir.h", + "torch/csrc/lazy/core/shape.h", + "optional", + "vector", + ] + ], + "lazy_ir_inc": [f'#include "{node_base_hdr}"'] + if node_base_hdr is not None + else [], + "ir_declarations": list( + concat_map_codegen( + lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen + ) + ), + "namespace_prologue": ns_helper.prologue, + "namespace_epilogue": ns_helper.epilogue, + }, + ) + + # Generate Non Native IR Node classes + fm.write_with_template( + "LazyNonNativeIr.h", + "LazyNonNativeIr.h", + lambda: { + "lazy_non_native_ir_inc": [ + f"#include <{path}>" + for path in [ + "torch/csrc/lazy/core/ir.h", + "torch/csrc/lazy/core/ir_builder.h", + "torch/csrc/lazy/core/internal_ops/ltc_ops.h", + "torch/csrc/lazy/core/shape_inference.h", + ] + + ([node_base_hdr] if node_base_hdr else []) + if path + ], + "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes( + non_native, lazy_ir_obj + ), + "namespace_prologue": ns_helper.prologue, + "namespace_epilogue": ns_helper.epilogue, + }, + ) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_schema_utils.py b/.venv/lib/python3.11/site-packages/torchgen/gen_schema_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..975fbee6df989e550ebbd8b7de61c0eb8c547318 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_schema_utils.py @@ -0,0 +1,97 @@ +from typing import Any, Optional, Tuple, Union + +from torchgen.model import ( + Annotation, + Argument, + Arguments, + BaseOperatorName, + BaseTy, + BaseType, + CustomClassType, + FunctionSchema, + ListType, + OperatorName, + Return, +) + + +# Note: These aren't actually used in torchgen, they're some utilities for generating a schema +# from real arguments. For example, this is used to generate HigherOrderOperators' schema since +# their schemas can vary for different instances of the same HOP. + + +class TypeGen: + convert_to_base_ty = { + int: BaseTy.int, + float: BaseTy.float, + str: BaseTy.str, + bool: BaseTy.bool, + } + + @staticmethod + def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: + import torch + + if isinstance(obj, torch.fx.GraphModule): + return BaseType(BaseTy.GraphModule) + elif isinstance(obj, torch.Tensor): + return BaseType(BaseTy.Tensor) + elif isinstance(obj, torch.SymInt): + return BaseType(BaseTy.SymInt) + elif isinstance(obj, torch.SymBool): + return BaseType(BaseTy.SymBool) + elif isinstance(obj, torch.ScriptObject): + return CustomClassType(obj._type().name()) # type: ignore[attr-defined] + elif isinstance(obj, (list, tuple)): + assert len(obj) > 0 + all_base_tys = [TypeGen.from_example(x) for x in obj] + if len(set(all_base_tys)) > 1: + raise RuntimeError( + f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. " + "Consider unpacking the argument and give proper names to them if possible " + "instead of using *args." + ) + return ListType(all_base_tys[0], len(obj)) + tp = type(obj) + if tp not in TypeGen.convert_to_base_ty: + raise RuntimeError(f"unsupported type {tp}") + return BaseType(TypeGen.convert_to_base_ty[tp]) + + +class ReturnGen: + @staticmethod + def from_example( + name: Optional[str], obj: Any, annotation: Optional[Annotation] + ) -> Return: + return Return(name, TypeGen.from_example(obj), annotation) + + +class ArgumentGen: + @staticmethod + def from_example( + name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] + ) -> Argument: + return Argument( + name, TypeGen.from_example(obj), default=default, annotation=annotation + ) + + +class FunctionSchemaGen: + @staticmethod + def from_example( + op_name: str, + example_inputs: Tuple[Tuple[str, Any], ...], + example_outputs: Tuple[Any, ...], + ) -> FunctionSchema: + args = [] + for name, inp in example_inputs: + args.append(ArgumentGen.from_example(name, inp, None, None)) + # ignore the annotations and other attributes for now, we could add more when needed. + arguments = Arguments( + tuple(), None, tuple(args), tuple(), None, tuple(), tuple() + ) + returns = tuple( + ReturnGen.from_example(None, out, None) for out in example_outputs + ) + op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "") + return FunctionSchema(op_name, arguments, returns) diff --git a/.venv/lib/python3.11/site-packages/torchgen/gen_vmap_plumbing.py b/.venv/lib/python3.11/site-packages/torchgen/gen_vmap_plumbing.py new file mode 100644 index 0000000000000000000000000000000000000000..af9af6454eb03317d282b189395df85a2f0fec9e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/gen_vmap_plumbing.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import textwrap +from dataclasses import dataclass +from typing import Sequence + +from torchgen.api.translate import translate +from torchgen.api.types import DispatcherSignature +from torchgen.context import method_with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + SchemaKind, + Type, +) +from torchgen.utils import mapMaybe + + +def is_tensor(typ: Type) -> bool: + return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor + + +def is_optional_tensor(typ: Type) -> bool: + return isinstance(typ, OptionalType) and is_tensor(typ.elem) + + +def is_tensor_list(typ: Type) -> bool: + return isinstance(typ, ListType) and is_tensor(typ.elem) + + +def unwrap_tensor(name: str, cur_level_var: str) -> list[str]: + result = f"""\ + auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});""" + return textwrap.dedent(result).split("\n") + + +def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]: + result = f"""\ + std::optional {name}_value; + std::optional {name}_bdim; + if ({name}) {{ + std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var}); + }}""" + return textwrap.dedent(result).split("\n") + + +def gen_unwraps( + flat_arguments: Sequence[Argument], cur_level_var: str +) -> tuple[str, list[str]]: + arg_names = [a.name for a in flat_arguments] + arg_types = [a.type for a in flat_arguments] + + tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)] + optional_tensors = [ + name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ) + ] + + unwraps = [] + for tensor in tensors: + unwraps += unwrap_tensor(tensor, cur_level_var) + + for opt_tensor in optional_tensors: + unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var) + unwrap_code = "\n".join(unwraps) + + unwrapped_arg_list = [] + for arg in arg_names: + if arg in tensors or arg in optional_tensors: + unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"] + else: + unwrapped_arg_list.append(arg) + return unwrap_code, unwrapped_arg_list + + +def gen_case_where_all_bdims_are_none( + outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str +) -> str: + conditions = [] + flat_args = schema.arguments.flat_all + for arg in flat_args: + if not arg.type.is_tensor_like(): + continue + conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})") + + sig = DispatcherSignature.from_schema(schema) + translated_args = ", ".join( + e.expr for e in translate(outer_sig.arguments(), sig.arguments()) + ) + return f"""\ +if ({' && '.join(conditions)}) {{ + return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args}); +}}""" + + +def gen_returns( + returns: tuple[Return, ...], cur_level_var: str, results_var: str +) -> str: + idx = 0 + wrapped_returns = [] + for ret in returns: + if is_tensor(ret.type): + wrapped_returns.append( + f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})" + ) + idx += 2 + elif is_tensor_list(ret.type): + wrapped_returns.append( + f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" + ) + idx += 2 + else: + wrapped_returns.append(f"std::get<{idx}>({results_var})") + idx += 1 + if len(wrapped_returns) == 1: + result = f"return {wrapped_returns[0]};" + else: + result = f'return std::make_tuple({", ".join(wrapped_returns)});' + return result + + +def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool: + return any(a.type.is_tensor_like() for a in schema.arguments.flat_all) + + +def is_mutated_arg(argument: Argument) -> bool: + return argument.annotation is not None and argument.annotation.is_write + + +def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None: + # Assumptions: + # - only one argument is being modified in-place + # - the argument that is being modified in-place is the first argument + # - all returns are either Tensor, tuple of Tensor, or TensorList + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + returns = schema.returns + + # Check assumptions. If these are invalid we return None + # and punt the work to handle them to the future. + assert schema.kind() == SchemaKind.inplace + if not is_mutated_arg(schema.arguments.flat_all[0]): + return None + if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: + return None + + # Only support cases where all returns are Tensors or vector + if len(returns) == 0: + return None + if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns): + return None + if not accepts_at_least_one_tensor_input(schema): + return None + + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) + + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + batch_rule({', '.join(unwrapped_arg_list)}); + return {schema.arguments.flat_all[0].name}; +}}""" + + +def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) + + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + batch_rule({', '.join(unwrapped_arg_list)}); +}}""" + + +def gen_vmap_plumbing(native_function: NativeFunction) -> str | None: + schema = native_function.func + sig = DispatcherSignature.from_schema(schema) + returns = schema.returns + + # Only support cases where all returns are Tensors or vector + if not accepts_at_least_one_tensor_input(schema): + return None + if len(returns) == 0: + return gen_vmap_plumbing_no_returns(native_function) + return_symint_overrides = [ + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_cudnn_attention", + ] + if ( + not all(ret.type.is_tensor_like() for ret in returns) + and schema.name.unambiguous_name() not in return_symint_overrides + ): + return None + # in-place views need special handling + if "inplace_view" in native_function.tags: + return None + + if schema.kind() == SchemaKind.inplace: + return gen_vmap_inplace_plumbing(native_function) + + # Don't support these (mutable, out, scratch) + if schema.kind() != SchemaKind.functional: + return None + + results_var = "results" + cur_level_var = "cur_level" + + unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) + + wrapped_returns = gen_returns(returns, cur_level_var, results_var) + return f"""\ +template +{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t {cur_level_var} = maybe_layer->layerId(); +{textwrap.indent(bdims_all_none_case, " ")} +{textwrap.indent(unwraps, " ")} + auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)}); + {wrapped_returns} +}}""" + + +@dataclass(frozen=True) +class ComputeBatchRulePlumbing: + @method_with_native_function + def __call__(self, f: NativeFunction) -> str | None: + result = gen_vmap_plumbing(f) + return result + + +def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str: + body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) + return f""" +#pragma once +#include +#include + +namespace at {{ namespace functorch {{ + +{body} + +}}}} // namespace at::functorch +""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/local.py b/.venv/lib/python3.11/site-packages/torchgen/local.py new file mode 100644 index 0000000000000000000000000000000000000000..7c687c3a7991807e2722c481c9b199e3b4fdf7ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/local.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import threading +from contextlib import contextmanager +from typing import Iterator + + +# Simple dynamic scoping implementation. The name "parametrize" comes +# from Racket. +# +# WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about +# why you need to add a toggle to the global behavior of code +# generation. The parameters here should really only be used +# for "temporary" situations, where we need to temporarily change +# the codegen in some cases because we cannot conveniently update +# all call sites, and are slated to be eliminated once all call +# sites are eliminated. If you don't have a plan for how to get there, +# DON'T add a new entry here. + + +class Locals(threading.local): + use_const_ref_for_mutable_tensors: bool | None = None + use_ilistref_for_tensor_lists: bool | None = None + + +_locals = Locals() + + +def use_const_ref_for_mutable_tensors() -> bool: + assert _locals.use_const_ref_for_mutable_tensors is not None, ( + "need to initialize local.use_const_ref_for_mutable_tensors with " + "local.parametrize" + ) + return _locals.use_const_ref_for_mutable_tensors + + +def use_ilistref_for_tensor_lists() -> bool: + assert _locals.use_ilistref_for_tensor_lists is not None, ( + "need to initialize local.use_ilistref_for_tensor_lists with " + "local.parametrize" + ) + return _locals.use_ilistref_for_tensor_lists + + +@contextmanager +def parametrize( + *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool +) -> Iterator[None]: + old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors + old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists + try: + _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors + _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists + yield + finally: + _locals.use_const_ref_for_mutable_tensors = ( + old_use_const_ref_for_mutable_tensors + ) + _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists diff --git a/.venv/lib/python3.11/site-packages/torchgen/model.py b/.venv/lib/python3.11/site-packages/torchgen/model.py new file mode 100644 index 0000000000000000000000000000000000000000..956949343101ad831832a466e48378defcdbbcc8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/model.py @@ -0,0 +1,2851 @@ +from __future__ import annotations + +import dataclasses +import itertools +import re +from dataclasses import dataclass +from enum import auto, Enum +from typing import Callable, Iterator, Sequence + +from torchgen.utils import assert_never, NamespaceHelper, OrderedSet + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# DATA MODEL +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Some general principles for our data model. +# +# - Stop using C++ data types as the internal data representation +# format. Instead, the internal data structures are centered +# around JIT schema representation. This avoid a big problem +# with the old codegen where we read in all the types from +# native_functions.yaml and then immediately had to retranslate +# them into C++ types. +# +# - More semantic data representation. Instead of representing +# everything as dicts and strings, we define dataclasses for +# every interesting entity the code generation has to deal with. +# These dataclasses have strong semantic invariants: for example, +# we generally require them to roundtrip losslessly into the +# form they were parsed from. These structures are immutable +# and you're expected to populate information once during +# construction. + + +# Represent a source location; used for better error reporting +@dataclass(frozen=True) +class Location: + file: str + line: int + + def __str__(self) -> str: + return f"{self.file}:{self.line}" + + +# Valid values of the 'variants' field in native_functions.yaml +class Variant(Enum): + function = auto() + method = auto() + + +# Default kernel namespace +DEFAULT_KERNEL_NAMESPACE = "at::native" + +# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h +BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() +FUNCTIONALITY_KEYS = [ + "", + "Quantized", + "Sparse", + "SparseCsr", + "NestedTensor", + "Autograd", +] + +# This list guards dispatches that can be used in derivatives.yaml +# For now we omit AutogradFunctionality and AutogradOther +AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [ + "Autograd" + component for component in BACKEND_COMPONENTS +] + +FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"} + + +# This doesn't have to be in sync with the header, it only needs to contain +# entries that we actually use in the codegen or want pyi entries for +class DispatchKey(Enum): + Undefined = 0 + CatchAll = Undefined + + FPGA = auto() + MAIA = auto() + Vulkan = auto() + Metal = auto() + MKLDNN = auto() + OpenGL = auto() + OpenCL = auto() + IDEEP = auto() + CustomRNGKeyId = auto() + MkldnnCPU = auto() + Sparse = auto() + SparseCsr = auto() + NestedTensor = auto() + Dense = auto() + + PythonTLSSnapshot = auto() + PreDispatch = auto() + PythonDispatcher = auto() + Python = auto() + FuncTorchDynamicLayerBackMode = auto() + ZeroTensor = auto() + Conjugate = auto() + Negative = auto() + BackendSelect = auto() + Named = auto() + AutogradOther = auto() + AutogradFunctionality = auto() + AutogradNestedTensor = auto() + Tracer = auto() + Autocast = auto() + AutocastCPU = auto() + AutocastCUDA = auto() + Batched = auto() + VmapMode = auto() + FuncTorchGradWrapper = auto() + FuncTorchBatched = auto() + BatchedNestedTensor = auto() + FuncTorchVmapMode = auto() + FuncTorchDynamicLayerFrontMode = auto() + Functionalize = auto() + TESTING_ONLY_GenericWrapper = auto() + TESTING_ONLY_GenericMode = auto() + + ADInplaceOrView = auto() + Autograd = auto() + CompositeImplicitAutograd = auto() + CompositeImplicitAutogradNestedTensor = auto() + CompositeExplicitAutograd = auto() + CompositeExplicitAutogradNonFunctional = auto() + FuncTorchBatchedDecomposition = auto() + + # BEGIN autogenerated + CPU = auto() + CUDA = auto() + HIP = auto() + XLA = auto() + MTIA = auto() + MPS = auto() + IPU = auto() + XPU = auto() + HPU = auto() + VE = auto() + Lazy = auto() + Meta = auto() + PrivateUse1 = auto() + PrivateUse2 = auto() + PrivateUse3 = auto() + QuantizedCPU = auto() + QuantizedCUDA = auto() + QuantizedHIP = auto() + QuantizedXLA = auto() + QuantizedMTIA = auto() + QuantizedMPS = auto() + QuantizedIPU = auto() + QuantizedXPU = auto() + QuantizedHPU = auto() + QuantizedVE = auto() + QuantizedLazy = auto() + QuantizedMeta = auto() + QuantizedPrivateUse1 = auto() + QuantizedPrivateUse2 = auto() + QuantizedPrivateUse3 = auto() + SparseCPU = auto() + SparseCUDA = auto() + SparseHIP = auto() + SparseXLA = auto() + SparseMTIA = auto() + SparseMPS = auto() + SparseIPU = auto() + SparseXPU = auto() + SparseHPU = auto() + SparseVE = auto() + SparseLazy = auto() + SparseMeta = auto() + SparsePrivateUse1 = auto() + SparsePrivateUse2 = auto() + SparsePrivateUse3 = auto() + SparseCsrCPU = auto() + SparseCsrCUDA = auto() + SparseCsrHIP = auto() + SparseCsrXLA = auto() + SparseCsrMTIA = auto() + SparseCsrMPS = auto() + SparseCsrIPU = auto() + SparseCsrXPU = auto() + SparseCsrHPU = auto() + SparseCsrVE = auto() + SparseCsrLazy = auto() + SparseCsrMeta = auto() + SparseCsrPrivateUse1 = auto() + SparseCsrPrivateUse2 = auto() + SparseCsrPrivateUse3 = auto() + NestedTensorCPU = auto() + NestedTensorCUDA = auto() + NestedTensorHIP = auto() + NestedTensorXLA = auto() + NestedTensorMTIA = auto() + NestedTensorMPS = auto() + NestedTensorIPU = auto() + NestedTensorXPU = auto() + NestedTensorHPU = auto() + NestedTensorVE = auto() + NestedTensorLazy = auto() + NestedTensorMeta = auto() + NestedTensorPrivateUse1 = auto() + NestedTensorPrivateUse2 = auto() + NestedTensorPrivateUse3 = auto() + AutogradCPU = auto() + AutogradCUDA = auto() + AutogradHIP = auto() + AutogradXLA = auto() + AutogradMTIA = auto() + AutogradMPS = auto() + AutogradIPU = auto() + AutogradXPU = auto() + AutogradHPU = auto() + AutogradVE = auto() + AutogradLazy = auto() + AutogradMeta = auto() + AutogradPrivateUse1 = auto() + AutogradPrivateUse2 = auto() + AutogradPrivateUse3 = auto() + # END autogenerated + + def __str__(self) -> str: + return self.name + + def lower(self) -> str: + return str(self).lower() + + @staticmethod + def parse(value: str) -> DispatchKey: + for k, v in DispatchKey.__members__.items(): + if k == value: + return v + raise AssertionError(f"unknown dispatch key {value}") + + +class _TorchDispatchModeKey(Enum): + FAKE = auto() + PROXY = auto() + FUNCTIONAL = auto() + + +def codegen_per_backend_entries() -> str: + r = [] + for fk in FUNCTIONALITY_KEYS: + for bc in BACKEND_COMPONENTS: + r.append(f" {fk}{bc} = auto()") + return "\n".join(r) + + +for fk in FUNCTIONALITY_KEYS: + for bc in BACKEND_COMPONENTS: + if not hasattr(DispatchKey, fk + bc): + r = codegen_per_backend_entries() + print(r) + raise RuntimeError( + f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" + ) + + +STRUCTURED_DISPATCH_KEYS = { + DispatchKey.MPS, + DispatchKey.CUDA, + DispatchKey.CPU, + DispatchKey.XPU, +} +UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} + +# Set of supported dispatch keys +dispatch_keys = [ + DispatchKey.CPU, + DispatchKey.SparseCPU, + DispatchKey.SparseCsrCPU, + DispatchKey.MkldnnCPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + DispatchKey.SparseCUDA, + DispatchKey.SparseCsrCUDA, + DispatchKey.QuantizedCPU, + DispatchKey.QuantizedCUDA, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.NestedTensorCPU, + DispatchKey.NestedTensorCUDA, + # Meta is a magic key: it is automatically generated for structured + # kernels + DispatchKey.Meta, + DispatchKey.SparseMeta, + DispatchKey.SparseCsrMeta, + DispatchKey.QuantizedMeta, + DispatchKey.NestedTensorMeta, + DispatchKey.ZeroTensor, +] + + +# Dispatch keys that "support all backends". These codegen slightly differently +# then backend specific keys. +def is_generic_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } + + +# CUDA specific dispatch keys +def is_cuda_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.CUDA, + DispatchKey.QuantizedCUDA, + DispatchKey.SparseCUDA, + DispatchKey.SparseCsrCUDA, + DispatchKey.NestedTensorCUDA, + DispatchKey.AutogradCUDA, + } + + +# XPU specific dispatcy keys +def is_xpu_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.XPU, + DispatchKey.QuantizedXPU, + DispatchKey.SparseXPU, + DispatchKey.SparseCsrXPU, + DispatchKey.NestedTensorXPU, + DispatchKey.AutogradXPU, + } + + +# Structured kernel generation is only supported for certain key types; +# otherwise use old-style +def is_structured_dispatch_key(dk: DispatchKey) -> bool: + return dk in STRUCTURED_DISPATCH_KEYS + + +def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: + # For now, ufunc dispatch keys coincide with structured keys + return dk in UFUNC_DISPATCH_KEYS + + +# This is oddly named ScalarType and not DType for symmetry with C++ +class ScalarType(Enum): + Byte = auto() + Char = auto() + Short = auto() + Int = auto() + Long = auto() + Half = auto() + Float = auto() + Double = auto() + ComplexHalf = auto() + ComplexFloat = auto() + ComplexDouble = auto() + Bool = auto() + BFloat16 = auto() + Float8_e5m2 = auto() + Float8_e5m2fnuz = auto() + Float8_e4m3fn = auto() + Float8_e4m3fnuz = auto() + + def __str__(self) -> str: + return self.name + + @staticmethod + def maybe_parse(value: str) -> ScalarType | None: + for k, v in ScalarType.__members__.items(): + if k == value: + return v + return None + + @staticmethod + def parse(value: str) -> ScalarType: + mb_r = ScalarType.maybe_parse(value) + assert mb_r is not None, f"unknown dtype {value}" + return mb_r + + @staticmethod + def parse_set(values: str) -> OrderedSet[ScalarType]: + dtypes: OrderedSet[ScalarType] = OrderedSet() + for value in values.split(", "): + if value in DTYPE_CLASSES: + dtypes.update(DTYPE_CLASSES[value]) + else: + dtypes.add(ScalarType.parse(value)) + return dtypes + + +DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {} +# NB: Integral doesn't include boolean +DTYPE_CLASSES["Integral"] = OrderedSet( + [ + ScalarType.Byte, + ScalarType.Char, + ScalarType.Int, + ScalarType.Long, + ScalarType.Short, + ] +) +# NB: Floating doesn't include low precision types +DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double]) +DTYPE_CLASSES["Complex"] = OrderedSet( + [ScalarType.ComplexFloat, ScalarType.ComplexDouble] +) +DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] +DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] +DTYPE_CLASSES["FloatingAndComplex"] = ( + DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] +) + + +# Represents the valid entries for ufunc_inner_loop in native_functions.yaml. +# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how +# to process it. Most logic will ignore keys they don't understand, so your +# new key will get silently ignored until you hook in logic to deal with it. +class UfuncKey(Enum): + # These are low level keys that represent exactly one particular + # instantiation of the kernel produced by codegen + CUDAFunctor = auto() + CUDAFunctorOnOther = auto() + CUDAFunctorOnSelf = auto() + + CPUScalar = auto() + CPUVector = auto() + + # These are the ones users will usually specify, and + # implicitly "fill in" the low level keys + ScalarOnly = auto() # CUDA*, CPUScalar + Generic = auto() # CUDA*, CPU* + + def __str__(self) -> str: + return self.name + + @staticmethod + def parse(value: str) -> UfuncKey: + for k, v in UfuncKey.__members__.items(): + if k == value: + return v + raise AssertionError(f"unknown ufunc key {value}") + + +class DeviceCheckType(Enum): + NoCheck = 0 + ExactSame = 1 + + +class ViewSchemaKind(Enum): + aliasing = auto() + aliasing_inplace = auto() + non_aliasing = auto() + + +# The basic input to the code generation is native_functions.yaml. +# The name "native", BTW, comes from the distinction between native +# functions and legacy TH functions. The legacy TH functions are gone, +# but the "native" descriptor has stuck. +# +# NativeFunction models a single entry in native_functions.yaml. Its +# fields roughly correspond to what you would see in the YAML itself, +# but after canonicalization and parsing has occurred. +# +# You can see some of the overall design patterns for how we setup +# dataclasses in this class, but we will defer a complete discussion +# of this at FunctionSchema. +@dataclass(frozen=True) +class NativeFunction: + # The namespace for this operator. For example, if we have "at::add" + # then the namespace would be "at". This enables ops to be registered + # through the same DSL with a custom namespace. If not specified, the + # default namespace would be "at". + namespace: str + + # The function schema of the operator in question. This schema + # has been parsed; see FunctionSchema for more about its structure. + # (This type is quoted as we are forward referencing a type + # defined later in the file. I opted for this ordering of the + # classes for expository clarity.) + func: FunctionSchema + + # Whether or not to generate mutable tensor arguments like regular + # ones + use_const_ref_for_mutable_tensors: bool + + # Whether or not to omit automatic generation of a DeviceGuard + device_guard: bool + + # How to emit automatic generation of device check + device_check: DeviceCheckType + + # What python module to put the function in + python_module: str | None + + # TODO: figure out what this does + category_override: str | None + + # If no variants are specified in native_functions.yaml, this is + # assumed to be {'function'}. + variants: set[Variant] + + # Whether or not we should skip generating registrations for + # this kernel. This is a bit of a double-edged sword, as manual + # registrations don't participate in codegen-based selective build! + manual_kernel_registration: bool + + # Whether or not to skip generating TensorMethod/Functions bindings + # for this kernel. Technically, this doesn't actually skip generating + # the binding; instead, the binding gets generated to __dispatch_{funcname} + # so you can make use of the normal binding if you need it. + manual_cpp_binding: bool + + # The location in the YAML file were this native function entry was + # defined. This is for conveniently reporting error messages! + loc: Location + + # A list of operators that are expected to be auto-generated for this NativeFunction. + # Note: This list isn't actually directly used by the codegen to generate anything. + # Instead, the codegen figures out what operators to generate purely based off of + # function schema, and uses the autogen declarations to error check. + # We expect every NativeFunction that gets auto-generated be explicitly called out + # in native_functions.yaml + autogen: list[OperatorName] + + # If non-empty, this kernel is subject to ufunc codegen. + # Sorted by ufunc_key + ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop] + + # Whether or not this out functions is a "structured kernel". Structured + # kernels are defined a little differently from normal kernels; in + # particular, their shape checking logic is defined separately from + # the kernel. Only out functions can be structured; other functions + # delegate to the out function using the structured_delegate keyword. + # Every structured kernel must have at least an out and a functional + # variant. + structured: bool + + # Whether or not this non-out function is a structured kernel, defined + # in terms of the out kernel referenced by the string here. + structured_delegate: OperatorName | None + + # Only valid for structured kernels. Specifies alternative of what + # to inherit from when defining the meta class for the structured + # operator. This will usually be TensorIteratorBase. This also + # changes the semantics of set_output to call the parent class. + structured_inherits: str | None + + # Structured kernels can declare elements as "precomputed". These elements + # are returned by the meta function in one struct and passed to the impl + # function in lieu of certain kernel arguments that these precomputed + # elements supersede. Information about the names and types of these + # precomputed elements and how they correspond to kernel arguments is stored + # in this member, if applicable. + precomputed: Precompute | None + + # Argument names whose default should be excluded from the C++ interface. + # Intended for resolving overload ambiguities between signatures. + cpp_no_default_args: set[str] + + # Note [Abstract ATen methods] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # An abstract ATen method is one whose dispatch differs between + # types. These are implemented in derived types (with a + # standard (throwing) definition in Type). A concrete ATen + # method is one which has the same dispatch for all types; + # we just implement it in the base Type. This is exposed + # in Declarations.yaml via a field named 'abstract'. + is_abstract: bool + + # Whether or not the NativeFunction contains a backend-agnostic kernel + has_composite_implicit_autograd_kernel: bool + has_composite_implicit_autograd_nested_tensor_kernel: bool + has_composite_explicit_autograd_kernel: bool + has_composite_explicit_autograd_non_functional_kernel: bool + + # Tags are used to describe semantic information about (groups of) operators, + # That aren't easily inferrable directly from the operator's schema. + tags: set[str] + + # NB: The benefit of defining a dataclass is that we automatically get + # a constructor defined for all the fields we specify. No need + # to explicitly write it out. + + # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. + @staticmethod + def from_yaml( + ei: dict[str, object], + loc: Location, + valid_tags: set[str], + ignore_keys: set[DispatchKey] | None = None, + ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: + """ + Parse a NativeFunction from a dictionary as directly parsed + from native_functions.yaml + """ + e = ei.copy() + + funcs = e.pop("func") + assert isinstance(funcs, str), f"not a str: {funcs}" + # only support one level of namespace. E.g., aten::add + namespace_helper = NamespaceHelper.from_namespaced_entity( + namespaced_entity=funcs, max_level=1 + ) + namespace = namespace_helper.get_cpp_namespace(default="aten") + func = FunctionSchema.parse(namespace_helper.entity_name) + + cpp_no_default_args_list = e.pop("cpp_no_default_args", []) + assert isinstance(cpp_no_default_args_list, list) + cpp_no_default_args = set(cpp_no_default_args_list) + + use_const_ref_for_mutable_tensors = e.pop( + "use_const_ref_for_mutable_tensors", False + ) + assert isinstance(use_const_ref_for_mutable_tensors, bool) + + variants_s = e.pop("variants", "function") + assert isinstance(variants_s, str) + variants: set[Variant] = set() + for v in variants_s.split(", "): + if v == "function": + variants.add(Variant.function) + elif v == "method": + variants.add(Variant.method) + else: + raise AssertionError(f"illegal variant {v}") + + manual_kernel_registration = e.pop("manual_kernel_registration", False) + assert isinstance( + manual_kernel_registration, bool + ), f"not a bool: {manual_kernel_registration}" + + manual_cpp_binding = e.pop("manual_cpp_binding", False) + assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" + + device_guard = e.pop("device_guard", True) + assert isinstance(device_guard, bool), f"not a bool: {device_guard}" + + device_check_s = e.pop("device_check", None) + assert device_check_s is None or isinstance( + device_check_s, str + ), f"not a str: {device_check_s}" + assert ( + device_check_s is None or device_check_s in DeviceCheckType.__members__ + ), f"illegal device_check: {device_check_s}" + device_check: DeviceCheckType + if device_check_s is None: + device_check = DeviceCheckType.ExactSame + else: + device_check = DeviceCheckType[device_check_s] + + structured = e.pop("structured", False) + assert isinstance(structured, bool), f"not a bool: {structured}" + + structured_delegate_s = e.pop("structured_delegate", None) + assert structured_delegate_s is None or isinstance( + structured_delegate_s, str + ), f"not a str: {structured_delegate_s}" + assert structured_delegate_s is None or "::" not in structured_delegate_s, ( + "namespace is not supported in structured delegate," + " using the same namespace as the native function" + ) + structured_delegate: OperatorName | None = None + if structured_delegate_s is not None: + structured_delegate = OperatorName.parse(structured_delegate_s) + + structured_inherits = e.pop("structured_inherits", None) + assert structured_inherits is None or isinstance( + structured_inherits, str + ), f"not a str: {structured_inherits}" + assert structured_inherits is None or "::" not in structured_inherits, ( + "namespace is not supported in structured inherits," + " using the same namespace as the native function" + ) + + python_module = e.pop("python_module", None) + assert python_module is None or isinstance( + python_module, str + ), f"not a str: {python_module}" + assert ( + python_module is None or Variant.method not in variants + ), "functions in modules cannot be methods" + + category_override = e.pop("category_override", None) + assert category_override is None or isinstance( + category_override, str + ), f"not a str: {category_override}" + + precomputed_dict = e.pop("precomputed", None) + assert precomputed_dict is None or structured is True + precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None + + tags_inp = e.pop("tags", []) + if isinstance(tags_inp, str): + tags_inp = [tags_inp] + assert isinstance(tags_inp, list) + + # All aten ops generated by torchgen receive the pt2_compliant tag. + if namespace == "aten" and "pt2_compliant_tag" in valid_tags: + tags_inp.append("pt2_compliant_tag") + + tags: set[str] = set() + for t in tags_inp: + assert len(valid_tags) > 0 + # TODO: verify that the tag is valid and has an entry in tags.yaml + if t in valid_tags: + tags.add(t) + else: + raise AssertionError(f"illegal tag {t}") + + from torchgen.api import cpp + + raw_dispatch = e.pop("dispatch", None) + assert raw_dispatch is None or isinstance(raw_dispatch, dict), e + dispatch: dict[DispatchKey, BackendMetadata] = {} + num_dispatch_keys: int = 0 + if raw_dispatch is not None: + assert not manual_kernel_registration, ( + "cannot specify both manual_kernel_registration and dispatch; with " + "manual registration, dispatch has no effect!" + ) + redundant_composite_implicit_autograd = False + for ks, v in raw_dispatch.items(): + if ks == "__line__": + continue # not worth tracking line numbers for dispatch entries + assert isinstance( + ks, str + ), f"illegal dispatch key '{ks}' in {raw_dispatch}" + assert isinstance( + v, str + ), f"illegal dispatch value '{v}' in {raw_dispatch}" + for k in ks.split(","): + dispatch_key = DispatchKey.parse(k.strip()) + num_dispatch_keys += 1 + + if ignore_keys and dispatch_key in ignore_keys: + continue + assert dispatch_key in dispatch_keys, ( + f"Dispatch key {dispatch_key} of kernel {v} " + "is not a supported dispatch key." + ) + # We only allow at most 3 levels of namespace for kernels. + # We will append "native" to a custom kernel namespace. + namespace_helper = NamespaceHelper.from_namespaced_entity( + v, max_level=3 + ) + kernel_namespace = namespace_helper.get_cpp_namespace(default="at") + # Why is 'structured' included? External backends (e.g. + # XLA) opt into which ops are structured independently + # of which in-tree ops are structured + dispatch[dispatch_key] = BackendMetadata( + kernel=namespace_helper.entity_name, + structured=structured + and is_structured_dispatch_key(dispatch_key), + cpp_namespace=(kernel_namespace + "::native"), + ) + if ( + dispatch_key is DispatchKey.CompositeImplicitAutograd + and v == cpp.name(func) + ): + redundant_composite_implicit_autograd = True + + # We count the number of dispatch keys which have not been ignored to prevent a dispatch table + # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit, + # from being treated as redundant. + assert not ( + num_dispatch_keys == 1 and redundant_composite_implicit_autograd + ), ( + "unnecessary dispatch table for this function; just delete the dispatch " + "key entirely" + ) + # if a function is a structured delegate, deleting the dispatch + # table is NOT semantics preserving + assert ( + structured_delegate + or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} + or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint() + or num_dispatch_keys != 1 + ), ( + f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " + f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " + "name, then delete the dispatch table" + ) + elif not structured and structured_delegate is None: + name = str(func.name.name) + assert not ( + name.startswith("new_") + or name.endswith("_like") + # TODO: maybe it's better to test the return + or ( + func.arguments.tensor_options + and not func.arguments.has_tensor_arg() + ) + ), ( + f"expected {name} to have a CompositeExplicitAutograd " + "dispatch entry, but there was no dispatch table. Factory functions " + "should not have implicit dispatch as they should not be decomposed " + "for __torch_dispatch__" + ) + dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( + cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE + ) + + composites_in_dispatch = [ + d + for d in dispatch + if d == DispatchKey.CompositeExplicitAutograd + or d == DispatchKey.CompositeExplicitAutogradNonFunctional + or d == DispatchKey.CompositeImplicitAutograd + or d == DispatchKey.CompositeImplicitAutogradNestedTensor + ] + + assert len(composites_in_dispatch) <= 1 or ( + len(composites_in_dispatch) == 2 + and ( + DispatchKey.CompositeExplicitAutogradNonFunctional + not in composites_in_dispatch + ) + and ( + DispatchKey.CompositeImplicitAutogradNestedTensor + in composites_in_dispatch + ) + ), ( + "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " + "or CompositeImplicitAutograd on a single kernel; each " + "strictly subsumes the other. If you wanted to provide an explicit autograd " + "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" + ) + + autogen_str = e.pop("autogen", "") + assert isinstance(autogen_str, str) + autogen = ( + [] + if autogen_str == "" + else [OperatorName.parse(x) for x in autogen_str.split(", ")] + ) + + raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) + ufunc_inner_loop = {} + if isinstance(raw_ufunc_inner_loop, str): + ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse( + raw_ufunc_inner_loop, UfuncKey.Generic + ) + elif isinstance(raw_ufunc_inner_loop, dict): + for k, vo in raw_ufunc_inner_loop.items(): + if k == "__line__": + continue + assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}" + assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}" + ufunc_key = UfuncKey.parse(k) + ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) + else: + raise AssertionError( + f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}" + ) + # Program the BackendIndex for the implicit dispatch entry from ufunc + if ufunc_inner_loop: + assert structured, "ufunc must be structured" + + # Delay import ufunc here to avoid circular import issue + # See: https://github.com/pytorch/pytorch/issues/81294 + import torchgen.api.ufunc as ufunc + + for dispatch_key in UFUNC_DISPATCH_KEYS: + assert ( + dispatch_key not in dispatch + ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" + dispatch[dispatch_key] = BackendMetadata( + kernel=ufunc.schema_kernel_name(func, dispatch_key), + structured=True, + cpp_namespace=DEFAULT_KERNEL_NAMESPACE, + ) + + if structured_delegate: + # Structured functions MUST have a dispatch table + is_abstract = True + else: + is_abstract = ( + dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} + and dispatch.keys() + != {DispatchKey.CompositeImplicitAutogradNestedTensor} + and dispatch.keys() + != { + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } + ) + + has_composite_implicit_autograd_kernel = ( + DispatchKey.CompositeImplicitAutograd in dispatch + ) + has_composite_implicit_autograd_nested_tensor_kernel = ( + DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch + ) + has_composite_explicit_autograd_kernel = ( + DispatchKey.CompositeExplicitAutograd in dispatch + ) + has_composite_explicit_autograd_non_functional_kernel = ( + DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch + ) + + # We aren't going to store dispatch metadata inline in NativeFunctions; + # instead it is separately indexed by backend (so other backends can + # add more dispatch entries after the fact). Reindex the individual + # metadata by OperatorName! + backend_metadata = {k: {func.name: v} for k, v in dispatch.items()} + + # don't care if it exists or not; make it easier to use this function + # with other yaml parsers that aren't setting __line__ in the dict + e.pop("__line__", None) + assert not e, f"leftover entries: {e}" + + # Asserts that we can't do in post_init, because they rely on backend-specific info + if structured_delegate is not None: + for key in STRUCTURED_DISPATCH_KEYS: + assert key not in dispatch, ( + f"if structured_delegate, then must not have {key} in dispatch dictionary " + "(it is delegated!)" + ) + + return ( + NativeFunction( + func=func, + use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors, + variants=variants, + structured=structured, + structured_delegate=structured_delegate, + structured_inherits=structured_inherits, + precomputed=precomputed, + autogen=autogen, + ufunc_inner_loop=ufunc_inner_loop, + manual_kernel_registration=manual_kernel_registration, + manual_cpp_binding=manual_cpp_binding, + python_module=python_module, + category_override=category_override, + device_guard=device_guard, + device_check=device_check, + loc=loc, + cpp_no_default_args=cpp_no_default_args, + is_abstract=is_abstract, + has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, + has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel, + has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, + has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, + tags=tags, + namespace=namespace, + ), + backend_metadata, + ) + + def validate_unstructured(self) -> None: + # TODO: probably better to accumulate these errors and report them all + # at once + assert not self.structured, ( + "This function is structured, but there was " + "no valid functional variant of it." + ) + assert self.structured_delegate, ( + "This function delegates to another structured out function, " + "but no valid function was found (the delegate may not exist, or it has the wrong type)" + ) + + # __post_init__ functions in dataclasses can be used to do extra + # validation after construction. + # + # Notice that we don't do any type validation here. In fact, we + # rely exclusively on mypy to check if you've done types correctly! + # Validation is for nontrivial invariants that cannot be (conveniently) + # encoded in the type system. + def __post_init__(self) -> None: + if self.func.arguments.out: + assert self.variants == {Variant.function}, ( + "Native functions with out arguments MUST " + "be declared with only function variant; e.g., variants: function; " + "otherwise you will tickle a Python argument binding bug " + "(which usually manifests itself as the result variable being undefined.)" + ) + if self.structured: + assert self.func.kind() == SchemaKind.out, ( + "Put structured field on the out= " + "variant of a function; did you mean structured_delegate?" + ) + assert ( + self.device_guard + ), "device_guard: False is not respected by structured kernels" + if self.structured_delegate: + assert self.func.kind() != SchemaKind.out, ( + "structured_delegate field not allowed " + "on out= functions; did you mean structured?" + ) + assert ( + self.device_guard + ), "device_guard: False is not respected by structured kernels" + # Technically, with the asserts above, this assert is impossible to + # happen + assert not ( + self.structured and self.structured_delegate + ), "Cannot have both structured and structured_delegate on function" + defaulted_arguments = { + a.name for a in self.func.schema_order_arguments() if a.default is not None + } + invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) + assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" + if self.structured_inherits is not None: + assert ( + self.structured + ), "structured_inherits must also imply structured: True" + if str(self.func.name).startswith("_foreach"): + assert self.device_check == DeviceCheckType.NoCheck, ( + "foreach kernels fall back to slow path when tensor are on different devices, " + "device_check not allowed to be enabled" + ) + + # NB: if your function accidentally has rand/dropout/... in its name + # but is not actually random, feel free to amend this to special case + if ( + "rand" in str(self.func.name) + or ( + ( + "dropout" in str(self.func.name) + or any( + "dropout" in arg.name for arg in self.func.arguments.flat_all + ) + ) + # Backwards of dropout is typically deterministic + and "backward" not in str(self.func.name) + and str(self.func.name.name) not in ["_cudnn_init_dropout_state"] + ) + or self.func.arguments.has_generator_arg() + ): + assert "nondeterministic_seeded" in self.tags, str(self.func.name) + + @property + def has_composite_kernel(self) -> bool: + return ( + self.has_composite_implicit_autograd_kernel + or self.has_composite_explicit_autograd_kernel + or self.has_composite_explicit_autograd_non_functional_kernel + ) or ( + self.has_composite_implicit_autograd_kernel + and self.has_composite_implicit_autograd_nested_tensor_kernel + ) + + @property + def is_view_op(self) -> bool: + rets = self.func.returns + is_non_mutating_view = len(rets) > 0 and any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + # See Note [resize_ in Functionalization] for more dtails + is_inplace_view = ( + "inplace_view" in self.tags + and str(self.func.name) != "resize_" + and str(self.func.name) != "resize_as_" + ) + is_wildcard_view = any( + inp.annotation is not None and "*" in inp.annotation.alias_set_after + for inp in self.func.schema_order_arguments() + ) + return is_non_mutating_view or is_inplace_view or is_wildcard_view + + @property + def view_schema_kind(self) -> ViewSchemaKind: + if self.is_view_op and self.func.name.name.inplace: + assert "inplace_view" in self.tags + return ViewSchemaKind.aliasing_inplace + if self.is_view_op: + return ViewSchemaKind.aliasing + else: + return ViewSchemaKind.non_aliasing + + @property + def root_name(self) -> str: + return self.func.name.name.base + + @property + def part_of_structured_group(self) -> bool: + return self.structured or self.structured_delegate is not None + + +class SchemaKind(Enum): + functional = auto() + inplace = auto() + out = auto() + mutable = auto() + scratch = auto() + + +# A structured kernel is guaranteed to have a functional and out variant, and +# optionally an inplace variant. +# +# NB: we create NativeFunctionsGroup *even if* the function is not +# actually annotated structured. Test the structured boolean to see if it +# actually is structured or not. +@dataclass(frozen=True) +class NativeFunctionsGroup: + functional: NativeFunction + inplace: NativeFunction | None + mutable: NativeFunction | None + out: NativeFunction + + @property + def structured(self) -> bool: + # Whether or not the operator has a meta() function. This information is backend-agnostic. + return self.out.structured + + def __post_init__(self) -> None: + test_sig: FunctionSchema = self.functional.func.signature() + for f in self.functions(): + if test_sig != f.func.signature(): + raise AssertionError( + "NativeFunctionsGroup constructed from two NativeFunctions " + f"that don't have matching signatures: {test_sig} != {f.func.signature()}" + ) + + if self.structured != f.part_of_structured_group: + raise AssertionError( + "NativeFunctionsGroup constructed from structured and unstructured " + f"functions: {self.out.func.name} and {f.func.name}" + ) + assert self.functional.func.kind() == SchemaKind.functional + assert self.out.func.kind() == SchemaKind.out + assert self.functional.namespace == self.out.namespace + if self.inplace is not None: + assert self.inplace.func.kind() == SchemaKind.inplace + assert self.inplace.namespace == self.functional.namespace + + if self.mutable is not None: + assert self.mutable.func.kind() == SchemaKind.mutable + assert self.mutable.namespace == self.functional.namespace + # See Note [Overload Ambiguity With Functional Variants] + assert self.functional.func.name.name.functional_overload + + if self.structured: + # For now, structured composite kernels are not supported (need some + # design work to figure out how to make the composite case work) + assert ( + not self.out.has_composite_implicit_autograd_kernel + and not self.out.has_composite_implicit_autograd_nested_tensor_kernel + ) + + assert self.functional.structured_delegate == self.out.func.name, ( + f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " + f"but its actual delegate is {self.out.func.name}" + ) + if self.inplace is not None: + assert self.inplace.structured_delegate == self.out.func.name + + generated_fns = sorted( + [str(f.func.name) for f in self.functions() if "generated" in f.tags] + ) + generated_fns_str = ", ".join(str(x) for x in generated_fns) + expected_generated_fns: set[str] = set() + for f in self.functions(): + expected_generated_fns.update(str(op) for op in f.autogen) + expected_generated_fns_str = ", ".join( + str(x) for x in sorted(expected_generated_fns) + ) + if len(expected_generated_fns) == 0 and len(generated_fns) > 0: + raise RuntimeError( + f"The codegen expects to be able to generate '{generated_fns_str}'." + " In order to generate them however, we expect them to be called out explicitly in the yaml." + f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" + ) + if expected_generated_fns_str != generated_fns_str: + raise RuntimeError( + f"The codegen expects to be able to generate '{generated_fns_str}'." + f" To do so, it expects a line: 'autogen: {generated_fns_str}'." + f" Instead, it found 'autogen: {expected_generated_fns_str}'" + ) + + def signature(self) -> FunctionSchema: + return self.out.func.signature() + + def functions(self) -> Iterator[NativeFunction]: + yield self.functional + yield self.out + if self.inplace is not None: + yield self.inplace + if self.mutable is not None: + yield self.mutable + + @property + def root_name(self) -> str: + return self.functional.root_name + + @staticmethod + def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None: + assert d + if len(d) == 1: + return None + d = dict(d) # non-destructive updates please + functional = d.pop(SchemaKind.functional, None) + inplace = d.pop(SchemaKind.inplace, None) + mutable = d.pop(SchemaKind.mutable, None) + out = d.pop(SchemaKind.out, None) + assert not d + assert functional is not None + # There are a few operators which only have functional/inplace variants; + # these don't count as structured for our purposes here + if out is None: + return None + # assuming all variants have the same namespace + return NativeFunctionsGroup( + functional=functional, + inplace=inplace, + mutable=mutable, + out=out, + ) + + +@dataclass(frozen=True) +class BackendMetadata: + # The name of the backend kernel, for a given operator + # for in-tree backends. These names come directly from the 'dispatch" field + # in native_functions.yaml. The dispatch entry is optional; in that + # case, that is equivalent to having written: + # + # dispatch: + # CompositeImplicitAutograd: $operator_name + kernel: str + # Whether or not the operator has a structured kernel implemented, for this particular backend. + # For in-tree backends, they all have the same value for structured- this is listed + # in native_functions.yaml. + # However, external backends like XLA can indendently toggle which ops are structured. + structured: bool + + # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE + cpp_namespace: str + + def supports_symint(self) -> bool: + return "_symint" in self.kernel + + +@dataclass(frozen=True) +class UfuncInnerLoop: + name: str + supported_dtypes: OrderedSet[ScalarType] + # key is stored here because it affects the semantics of name, + # so its helpful to have them together for further processing + ufunc_key: UfuncKey + + @staticmethod + def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop: + name, supported_dtypes_str = value.split(" ", 1) + assert supported_dtypes_str[0] == "(" + assert supported_dtypes_str[-1] == ")" + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() + for k in supported_dtypes_str[1:-1].split(", "): + supported_dtypes |= ScalarType.parse_set(k) + return UfuncInnerLoop( + name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key + ) + + +# BackendIndex represents a backend. +# The BackendIndex encodes per-operator information that is potentially different +# for each backend. The most obvious example is the name of the kernel +# (the 'dispatch' entry in native_functions.yaml). +# However, there can be other examples of different backends having different information. +# External backends can choose to opt their kernels to be structured independently from in-tree backends, +# which means that this information isn't inherently tied to a NativeFunction- it's different per backend. +@dataclass(frozen=True) +class BackendIndex: + dispatch_key: DispatchKey + # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others. + # All in-tree ops use out kernels, while XLA uses functional kernels. + use_out_as_primary: bool + # Whether the backend requires a device guard, and device checks. + # For in-tree backends, this is currently just CUDA/HIP + # For out-of-tree backends, this is currently just Intel XPU + device_guard: bool + # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) + external: bool + # Other backend-specific information that is on a per-operator basis + index: dict[OperatorName, BackendMetadata] + + @staticmethod + def grow_index( + parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], + ) -> None: + for k, v in child_index.items(): + for op_name, metadata in v.items(): + assert ( + op_name not in parent_index[k] + ), f"duplicate operator {op_name} for dispatch key {k}" + parent_index[k][op_name] = metadata + + def primary(self, g: NativeFunctionsGroup) -> NativeFunction: + if self.use_out_as_primary: + return g.out + else: + return g.functional + + def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool: + m = self.get_kernel(g) + return m is not None + + def get_kernel( + self, g: NativeFunction | NativeFunctionsGroup + ) -> BackendMetadata | None: + if isinstance(g, NativeFunction): + f = g + elif isinstance(g, NativeFunctionsGroup): + f = self.primary(g) + else: + assert_never(g) + if f.func.name not in self.index: + return None + return self.index[f.func.name] + + def native_function_class_name(self) -> str | None: + if self.external: + return f"{str(self.dispatch_key)}NativeFunctions" + else: + # TODO: This discrepancy isn't required; we could also generated + # a class for in-tree kernels. It'll just require carefully + # updating every kernel definition + callsite of every in-tree aten kernel. + return None + + +# The function schema is undoubtedly the most important data structure +# in all of the codegen, as it defines the type signature for operators, +# and most of the code generation we do is type directed (e.g., look at +# the types, decide what to do. Think about how we code generate +# C++ function stubs!) +# +# We will also see in this class the general structure for how we model +# data in this code generation. A few notable properties to point out +# ahead of time: +# +# - These dataclasses are a *lossless* representation of the strings +# they are parsed from. In fact, we assert that given the +# information stored in the dataclass, we can exactly reconstruct +# the string we parsed from (and assert this inside the parse +# definition). There are a few reasons for this: +# +# - If you find that it is difficult to reconstruct the string +# given a dataclass, that is a clue that you are data +# representation is wrong. +# +# - It helps ensure that all relevant information is present +# in the dataclass, so that downstream users aren't tempted +# to reparse the original string to get some information +# that was omitted. +# +# - It forces you to represent the data in-memory in the same way +# it is recorded textually, which makes the dataclasses easier +# to understand for someone who is familiar with the +# textual format. (As a tradeoff, it means you have to model +# the syntax, even when it is inconvenient. But maybe that means +# the syntax is bad!) If you don't understand the internal +# representation, go look at the printing code to see how +# it maps onto the surface syntax! +# +# - It makes it easy to test the parsing code, as parsing code +# that is inconsistent with the string code will fail early +# and loudly. (As a tradeoff, it makes the parsing code a bit +# brittle (in particular, with trivial whitespace changes you +# are likely to trigger an assert error). +# +# In general, try to make the __str__ code as simple as possible +# (even at the cost of more complex parsing logic.) Additionally, +# try to minimize redundancy in data representation. (Precomputed +# fields are OK though: they are defined as a simple function on +# the canonical representation in question.) +# +# - These dataclasses are all frozen; once constructed their +# values never change. This makes it easy to tell where any +# given data came from: just look to the constructor. As a +# tradeoff, you can't easily "decorate" a schema with extra +# information from a post-facto analysis. We impose this +# restriction to make these structures more understandable. +# +@dataclass(frozen=True) +class FunctionSchema: + # The name of the operator this function schema describes. + name: OperatorName + + arguments: Arguments + + # TODO: Need to handle collisions with argument names at some point + returns: tuple[Return, ...] + + @property + def is_mutable(self) -> bool: + def is_write(arg: Argument) -> bool: + if arg.annotation is None: + return False + return arg.annotation.is_write + + # Corresponds to torch._C._FunctionSchema.is_mutable + # See aten/src/ATen/core/function_schema.h (keep these in sync) + return any(is_write(a) for a in self.arguments.flat_all) + + def schema_order_arguments(self) -> Iterator[Argument]: + return itertools.chain( + self.arguments.flat_positional, + self.arguments.flat_kwarg_only, + self.arguments.out, + ) + + decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)") + + @staticmethod + def parse(func: str) -> FunctionSchema: + # We should probably get a proper parser here + decls = FunctionSchema.decl_re.findall(func) + assert len(decls) == 1, f"Invalid function schema: {func}" + ops, args, return_decl = decls[0] + name = OperatorName.parse(ops) + arguments = Arguments.parse(args) + returns = parse_returns(return_decl) + r = FunctionSchema(name=name, arguments=arguments, returns=returns) + assert str(r) == func, f"{str(r)} != {func}" + return r + + def returns_are_aliased(self) -> bool: + # We assert earlier that schemas can't have a mix of aliased and non-aliased returns + return any( + r + for r in self.returns + if r.annotation is not None and r.annotation.is_write + ) + + def __post_init__(self) -> None: + for arg, ret in zip(self.arguments.out, self.returns): + assert arg.annotation == ret.annotation, ( + "Out arguments must have matching return Tensor; furthermore, " + "the ith-argument needs to correspond to the ith return" + ) + # We also enforce that if you have any mutable, positional args, then they are not returned. + # This makes it easier to group these functions properly with their functional/out= counterparts. + for a in self.arguments.post_self_positional_mutable: + assert not any( + a.annotation == r.annotation for r in self.returns + ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" + # Invariant: we expect out arguments to appear as keyword arguments in the schema. + # This means that all mutable returns should be aliased to a keyword argument + # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) + # See Note [is_out_fn] + out_and_self = list(self.arguments.out) + [ + arg for arg in self.arguments.flat_positional if arg.name == "self" + ] + mutable_returns = [ + ret + for ret in self.returns + if ret.annotation is not None and ret.annotation.is_write + ] + immutable_returns = [ + ret + for ret in self.returns + if ret.annotation is None or not ret.annotation.is_write + ] + # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", + # because: + # (1) It's more annoying to handle properly + # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. + # Instead, we expect the (a!) argument to not be returned. + assert ( + len(mutable_returns) == 0 or len(immutable_returns) == 0 + ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" + for ret in mutable_returns: + assert any(ret.annotation == arg.annotation for arg in out_and_self), ( + 'All mutable returns must be aliased either to a keyword argument, or to "self". ' + "Did you forget to mark an out argument as keyword-only?" + ) + if self.arguments.out: + # out= ops that return their mutable inputs are only really useful for method chaining. + # And method chaining is only really useful if the thing you're returning is a plain Tensor. + # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor, + # and all other types of out= op schemas should return void. + # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. + if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): + assert ( + len(self.returns) == 0 + ), "out= ops that accept tensor lists as out arguments " + "are expected to have no return type (since you can't do method chaining on them)" + else: + # mutable keyword arguments whose name has _scratch_ prefix are + # scratch tensors for memory planning and should not be returned + assert len( + [ + arg + for arg in self.arguments.out + if not arg.name.startswith("_scratch_") + ] + ) == len( + self.returns + ), "Must return as many arguments as there are out arguments, or no return at all" + + if self.name.name.inplace: + self_a = self.arguments.self_arg + assert ( + self_a + and self_a.argument.annotation + and self_a.argument.annotation.is_write + ) + if self_a.argument.type == BaseType(BaseTy.Tensor): + # All inplace ops with an ordinary `Tensor self` argument should return self, + # to allow for method chaining. + assert ( + len(self.returns) == 1 + and self.returns[0].annotation == self_a.argument.annotation + ) + else: + # You can't method chain on non-tensor self arguments though (like a List[Tensor]) + # so in all other cases we expect the return type to be none. + assert len(self.returns) == 0 + + if self.arguments.tensor_options is not None: + assert self.kind() == SchemaKind.functional, ( + "Found an operator that is not functional or out variant, but has tensor options arguments." + "This is not allowed- tensor options arguments are only allowed for factory functions." + f"schema: {str(self)}" + ) + if self.is_functional_fn(): + assert self.kind() == SchemaKind.functional, ( + "Found an operator that is not functional, but its overload contains the string 'functional'." + "This is a special keyword in the codegen, please use a different overload name." + f"schema: {str(self)}" + ) + + def is_functional_fn(self) -> bool: + return "functional" in self.name.overload_name + + def is_out_fn(self) -> bool: + # Note [is_out_fn] + # + # out functions are the variants which take an explicit out= argument + # to populate into. We need to know if a schema corresponds to an + # out function for several reasons: + # + # - They codegen differently in C++ API + # - codegen to at::add_out rather than at::add + # - out argument is moved to front of C++ argument list + # + # out functions are DEFINED to be any function with a keyword-only + # argument that is mutable. In principle, this could lead to a + # false positive if you define a function that mutates a + # kwarg only argument, but this isn't the "true" output of this + # function. A more robust definition that would work in this + # case would also look at: + # + # - The output types. Out functions take in the arguments + # they mutate and then return them again; this is sort + # of "definitionally" what makes something an out function. + # Historically, we DO check this for consistency. + # - Correspondence with pure variant. An out function + # should have a signature equivalent to its pure variant, + # but just with extra kwargs for the output elements. This + # is difficult to actually check for and historically + # we only do this check in tools/ + return bool(self.arguments.out) + + def kind(self) -> SchemaKind: + """ + What kind of schema is this? A functional schema is one + that returns a newly allocated output; an inplace schema + modifies the self argument inplace; an out schema writes + the result into an explicitly provided out argument. + """ + is_out = bool(self.arguments.out) + is_scratch = bool( + [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")] + ) + is_inplace = self.name.name.inplace + is_mutable = any( + a.annotation is not None and a.annotation.is_write + for a in self.arguments.post_self_positional + ) + assert not (is_out and is_inplace) + # out= and inplace schemas can also have post_self_positional mutable args, + # but we give precedence to out= and inplace when deciding the schema kind. + # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops + # to also worry about mutable post_self_positional arguments, + # but it seems like a much bigger lift to classify them has having a new schema kind. + # The number of ops that fit in this strange category is small enough that + # we can probably manually write code for them instead of forcing the codegen to handle them. + if is_inplace: + return SchemaKind.inplace + elif is_scratch: + assert ( + is_out + ), "invariant: all scratch operators are expected to be out= operators too" + return SchemaKind.scratch + elif is_out: + assert ( + not is_scratch + ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" + return SchemaKind.out + elif is_mutable: + return SchemaKind.mutable + else: + return SchemaKind.functional + + # For every return: + # - If the return aliases an input, we return the input name + # - Otherwise, we return None. + # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. + def aliased_return_names(self) -> list[str | None]: + outs: list[str | None] = [] + for r in self.returns: + aliased_args = [ + a + for a in self.arguments.flat_all + if a.annotation is not None and a.annotation == r.annotation + ] + if len(aliased_args) == 0: + outs.append(None) + elif len(aliased_args) == 1: + outs.append(aliased_args[0].name) + else: + aliased_names = ", ".join(a.name for a in aliased_args) + raise AssertionError( + f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" + ) + return outs + + def signature( + self, + *, + strip_default: bool = False, + strip_view_copy_name: bool = False, + keep_return_names: bool = False, + ) -> FunctionSchema: + """ + Certain schemas are 'related', in that they are simply + inplace/out/functional versions of the same function. This method + factors these schemas into the "core" functional signature which + is equal across all versions. + + Here is what normalization happens to the schema to convert + it to a signature: + - The overload name is stripped (name is retained, since + it expresses semantic content about what the function does) + - Inplace is set False + - Out arguments are stripped + - Mutable post_self_positional args are converted to returns + - Mutability annotations are stripped (this is sound + because you cannot overload on mutability annotation) + - Return names are stripped since they are not overloadable and + some variants have return names but some not + - TensorOptions are dropped + because out= variants of factory functions don't include them + (and we want to be able to pair up factory functions with their out variants) + + Finally, we want to be able to pair up related "view" and their + corresponding "view_copy" operators. We do this by optionally + stripping the trailing "_copy" from the base name. + + Example of a mutable op before and after: + + f.func (Mutable operator): + _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 + + f.func (Corresponding functional operator): + _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950 + + f.func.signature() output: + _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950 + """ + + def strip_ret_annotation(r: Return) -> Return: + return Return( + name=r.name if keep_return_names else None, + type=r.type, + annotation=None, + ) + + base_name = self.name.name.base + if strip_view_copy_name: + if base_name.endswith("_copy"): + base_name = base_name.replace("_copy", "") + elif base_name.endswith("_scatter"): + base_name = base_name.replace("scatter", "inverse") + + # find mutable inputs that are not originally returned, and convert them to returns + returns_from_mutable_inputs = tuple( + # When we're grouping functions we strip the return names, + # but when we're generating the actual functional variants then we follow + # a convention for what to name the returns + Return( + name=f"{a.name}_out" if keep_return_names else None, + type=a.type, + annotation=None, + ) + for a in itertools.chain( + # Order is important here (otherwise e.g. inplace with mutable args + # and out= with mutable args won't have the same signature) + [self.arguments.self_arg.argument] + if self.arguments.self_arg is not None + else [], + self.arguments.out, + self.arguments.post_self_positional, + ) + if a.annotation is not None + and a.annotation.is_write + and not any(a.annotation == r.annotation for r in self.returns) + ) + original_returns = tuple(map(strip_ret_annotation, self.returns)) + # Ordering is important here. We expect the "mutable input" returns to come last. + returns = original_returns + returns_from_mutable_inputs + + args_sig = self.arguments.signature(strip_default=strip_default) + # See Note [bernoulli.p schema] + if str(self.name) == "bernoulli.p": + args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) + + return FunctionSchema( + name=OperatorName( + name=BaseOperatorName( + base=base_name, + inplace=False, + dunder_method=self.name.name.dunder_method, + ), + overload_name="", # stripped + ), + arguments=args_sig, + returns=returns, + ) + + def view_signature(self) -> FunctionSchema: + return self.signature(strip_view_copy_name=True) + + def with_name(self, name: OperatorName) -> FunctionSchema: + return FunctionSchema( + name=name, + arguments=self.arguments, + returns=self.returns, + ) + + @property + def modifies_arguments(self) -> bool: + return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] + + def has_symint(self) -> bool: + return self.arguments.has_symint_arg() + + def __str__(self) -> str: + all_arguments_str = str(self.arguments) + if len(self.returns) == 1: + returns = str(self.returns[0]) # omit parentheses + else: + returns = "(" + ", ".join(map(str, self.returns)) + ")" + return f"{self.name}({all_arguments_str}) -> {returns}" + + +# Here is the rest of the data model, described more briefly. + + +# Simplified version for what actually shows up in built-ins. +# Look at alias_info.h for expanded syntax. If you need the structure, +# you also need to make this structure recursive so it can be lined +# up with the type components too. For primitives this isn't really +# necessary +@dataclass(frozen=True) +class Annotation: + # Typically only has one element. Not actually a set so + # we can conveniently assume it is canonically ordered + alias_set: tuple[str, ...] + is_write: bool + alias_set_after: tuple[str, ...] + + @staticmethod + def parse(ann: str) -> Annotation: + # TODO: implement a proper parser if this gets more ugly + # Regex Explanation: + # Example: "a! -> a|b" + # Group #1: alias before optional '|', required. Matches the first + # character 'a' in the example + # Group #2: optional alias set after optional '|', matches empty string + # in the example + # Group #3: optional "is write" flag, matches '!' in the example. + # Group #4: optional section containing arrow, matches " -> a|b" in the + # example. + # Group #5: optional alias after set, supports wildcard, matches "a|b" + # in the example. + # Group #6: optional sub-section of alias after set, matches "|b" in the + # example. + m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann) + + assert m is not None, f"unrecognized alias annotation {ann}" + before_alias = m.group(1) + (m.group(2) if m.group(2) else "") + alias_set = tuple(before_alias.split("|")) + is_write = m.group(3) == "!" + assert not ( + is_write and len(alias_set) > 1 + ), f"alias set larger than 1 is not mutable, got {ann} instead." + after_set = tuple(m.group(5).split("|")) if m.group(5) else () + assert not ( + len(before_alias) > 1 and len(after_set) > 1 + ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." + r = Annotation( + alias_set=alias_set, is_write=is_write, alias_set_after=after_set + ) + assert str(r) == ann, f"{r} != {ann}" + return r + + def __str__(self) -> str: + alias_set = "|".join(self.alias_set) + if self.is_write: + alias_set = f"{alias_set}!" + alias_set_after = "|".join(self.alias_set_after) + if alias_set_after: + alias_set = f'{alias_set}{" -> "}{alias_set_after}' + return alias_set + + +# The base class for the type system. This is also loosely modeled +# off of jit_type.h, but we've simplified the hierarchy to focus +# in on the aspects of the type system that matter for code generation +# (for example, there's no SingleElementType subclass anymore). +# You never actually construct a Type; usually it's going to be one +# of the subclasses. If Python had ADTs this would be one! +@dataclass(frozen=True) +class Type: + @staticmethod + def parse(t: str) -> Type: + r = Type._parse(t) + assert str(r) == t, f"{r} != {t}" + return r + + @staticmethod + def _parse(t: str) -> Type: + m = re.match(r"^(.+)\?$", t) + if m is not None: + return OptionalType(Type.parse(m.group(1))) + m = re.match(r"^(.+)\[([0-9]+)?\]$", t) + if m is not None: + size = int(m.group(2)) if m.group(2) is not None else None + return ListType(elem=Type.parse(m.group(1)), size=size) + + # '__torch__.torch.classes.' is the prefix for custom class + m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t) + if m is not None: + return CustomClassType(m.group(1)) + try: + return BaseType(BaseTy[t]) + except KeyError as e: + raise RuntimeError(f"unrecognized type {t}") from e + + def __str__(self) -> str: + raise NotImplementedError + + # WARNING: These concepts are not very well-defined. For example, + # is "int?" nullable? How about "int?[]". They are defined + # so we can conveniently generate legacy Declarations.yaml but + # really we should probably just remove these at some point + + def is_base_ty_like(self, base_ty: BaseTy) -> bool: + raise NotImplementedError + + def is_tensor_like(self) -> bool: + return self.is_base_ty_like(BaseTy.Tensor) + + def is_generator_like(self) -> bool: + return self.is_base_ty_like(BaseTy.Generator) + + def is_symint_like(self) -> bool: + return self.is_base_ty_like(BaseTy.SymInt) + + def is_nullable(self) -> bool: + raise NotImplementedError + + def is_list_like(self) -> ListType | None: + raise NotImplementedError + + +# Base types are simple, atomic types with no further structure +class BaseTy(Enum): + Generator = auto() + ScalarType = auto() + Tensor = auto() + int = auto() + Dimname = auto() + DimVector = auto() + float = auto() + str = auto() + bool = auto() + Layout = auto() + Device = auto() + DeviceIndex = auto() + Scalar = auto() + MemoryFormat = auto() + QScheme = auto() + Storage = auto() + Stream = auto() + SymInt = auto() + SymBool = auto() + ConstQuantizerPtr = auto() # TODO: rename + GraphModule = auto() + + +@dataclass(frozen=True) +class BaseType(Type): + name: BaseTy + + def __str__(self) -> str: + return f"{self.name.name}" + + def is_base_ty_like(self, base_ty: BaseTy) -> bool: + return self.name == base_ty + + def is_nullable(self) -> bool: + return False + + def is_list_like(self) -> ListType | None: + return None + + def is_symint_like(self) -> bool: + return self.name == BaseTy.SymInt + + +# Optional types may be specified, or may also be validly given None +@dataclass(frozen=True) +class OptionalType(Type): + elem: Type + + def __str__(self) -> str: + return f"{self.elem}?" + + def is_base_ty_like(self, base_ty: BaseTy) -> bool: + return self.elem.is_base_ty_like(base_ty) + + def is_symint_like(self) -> bool: + return self.elem.is_symint_like() + + def is_nullable(self) -> bool: + return True + + def is_list_like(self) -> ListType | None: + return self.elem.is_list_like() + + +# A type representing a PyTorch custom class +@dataclass(frozen=True) +class CustomClassType(Type): + class_name: str + + def __str__(self) -> str: + """ + Return the class name will prefix __torch__.torch.classes + """ + return f"__torch__.torch.classes.{self.class_name}" + + def is_base_ty_like(self, base_ty: BaseTy) -> bool: + return False + + def is_symint_like(self) -> bool: + return False + + def is_nullable(self) -> bool: + """ + Assume a custom class is not nullable. + """ + return False + + def is_list_like(self) -> ListType | None: + return None + + +# List types specify that we may have multiples of an element. We +# also support explicit sizes on list types, but these have +# some nontrivial semantics! (However, for C++ API purposes, explicit +# sizes are mostly erased from the type system.) +# +# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., +# int[] elaborates differently than bool[3]! +@dataclass(frozen=True) +class ListType(Type): + elem: Type + size: int | None + + def __str__(self) -> str: + size = f"{self.size}" if self.size else "" + return f"{self.elem}[{size}]" + + def is_base_ty_like(self, base_ty: BaseTy) -> bool: + return self.elem.is_base_ty_like(base_ty) + + def is_symint_like(self) -> bool: + return self.elem.is_symint_like() + + def is_nullable(self) -> bool: + return self.elem.is_nullable() + + def is_list_like(self) -> ListType | None: + return self + + +@dataclass(frozen=True) +class Argument: + # NB: I didn't put kwarg_only as a boolean field here, unlike + # c10::Argument, so that printing works correctly + + name: str + type: Type + default: str | None + + # The semantics of the annotation field are a little strange. + # + # Alias annotations parametrize Tensors (since Tensors are the only things + # that can alias.) This motivates why I write Tensor(a!)? (and not, for + # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, + # which may be optional (i.e., the alias annotation should bind first to + # Tensor, before the optional postfix annotation). + # + # However, despite being a property of Tensor, we (and c10::Argument) + # store the annotation at the top level of the Argument, rather than + # inside the embedded Tensor type. In the C++ version of this + # class, we then go through great lengths to mimic the type + # structure in the annotation structure so we can correlate + # annotations with types. + # + # Now, it turns out, in all applications in code generation, the + # structure of annotated types is very simple. So we just hard + # code it here. But if we ever do get anything more complex, this + # model will have to change! + annotation: Annotation | None + + @property + def alias_info(self) -> Annotation | None: + return self.annotation + + @staticmethod + def parse(arg: str) -> Argument: + name: str + default: str | None + assert " " in arg, f"illegal argument '{arg}'" + if "=" in arg: + assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'" + type_and_annot_and_name, default = arg.split("=") + type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1) + name_and_default = f"{name}={default}" + else: + type_and_annot, name_and_default = arg.rsplit(" ", 1) + name = name_and_default + default = None + # TODO: deduplicate annotation matching with Return + match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) + annotation: Annotation | None + if match: + # If you update this, make sure the __str__ still works too + assert match.group(2) in [ + "", + "?", + "[]", + ], "unrecognized alias analysis form with Tensor" + type_s = "Tensor" + match.group(2) + annotation = Annotation.parse(match.group(1)) + else: + type_s = type_and_annot + annotation = None + type = Type.parse(type_s) + r = Argument( + name=name, + type=type, + default=default, + annotation=annotation, + ) + assert str(r) == arg, f"{str(r)} != {arg}" + return r + + @property + def is_write(self) -> bool: + return self.annotation is not None and self.annotation.is_write + + def __str__(self) -> str: + type = f"{self.type}" + if self.annotation: + assert type in ["Tensor", "Tensor?", "Tensor[]"] + type = type.replace("Tensor", f"Tensor({self.annotation})") + if self.name is None: + return type + else: + mb_default = "" + if self.default: + mb_default = f"={self.default}" + return f"{type} {self.name}{mb_default}" + + +@dataclass(frozen=True) +class Return: + name: str | None + type: Type + annotation: Annotation | None + + @property + def alias_info(self) -> Annotation | None: + return self.annotation + + @staticmethod + def parse(arg: str) -> Return: + name: str | None + if " " in arg: + type_and_annot, name = arg.rsplit(" ", 1) + else: + type_and_annot = arg + name = None + match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) + annotation: Annotation | None + if match: + # If you update this, make sure the __str__ still works too + assert match.group(2) in [ + "", + "?", + "[]", + ], "unrecognized alias analysis form with Tensor" + type_s = "Tensor" + match.group(2) + annotation = Annotation.parse(match.group(1)) + else: + type_s = type_and_annot + annotation = None + type = Type.parse(type_s) + r = Return( + name=name, + type=type, + annotation=annotation, + ) + assert str(r) == arg, f"{str(r)} != {arg}" + return r + + @property + def is_write(self) -> bool: + return self.annotation is not None and self.annotation.is_write + + def __str__(self) -> str: + type = f"{self.type}" + if self.annotation: + assert type in ["Tensor", "Tensor?", "Tensor[]"] + type = type.replace("Tensor", f"Tensor({self.annotation})") + if self.name is None: + return type + else: + return f"{type} {self.name}" + + +# Represents the self argument for functions that may be methods +@dataclass(frozen=True) +class SelfArgument: + argument: Argument + + +# Bundle of arguments that represent a TensorOptions. This is mostly +# relevant for the public C++ API but we bake it into the core data +# model because other APIs often have to interact with it +@dataclass(frozen=True) +class TensorOptionsArguments: + dtype: Argument + layout: Argument + device: Argument + pin_memory: Argument + + def all(self) -> Sequence[Argument]: + return [self.dtype, self.layout, self.device, self.pin_memory] + + +@dataclass(frozen=True) +class Arguments: + # pre_self_positional is usually empty, but is notably non-empty + # for where.self, where the condition argument comes before the + # self argument + pre_self_positional: tuple[Argument, ...] + self_arg: SelfArgument | None + post_self_positional: tuple[Argument, ...] + + pre_tensor_options_kwarg_only: tuple[Argument, ...] + tensor_options: TensorOptionsArguments | None + # post_tensor_options is typically memory format, which should be + # part of tensor options but isn't right now, and is usually + # placed after the tensor options arguments + post_tensor_options_kwarg_only: tuple[Argument, ...] + + # Unlike in the previous codegen, we have factored out 'out' arguments + # in the canonical representation, removing them from kwarg + # arguments. This choice is justified by numerous downstream + # transformations which treat out arguments specially; additionally, + # you can see that canonicity is not violated! + out: tuple[Argument, ...] # these are also kwarg-only + + @property + def flat_non_out(self) -> Sequence[Argument]: + ret: list[Argument] = [] + ret.extend(self.flat_positional) + ret.extend(self.flat_kwarg_only) + return ret + + @property + def flat_positional(self) -> Sequence[Argument]: + ret: list[Argument] = [] + ret.extend(self.pre_self_positional) + if self.self_arg is not None: + ret.append(self.self_arg.argument) + ret.extend(self.post_self_positional) + return ret + + @property + def post_self_positional_mutable(self) -> Sequence[Argument]: + return [a for a in self.post_self_positional if a.is_write] + + # NB: doesn't contain out arguments + @property + def flat_kwarg_only(self) -> Sequence[Argument]: + ret: list[Argument] = [] + ret.extend(self.pre_tensor_options_kwarg_only) + if self.tensor_options is not None: + ret.extend(self.tensor_options.all()) + ret.extend(self.post_tensor_options_kwarg_only) + return ret + + @property + def flat_all(self) -> Sequence[Argument]: + ret: list[Argument] = [] + ret.extend(self.flat_positional) + ret.extend(self.flat_kwarg_only) + ret.extend(self.out) + return ret + + @property + def non_out( + self, + ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: + ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] + ret.extend(self.positional) + ret.extend(self.kwarg_only) + return ret + + @property + def positional(self) -> Sequence[Argument | SelfArgument]: + ret: list[Argument | SelfArgument] = [] + ret.extend(self.pre_self_positional) + if self.self_arg is not None: + ret.append(self.self_arg) + ret.extend(self.post_self_positional) + return ret + + @property + def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]: + ret: list[Argument | TensorOptionsArguments] = [] + ret.extend(self.pre_tensor_options_kwarg_only) + if self.tensor_options is not None: + ret.append(self.tensor_options) + ret.extend(self.post_tensor_options_kwarg_only) + return ret + + @property + def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: + ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] + ret.extend(self.positional) + ret.extend(self.kwarg_only) + ret.extend(self.out) + return ret + + def mutable_arg_names(self) -> list[str]: + return [ + a.name + for a in self.flat_all + if a.annotation is not None and a.annotation.is_write + ] + + def has_tensor_arg(self) -> bool: + return any(a.type.is_tensor_like() for a in self.flat_non_out) + + def has_symint_arg(self) -> bool: + return any(a.type.is_symint_like() for a in self.flat_non_out) + + def has_generator_arg(self) -> bool: + return any(a.type.is_generator_like() for a in self.flat_non_out) + + def signature(self, *, strip_default: bool = False) -> Arguments: + # dataclasses.replace could be used here, but it is less + # type safe so for now I've opted to type everything out + def strip_arg_annotation(a: Argument) -> Argument: + return Argument( + name=a.name, + type=a.type, + default=a.default if not strip_default else None, + annotation=None, + ) + + return Arguments( + pre_self_positional=tuple( + map(strip_arg_annotation, self.pre_self_positional) + ), + self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument)) + if self.self_arg is not None + else None, + post_self_positional=tuple( + map(strip_arg_annotation, self.post_self_positional) + ), + # Since TensorOptions are dropped, the post_tensor_options_kwargs are + # converted to pre_tensor_options_kwargs + pre_tensor_options_kwarg_only=tuple( + map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) + ) + + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), + # TensorOptions are dropped in signature, + # so we can pair factory functions with their out= variants. + tensor_options=None, + post_tensor_options_kwarg_only=(), + # out arguments are dropped in signature + out=(), + ) + + def remove_self_annotation(self) -> Arguments: + assert self.self_arg is not None + return dataclasses.replace( + self, + self_arg=SelfArgument( + dataclasses.replace(self.self_arg.argument, annotation=None) + ), + ) + + def with_out_args(self, outs: list[Argument]) -> Arguments: + assert len(self.out) == 0 + return dataclasses.replace( + self, + out=tuple(outs), + ) + + @staticmethod + def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]: + positional: list[Argument] = [] + kwarg_only: list[Argument] = [] + out: list[Argument] = [] + arguments_acc = positional + + # TODO: Use a real parser here; this will get bamboozled + # by signatures that contain things like std::array (note the space) + for arg in args.split(", "): + if not arg: + continue + if arg == "*": + assert ( + arguments_acc is positional + ), "invalid syntax: kwarg-only specifier * can only occur once" + arguments_acc = kwarg_only + continue + parg = Argument.parse(arg) + # Currently, we rely directly on the invariant that there are NO + # kwarg-only mutating arguments. If you want to relax this, + # we will need a more semantic way of matching that takes + # into account return arguments. In that case, you will have + # to manage out computation a level up, in FunctionSchema. See Note + # [is_out_fn] + if parg.annotation is not None and parg.annotation.is_write: + if arguments_acc is positional: + pass # do nothing + elif arguments_acc is kwarg_only: + arguments_acc = out + else: + assert arguments_acc is not out + arguments_acc.append(parg) + + return positional, kwarg_only, out + + @staticmethod + def parse(args: str) -> Arguments: + """ + Input: 'int x, int y, int z' + """ + + # We do this in two phases. First we parse into three + # main categories: positional, kwarg_only, out. + # Then, we reparse positional and kwarg_only to separate + # out the self argument and tensor options arguments. + + positional, kwarg_only, out = Arguments._preparse(args) + + # Split self argument + self_ix = None + for i, a in enumerate(positional): + if a.name == "self": + self_ix = i + break + pre_self_positional: list[Argument] + self_arg: SelfArgument | None + post_self_positional: list[Argument] + if self_ix is not None: + pre_self_positional = positional[:self_ix] + self_arg = SelfArgument(positional[self_ix]) + post_self_positional = positional[self_ix + 1 :] + else: + pre_self_positional = [] + self_arg = None + post_self_positional = positional + + # Group tensor options arguments + pre_tensor_options_kwarg_only: list[Argument] = [] + tensor_options: TensorOptionsArguments | None = None + post_tensor_options_kwarg_only: list[Argument] = [] + kwarg_only_acc = pre_tensor_options_kwarg_only + + def pred(name: str, ty: Type) -> Callable[[Argument], bool]: + return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] + + predicates = [ # order matters + pred("dtype", Type.parse("ScalarType")), + pred("layout", Type.parse("Layout")), + pred("device", Type.parse("Device")), + pred("pin_memory", Type.parse("bool")), + ] + + i = 0 + while i < len(kwarg_only): + # If there is enough space... + if i <= len(kwarg_only) - len(predicates): + # And the next len(predicates) arguments look like TensorOptions arguments + if all( + p(a) + for p, a in zip(predicates, kwarg_only[i : i + len(predicates)]) + ): + assert kwarg_only_acc is pre_tensor_options_kwarg_only + # Group them together as one argument + tensor_options = TensorOptionsArguments( + dtype=kwarg_only[i], + layout=kwarg_only[i + 1], + device=kwarg_only[i + 2], + pin_memory=kwarg_only[i + 3], + ) + i += len(predicates) + kwarg_only_acc = post_tensor_options_kwarg_only + continue + kwarg_only_acc.append(kwarg_only[i]) + i += 1 + + return Arguments( + pre_self_positional=tuple(pre_self_positional), + self_arg=self_arg, + post_self_positional=tuple(post_self_positional), + pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), + tensor_options=tensor_options, + post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), + out=tuple(out), + ) + + def __str__(self) -> str: + all_arguments: list[str] = [] + all_arguments.extend(map(str, self.flat_positional)) + if self.flat_kwarg_only or self.out: + all_arguments.append("*") + all_arguments.extend(map(str, self.flat_kwarg_only)) + all_arguments.extend(map(str, self.out)) + return ", ".join(all_arguments) + + def __post_init__(self) -> None: + # TODO: These invariants are weirdly asymmetric? + # TODO: Fancier types? + if self.self_arg is None: + assert not self.pre_self_positional + if self.tensor_options is None: + assert not self.post_tensor_options_kwarg_only + + # We don't allow any of the following to have argument annotations, + # to keep things simple. + mutable_pre_self_positionals = [ + a + for a in self.pre_self_positional + if a.annotation is not None and a.annotation.is_write + ] + assert ( + len(mutable_pre_self_positionals) == 0 + ), "mutable pre_self_positional arguments are not currently supported in the schema" + + +# Names that validly are __iXXX__ indicating inplace operations. +# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods +# NB: PyTorch hasn't actually implemented all of these +AUGMENTED_ASSIGNMENT_NAMES = [ + "add", + "sub", + "mul", + "div", + "mod", + "pow", + "lshift", + "rshift", + "and", + "xor", + "or", +] + + +# A BaseOperatorName is what we think of the operator name, without +# the overload name. Unusually, we don't represent this as just a +# string; instead, we directly represent a few important semantic +# bits of information we derive from the string: namely whether +# or not it's inplace (add_) and whether or not it's a double-underscore +# method (__add__) +@dataclass(frozen=True) +class BaseOperatorName: + base: str + inplace: bool + dunder_method: bool + # Note [Overload Ambiguity With Functional Variants] + # A handful of operators have both a "mutable" and a "functional" variant. + # (native_batch_norm is a good example, although this isn't the case today). + # For those operators, the mutable and functional variant take in the same set of + # arguments, but have different alias annotations. + # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, + # given a set of input arguments. + # + # So instead of making the "functional" variant in this case a real overload, e.g: + # native_batch_norm (mutable variant) + # native_batch_norm.functional (functional variant) + # we make it a new base operator, + # native_batch_norm_functional (functional variant) + # + # In an ideal world, we would probably invert this so the operators were: + # native_batch_norm.mutable (mutable variant) + # native_batch_norm (functional variant) + # + # Doing that is BC-breaking though, so we're stuck with the above modeling. + functional_overload: bool = False + + @staticmethod + def parse(op: str) -> BaseOperatorName: + assert op != "" + assert not op.endswith("_out"), ( + "_out suffix is reserved and not permitted for operator names; " + "did you mean to specify an out overload name instead?" + ) + m = re.match(r"^__([^_]+)__$", op) + if m is not None: + dunder_method = True + base = m.group(1) + if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES): + inplace = True + base = base[1:] + else: + inplace = False + # temporary, this is not intrinsically true but + # has been historically true for dunder methods + # we support (but, if we ever got, say, __int__, this would + # be wrong!) + assert base[0] != "i" + else: + dunder_method = False + base = op + if base[-1] == "_": + inplace = True + base = base[:-1] + else: + inplace = False + + # See Note [Overload Ambiguity With Functional Variants] + functional_suffix = "_functional" + if base.endswith(functional_suffix): + functional_overload = True + base = base[: -len(functional_suffix)] + # This seems complicated and unnecessary, so banning dunder methods + # for now on ops that have a functional + mutable variant (like native_batch_norm). + assert not dunder_method and not inplace + else: + functional_overload = False + + r = BaseOperatorName( + base=base, + inplace=inplace, + dunder_method=dunder_method, + functional_overload=functional_overload, + ) + assert str(r) == op, f"{str(r)} != {op}" + return r + + def __str__(self) -> str: + if self.dunder_method: + i = "i" if self.inplace else "" + return f"__{i}{self.base}__" + else: + i = ( + "_" + if self.inplace + else "_functional" + if self.functional_overload + else "" + ) + return f"{self.base}{i}" + + +# Operator name is the base operator name along with the (typically not +# user visible) overload string. +@dataclass(frozen=True) +class OperatorName: + name: BaseOperatorName + overload_name: str + + @staticmethod + def parse(op_name: str) -> OperatorName: + if "." in op_name: + name, overload_name = op_name.split(".", 1) + else: + name = op_name + overload_name = "" + r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name) + assert str(r) == op_name, f"{str(r)} != {op_name}" + return r + + def __str__(self) -> str: + if self.overload_name: + return f"{self.name}.{self.overload_name}" + else: + return f"{self.name}" + + # NB: This must be synchronized with the naming scheme in + # aten/src/ATen/templates/Operators.h + # Given a function schema "aten::op.overload(...)", + # If there is no overload name, this returns f"{op}" + # If there is an overload name, this returns f"{op}_{overload}" + def unambiguous_name(self) -> str: + if self.overload_name: + return f"{self.name}_{self.overload_name}" + else: + return f"{self.name}" + + def remove_inplace(self) -> OperatorName: + return OperatorName( + name=BaseOperatorName( + base=self.name.base, + inplace=False, + dunder_method=self.name.dunder_method, + ), + overload_name=self.overload_name, + ) + + def with_overload(self, overload: str) -> OperatorName: + return OperatorName( + name=BaseOperatorName( + base=self.name.base, + inplace=False, + dunder_method=self.name.dunder_method, + ), + overload_name=overload, + ) + + +def gets_generated_out_inplace_wrapper( + f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex +) -> bool: + return ( + f.func.kind() is not SchemaKind.functional + and not b.has_kernel(f) + and b.has_kernel(g.functional) + ) + + +# NativeFunction objects that are views (f.is_view_op returns True) +# are added into a `NativeFunctionsViewGroup`, which we can use to +# easily access the generated (optional) view_copy NativeFunction. +# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. +# See Note [Codegen'd {view}_copy Operators] +# +# One property of this representation is that in order for a view-like op to be part of +# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. +# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, +# but don't have corresponding aliasing `narrow.out` op. +# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. +@dataclass(frozen=True) +class NativeFunctionsViewGroup: + view: NativeFunction + # Note: the {view}_copy operator is optional because we currently don't generate copy variants + # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views + # (we already get them "for free" through decomposition) + view_copy: NativeFunction | None + # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. + view_inplace: NativeFunction | None + + def __post_init__(self) -> None: + assert self.view.is_view_op + if self.view_copy is None: + assert not gets_generated_view_copy(self.view), ( + f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." + " The codegen expects you to add a corresponding operator to native_functions.yaml:" + f" {get_view_copy_name(self.view)!s}." + " See Note [view_copy NativeFunctions] for details." + ) + else: + assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter")) + assert self.view.func.signature() == self.view_copy.func.signature( + strip_view_copy_name=True, + ) + assert "view_copy" in self.view_copy.tags, ( + f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" + " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." + " See Note [view_copy NativeFunction] for details." + ) + if self.view_inplace is not None: + assert self.view.func.signature() == self.view_inplace.func.signature() + + if self.view.has_composite_implicit_autograd_kernel: + if self.view_inplace is not None: + assert self.view_inplace.has_composite_implicit_autograd_kernel, ( + f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" + " both have CompositeImplicitAutograd kernels, or both not have composite kernels." + ) + if self.view.has_composite_implicit_autograd_nested_tensor_kernel: + if self.view_inplace is not None: + assert ( + self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel + ), ( + f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" + " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." + ) + + def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: + yield self.view + if self.view_inplace is not None: + yield self.view_inplace + if self.view_copy is not None and include_copy: + yield self.view_copy + + @property + def root_name(self) -> str: + return self.view.root_name + + @property + def composite(self) -> bool: + # We currently assert that the "group" is consistent. + # If the view op is composite, then its view_inplace op is too. + return self.view.has_composite_implicit_autograd_kernel + + +def gets_generated_view_copy(f: NativeFunction) -> bool: + # Only aliasing (view) operators get a copy variant. + if not f.is_view_op: + return False + # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, + # because we can let them decompose into base view ops. + if f.has_composite_implicit_autograd_kernel: + return False + # We also don't need to generate copy variants for inplace views. + if "inplace_view" in f.tags: + return False + # Assume ops ending in _inverse have manually-defined copy variants + # (e.g. slice_inverse() has the copy variant slice_scatter()). + # We -could- probably generate these as well, but the codegen will be + # slightly different, and hand-writing these few kernels keeps codegen + # complexity lower. + if f.func.name.name.base.endswith("_inverse"): + return False + return True + + +# Given a NativeFunction that corresponds to a view op, +# returns the OperatorName of the corresponding "copy" variant of the op. +def get_view_copy_name(f: NativeFunction) -> OperatorName: + # Right now, when asking for a view op's corresponding "view_copy" name + # we assert for sanity that the op is allowed to have a generated view_copy variant. + # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). + # However, narrow_copy() already exists as an op directly in native_functions.yaml. + # I'm hardcoding narrow_copy here for now to maintain the assert, + # But we could also just get rid of the assert. + list_of_ops_with_explicit_view_copy_operators = ["narrow"] + if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: + assert gets_generated_view_copy(f) + + base_name = f"{f.func.name.name.base}_copy" + view_copy_name = OperatorName( + name=BaseOperatorName( + base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method + ), + overload_name=f.func.name.overload_name, + ) + return view_copy_name + + +# Helper functions for parsing argument lists (both inputs and returns) + + +def parse_returns(return_decl: str) -> tuple[Return, ...]: + """ + Input: '()' + Output: [] + """ + if return_decl == "()": + return () + if return_decl[0] == "(" and return_decl[-1] == ")": + return_decl = return_decl[1:-1] + return tuple(Return.parse(arg) for arg in return_decl.split(", ")) + + +# A Precompute instance consists of a map from kernel argument name +# to the list of Argument instances that should replace that +# kernel argument in the impl function. +@dataclass(frozen=True) +class Precompute: + # A map from kernel argument name -> a list of precomputed + # elements that replaces/supersedes it. + replace: dict[str, list[Argument]] + # List of precomputed args added without replacement + add: list[Argument] + + @staticmethod + def parse(src: object) -> Precompute: + assert isinstance(src, list) + + # src is a list of strings of the format: + # {kernel param name} -> {replacement decl}[, {replacement decl}, ...] + # [{add decl}[, {add decl}, ...]] + # The last line is optional and contains the precomputed parameters that are + # added without replacement. + # The other lines are parsed to get the names of which precomputed elements + # should replace which kernel arguments. + add_args = [] + if " -> " not in src[-1]: + add_list = src[-1].split(",") + add_args = [Argument.parse(name.strip()) for name in add_list] + src = src[:-1] + + replace = {} + for raw_replace_item in src: + assert isinstance(raw_replace_item, str) + assert " -> " in raw_replace_item, ( + "precomputed parameters without replacement" + " are allowed only in the last line" + ) + + arg, with_list_raw = raw_replace_item.split(" -> ") + assert ( + " " not in arg + ), f"illegal kernel param name '{arg}' in precomputed parameters'" + with_list = with_list_raw.split(",") + with_list_args = [Argument.parse(name.strip()) for name in with_list] + replace[arg] = with_list_args + + r = Precompute(replace=replace, add=add_args) + assert r.to_list() == src, "r.to_list() != src" + return r + + def __post_init__(self) -> None: + # the template parameters are upper so if these are the + # same then it is ambiguous + for a in self.add: + assert a.name.upper() != a.name + for args in self.replace.values(): + for a in args: + assert a.name.upper() != a.name + + def to_list(self) -> list[str]: + replace_list = [] + for kernel_param, replacement_params in self.replace.items(): + replacements = ", ".join(str(param) for param in replacement_params) + replace_list.append(f"{kernel_param} -> {replacements}") + + return replace_list diff --git a/.venv/lib/python3.11/site-packages/torchgen/native_function_generation.py b/.venv/lib/python3.11/site-packages/torchgen/native_function_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..a44efab68426df661fa673173b69622dae666f82 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/native_function_generation.py @@ -0,0 +1,646 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Sequence + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.translate import translate +from torchgen.api.types import Binding, DispatcherSignature, Expr +from torchgen.context import with_native_function +from torchgen.model import ( + Annotation, + Argument, + BackendIndex, + BackendMetadata, + BaseOperatorName, + BaseTy, + BaseType, + DEFAULT_KERNEL_NAMESPACE, + DeviceCheckType, + DispatchKey, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, + OperatorName, + Return, + SchemaKind, + Variant, +) +from torchgen.utils import concatMap + + +# See Note: [Out ops with functional variants that don't get grouped properly] +OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ + # This has a functional variant, but it's currently marked private. + # This function should be marked private as well (*_backward ops aren't exposed to python anyway). + "adaptive_avg_pool3d_backward.grad_input", + # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly. + # Maybe we can kill this operator in favor of convolution_backward? + "_slow_conv2d_backward.grad_input", +] + + +# See Note: [Mutable ops that cannot get an out variant] +MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ + # should be out=? + "_cummax_helper", + # should be out=? + "_cummin_helper", +] + +# All of these operators don't have any tensor like returns +FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ + "_assert_async", # no return + "_assert_async.msg", # no return + "_cslt_sparse_mm_search", # returns an int + "_assert_scalar", # no return + "_dimI", # returns an int + "_dimV", # returns an int + "_has_same_storage_numel", # returns a boolean + "_linalg_check_errors", # no return + "_local_scalar_dense", # returns a Scalar + "_nested_tensor_from_mask_left_aligned", # returns a boolean + "_nnz", # returns an int + "_use_cudnn_ctc_loss", # returns a boolean + "_use_cudnn_ctc_loss.Tensor", # returns a boolean + "_validate_compressed_sparse_indices", # no return + "allclose", # returns a boolean + "dense_dim", # returns an int + "equal", # returns a boolean + "is_coalesced", # returns an boolean + "is_pinned", # returns a boolean + "is_same_size", # returns a boolean + "is_set_to", # returns a boolean + "q_per_channel_axis", # returns an int + "q_scale", # returns a float + "q_zero_point", # returns an int + "qscheme", # returns a QScheme + "record_stream", # no return + "sparse_dim", # returns an int + "sym_constrain_range", # no return + "sym_constrain_range_for_size", # no return + "_nested_tensor_storage_offsets", # returns a vector of ints + "_chunk_grad_outputs_efficient_attention", # returns a bool + "_fused_sdp_choice", # returns an int + "_print", # no return + "_sink_tokens", # no return + "_nested_get_ragged_idx", # returns an int +] + +INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ + # polygamma and polygamma.out both exist, but have a + # pre-self arg (while polygamma_ does not) + # We should either fix this schema so it can be grouped properly, + # or allow the codegen to generate new functional/out= NativeFunctions for this op + # (which would require changing its overload name to prevent overload ambiguity). + "polygamma_" +] + + +# Groups "similar" NativeFunctions together +# example add.Tensor, add_.Tensor, add.out +# "similar" NativeFunctions are all expected to have an identical `signature()`, +# But have differing SchemaKinds. +def pre_group_native_functions( + native_functions: Sequence[NativeFunction], +) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]: + pre_grouped_native_functions: dict[ + FunctionSchema, dict[SchemaKind, NativeFunction] + ] = defaultdict(dict) + for f in native_functions: + d = pre_grouped_native_functions[f.func.signature()] + assert f.func.kind() not in d + d[f.func.kind()] = f + return pre_grouped_native_functions + + +# Returns the out variant overload name given a base function overload name +def get_expected_out_variant_overload_name(overload_name: str | None) -> str: + return "out" if not overload_name else f"{overload_name}_out" + + +# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant +# Example before: +# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +# Example after: +# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) +def self_to_out_signature(func: FunctionSchema) -> FunctionSchema: + # Generating an out= schema from an inplace schema. + assert func.kind() == SchemaKind.inplace + assert func.arguments.self_arg is not None + # The new out= schema has: + # - a new out argument with the same type as "func" (but with a mutable annotation) + # - The returns (if any) now alias the out= argument instead of "func" + # - an "out" overload name + return FunctionSchema( + name=func.name.remove_inplace().with_overload( + get_expected_out_variant_overload_name(func.name.overload_name) + ), + arguments=func.arguments.remove_self_annotation().with_out_args( + [ + Argument( + name="out", + type=func.arguments.self_arg.argument.type, + default=None, + annotation=func.arguments.self_arg.argument.annotation, + ) + ] + ), + returns=func.returns, + ) + + +# Helper function: given a functional FunctionSchema, generate its corresponding out= variant +# Example before: +# _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, +# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor +# Example after: +# _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, +# Tensor(a!) out) -> Tensor(a!) +def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema: + # Generating an out= schema from a functional schema. + assert func.kind() == SchemaKind.functional + + new_returns, new_out_args = generate_out_args_from_schema(func) + # The new out= schema has: + # - one or more new out argument(s) with the same type as returns (but with a mutable annotation) + # - The returns now alias the out= arguments + # - an "_out" overload name + return FunctionSchema( + name=func.name.with_overload( + get_expected_out_variant_overload_name(func.name.overload_name) + ), + arguments=func.arguments.signature().with_out_args( + new_out_args, + ), + returns=tuple(new_returns), + ) + + +# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations. +def generate_out_args_from_schema( + func: FunctionSchema, +) -> tuple[list[Return], list[Argument]]: + # More of a sanity check - our existing restrictions on schemas should enforce that + # mutable schema kinds never return their mutable arguments. + assert not any( + r.annotation is not None and r.annotation.is_write for r in func.returns + ) + + tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] + assert len(tensorlike_rets) > 0 + + used_annotations = concatMap( + lambda a: [] if a.annotation is None else a.annotation.alias_set, + func.arguments.flat_all, + ) + valid_annotations = [ + x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations + ] + + all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) + + new_out_args: list[Argument] = [] + # The end result of new_returns is that: + # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. + # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). + new_returns: list[Return] = [] + for i, r in enumerate(func.returns): + if r.type.is_tensor_like(): + new_out = Argument( + name="out" if len(func.returns) == 1 else f"out{i}", + type=r.type, + default=None, + annotation=Annotation.parse(f"{valid_annotations[i]}!"), + ) + new_out_args.append(new_out) + if all_rets_are_tensors: + # The convention for out= schemas is that they only return their out arguments + # if the return is a plain Tensor (or if it's a tuple of plain Tensors) + new_ret = Return( + name=None, type=new_out.type, annotation=new_out.annotation + ) + new_returns.append(new_ret) + else: + new_returns.append(r) + return new_returns, new_out_args + + +# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant +# Example before: +# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 +# Example after: +# _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950 +def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: + # Generating an out= schema from a mutable schema. + assert func.kind() == SchemaKind.mutable + # The new out= schema has: + # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments + # (if the argument is a tensor then we also return it for method chaining, + # otherwise we return nothing) + # - an "out" overload name + # + # Note that: + # (1) This also means that we can *only* generate an out= variant from a mutable schema + # if the mutable schema has at least one tensor-like non-aliasing return. + # (2) The generated out= variant still has mutable positional arguments, + # but if necessary we could probably add another out= variant that also + # functionalizes the mutable arguments (a functional_out variant) + + new_returns, new_out_args = generate_out_args_from_schema(func) + + return FunctionSchema( + name=func.name.remove_inplace().with_overload( + get_expected_out_variant_overload_name(func.name.overload_name) + ), + arguments=func.arguments.with_out_args(new_out_args), + returns=tuple(new_returns), + ) + + +# This function, given function of one SchemaKind, as well as a target SchemaKind, +# generates a new NativeFunction with the same properties, but using the target SchemaKind. +# We only actually generate functions for either functional or out= SchemaKinds. +# This function returns a tuple, with: +# - The generated NativeFunction +# - a dictionary of `BackendIndex` objects, describing which dispatch keys +# we will generate kernels for, for the new NativeFunction. +# Details are in the function, but we only generate composite kernels (in some cases) today. +def generate_function( + f: NativeFunction, k: SchemaKind +) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: + from torchgen.api import cpp + + if k == SchemaKind.functional: + assert f.func.kind() != SchemaKind.functional + # The new "functional" NativeFunction has: + # - any mutable arguments have been converted into (immutable) returns. + # (if a mutable argument was not also a return, it gets converted to one) + # - "_functional" appended to the base name, ONLY IF this op has a mutable variant. + # See Note [Overload Ambiguity With Functional Variants] + # The default grouping logic in signature() actually already does this, + # so we can piggy-back off it (but we still want return names) + func = f.func.signature(keep_return_names=True).with_name( + OperatorName( + name=BaseOperatorName( + base=f.func.name.name.base, + inplace=False, + dunder_method=f.func.name.name.dunder_method, + # See Note [Overload Ambiguity With Functional Variants] + functional_overload=f.func.kind() == SchemaKind.mutable, + ), + overload_name=f.func.name.overload_name, + ) + ) + elif k == SchemaKind.out: + # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, + # but at least today, there is no good reason to actually use them. + # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. + if f.func.kind() == SchemaKind.inplace: + func = self_to_out_signature(f.func) + elif f.func.kind() == SchemaKind.mutable: + func = mutable_to_out_signature(f.func) + elif f.func.kind() == SchemaKind.functional: + func = functional_to_out_signature(f.func) + else: + raise AssertionError( + "We only bother generating out= functions from either inplace or mutable or functional variants" + ) + else: + raise AssertionError( + "We currently only generate either functional or out= NativeFunctions" + ) + + # Generated kernel naming convention for out: _. The reason for this is to + # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and + # `randn.generator_with_names_out`. + kernel_name = ( + func.name.unambiguous_name() + if func.kind() == SchemaKind.out + else cpp.name(func) + ) + if f.func.has_symint(): + kernel_name += "_symint" + backend_metadata = { + DispatchKey.CompositeExplicitAutograd: { + func.name: BackendMetadata( + kernel=kernel_name, + structured=False, + cpp_namespace=DEFAULT_KERNEL_NAMESPACE, + ) + } + } + tags = {"generated"} | set( + f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"} + ) + + return ( + NativeFunction( + func=func, + use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, + # These generated fn's aren't meant to be user friendly- don't generate methods. + variants={Variant.function}, + structured=False, + structured_delegate=None, + structured_inherits=None, + precomputed=None, + autogen=[], + ufunc_inner_loop={}, + manual_kernel_registration=False, + manual_cpp_binding=False, + python_module=None, + category_override=None, + device_guard=False, + device_check=DeviceCheckType.NoCheck, + loc=f.loc, + cpp_no_default_args=set(), + is_abstract=f.is_abstract, + has_composite_implicit_autograd_kernel=False, + has_composite_implicit_autograd_nested_tensor_kernel=False, + has_composite_explicit_autograd_kernel=True, + has_composite_explicit_autograd_non_functional_kernel=False, + # Every generated NativeFunction gets a "generated" tag, so it's easy to tell + # which NativeFunction objects did not come directly from native_functions.yaml. + tags=tags, + namespace=f.namespace, + ), + backend_metadata, + ) + + +# This function is responsible for adding generated NativeFunctions which don't appear +# explicitly in the codegen. +# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running +# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml") +# (Maybe we should make a friendly API for this) +# +# Note: this function *mutates* its two inputs, +# adding the new NativeFunctions / BackendMetadata to them +def add_generated_native_functions( + rs: list[NativeFunction], + indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], +) -> None: + # The main code for generating new NativeFunctions + # First we group of NativeFunctions by schema kind, + # then we detect which ones are missing and generate them. + pre_grouped_native_functions = pre_group_native_functions(rs) + for d in pre_grouped_native_functions.values(): + has_functional = SchemaKind.functional in d + has_inplace = SchemaKind.inplace in d + has_mutable = SchemaKind.mutable in d + has_out = SchemaKind.out in d + + # We automatically generate a few native functions that don't exist in the yaml, for a few reasons: + # (1) If an operator has an inplace/out= variant but no functional variant, we can generate + # a simple functional variant that the functionalization pass can consume. + # (2) If an operator has an inplace or functional but no out= variant, we generate an out= + # variant, mostly so we can easily pair up functions into NativeFunctionsGroup, + # while maintaining the constraint that the out= variant is "required". + if has_mutable or has_inplace or has_out or has_functional: + # Don't bother generating functions trio's for native functions that bypass the dispatcher. + are_manual = all(f.manual_cpp_binding for f in d.values()) + # Don't bother generating functional + out= variants for view operators + # set_ is technically an inplace_view, but for now it is treated + # as a normal inplace op in the codegen + has_view_ops = any( + f.is_view_op and str(f.func.name.name) != "set_" for f in d.values() + ) + # Don't generate the other variants for CompositeImplicitAutograd operators. + # We could probably do this, but the main benefit of generating the function triplets + # is for transforms that need them, and transforms don't need to act directly + # on CompositeImplicitAutograd operators (since we let them decompose). + are_composite_implicit = all( + f.has_composite_implicit_autograd_kernel for f in d.values() + ) + if are_manual or has_view_ops or are_composite_implicit: + continue + if has_out and len(d.values()) == 1: + # Note: [Out ops with functional variants that don't get grouped properly] + # In theory we could validly have an out= operator in native_functions.yaml + # that has no other variants. + # But today, all of the operators where that's the case actually do have + # functional variants, that we are just unable to pair up properly. + # I think banning this all together is probably safer + # (you can always add a functional variant yourself if you want to add a new out= operator). + # + # We should probably fix the existing cases; this check is to prevent us from adding more over time. + if ( + str(d[SchemaKind.out].func.name) + not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY + ): + raise AssertionError( + f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}" + ) + continue + + # Some inplace ops that have problematic schemas (that we should fix), which prevent us + # from generating out= and functional variants + if ( + has_inplace + and str(d[SchemaKind.inplace].func.name) + in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY + ): + continue + + base_fn = ( + d[SchemaKind.inplace] + if has_inplace + else d[SchemaKind.mutable] + if has_mutable + else d[SchemaKind.out] + if has_out + else d[SchemaKind.functional] + ) + + # Note: [Mutable ops that cannot get an out variant] + # We can only generate an out= variant if either: + # - the original function has tensor-like returns (since we can convert them to out kwargs) + # - or it's inplace (since we can convert `self` to an out kwarg) + # There are only two functions that don't fit this criteria today though, + # and they both look like they should be fixed to be out= variants, + # so if feels safer to ban this schema all-together + base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any( + r.type.is_tensor_like() for r in base_fn.func.returns + ) + # Note: [Loosen the assertion that all functional should have out variant] + # By design all functional operators should have our variants. The needs_out check + # is loosening this requirement, changing it to only generate out variant if there's + # an `autogen` block in the native function, in the long run it should be removed. + # FIXME: Remove this after figuring out CI job failures related to min, max, mean + needs_out = any("out" in str(op_name) for op_name in base_fn.autogen) + gets_out_variant = not has_out and base_fn_valid and needs_out + if not has_out and not base_fn_valid: + if ( + str(base_fn.func.name) + not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT + and str(base_fn.func.name) + not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT + ): + raise AssertionError( + f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}. +This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If +out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list.""" + ) + + # Generate an out= variant + if gets_out_variant: + fn, metadata = generate_function(base_fn, SchemaKind.out) + d[SchemaKind.out] = fn + BackendIndex.grow_index(indices, metadata) + rs.append(fn) + + # Generate a functional variant, but only do it if the operator got an out= variant + # (Functional variants are only useful if we can group up the variants, + # which we can only do if they have an out= variant) + if not has_functional and (has_out or gets_out_variant): + fn, metadata = generate_function(base_fn, SchemaKind.functional) + d[SchemaKind.functional] = fn + BackendIndex.grow_index(indices, metadata) + rs.append(fn) + + +def return_str(rets: tuple[Return, ...], names: list[str]) -> str: + assert len(rets) == len(names) + if len(rets) == 0: + return "" + elif len(rets) == 1: + return f"return {names[0]};" + else: + return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});" + + +# Given a function, and the name of a variable corresponding to the output of that function, +# gather up all of the individual returns that are not aliased +def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]: + aliased_rets = func.aliased_return_names() + non_aliased_names = [] + is_out_var_a_tuple = len(func.returns) > 1 + for i, r in enumerate(aliased_rets): + if r is None: + non_aliased_names.append( + f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var + ) + return non_aliased_names + + +# Generates functional kernels in terms of their inplace.mutable counterparts. +# We only do this for "generated" NativeFunctions +@with_native_function +def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None: + # We should only be generating these for code-generated NativeFunctions + if "generated" not in g.functional.tags: + return None + # And we always write the kernel for a generated op in terms of a non-generated op. + if g.inplace is not None and "generated" not in g.inplace.tags: + target_f = g.inplace + elif g.mutable is not None and "generated" not in g.mutable.tags: + target_f = g.mutable + else: + # We should be guaranteed to have a valid inplace/mutable variant to call into. + # See Note: [Mutable Ops Not Using Functionalization] + raise AssertionError(str(g.functional.func)) + + sig = DispatcherSignature(g.functional.func) + target_sig = DispatcherSignature(target_f.func) + + context: list[Binding | Expr] = [] + clone_mutable_inputs = [] + cloned_return_names = [] + # We can't just directly pass all of the arguments from the functional op into the mutating op. + # We need to check for which inputs to the mutating operator are mutable, + # and clone those inputs first. + for a_curr, a_tgt in zip( + dispatcher.jit_arguments(g.functional.func), + dispatcher.jit_arguments(target_f.func), + ): + if a_tgt.annotation is not None and a_tgt.annotation.is_write: + clone_mutable_inputs.append( + f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" + ) + context.append( + Expr( + expr=f"{a_curr.name}_clone", + type=dispatcher.argument_type(a_curr, binds=a_curr.name), + ) + ) + # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. + cloned_return_names.append(f"{a_curr.name}_clone") + else: + context.append(dispatcher.argument(a_curr)) + exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) + + out_name = "output" + maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" + inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) + ret_str = return_str( + g.functional.func.returns, inner_return_names + cloned_return_names + ) + + clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) + return f""" +{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{ + {clone_mutable_inputs_str} + {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs}); + {ret_str} +}} +""" + + +# Generates out= kernels in terms of their functional counterparts. +# We only do this for "generated" NativeFunctions +@with_native_function +def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None: + # We should only be generating these for code-generated NativeFunctions + if "generated" not in g.out.tags: + return None + # And we always write the kernel for the out= op in terms of the functional. + # Note that the functional op might have also been generated, but we don't have to + # worry about cycles, because the generated functional kernels are always implemented + # in terms of non-generated kernels (see gen_composite_functional_kernel). + + sig = DispatcherSignature(g.out.func) + target_sig = DispatcherSignature(g.functional.func) + + exprs = ", ".join( + [e.expr for e in translate(sig.arguments(), target_sig.arguments())] + ) + + copy_outs = [] + out_name = "tmp_output" + for i, out_arg in enumerate(g.out.func.arguments.out): + functional_return_name = ( + out_name + if len(g.functional.func.returns) == 1 + else f"std::get<{i}>({out_name})" + ) + copy_outs.append( + f"""\ + resize_out_helper({out_arg.name}, {functional_return_name}); + copy_arg({out_arg.name}, {functional_return_name});""" + ) + + rets = [] + # For each return arg in the calling (out=) operator, + # If it corresponds to an aliased input, return the input. + # Otherwise, return the corresponding output from calling the functional operator. + for i, ret_name in enumerate(g.out.func.aliased_return_names()): + if ret_name is not None: + rets.append(ret_name) + else: + functional_return_name = ( + out_name + if len(g.functional.func.returns) == 1 + else f"std::get<{i}>({out_name})" + ) + rets.append(functional_return_name) + + copy_outs_str = "\n".join(copy_outs) + + # Kernel name needs to follow the naming convention defined in `generate_function()` + return f""" +{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{ + auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs}); + {copy_outs_str} + {return_str(g.out.func.returns, rets)} +}} +""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43c79401e50661f2db8cd2d65119c73e156d062 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8aa0ad9fd0efc4cbb575fb2db7ae62108c5f53d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57ee751f17da9645ec5d361867f44ba23cdab169 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py new file mode 100644 index 0000000000000000000000000000000000000000..362ce427d508ca7885803e013ef3ac4640314a1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import os +from enum import Enum +from operator import itemgetter +from pathlib import Path +from typing import Any + +import torch +from torch.jit.generate_bytecode import generate_upgraders_bytecode +from torchgen.code_template import CodeTemplate +from torchgen.operator_versions.gen_mobile_upgraders_constant import ( + MOBILE_UPGRADERS_HEADER_DESCRIPTION, +) + + +class ByteCode(Enum): + instructions = 1 + constants = 2 + types = 3 + operators = 4 + register_size = 5 + + +EXCLUDED_OP_SET = [ + "aten::full.names", + "aten::full.out", + "aten::full", +] + +EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"] + +ONE_INSTRUCTION = CodeTemplate( + """ + Instruction{OpCode::${operator_name}, ${X}, ${N}},""" +) + +INSTRUCTION_LIST = CodeTemplate( + """std::vector({ + ${instruction_list} + }), // instructions list""" +) + +ONE_CONSTANT = CodeTemplate( + """ + c10::IValue(${constant}),""" +) + +CONSTANT_LIST = CodeTemplate( + """std::vector({ + ${constant_list} + }), // constants list""" +) + +CONSTANTS_LIST_EMPTY = """std::vector(), // constants list""" + +ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""") + +TYPE_LIST = CodeTemplate( + """std::vector({ + ${type_list} + }), // types list""" +) + +TYPE_LIST_EMPTY = """std::vector(), // types list""" + +ONE_OPERATOTR_STRING = CodeTemplate( + """ + OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),""" +) + +OPERATOR_STRING_LIST = CodeTemplate( + """ + std::vector({ + ${operator_string_list} + }), // operators list""" +) + +ONE_UPGRADER_FUNCTION = CodeTemplate( + """ + mobile::Function::registerFunc( + "${upgrader_name}", + ${instruction_list}, + ${constant_list}, + ${type_list}, + ${register_size} + )""" +) + +ONE_UPGRADER_SRC = CodeTemplate( + """ + ByteCodeFunctionWithOperator({ + ${bytecode_function}, + ${operator_string_list} + }),""" +) + + +ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate( + """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})""" +) # noqa: E501 + +ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate( + """ + {std::string("${operator_name}"), + std::vector({ + ${upgrader_list_in_version_map} + })},""" +) + + +OPERATOR_VERSION_MAP = CodeTemplate( + """ +const std::unordered_map> +getOperatorVersionMapForMobile() { + static std::unordered_map> + operatorVersionMapForMobile({ + ${operator_list_in_version_map} + }); + return operatorVersionMapForMobile; +} +""" +) + + +UPGRADER_CPP_SRC = CodeTemplate( + MOBILE_UPGRADERS_HEADER_DESCRIPTION + + """ +#include +#include + +namespace c10 { +TypePtr parseType(const std::string& pythonStr); +} // namespace c10 + +namespace torch { +namespace jit { + +// clang-format off + +// From operator_versions_map +${operator_version_map} + +const std::vector& getUpgraderBytecodeList() { + auto generate_upgrader_bytecode_list = []() { + std::vector upgrader_function_list({ + ${upgrader_bytecode} + }); + for (const auto& upgrader_function : upgrader_function_list) { + for (const auto& op : upgrader_function.operators) { + upgrader_function.function.append_operator( + op.name, + op.overload_name, + op.num_specified_args); + } + } + return upgrader_function_list; + }; + static std::vector upgraderBytecodeList = + generate_upgrader_bytecode_list(); + return upgraderBytecodeList; +} + +// clang-format on + +} // namespace jit +} // namespace torch +""" +) + +UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp" + +UPGRADER_ELEMENT = CodeTemplate( + """\ +Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}), +""" +) + +PER_OPERATOR_UPGRADER_LIST = CodeTemplate( + """\ +{ + std::string(${operator_name}), + std::vector({${upgrader_list}}); +} +""" +) + + +def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: + instruction_list_part = [] + for instruction in instruction_list_from_yaml: + instruction_list_part.append( + ONE_INSTRUCTION.substitute( + operator_name=instruction[0], + X=instruction[1], + N=instruction[2], + ) + ) + return INSTRUCTION_LIST.substitute( + instruction_list="".join(instruction_list_part).lstrip("\n") + ) + + +def construct_constants(constants_list_from_yaml: list[Any]) -> str: + constants_list_part = [] + for constant_from_yaml in constants_list_from_yaml: + convert_constant = None + if isinstance(constant_from_yaml, str): + # Add quotes if it's string + convert_constant = f'"{constant_from_yaml}"' + elif isinstance(constant_from_yaml, bool): + convert_constant = "true" if constant_from_yaml else "false" + elif constant_from_yaml is None: + convert_constant = "" + elif isinstance(constant_from_yaml, int): + convert_constant = str(constant_from_yaml) + else: + raise ValueError( + f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. " + "Please add change in construct_constants function in gen_mobile_upgraders.py." + ) + constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant)) + if len(constants_list_part) == 0: + return CONSTANTS_LIST_EMPTY + return CONSTANT_LIST.substitute( + constant_list="".join(constants_list_part).lstrip("\n") + ) + + +def construct_operators(operator_list_from_yaml: list[Any]) -> str: + operator_list_part = [] + for operator in operator_list_from_yaml: + operator_list_part.append( + ONE_OPERATOTR_STRING.substitute( + operator_name=operator[0], + overload_name=operator[1], + num_of_args=operator[2], + ) + ) + return OPERATOR_STRING_LIST.substitute( + operator_string_list="".join(operator_list_part).lstrip("\n") + ) + + +def construct_types(types_tr_list_from_yaml: list[Any]) -> str: + types_tr_list_part = [] + for types_tr in types_tr_list_from_yaml: + types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr)) + if len(types_tr_list_part) == 0: + return TYPE_LIST_EMPTY + return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n")) + + +def construct_register_size(register_size_from_yaml: int) -> str: + if not isinstance(register_size_from_yaml, int): + raise ValueError( + f"Input register size is {register_size_from_yaml} and" + "it's type is {type(register_size_from_yaml)}. An int type is expected." + ) + return str(register_size_from_yaml) + + +def construct_version_maps( + upgrader_bytecode_function_to_index_map: dict[str, Any] +) -> str: + version_map = torch._C._get_operator_version_map() + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] + sorted_version_map = dict(sorted_version_map_) + + operator_list_in_version_map_part = [] + for op_name in sorted_version_map: + upgraders_in_version_map_part = [] + # TODO: remove the skip after these two operators schemas are fixed + if op_name in EXCLUDED_OP_SET: + continue + upgrader_ranges = torch._C._get_upgrader_ranges(op_name) + upgrader_entries = sorted_version_map[op_name] + assert len(upgrader_ranges) == len(upgrader_entries) + for idx, upgrader_entry in enumerate(upgrader_entries): + upgrader_name = upgrader_entry.upgrader_name + bytecode_function_index = upgrader_bytecode_function_to_index_map[ + upgrader_name + ] + upgraders_in_version_map_part.append( + ONE_UPGRADER_IN_VERSION_MAP.substitute( + upgrader_min_version=upgrader_ranges[idx].min_version, + upgrader_max_version=upgrader_ranges[idx].max_version, + upgrader_name=upgrader_name, + bytecode_func_index=bytecode_function_index, + ) + ) + operator_list_in_version_map_part.append( + ONE_OPERATOR_IN_VERSION_MAP.substitute( + operator_name=op_name, + upgrader_list_in_version_map="".join(upgraders_in_version_map_part), + ) + ) + return OPERATOR_VERSION_MAP.substitute( + operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip( + "\n" + ) + ) + + +def get_upgrader_bytecode_function_to_index_map( + upgrader_dict: list[dict[str, Any]] +) -> dict[str, Any]: + upgrader_bytecode_function_to_index_map = {} + index = 0 + for upgrader_bytecode in upgrader_dict: + for upgrader_name in upgrader_bytecode.keys(): + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + upgrader_bytecode_function_to_index_map[upgrader_name] = index + index += 1 + return upgrader_bytecode_function_to_index_map + + +def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: + body_parts = [] + upgrader_bytecode_function_to_index_map = ( + get_upgrader_bytecode_function_to_index_map(upgrader_dict) + ) + version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map) + all_upgrader_src_string = [] + for upgrader_bytecode in upgrader_dict: + for upgrader_name, bytecode in upgrader_bytecode.items(): + # TODO: remove the skip after these two operators schemas are fixed + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + instruction_list_str = "" + constant_list_str = "" + type_list_str = "" + register_size_str = "" + operator_list_str = "" + for table_name, contents in bytecode.items(): + element = ByteCode[table_name] + body_string = "" + if element is ByteCode.instructions: + instruction_list_str = construct_instruction(contents) + elif element is ByteCode.constants: + constant_list_str = construct_constants(contents) + elif element is ByteCode.operators: + operator_list_str = construct_operators(contents) + elif element is ByteCode.types: + type_list_str = construct_types(contents) + elif element is ByteCode.register_size: + register_size_str = construct_register_size(contents) + + one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute( + upgrader_name=upgrader_name, + instruction_list=instruction_list_str, + constant_list=constant_list_str, + type_list=type_list_str, + register_size=register_size_str, + ) + one_upgrader_src_string = ONE_UPGRADER_SRC.substitute( + bytecode_function=one_upgrader_function_string.lstrip("\n"), + operator_string_list=operator_list_str.lstrip("\n"), + ) + all_upgrader_src_string.append(one_upgrader_src_string) + + upgrader_file_content = UPGRADER_CPP_SRC.substitute( + operator_version_map=version_map_src, + upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"), + ) + body_parts.append(upgrader_file_content) + print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME) + with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file: + final_output = "".join(body_parts) + out_file.write(upgrader_file_content.encode("utf-8")) + + +def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: + sorted_upgrader_list = sorted( + upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) + ) + return sorted_upgrader_list + + +def main() -> None: + upgrader_list = generate_upgraders_bytecode() + sorted_upgrader_list = sort_upgrader(upgrader_list) + for up in sorted_upgrader_list: + print("after sort upgrader : ", next(iter(up))) + + pytorch_dir = Path(__file__).resolve().parents[2] + upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile" + write_cpp(str(upgrader_path), sorted_upgrader_list) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..04b5ad887e54153115eeca7b6686d7c2de8dfc06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py @@ -0,0 +1,7 @@ +MOBILE_UPGRADERS_HEADER_DESCRIPTION = """/** + * @generated + * This is an auto-generated file. Please do not modify it by hand. + * To re-generate, please run: + * cd ~/pytorch && python torchgen/operator_versions/gen_mobile_upgraders.py + */ +""" diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9878f2be6c9ff2ba73f394715a2fdf7d71d16c90 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1bb8f5db17685f3270aa24650521dd5a171b02a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0433c9441719a8da1847b41b0f19ab3c76b83034 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/selective_build/__pycache__/selector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py b/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb92dfc09e28c7c98ab7230af362c363a30d621 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/selective_build/operator.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +# This class holds information about a single operator used to determine +# the outcome of a selective/custom PyTorch build that doesn't include +# registration code for all the supported operators. This is done to +# reduce the size of the generated binary so that it can be deployed in +# situations where binary size comes at a premium. +# +@dataclass(frozen=True) +class SelectiveBuildOperator: + # The name of the operator. This includes the aten::, etc... prefix + # The operator name may or may not have the overload name. If this + # operator name does not specify an overload name, the way to determine + # if this entry refers to the family of operators with this base name + # or just the operator with this name is to look at the value of the + # 'include_all_overloads' flag in this class. + name: str + + # True if this is a root operator (i.e. called directly from a + # TorchScript model, etc...). An operator is considered to be a + # root operator if it is called directly from any one of the models + # that this instance of the pytorch library was built for. Hence, it + # may not be a root operator in all of the models that are used in + # this instance of the pytorch library. + is_root_operator: bool + + # Is this operator used for on-device training? If True, then we need to + # use the information to generate code in VariableType_N.cpp for registration + # of training related operators. Again, this is True if this operator + # is used for training in one or more models used by this instance of the + # pytorch library. + is_used_for_training: bool + + # If True, it indicates that this operator instance (object) refers to an + # operator without the overload name and should apply to all overloads + # which have this operator name as the base name. This flag is applicable + # only for objects that have operator names without a DOT (period) character + # in them. + # + # Note: This flag is a temporary workaround to grandfather in the current + # static selective (custom) build mechanism, which largely ignores overload + # names when determining whether to select operators for registration + # purposes. + include_all_overloads: bool + + # Debug Information at the operator level + _debug_info: tuple[str, ...] | None + + @staticmethod + def from_yaml_dict( + op_name: str, op_info: dict[str, object] + ) -> SelectiveBuildOperator: + allowed_keys = { + "name", + "is_root_operator", + "is_used_for_training", + "include_all_overloads", + "debug_info", + } + + if len(set(op_info.keys()) - allowed_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(set(op_info.keys()) - allowed_keys), + ) + ) + + if "name" in op_info: + assert op_name == op_info["name"] + + is_root_operator = op_info.get("is_root_operator", True) + assert isinstance(is_root_operator, bool) + + is_used_for_training = op_info.get("is_used_for_training", True) + assert isinstance(is_used_for_training, bool) + + include_all_overloads = op_info.get("include_all_overloads", True) + assert isinstance(include_all_overloads, bool) + + debug_info: tuple[str, ...] | None = None + if "debug_info" in op_info: + di_list = op_info["debug_info"] + assert isinstance(di_list, list) + debug_info = tuple(str(x) for x in di_list) + + return SelectiveBuildOperator( + name=op_name, + is_root_operator=is_root_operator, + is_used_for_training=is_used_for_training, + include_all_overloads=include_all_overloads, + _debug_info=debug_info, + ) + + @staticmethod + def from_legacy_operator_name_without_overload( + name: str, + ) -> SelectiveBuildOperator: + return SelectiveBuildOperator( + name=name, + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=True, + _debug_info=None, + ) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "is_root_operator": self.is_root_operator, + "is_used_for_training": self.is_used_for_training, + "include_all_overloads": self.include_all_overloads, + } + if self._debug_info is not None: + ret["debug_info"] = self._debug_info + + return ret + + +def merge_debug_info( + lhs: tuple[str, ...] | None, + rhs: tuple[str, ...] | None, +) -> tuple[str, ...] | None: + # Ensure that when merging, each entry shows up just once. + if lhs is None and rhs is None: + return None + + return tuple(set((lhs or ()) + (rhs or ()))) + + +def combine_operators( + lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator +) -> SelectiveBuildOperator: + if str(lhs.name) != str(rhs.name): + raise Exception( # noqa: TRY002 + f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" + ) + + return SelectiveBuildOperator( + name=lhs.name, + # Consider this operator to be a root operator if it is a + # root operator in any of the models used in this instance of + # the pytorch library. + is_root_operator=lhs.is_root_operator or rhs.is_root_operator, + # Consider this operator to be a training operator if it is + # an operator used for training in any of the models used + # in this instance of the pytorch library. + is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, + include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, + _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), + ) + + +def merge_operator_dicts( + lhs: dict[str, SelectiveBuildOperator], + rhs: dict[str, SelectiveBuildOperator], +) -> dict[str, SelectiveBuildOperator]: + operators: dict[str, SelectiveBuildOperator] = {} + for op_name, op in list(lhs.items()) + list(rhs.items()): + new_op = op + if op_name in operators: + new_op = combine_operators(operators[op_name], op) + + operators[op_name] = new_op + + return operators + + +def strip_operator_overload_name(op_name: str) -> str: + return op_name.split(".")[0] diff --git a/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py b/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..04acc354203ade2f48dcef56fd9d9ef70c82ad1d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/selective_build/selector.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import yaml + +from torchgen.selective_build.operator import ( + merge_debug_info, + merge_operator_dicts, + SelectiveBuildOperator, + strip_operator_overload_name, +) + + +if TYPE_CHECKING: + from torchgen.model import NativeFunction + + +# A SelectiveBuilder holds information extracted from the selective build +# YAML specification. +# +# It includes information about the build's selectivity, the debug_info +# associated with this selective build (opaque string), and the set of +# operators that should be included in the build. +# +@dataclass(frozen=True) +class SelectiveBuilder: + # If true, then the build is not selective, and includes all + # operators. + include_all_operators: bool + + # Debug Information at the selective/custom build level. + _debug_info: tuple[str, ...] | None + + # A dictionary of operator -> operator metadata. + operators: dict[str, SelectiveBuildOperator] + + # A dictionary of selected kernel tags and dtypes. Typically a + # PyTorch Operator Kernel (function) may have many code paths + # that are specialized for many many Tensor dtypes, so it's not + # one per kernel function, but there could be many per kernel + # function. The tag isn't a kernel function name, but some fragment + # of the kernel function implementation itself. + kernel_metadata: dict[str, list[str]] + + # ExecuTorch only. A dictionary of kernel tag -> list of (list of input + # dtypes for tensor-like input args). + # This is from selective.yaml + et_kernel_metadata: dict[str, list[str]] + + # A set of all the custom torch bind classes used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + custom_classes: set[str] + + # A set of all the build features used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + build_features: set[str] + + # If true, then fragments for all dtypes for all kernel functions + # are included as well as all custom classes. This is typically set when any one of the + # operator lists is generated from a mechanism other than + # tracing based selective build. + include_all_non_op_selectives: bool + + @staticmethod + def get_nop_selector() -> SelectiveBuilder: + return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) + + @staticmethod + def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder: + valid_top_level_keys = { + "include_all_non_op_selectives", + "include_all_operators", + "debug_info", + "operators", + "kernel_metadata", + "et_kernel_metadata", + "custom_classes", + "build_features", + } + top_level_keys = set(data.keys()) + if len(top_level_keys - valid_top_level_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(top_level_keys - valid_top_level_keys), + ) + ) + include_all_operators = data.get("include_all_operators", False) + assert isinstance(include_all_operators, bool) + + debug_info = None + if "debug_info" in data: + di_list = data["debug_info"] + assert isinstance(di_list, list) + + debug_info = tuple(str(x) for x in di_list) + + operators = {} + operators_dict = data.get("operators", {}) + assert isinstance(operators_dict, dict) + + for k, v in operators_dict.items(): + operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) + + kernel_metadata = {} + kernel_metadata_dict = data.get("kernel_metadata", {}) + assert isinstance(kernel_metadata_dict, dict) + + for k, v in kernel_metadata_dict.items(): + kernel_metadata[str(k)] = [str(dtype) for dtype in v] + + et_kernel_metadata = data.get("et_kernel_metadata", {}) + assert isinstance(et_kernel_metadata, dict) + + custom_classes = data.get("custom_classes", []) + assert isinstance(custom_classes, Iterable) + custom_classes = set(custom_classes) + + build_features = data.get("build_features", []) + assert isinstance(build_features, Iterable) + build_features = set(build_features) + + include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) + assert isinstance(include_all_non_op_selectives, bool) + + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, # type: ignore[arg-type] + build_features, # type: ignore[arg-type] + include_all_non_op_selectives, + ) + + @staticmethod + def from_yaml_str(config_contents: str) -> SelectiveBuilder: + contents = yaml.safe_load(config_contents) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_yaml_path(config_path: str) -> SelectiveBuilder: + with open(config_path) as f: + contents = yaml.safe_load(f) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_legacy_op_registration_allow_list( + allow_list: set[str], is_root_operator: bool, is_used_for_training: bool + ) -> SelectiveBuilder: + operators = {} + for op in allow_list: + operators[op] = { + "name": op, + "is_root_operator": is_root_operator, + "is_used_for_training": is_used_for_training, + "include_all_overloads": True, + } + return SelectiveBuilder.from_yaml_dict( + { + "operators": operators, + "include_all_non_op_selectives": True, + } + ) + + def is_operator_selected(self, name: str) -> bool: + if self.include_all_operators: + return True + + if name in self.operators: + return True + name = strip_operator_overload_name(name) + return name in self.operators and self.operators[name].include_all_overloads + + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + + def is_operator_selected_for_training(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + not_training_op = SelectiveBuildOperator( + name="", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op = not_training_op + if name in self.operators: + op = self.operators[name] + + name = strip_operator_overload_name(name) + base_op = not_training_op + if name in self.operators: + base_op = self.operators[name] + + return op.is_used_for_training or ( + base_op.include_all_overloads and base_op.is_used_for_training + ) + + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + + def is_root_operator(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + if name in self.operators: + op: SelectiveBuildOperator = self.operators[name] + return op.is_root_operator + name = strip_operator_overload_name(name) + if name not in self.operators: + return False + base_op: SelectiveBuildOperator = self.operators[name] + return base_op.include_all_overloads and base_op.is_root_operator + + def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: + if self.include_all_operators or self.include_all_non_op_selectives: + return True + + return ( + kernel_tag in self.kernel_metadata + and dtype in self.kernel_metadata[kernel_tag] + ) + + def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]: + """ + Return a list of kernel keys that cover the used ops + """ + # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. + if op_name not in self.et_kernel_metadata: + return kernel_key if self.include_all_operators else [] + # Otherwise, only return the specific kernel keys. + + result_set = set() + + for model_kernel_keys in self.et_kernel_metadata[op_name]: + key_found = False + for key in kernel_key: + # Don't compare the version for now + if ( + key != "default" + and key.split("/")[1] == model_kernel_keys.split("/")[1] + ): + result_set.add(key) + key_found = True + break + if not key_found: + if "default" not in kernel_key: + raise Exception("Missing kernel for the model") # noqa: TRY002 + else: + result_set.add("default") + + return list(result_set) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "include_all_non_op_selectives": self.include_all_non_op_selectives, + "include_all_operators": self.include_all_operators, + } + operators = {} + for op_name, op in self.operators.items(): + operators[op_name] = op.to_dict() + ret["operators"] = operators + + if self._debug_info is not None: + ret["debug_info"] = sorted(self._debug_info) + + ret["kernel_metadata"] = { + k: sorted(v) for (k, v) in self.kernel_metadata.items() + } + + ret["et_kernel_metadata"] = self.et_kernel_metadata + + ret["custom_classes"] = sorted(self.custom_classes) + + ret["build_features"] = sorted(self.build_features) + + return ret + + +def merge_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + kernel_metadata: dict[str, list[str]] = {} + for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): + dtypes_copy = set(dtypes) + if tag_name in kernel_metadata: + dtypes_copy |= set(kernel_metadata[tag_name]) + + kernel_metadata[tag_name] = list(dtypes_copy) + + return kernel_metadata + + +def merge_et_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set) + for op in list(lhs.keys()) + list(rhs.keys()): + merge_et_kernel_metadata[op].update(lhs.get(op, [])) + merge_et_kernel_metadata[op].update(rhs.get(op, [])) + + return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} + + +def combine_selective_builders( + lhs: SelectiveBuilder, rhs: SelectiveBuilder +) -> SelectiveBuilder: + include_all_operators = lhs.include_all_operators or rhs.include_all_operators + debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) + operators = merge_operator_dicts(lhs.operators, rhs.operators) + kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) + et_kernel_metadata = merge_et_kernel_metadata( + lhs.et_kernel_metadata, rhs.et_kernel_metadata + ) + include_all_non_op_selectives = ( + lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives + ) + custom_classes = lhs.custom_classes.union(rhs.custom_classes) + build_features = lhs.build_features.union(rhs.build_features) + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, + build_features, + include_all_non_op_selectives, + ) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f"{f.namespace}::{f.func.name}" diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b892a95bdafda17a55ed32f6a71dafc5535a2725 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89054c58a61c1a531ffcff43ba94f6b83e74451d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6786c9a59064b3f339b678150af624f2cec73c4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54df420fdc4bb54a38380ad2082ff7b909a7d69d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/config.py b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7b541fa2c1287921613384aec2fee2cd7d4e97 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/config.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup + + +def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str: + if isinstance(g, NativeFunctionsGroup): + return str(g.functional.func.name.name.base) + else: + return str(g.view.root_name) + + +is_hand_written_ops_ = frozenset( + ( + "abs", + "add", + "addmm", + "all", + "any", + "argmin", + "bmm", + "clamp", + "clamp_min", + "cumsum", + "div", + "fmod", + "index_select", + "leaky_relu", + "linear", + "log", + "matmul", + "mul", + "narrow_copy", + "nonzero", + "pow", + "remainder", + "sigmoid", + "sign", + "sub", + "tanh", + "detach", + "expand_as", + "flatten", + "narrow", + "reshape_as", + "select", + "slice", + "softmax", + "split", + "squeeze", + "transpose", + "view", + "where", + ) +) + + +def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + name_base = func_name_base_str(g) + return name_base in is_hand_written_ops_ + + +def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None: + assert index == 0 or index == 1 + if op_name == "addr": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec1"] = "at::rand({6})" + arg_map["vec2"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec1"] = "at::rand({22})" + arg_map["vec2"] = "at::rand({22})" + return + if op_name == "mv": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec"] = "at::rand({22})" + return + if op_name == "addbmm": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + else: + arg_map["self"] = "at::rand({22, 22})" + return + if op_name == "cross": + if index == 0: + arg_map["self"] = "at::rand({3, 3, 3})" + arg_map["other"] = "at::rand({3, 3, 3})" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + return + if op_name == "take": + if index == 0: + arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)" + return + if op_name == "take_along_dim": + if index == 0: + arg_map["indices"] = "at::argsort(self0, 1, true)" + else: + arg_map["indices"] = "at::argsort(self1, 1, true)" + return + if op_name == "masked_select": + if index == 0: + arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5" + else: + arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5" + return + if op_name == "orgqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "ormqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "quantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "nanquantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "multi_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "multilabel_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)" + return + if op_name == "nll_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "nll_loss2d": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name in ( + "fft_fft", + "fft_ifft", + "fft_rfft", + "fft_irfft", + "fft_hfft", + "fft_ihfft", + ): + arg_map["norm"] = '"forward"' + return + if op_name == "linalg_tensorinv": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["ind"] = "2" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["ind"] = "2" + return + if op_name == "addmv": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["mat"] = "at::rand({2, 2})" + arg_map["vec"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({35})" + arg_map["mat"] = "at::rand({35, 35})" + arg_map["vec"] = "at::rand({35})" + return + if op_name == "acosh": + if index == 0: + arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})" + else: + arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})" + return + if op_name == "adaptive_max_pool2d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)" + return + if op_name == "adaptive_max_pool3d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)" + return + if op_name == "bitwise_left_shift": + if index == 0: + arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)" + return + if op_name == "bitwise_right_shift": + if index == 0: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)" + return + if op_name == "gather": + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + return + if op_name == "gelu": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "gelu_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({6, 6, 6})" + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["grad_output"] = "at::rand({22, 22, 22})" + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "index_add": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)" + arg_map["source"] = "at::rand({2})" + arg_map["alpha"] = "2" + else: + arg_map["self"] = "at::rand({16})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)" + arg_map["source"] = "at::rand({16})" + arg_map["alpha"] = "2" + return + if op_name == "index_copy": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)" + arg_map["source"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({32})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)" + arg_map["source"] = "at::rand({32})" + return + if op_name == "linalg_cross": + if index == 0: + arg_map["self"] = "at::rand({6, 3, 6})" + arg_map["other"] = "at::rand({6, 3, 6})" + arg_map["dim"] = "1" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + arg_map["dim"] = "1" + return + if op_name == "nll_loss_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({6})" + arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + else: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({36})" + arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)" + arg_map["weight"] = "at::rand({36})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + return + if op_name in ["scatter", "scatter_add", "_scatter_reduce"]: + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + if "reduce" in arg_map: + arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"' + return + if op_name == "scatter_reduce": + arg_map["reduce"] = '"mean"' + if index == 0: + arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + return + if op_name == "special_zeta": + if index == 0: + arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + else: + arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + return + if op_name == "_convert_indices_from_csr_to_coo": + if index == 0: + arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)" + arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)" + arg_map["out_int32"] = "false" + else: + arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)" + arg_map[ + "col_indices" + ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" + arg_map["out_int32"] = "false" + return + if op_name == "_convert_indices_from_coo_to_csr": + if index == 0: + arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)" + arg_map["size"] = "10" + arg_map["out_int32"] = "false" + else: + arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)" + arg_map["size"] = "24" + arg_map["out_int32"] = "false" + return + if op_name in ("diagonal", "linalg_diagonal"): + arg_map["offset"] = "0" + arg_map["dim1"] = "2" + arg_map["dim2"] = "1" + return diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7357173746748bacfc3e540ebcf37426b5455e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import argparse +import itertools +import os +from typing import Sequence, TypeVar, Union + +from libfb.py.log import set_simple_logging # type: ignore[import] + +from torchgen import gen +from torchgen.context import native_function_manager +from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup +from torchgen.static_runtime import config, generator + + +# Given a list of `grouped_native_functions` sorted by their op names, return a list of +# lists each of which groups ops that share the base name. For example, `mean` and +# `mean.dim` are grouped together by this function. + +NativeGroupT = TypeVar( + "NativeGroupT", + bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], +) + + +def group_functions_by_op_name( + grouped_native_functions: Sequence[NativeGroupT], +) -> Sequence[Sequence[NativeGroupT]]: + if not grouped_native_functions: + return [] + groups = [] + + def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + with native_function_manager(g): + return generator.is_supported(g) + + eligible_ops = (g for g in grouped_native_functions if is_supported(g)) + groups = [ + list(group) + for k, group in ( + itertools.groupby( + eligible_ops, + key=config.func_name_base_str, + ) + ) + ] + + return groups + + +def clang_format(cpp_file_path: str) -> None: + import subprocess + + subprocess.check_call(["clang-format", "-i", cpp_file_path]) + + +def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch {{ +namespace jit {{ + +{code} + +}} // namespace jit +}} // namespace torch +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include +#include +#include + +#include "test_utils.h" + +using namespace caffe2; +using namespace torch; +using namespace torch::jit; +using namespace torch::jit::test; +using c10::IValue; + +{code} + +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate ATen source files") + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for ATen", + default="caffe2/aten/src/ATen", + ) + parser.add_argument( + "-p", + "--generated-ops-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp", + ) + parser.add_argument( + "-t", + "--generated-ops-test-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/benchmarks/static_runtime/test_generated_ops.cc", + ) + options = parser.parse_args() + native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") + parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path) + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + + op_generator = generator.GenOpDispatcher() + test_case_generator = generator.GenOpTestCase() + + native_functions_groups = [ + g + for g in gen.get_grouped_native_functions(native_functions) + if isinstance(g, NativeFunctionsGroup) + ] + + supported_functions_groups = group_functions_by_op_name(native_functions_groups) + + out_variant_op_result = [ + op_generator.out_variant(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_groups + ] + out_variant_test_result = [ + test_case_generator.out_variant(groups) for groups in supported_functions_groups + ] + + native_functions_view_groups = [ + g + for g in gen.get_grouped_by_view_native_functions(native_functions) + if isinstance(g, NativeFunctionsViewGroup) + ] + + supported_functions_view_groups = group_functions_by_op_name( + native_functions_view_groups + ) + + view_op_result = [ + op_generator.view(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_view_groups + ] + view_test_result = [ + test_case_generator.view(groups) for groups in supported_functions_view_groups + ] + + op_result = out_variant_op_result + ["\n\n"] + view_op_result + test_result = out_variant_test_result + ["\n\n"] + view_test_result + + write_cpp(op_result, options.generated_ops_cpp_path) + write_test_cpp(test_result, options.generated_ops_test_cpp_path) + + print( + "\ntotal grouped native ops: %d" + % len(gen.get_grouped_native_functions(native_functions)) + ) + + print("grouped native ops with out variant: %d" % len(native_functions_groups)) + supported_functions_num = sum(len(groups) for groups in supported_functions_groups) + print("generated functions groups with out variant: %d" % supported_functions_num) + + print("\nview grouped native ops: %d" % len(native_functions_view_groups)) + supported_view_functions_num = sum( + len(groups) for groups in supported_functions_view_groups + ) + print("generated functions view groups: %d" % supported_view_functions_num) + + print( + "\noverall generated : %d" + % (supported_functions_num + supported_view_functions_num) + ) + + +if __name__ == "__main__": + set_simple_logging(escape_newlines=False) + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7bbb7f64d8644252cd6a92492c0c36b40d623b2f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/static_runtime/generator.py @@ -0,0 +1,809 @@ +from __future__ import annotations + +import json +import logging +import math +from typing import Sequence + +import torchgen.api.cpp as cpp +from torchgen.context import native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.static_runtime import config + + +logger: logging.Logger = logging.getLogger() + + +def has_alias( + arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments], +) -> bool: + for arg in arguments: + annotation = getattr(arg, "annotation", None) + if not annotation: + continue + alias_set = getattr(annotation, "alias_set", ()) + if alias_set: + return True + return False + + +BLOCKED_OPS = frozenset( + ( + # non cpu ops + "sparse_sampled_addmm", + "hspmm", + "linalg_svdvals", + # sparse ops + "sspaddmm", + "coalesce", + "_indices", + "indices", + "_values", + "values", + "crow_indices", + "col_indices", + # deprecated ops + "floor_divide", + "ger", + # buggy ops + "conj_physical", # P495807361 + "binary_cross_entropy", # P496394764 + "arccosh", + # uncommon ops + "cholesky", + "lu_solve", + "linalg_cholesky", + "linalg_householder_product", + "linalg_ldl_solve", + "_compute_linear_combination", + # training related ops + "_make_dual", + # cannot call directly + "_fw_primal", + # no documentation + "_index_reduce", + # TODO: these ones got added recently and need manual inspection + "_new_zeros_with_same_feature_meta", + "_conj_physical", + "binary_cross_entropy_with_logits", + "bincount", + "conv_tbc", + "copy", + "_copy_from", + "_copy_from_and_resize", + "count_nonzero", + "cudnn_affine_grid_generator", + "cudnn_affine_grid_generator_backward", + "cudnn_grid_sampler", + "diag_embed", + "embedding", + "embedding_dense_backward", + "_embedding_bag_dense_backward", + "_embedding_bag_per_sample_weights_backward", + "grid_sampler_2d", + "_grid_sampler_2d_cpu_fallback", + "grid_sampler_3d", + "isnan", + "mkldnn_linear", + "median", + "nanmedian", + "_sparse_sparse_matmul", + "batch_norm_backward_elemt", + "_euclidean_dist", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "_reshape_nested_backward", + "relu", + "prelu", + "celu", + "slice_scatter", + "select_scatter", + "diagonal_scatter", + "sum", + "_mkldnn_transpose", + "_nested_tensor_from_mask", + "_nested_from_padded", + "_nested_tensor_size", + "_nested_from_padded_and_nested_example", + "_standard_gamma_grad", + "_dirichlet_grad", + "native_norm", + "_sparse_softmax", + "_sparse_softmax_backward_data", + "_sparse_log_softmax", + "_sparse_log_softmax_backward_data", + "zero", + "_sparse_addmm", + "sparse_mask", + "_sparse_mask_projection", + "_to_dense", + "_coalesce", + "_coalesced", + "copy_sparse_to_sparse", + "to_sparse", + "to_sparse_csr", + "to_sparse_csc", + "to_mkldnn", + "quantize_per_tensor_dynamic", + "quantize_per_channel", + "q_per_channel_scales", + "q_per_channel_zero_points", + "int_repr", + "_make_per_channel_quantized_tensor", + "set", + "lift", + "lift_fresh", + "lift_fresh_copy", + "masked_scatter", + "_masked_softmax", + "_masked_softmax_backward", + "put", + "index_reduce", + "trace", + "_cholesky_solve_helper", + "dist", + "max", + "_torch_cuda_cu_linker_symbol_op", + "glu_jvp", + "glu_backward_jvp", + "hardswish_backward", + "rrelu_with_noise_backward", + "mkldnn_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool3d_backward", + "isinf", + "linalg_lu_solve", + "linalg_vecdot", + "linalg_matrix_exp", + "linalg_eigvalsh", + "_test_warn_in_autograd", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_segment_reduce", + "_segment_reduce_backward", + "_fw_primal_copy", + "_make_dual_copy", + "view_as_real_copy", + "view_as_complex_copy", + "_conj_copy", + "_neg_view_copy", + "diagonal_copy", + "detach_copy", + "squeeze_copy", + "t_copy", + "unsqueeze_copy", + "_indices_copy", + "_values_copy", + "indices_copy", + "values_copy", + "crow_indices_copy", + "col_indices_copy", + "ccol_indices", + "ccol_indices_copy", + "row_indices", + "row_indices_copy", + "unfold_copy", + "alias_copy", + "_triton_multi_head_attention", + "special_airy_ai", + "special_bessel_j0", + "special_bessel_j1", + "special_bessel_y0", + "special_bessel_y1", + "special_chebyshev_polynomial_t", + "special_chebyshev_polynomial_u", + "special_chebyshev_polynomial_v", + "special_chebyshev_polynomial_w", + "special_hermite_polynomial_h", + "special_hermite_polynomial_he", + "special_laguerre_polynomial_l", + "special_legendre_polynomial_p", + "special_modified_bessel_i0", + "special_modified_bessel_i1", + "special_modified_bessel_k0", + "special_modified_bessel_k1", + "special_scaled_modified_bessel_k0", + "special_scaled_modified_bessel_k1", + "special_shifted_chebyshev_polynomial_t", + "special_shifted_chebyshev_polynomial_u", + "special_shifted_chebyshev_polynomial_v", + "special_shifted_chebyshev_polynomial_w", + "special_spherical_bessel_j0", + "_foobar", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched + ) +) + + +def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + base_op_name = "" + func = None + if isinstance(g, NativeFunctionsViewGroup): + base_op_name = g.view.root_name + func = g.view.func + else: + base_op_name = g.out.func.name.name.base + func = g.out.func + if config.is_hand_written(g): + logger.info("HAND WRITTEN: %s", base_op_name) + return False + if base_op_name in BLOCKED_OPS: + logger.info("BLOCKED: %s", base_op_name) + return False + for arg in func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) + return False + + if isinstance(g, NativeFunctionsViewGroup): + # TODO: stop doing type tests by converting to C++ and then testing + # the string, just test the dang thing directly + if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): + # Returns a non-Tensor value. + logger.info("NON-TENSOR RET TYPE: %s", str(func)) + return False + return True + + # For out variant ops, we need to check the arguments of its functional func. + for arg in g.functional.func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) + return False + + if not g.structured: + # In case of unstructured op, we check if it has out variant implementation. + # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last + # parameter. + if ( + not hasattr(g, "out") + or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") + or not str(func.name).endswith(".out") + ): + return False + # TODO: stop type testing by converting to C++ + if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): + logger.info("NON_TENSOR RET TYPE: %s", func) + return False + if has_alias(func.arguments.non_out): + # This op may create an alias of inputs. + logger.info("INPUTS ALIAS: %s", base_op_name) + return False + return True + + +def ivalue_type_conversion_method( + arg_type: BaseType | OptionalType | Type, +) -> tuple[bool, str] | None: + """ + Return the method call expression of `c10::ivalue' to convert its contained value to + the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, + this function returns ".toTensor()", so that it can be appended to the ivalue's + variable name to get the value of the expected type. + """ + type_conversion_methods = { + BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")), + BaseTy.int: ((False, "toInt()"), (False, "toOptional()")), + BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")), + BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")), + BaseTy.ScalarType: ( + (False, "toScalarType()"), + (False, "toOptional()"), + ), + BaseTy.str: ( + (False, "toStringView()"), + (False, "toOptional()"), + ), + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + elif isinstance(arg_type, OptionalType): + if not isinstance(arg_type.elem, BaseType): + # ListType is currently unsupported. + return None + base_ty_object = arg_type.elem.name + else: + return None + + if base_ty_object not in type_conversion_methods: + return None + methods = type_conversion_methods[base_ty_object] + if isinstance(arg_type, BaseType): + return methods[0] + return methods[1] + + +should_use_int_tensor_ops_ = frozenset( + ( + "bitwise_not", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "gcd", + "lcm", + "scatter", + "gather", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + ) +) +should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) + + +def should_use_int_tensor(op_name: str) -> bool: + return op_name in should_use_int_tensor_ops_ + + +def should_use_complex_tensor(op_name: str) -> bool: + return op_name in should_use_complex_tensor_ops_ + + +test_tensor_dim_ops_1_ = frozenset( + ( + "addmv", + "index_add", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "nll_loss_backward", + "dot", + "vdot", + "outer", + "ger", + ) +) +test_tensor_dim_ops_2_ = frozenset( + ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") +) + + +def test_tensor_dim(op_name: str) -> int: + if op_name in test_tensor_dim_ops_1_: + return 1 + if op_name in test_tensor_dim_ops_2_: + return 2 + return 3 + + +test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' +test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string) + + +def test_tensor_shape(op_name: str) -> str: + if op_name in test_tensor_shape_json: + return test_tensor_shape_json[op_name] + else: + return "" + + +def test_value_expression( + arg_type: BaseType | OptionalType | Type, index: int, op_name: str +) -> str: + tensor_size_ex = test_tensor_shape(op_name) + if tensor_size_ex == "": + num_tensors = 16 if index == 0 else 64 + num_dim = test_tensor_dim(op_name) + size_per_dim = math.ceil(num_tensors / float(num_dim)) + size_per_dim += size_per_dim % 2 + tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim)) + if should_use_int_tensor(op_name): + tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" + elif should_use_complex_tensor(op_name): + tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" + else: + tensor_expression = f"at::rand({tensor_size_ex})" + + value_expressions = { + BaseTy.Tensor: tensor_expression, + BaseTy.int: "1", + BaseTy.bool: "false", + BaseTy.Scalar: "2", + BaseTy.ScalarType: "at::ScalarType::Float", + BaseTy.str: '"floor"', + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + else: + assert isinstance(arg_type, OptionalType) and isinstance( + arg_type.elem, BaseType + ) + base_ty_object = arg_type.elem.name + assert base_ty_object in value_expressions, "not expected type" + value_expression = value_expressions[base_ty_object] + return value_expression + + +def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + schema_name = schema.name.name.base + arg_map = {} + for arg in schema.schema_order_arguments(): + test_value_exp = test_value_expression(arg.type, index, schema_name) + arg_map[arg.name] = test_value_exp + config.override_test_values(arg_map, schema_name, index) + arg_populations = [] + for arg_name, arg_value in arg_map.items(): + arg_populations.append(f"auto {arg_name}{index} = {arg_value}") + return ";\n ".join(arg_populations) + ";" + + +def generate_test_value_names(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) + + +generate_test_ir_arguments_base_ty_to_type_str_ = { + BaseTy.Tensor: "Tensor", + BaseTy.int: "int", + BaseTy.float: "float", + BaseTy.str: "str", + BaseTy.Scalar: "int", + BaseTy.ScalarType: "int", + BaseTy.bool: "bool", +} + + +def generate_test_ir_arguments( + schema: FunctionSchema, +) -> list[tuple[str, str | None]]: + def ir_argument(arg: Argument) -> tuple[str, str | None]: + t = arg.type + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + assert isinstance(t, BaseType) + type_str = None + if t.name in generate_test_ir_arguments_base_ty_to_type_str_: + type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] + if type_str and add_optional: + type_str = f"{type_str}?" + return ("%" + arg.name, type_str) + + return [ir_argument(arg) for arg in schema.schema_order_arguments()] + + +def generate_arg_extraction(schema: FunctionSchema) -> str: + arg_populations = [] + for i, arg in enumerate(schema.schema_order_arguments()): + maybe_method = ivalue_type_conversion_method(arg.type) + assert maybe_method + is_reference, type_conversion_method = maybe_method + reference = "&" if is_reference else "" + arg_populations.append( + f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" + ) + return ";\n ".join(arg_populations) + ";" + + +def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.functional) + if g.structured or kernel is None: + return cpp.name(g.functional.func) + return kernel.kernel + + +def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.out) + if g.structured or kernel is None: + return cpp.name(g.out.func) + return kernel.kernel + + +def generate_non_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.functional.func + assert not schema.is_out_fn() + kernel_name = get_kernel_name(g, backend_index) + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "cpu" if g.structured else "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_call_to_view_ops( + g: NativeFunctionsViewGroup, backend_index: BackendIndex +) -> str: + schema = g.view.func + kernel_name = cpp.name(schema) + kernel = backend_index.get_kernel(g.view) + if kernel: + kernel_name = kernel.kernel + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "native" + return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' + + +def generate_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.out.func + assert schema.is_out_fn() + arg_names = [] + kernel_name = get_out_kernel_name(g, backend_index) + if g.structured: + # structured op starts with the output tensor argument. + arg_names = [out_arg.name for out_arg in schema.arguments.out] + else: + arg_names = [] + for arg in schema.arguments.non_out: + if isinstance(arg, SelfArgument): + arg_names.append(arg.argument.name) + else: + assert isinstance(arg, Argument) + arg_names.append(arg.name) + if not g.structured: + assert len(schema.arguments.out) == 1 + arg_names.append(schema.arguments.out[0].name) + cpp_arg_names = ",".join(arg_names) + namespace_name = "cpu" if g.structured else "native" + return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" + + +no_memory_resize_ops = frozenset( + ( + "isin.Scalar_Tensor", + "index_add", + "dot", + "vdot", + "nuclear_norm", + "histc", + "l1_loss", + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + "prod", + ) +) + + +def should_check_resize(schema: FunctionSchema) -> bool: + schema_str = str(schema) + type_variant_op_name = schema_str[: schema_str.find("(")] + return type_variant_op_name not in no_memory_resize_ops + + +def op_name_from_group(g: NativeFunctionsGroup) -> str: + return g.functional.func.name.name.base + + +class GenOpDispatcher: + def out_variant( + self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = op_name_from_group(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def view( + self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = config.func_name_base_str(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def out_variant_op_generator( + self, g: NativeFunctionsGroup, backend_index: BackendIndex + ) -> str: + functional = g.functional + schema = str(functional.func) + populated_argument = generate_arg_extraction(g.functional.func) + functional_variant_call = generate_non_out_variant_call(g, backend_index) + assert len(g.out.func.arguments.out) == 1 + out_variable_name = str(g.out.func.arguments.out[0].name) + out_variant_call = generate_out_variant_call(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + if (p_node->Output(0).isNone()) {{ + p_node->Output(0) = {functional_variant_call}; + return; + }} + auto& {out_variable_name} = p_node->Output(0).toTensor(); + fastResizeToZero({out_variable_name}); + {out_variant_call}; + }}; + }}""" + return generated + + def view_op_generator( + self, g: NativeFunctionsViewGroup, backend_index: BackendIndex + ) -> str: + schema = str(g.view.func) + populated_argument = generate_arg_extraction(g.view.func) + functional_variant_call = generate_call_to_view_ops(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + p_node->Output(0) = {functional_variant_call}; + }}; + }}""" + return generated + + +class GenOpTestCase: + def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: + schema = g.functional.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = op_name_from_group(g) + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + test_value_definitions2 = generate_test_value_definitions(schema, 1) + test_value_names2 = generate_test_value_names(schema, 1) + check_resize = "true" if should_check_resize(schema) else "false" + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + + {test_value_definitions2} + std::vector args2{{{test_value_names2}}}; + testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + +}} +""" + return generated + + def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: + schema = g.view.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = g.view.root_name + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args); +}} +""" + + return generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/utils.py b/.venv/lib/python3.11/site-packages/torchgen/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d83a27dc9e76b4a7684708dfb2a7fd012c49630 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/utils.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import contextlib +import functools +import hashlib +import os +import re +import sys +import textwrap +from dataclasses import fields, is_dataclass +from enum import auto, Enum +from pathlib import Path +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + Literal, + NoReturn, + Sequence, + TYPE_CHECKING, + TypeVar, +) +from typing_extensions import Self + +from torchgen.code_template import CodeTemplate + + +if TYPE_CHECKING: + from argparse import Namespace + + +REPO_ROOT = Path(__file__).absolute().parent.parent + + +# Many of these functions share logic for defining both the definition +# and declaration (for example, the function signature is the same), so +# we organize them into one function that takes a Target to say which +# code we want. +# +# This is an OPEN enum (we may add more cases to it in the future), so be sure +# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY] +# what targets are valid for your use. +class Target(Enum): + # top level namespace (not including at) + DEFINITION = auto() + DECLARATION = auto() + # TORCH_LIBRARY(...) { ... } + REGISTRATION = auto() + # namespace { ... } + ANONYMOUS_DEFINITION = auto() + # namespace cpu { ... } + NAMESPACED_DEFINITION = auto() + NAMESPACED_DECLARATION = auto() + + +# Matches "foo" in "foo, bar" but not "foobar". Used to search for the +# occurrence of a parameter in the derivative formula +IDENT_REGEX = r"(^|\W){}($|\W)" + + +# TODO: Use a real parser here; this will get bamboozled +def split_name_params(schema: str) -> tuple[str, list[str]]: + m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) + if m is None: + raise RuntimeError(f"Unsupported function schema: {schema}") + name, _, params = m.groups() + return name, params.split(", ") + + +T = TypeVar("T") +S = TypeVar("S") + +# These two functions purposely return generators in analogy to map() +# so that you don't mix up when you need to list() them + + +# Map over function that may return None; omit Nones from output sequence +def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]: + for x in xs: + r = func(x) + if r is not None: + yield r + + +# Map over function that returns sequences and cat them all together +def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: + for x in xs: + yield from func(x) + + +# Conveniently add error context to exceptions raised. Lets us +# easily say that an error occurred while processing a specific +# context. +@contextlib.contextmanager +def context(msg_fn: Callable[[], str]) -> Iterator[None]: + try: + yield + except Exception as e: + # TODO: this does the wrong thing with KeyError + msg = msg_fn() + msg = textwrap.indent(msg, " ") + msg = f"{e.args[0]}\n{msg}" if e.args else msg + e.args = (msg,) + e.args[1:] + raise + + +# A little trick from https://github.com/python/mypy/issues/6366 +# for getting mypy to do exhaustiveness checking +# TODO: put this somewhere else, maybe +def assert_never(x: NoReturn) -> NoReturn: + raise AssertionError(f"Unhandled type: {type(x).__name__}") + + +@functools.lru_cache(maxsize=None) +def _read_template(template_fn: str) -> CodeTemplate: + return CodeTemplate.from_file(template_fn) + + +# String hash that's stable across different executions, unlike builtin hash +def string_stable_hash(s: str) -> int: + sha1 = hashlib.sha1(s.encode("latin1")).digest() + return int.from_bytes(sha1, byteorder="little") + + +# A small abstraction for writing out generated files and keeping track +# of what files have been written (so you can write out a list of output +# files) +class FileManager: + install_dir: str + template_dir: str + dry_run: bool + filenames: set[str] + + def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: + self.install_dir = install_dir + self.template_dir = template_dir + self.filenames = set() + self.dry_run = dry_run + + def _write_if_changed(self, filename: str, contents: str) -> None: + old_contents: str | None + try: + with open(filename) as f: + old_contents = f.read() + except OSError: + old_contents = None + if contents != old_contents: + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w") as f: + f.write(contents) + + # Read from template file and replace pattern with callable (type could be dict or str). + def substitute_with_template( + self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] + ) -> str: + template_path = os.path.join(self.template_dir, template_fn) + env = env_callable() + if isinstance(env, dict): + if "generated_comment" not in env: + generator_default = REPO_ROOT / "torchgen" / "gen.py" + try: + generator = Path( + sys.modules["__main__"].__file__ or generator_default + ).absolute() + except (KeyError, AttributeError): + generator = generator_default.absolute() + + try: + generator_path = generator.relative_to(REPO_ROOT).as_posix() + except ValueError: + generator_path = generator.name + + env = { + **env, # copy the original dict instead of mutating it + "generated_comment": ( + "@" + f"generated by {generator_path} from {template_fn}" + ), + } + template = _read_template(template_path) + return template.substitute(env) + elif isinstance(env, str): + return env + else: + assert_never(env) + + def write_with_template( + self, + filename: str, + template_fn: str, + env_callable: Callable[[], str | dict[str, Any]], + ) -> None: + filename = f"{self.install_dir}/{filename}" + assert filename not in self.filenames, "duplicate file write {filename}" + self.filenames.add(filename) + if not self.dry_run: + substitute_out = self.substitute_with_template( + template_fn=template_fn, + env_callable=env_callable, + ) + self._write_if_changed(filename=filename, contents=substitute_out) + + def write( + self, + filename: str, + env_callable: Callable[[], str | dict[str, Any]], + ) -> None: + self.write_with_template(filename, filename, env_callable) + + def write_sharded( + self, + filename: str, + items: Iterable[T], + *, + key_fn: Callable[[T], str], + env_callable: Callable[[T], dict[str, list[str]]], + num_shards: int, + base_env: dict[str, Any] | None = None, + sharded_keys: set[str], + ) -> None: + everything: dict[str, Any] = {"shard_id": "Everything"} + shards: list[dict[str, Any]] = [ + {"shard_id": f"_{i}"} for i in range(num_shards) + ] + all_shards = [everything] + shards + + if base_env is not None: + for shard in all_shards: + shard.update(base_env) + + for key in sharded_keys: + for shard in all_shards: + if key in shard: + assert isinstance( + shard[key], list + ), "sharded keys in base_env must be a list" + shard[key] = shard[key].copy() + else: + shard[key] = [] + + def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None: + for k, v in from_.items(): + assert k in sharded_keys, f"undeclared sharded key {k}" + into[k] += v + + if self.dry_run: + # Dry runs don't write any templates, so incomplete environments are fine + items = () + + for item in items: + key = key_fn(item) + sid = string_stable_hash(key) % num_shards + env = env_callable(item) + + merge_env(shards[sid], env) + merge_env(everything, env) + + dot_pos = filename.rfind(".") + if dot_pos == -1: + dot_pos = len(filename) + base_filename = filename[:dot_pos] + extension = filename[dot_pos:] + + for shard in all_shards: + shard_id = shard["shard_id"] + self.write_with_template( + f"{base_filename}{shard_id}{extension}", filename, lambda: shard + ) + + # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled + self.filenames.discard( + f"{self.install_dir}/{base_filename}Everything{extension}" + ) + + def write_outputs(self, variable_name: str, filename: str) -> None: + """Write a file containing the list of all outputs which are + generated by this script.""" + content = "set({}\n {})".format( + variable_name, + "\n ".join('"' + name + '"' for name in sorted(self.filenames)), + ) + self._write_if_changed(filename, content) + + def template_dir_for_comments(self) -> str: + """ + This needs to be deterministic. The template dir is an absolute path + that varies across builds. So, just use the path relative to this file, + which will point to the codegen source but will be stable. + """ + return os.path.relpath(self.template_dir, os.path.dirname(__file__)) + + +# Helper function to generate file manager +def make_file_manager( + options: Namespace, install_dir: str | None = None +) -> FileManager: + template_dir = os.path.join(options.source_path, "templates") + install_dir = install_dir if install_dir else options.install_dir + return FileManager( + install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run + ) + + +# Helper function to create a pretty representation for dataclasses +def dataclass_repr( + obj: Any, + indent: int = 0, + width: int = 80, +) -> str: + # built-in pprint module support dataclasses from python 3.10 + if sys.version_info >= (3, 10): + from pprint import pformat + + return pformat(obj, indent, width) + + return _pformat(obj, indent=indent, width=width) + + +def _pformat( + obj: Any, + indent: int, + width: int, + curr_indent: int = 0, +) -> str: + assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}" + + class_name = obj.__class__.__name__ + # update current indentation level with class name + curr_indent += len(class_name) + 1 + + fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr] + + fields_str = [] + for name, attr in fields_list: + # update the current indent level with the field name + # dict, list, set and tuple also add indent as done in pprint + _curr_indent = curr_indent + len(name) + 1 + if is_dataclass(attr): + str_repr = _pformat(attr, indent, width, _curr_indent) + elif isinstance(attr, dict): + str_repr = _format_dict(attr, indent, width, _curr_indent) + elif isinstance(attr, (list, set, tuple)): + str_repr = _format_list(attr, indent, width, _curr_indent) + else: + str_repr = repr(attr) + + fields_str.append(f"{name}={str_repr}") + + indent_str = curr_indent * " " + body = f",\n{indent_str}".join(fields_str) + return f"{class_name}({body})" + + +def _format_dict( + attr: dict[Any, Any], + indent: int, + width: int, + curr_indent: int, +) -> str: + curr_indent += indent + 3 + dict_repr = [] + for k, v in attr.items(): + k_repr = repr(k) + v_str = ( + _pformat(v, indent, width, curr_indent + len(k_repr)) + if is_dataclass(v) + else repr(v) + ) + dict_repr.append(f"{k_repr}: {v_str}") + + return _format(dict_repr, indent, width, curr_indent, "{", "}") + + +def _format_list( + attr: list[Any] | set[Any] | tuple[Any, ...], + indent: int, + width: int, + curr_indent: int, +) -> str: + curr_indent += indent + 1 + list_repr = [ + _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l) + for l in attr + ] + start, end = ("[", "]") if isinstance(attr, list) else ("(", ")") + return _format(list_repr, indent, width, curr_indent, start, end) + + +def _format( + fields_str: list[str], + indent: int, + width: int, + curr_indent: int, + start: str, + end: str, +) -> str: + delimiter, curr_indent_str = "", "" + # if it exceed the max width then we place one element per line + if len(repr(fields_str)) >= width: + delimiter = "\n" + curr_indent_str = " " * curr_indent + + indent_str = " " * indent + body = f", {delimiter}{curr_indent_str}".join(fields_str) + return f"{start}{indent_str}{body}{end}" + + +class NamespaceHelper: + """A helper for constructing the namespace open and close strings for a nested set of namespaces. + + e.g. for namespace_str torch::lazy, + + prologue: + namespace torch { + namespace lazy { + + epilogue: + } // namespace lazy + } // namespace torch + """ + + def __init__( + self, namespace_str: str, entity_name: str = "", max_level: int = 2 + ) -> None: + # cpp_namespace can be a colon joined string such as torch::lazy + cpp_namespaces = namespace_str.split("::") + assert ( + len(cpp_namespaces) <= max_level + ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." + self.cpp_namespace_ = namespace_str + self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) + self.epilogue_ = "\n".join( + [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] + ) + self.namespaces_ = cpp_namespaces + self.entity_name_ = entity_name + + @staticmethod + def from_namespaced_entity( + namespaced_entity: str, max_level: int = 2 + ) -> NamespaceHelper: + """ + Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" + """ + names = namespaced_entity.split("::") + entity_name = names[-1] + namespace_str = "::".join(names[:-1]) + return NamespaceHelper( + namespace_str=namespace_str, entity_name=entity_name, max_level=max_level + ) + + @property + def prologue(self) -> str: + return self.prologue_ + + @property + def epilogue(self) -> str: + return self.epilogue_ + + @property + def entity_name(self) -> str: + return self.entity_name_ + + # Only allow certain level of namespaces + def get_cpp_namespace(self, default: str = "") -> str: + """ + Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). + Return default if namespace string is empty. + """ + return self.cpp_namespace_ if self.cpp_namespace_ else default + + +class OrderedSet(Generic[T]): + storage: dict[T, Literal[None]] + + def __init__(self, iterable: Iterable[T] | None = None) -> None: + if iterable is None: + self.storage = {} + else: + self.storage = dict.fromkeys(iterable) + + def __contains__(self, item: T) -> bool: + return item in self.storage + + def __iter__(self) -> Iterator[T]: + return iter(self.storage.keys()) + + def update(self, items: OrderedSet[T]) -> None: + self.storage.update(items.storage) + + def add(self, item: T) -> None: + self.storage[item] = None + + def copy(self) -> OrderedSet[T]: + ret: OrderedSet[T] = OrderedSet() + ret.storage = self.storage.copy() + return ret + + @staticmethod + def union(*args: OrderedSet[T]) -> OrderedSet[T]: + ret = args[0].copy() + for s in args[1:]: + ret.update(s) + return ret + + def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: + return OrderedSet.union(self, other) + + def __ior__(self, other: OrderedSet[T]) -> Self: + self.update(other) + return self + + def __eq__(self, other: object) -> bool: + if isinstance(other, OrderedSet): + return self.storage == other.storage + else: + return set(self.storage.keys()) == other diff --git a/.venv/lib/python3.11/site-packages/torchgen/yaml_utils.py b/.venv/lib/python3.11/site-packages/torchgen/yaml_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0278af84bf633a85a857b6bd7798dd64cb8259dc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/yaml_utils.py @@ -0,0 +1,26 @@ +# Safely load fast C Yaml loader/dumper if they are available +try: + from yaml import CSafeLoader as Loader +except ImportError: + from yaml import SafeLoader as Loader # type: ignore[assignment, misc] + +try: + from yaml import CSafeDumper as Dumper +except ImportError: + from yaml import SafeDumper as Dumper # type: ignore[assignment, misc] +YamlDumper = Dumper + + +# A custom loader for YAML that errors on duplicate keys. +# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165 +class YamlLoader(Loader): + def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] + mapping = [] + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] + assert ( + key not in mapping + ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}" + mapping.append(key) + mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] + return mapping