koichi12 commited on
Commit
587c1a9
·
verified ·
1 Parent(s): e90af2b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/numpy/polynomial/__init__.py +185 -0
  4. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/_polybase.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/chebyshev.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite_e.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/laguerre.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/legendre.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polynomial.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polyutils.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/setup.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.py +1206 -0
  15. .venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.pyi +71 -0
  16. .venv/lib/python3.11/site-packages/numpy/polynomial/chebyshev.py +2082 -0
  17. .venv/lib/python3.11/site-packages/numpy/polynomial/hermite.pyi +46 -0
  18. .venv/lib/python3.11/site-packages/numpy/polynomial/polynomial.py +1542 -0
  19. .venv/lib/python3.11/site-packages/numpy/polynomial/polyutils.py +789 -0
  20. .venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_chebyshev.py +619 -0
  21. .venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_hermite_e.py +556 -0
  22. .venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_printing.py +530 -0
  23. .venv/lib/python3.11/site-packages/torchgen/__init__.py +10 -0
  24. .venv/lib/python3.11/site-packages/torchgen/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torchgen/__pycache__/utils.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torchgen/aoti/__init__.py +0 -0
  39. .venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchgen/aoti/fallback_ops.py +149 -0
  42. .venv/lib/python3.11/site-packages/torchgen/code_template.py +99 -0
  43. .venv/lib/python3.11/site-packages/torchgen/context.py +130 -0
  44. .venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py +48 -0
  45. .venv/lib/python3.11/site-packages/torchgen/gen.py +0 -0
  46. .venv/lib/python3.11/site-packages/torchgen/gen_aoti_c_shim.py +486 -0
  47. .venv/lib/python3.11/site-packages/torchgen/gen_backend_stubs.py +611 -0
  48. .venv/lib/python3.11/site-packages/torchgen/gen_executorch.py +998 -0
  49. .venv/lib/python3.11/site-packages/torchgen/gen_functionalization_type.py +882 -0
  50. .venv/lib/python3.11/site-packages/torchgen/gen_lazy_tensor.py +581 -0
.gitattributes CHANGED
@@ -397,3 +397,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
397
  .venv/lib/python3.11/site-packages/mistral_common/data/tokenizer.model.v1 filter=lfs diff=lfs merge=lfs -text
398
  .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text
399
  .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
397
  .venv/lib/python3.11/site-packages/mistral_common/data/tokenizer.model.v1 filter=lfs diff=lfs merge=lfs -text
398
  .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text
399
  .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
400
+ .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07faac212a7a262c6ea1fffc03378750fe6bb57a142cd64278e8827f652c7424
3
+ size 390546
.venv/lib/python3.11/site-packages/numpy/polynomial/__init__.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A sub-package for efficiently dealing with polynomials.
3
+
4
+ Within the documentation for this sub-package, a "finite power series,"
5
+ i.e., a polynomial (also referred to simply as a "series") is represented
6
+ by a 1-D numpy array of the polynomial's coefficients, ordered from lowest
7
+ order term to highest. For example, array([1,2,3]) represents
8
+ ``P_0 + 2*P_1 + 3*P_2``, where P_n is the n-th order basis polynomial
9
+ applicable to the specific module in question, e.g., `polynomial` (which
10
+ "wraps" the "standard" basis) or `chebyshev`. For optimal performance,
11
+ all operations on polynomials, including evaluation at an argument, are
12
+ implemented as operations on the coefficients. Additional (module-specific)
13
+ information can be found in the docstring for the module of interest.
14
+
15
+ This package provides *convenience classes* for each of six different kinds
16
+ of polynomials:
17
+
18
+ ======================== ================
19
+ **Name** **Provides**
20
+ ======================== ================
21
+ `~polynomial.Polynomial` Power series
22
+ `~chebyshev.Chebyshev` Chebyshev series
23
+ `~legendre.Legendre` Legendre series
24
+ `~laguerre.Laguerre` Laguerre series
25
+ `~hermite.Hermite` Hermite series
26
+ `~hermite_e.HermiteE` HermiteE series
27
+ ======================== ================
28
+
29
+ These *convenience classes* provide a consistent interface for creating,
30
+ manipulating, and fitting data with polynomials of different bases.
31
+ The convenience classes are the preferred interface for the `~numpy.polynomial`
32
+ package, and are available from the ``numpy.polynomial`` namespace.
33
+ This eliminates the need to navigate to the corresponding submodules, e.g.
34
+ ``np.polynomial.Polynomial`` or ``np.polynomial.Chebyshev`` instead of
35
+ ``np.polynomial.polynomial.Polynomial`` or
36
+ ``np.polynomial.chebyshev.Chebyshev``, respectively.
37
+ The classes provide a more consistent and concise interface than the
38
+ type-specific functions defined in the submodules for each type of polynomial.
39
+ For example, to fit a Chebyshev polynomial with degree ``1`` to data given
40
+ by arrays ``xdata`` and ``ydata``, the
41
+ `~chebyshev.Chebyshev.fit` class method::
42
+
43
+ >>> from numpy.polynomial import Chebyshev
44
+ >>> c = Chebyshev.fit(xdata, ydata, deg=1)
45
+
46
+ is preferred over the `chebyshev.chebfit` function from the
47
+ ``np.polynomial.chebyshev`` module::
48
+
49
+ >>> from numpy.polynomial.chebyshev import chebfit
50
+ >>> c = chebfit(xdata, ydata, deg=1)
51
+
52
+ See :doc:`routines.polynomials.classes` for more details.
53
+
54
+ Convenience Classes
55
+ ===================
56
+
57
+ The following lists the various constants and methods common to all of
58
+ the classes representing the various kinds of polynomials. In the following,
59
+ the term ``Poly`` represents any one of the convenience classes (e.g.
60
+ `~polynomial.Polynomial`, `~chebyshev.Chebyshev`, `~hermite.Hermite`, etc.)
61
+ while the lowercase ``p`` represents an **instance** of a polynomial class.
62
+
63
+ Constants
64
+ ---------
65
+
66
+ - ``Poly.domain`` -- Default domain
67
+ - ``Poly.window`` -- Default window
68
+ - ``Poly.basis_name`` -- String used to represent the basis
69
+ - ``Poly.maxpower`` -- Maximum value ``n`` such that ``p**n`` is allowed
70
+ - ``Poly.nickname`` -- String used in printing
71
+
72
+ Creation
73
+ --------
74
+
75
+ Methods for creating polynomial instances.
76
+
77
+ - ``Poly.basis(degree)`` -- Basis polynomial of given degree
78
+ - ``Poly.identity()`` -- ``p`` where ``p(x) = x`` for all ``x``
79
+ - ``Poly.fit(x, y, deg)`` -- ``p`` of degree ``deg`` with coefficients
80
+ determined by the least-squares fit to the data ``x``, ``y``
81
+ - ``Poly.fromroots(roots)`` -- ``p`` with specified roots
82
+ - ``p.copy()`` -- Create a copy of ``p``
83
+
84
+ Conversion
85
+ ----------
86
+
87
+ Methods for converting a polynomial instance of one kind to another.
88
+
89
+ - ``p.cast(Poly)`` -- Convert ``p`` to instance of kind ``Poly``
90
+ - ``p.convert(Poly)`` -- Convert ``p`` to instance of kind ``Poly`` or map
91
+ between ``domain`` and ``window``
92
+
93
+ Calculus
94
+ --------
95
+ - ``p.deriv()`` -- Take the derivative of ``p``
96
+ - ``p.integ()`` -- Integrate ``p``
97
+
98
+ Validation
99
+ ----------
100
+ - ``Poly.has_samecoef(p1, p2)`` -- Check if coefficients match
101
+ - ``Poly.has_samedomain(p1, p2)`` -- Check if domains match
102
+ - ``Poly.has_sametype(p1, p2)`` -- Check if types match
103
+ - ``Poly.has_samewindow(p1, p2)`` -- Check if windows match
104
+
105
+ Misc
106
+ ----
107
+ - ``p.linspace()`` -- Return ``x, p(x)`` at equally-spaced points in ``domain``
108
+ - ``p.mapparms()`` -- Return the parameters for the linear mapping between
109
+ ``domain`` and ``window``.
110
+ - ``p.roots()`` -- Return the roots of `p`.
111
+ - ``p.trim()`` -- Remove trailing coefficients.
112
+ - ``p.cutdeg(degree)`` -- Truncate p to given degree
113
+ - ``p.truncate(size)`` -- Truncate p to given size
114
+
115
+ """
116
+ from .polynomial import Polynomial
117
+ from .chebyshev import Chebyshev
118
+ from .legendre import Legendre
119
+ from .hermite import Hermite
120
+ from .hermite_e import HermiteE
121
+ from .laguerre import Laguerre
122
+
123
+ __all__ = [
124
+ "set_default_printstyle",
125
+ "polynomial", "Polynomial",
126
+ "chebyshev", "Chebyshev",
127
+ "legendre", "Legendre",
128
+ "hermite", "Hermite",
129
+ "hermite_e", "HermiteE",
130
+ "laguerre", "Laguerre",
131
+ ]
132
+
133
+
134
+ def set_default_printstyle(style):
135
+ """
136
+ Set the default format for the string representation of polynomials.
137
+
138
+ Values for ``style`` must be valid inputs to ``__format__``, i.e. 'ascii'
139
+ or 'unicode'.
140
+
141
+ Parameters
142
+ ----------
143
+ style : str
144
+ Format string for default printing style. Must be either 'ascii' or
145
+ 'unicode'.
146
+
147
+ Notes
148
+ -----
149
+ The default format depends on the platform: 'unicode' is used on
150
+ Unix-based systems and 'ascii' on Windows. This determination is based on
151
+ default font support for the unicode superscript and subscript ranges.
152
+
153
+ Examples
154
+ --------
155
+ >>> p = np.polynomial.Polynomial([1, 2, 3])
156
+ >>> c = np.polynomial.Chebyshev([1, 2, 3])
157
+ >>> np.polynomial.set_default_printstyle('unicode')
158
+ >>> print(p)
159
+ 1.0 + 2.0·x + 3.0·x²
160
+ >>> print(c)
161
+ 1.0 + 2.0·T₁(x) + 3.0·T₂(x)
162
+ >>> np.polynomial.set_default_printstyle('ascii')
163
+ >>> print(p)
164
+ 1.0 + 2.0 x + 3.0 x**2
165
+ >>> print(c)
166
+ 1.0 + 2.0 T_1(x) + 3.0 T_2(x)
167
+ >>> # Formatting supersedes all class/package-level defaults
168
+ >>> print(f"{p:unicode}")
169
+ 1.0 + 2.0·x + 3.0·x²
170
+ """
171
+ if style not in ('unicode', 'ascii'):
172
+ raise ValueError(
173
+ f"Unsupported format string '{style}'. Valid options are 'ascii' "
174
+ f"and 'unicode'"
175
+ )
176
+ _use_unicode = True
177
+ if style == 'ascii':
178
+ _use_unicode = False
179
+ from ._polybase import ABCPolyBase
180
+ ABCPolyBase._use_unicode = _use_unicode
181
+
182
+
183
+ from numpy._pytesttester import PytestTester
184
+ test = PytestTester(__name__)
185
+ del PytestTester
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.19 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/_polybase.cpython-311.pyc ADDED
Binary file (50 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/chebyshev.cpython-311.pyc ADDED
Binary file (74.6 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite.cpython-311.pyc ADDED
Binary file (62.1 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/hermite_e.cpython-311.pyc ADDED
Binary file (61.9 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/laguerre.cpython-311.pyc ADDED
Binary file (59.8 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/legendre.cpython-311.pyc ADDED
Binary file (60.4 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polynomial.cpython-311.pyc ADDED
Binary file (56.4 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/polyutils.cpython-311.pyc ADDED
Binary file (32.8 kB). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/__pycache__/setup.cpython-311.pyc ADDED
Binary file (844 Bytes). View file
 
.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for the various polynomial Classes.
3
+
4
+ The ABCPolyBase class provides the methods needed to implement the common API
5
+ for the various polynomial classes. It operates as a mixin, but uses the
6
+ abc module from the stdlib, hence it is only available for Python >= 2.6.
7
+
8
+ """
9
+ import os
10
+ import abc
11
+ import numbers
12
+
13
+ import numpy as np
14
+ from . import polyutils as pu
15
+
16
+ __all__ = ['ABCPolyBase']
17
+
18
+ class ABCPolyBase(abc.ABC):
19
+ """An abstract base class for immutable series classes.
20
+
21
+ ABCPolyBase provides the standard Python numerical methods
22
+ '+', '-', '*', '//', '%', 'divmod', '**', and '()' along with the
23
+ methods listed below.
24
+
25
+ .. versionadded:: 1.9.0
26
+
27
+ Parameters
28
+ ----------
29
+ coef : array_like
30
+ Series coefficients in order of increasing degree, i.e.,
31
+ ``(1, 2, 3)`` gives ``1*P_0(x) + 2*P_1(x) + 3*P_2(x)``, where
32
+ ``P_i`` is the basis polynomials of degree ``i``.
33
+ domain : (2,) array_like, optional
34
+ Domain to use. The interval ``[domain[0], domain[1]]`` is mapped
35
+ to the interval ``[window[0], window[1]]`` by shifting and scaling.
36
+ The default value is the derived class domain.
37
+ window : (2,) array_like, optional
38
+ Window, see domain for its use. The default value is the
39
+ derived class window.
40
+ symbol : str, optional
41
+ Symbol used to represent the independent variable in string
42
+ representations of the polynomial expression, e.g. for printing.
43
+ The symbol must be a valid Python identifier. Default value is 'x'.
44
+
45
+ .. versionadded:: 1.24
46
+
47
+ Attributes
48
+ ----------
49
+ coef : (N,) ndarray
50
+ Series coefficients in order of increasing degree.
51
+ domain : (2,) ndarray
52
+ Domain that is mapped to window.
53
+ window : (2,) ndarray
54
+ Window that domain is mapped to.
55
+ symbol : str
56
+ Symbol representing the independent variable.
57
+
58
+ Class Attributes
59
+ ----------------
60
+ maxpower : int
61
+ Maximum power allowed, i.e., the largest number ``n`` such that
62
+ ``p(x)**n`` is allowed. This is to limit runaway polynomial size.
63
+ domain : (2,) ndarray
64
+ Default domain of the class.
65
+ window : (2,) ndarray
66
+ Default window of the class.
67
+
68
+ """
69
+
70
+ # Not hashable
71
+ __hash__ = None
72
+
73
+ # Opt out of numpy ufuncs and Python ops with ndarray subclasses.
74
+ __array_ufunc__ = None
75
+
76
+ # Limit runaway size. T_n^m has degree n*m
77
+ maxpower = 100
78
+
79
+ # Unicode character mappings for improved __str__
80
+ _superscript_mapping = str.maketrans({
81
+ "0": "⁰",
82
+ "1": "¹",
83
+ "2": "²",
84
+ "3": "³",
85
+ "4": "⁴",
86
+ "5": "⁵",
87
+ "6": "⁶",
88
+ "7": "⁷",
89
+ "8": "⁸",
90
+ "9": "⁹"
91
+ })
92
+ _subscript_mapping = str.maketrans({
93
+ "0": "₀",
94
+ "1": "₁",
95
+ "2": "₂",
96
+ "3": "₃",
97
+ "4": "₄",
98
+ "5": "₅",
99
+ "6": "₆",
100
+ "7": "₇",
101
+ "8": "₈",
102
+ "9": "₉"
103
+ })
104
+ # Some fonts don't support full unicode character ranges necessary for
105
+ # the full set of superscripts and subscripts, including common/default
106
+ # fonts in Windows shells/terminals. Therefore, default to ascii-only
107
+ # printing on windows.
108
+ _use_unicode = not os.name == 'nt'
109
+
110
+ @property
111
+ def symbol(self):
112
+ return self._symbol
113
+
114
+ @property
115
+ @abc.abstractmethod
116
+ def domain(self):
117
+ pass
118
+
119
+ @property
120
+ @abc.abstractmethod
121
+ def window(self):
122
+ pass
123
+
124
+ @property
125
+ @abc.abstractmethod
126
+ def basis_name(self):
127
+ pass
128
+
129
+ @staticmethod
130
+ @abc.abstractmethod
131
+ def _add(c1, c2):
132
+ pass
133
+
134
+ @staticmethod
135
+ @abc.abstractmethod
136
+ def _sub(c1, c2):
137
+ pass
138
+
139
+ @staticmethod
140
+ @abc.abstractmethod
141
+ def _mul(c1, c2):
142
+ pass
143
+
144
+ @staticmethod
145
+ @abc.abstractmethod
146
+ def _div(c1, c2):
147
+ pass
148
+
149
+ @staticmethod
150
+ @abc.abstractmethod
151
+ def _pow(c, pow, maxpower=None):
152
+ pass
153
+
154
+ @staticmethod
155
+ @abc.abstractmethod
156
+ def _val(x, c):
157
+ pass
158
+
159
+ @staticmethod
160
+ @abc.abstractmethod
161
+ def _int(c, m, k, lbnd, scl):
162
+ pass
163
+
164
+ @staticmethod
165
+ @abc.abstractmethod
166
+ def _der(c, m, scl):
167
+ pass
168
+
169
+ @staticmethod
170
+ @abc.abstractmethod
171
+ def _fit(x, y, deg, rcond, full):
172
+ pass
173
+
174
+ @staticmethod
175
+ @abc.abstractmethod
176
+ def _line(off, scl):
177
+ pass
178
+
179
+ @staticmethod
180
+ @abc.abstractmethod
181
+ def _roots(c):
182
+ pass
183
+
184
+ @staticmethod
185
+ @abc.abstractmethod
186
+ def _fromroots(r):
187
+ pass
188
+
189
+ def has_samecoef(self, other):
190
+ """Check if coefficients match.
191
+
192
+ .. versionadded:: 1.6.0
193
+
194
+ Parameters
195
+ ----------
196
+ other : class instance
197
+ The other class must have the ``coef`` attribute.
198
+
199
+ Returns
200
+ -------
201
+ bool : boolean
202
+ True if the coefficients are the same, False otherwise.
203
+
204
+ """
205
+ if len(self.coef) != len(other.coef):
206
+ return False
207
+ elif not np.all(self.coef == other.coef):
208
+ return False
209
+ else:
210
+ return True
211
+
212
+ def has_samedomain(self, other):
213
+ """Check if domains match.
214
+
215
+ .. versionadded:: 1.6.0
216
+
217
+ Parameters
218
+ ----------
219
+ other : class instance
220
+ The other class must have the ``domain`` attribute.
221
+
222
+ Returns
223
+ -------
224
+ bool : boolean
225
+ True if the domains are the same, False otherwise.
226
+
227
+ """
228
+ return np.all(self.domain == other.domain)
229
+
230
+ def has_samewindow(self, other):
231
+ """Check if windows match.
232
+
233
+ .. versionadded:: 1.6.0
234
+
235
+ Parameters
236
+ ----------
237
+ other : class instance
238
+ The other class must have the ``window`` attribute.
239
+
240
+ Returns
241
+ -------
242
+ bool : boolean
243
+ True if the windows are the same, False otherwise.
244
+
245
+ """
246
+ return np.all(self.window == other.window)
247
+
248
+ def has_sametype(self, other):
249
+ """Check if types match.
250
+
251
+ .. versionadded:: 1.7.0
252
+
253
+ Parameters
254
+ ----------
255
+ other : object
256
+ Class instance.
257
+
258
+ Returns
259
+ -------
260
+ bool : boolean
261
+ True if other is same class as self
262
+
263
+ """
264
+ return isinstance(other, self.__class__)
265
+
266
+ def _get_coefficients(self, other):
267
+ """Interpret other as polynomial coefficients.
268
+
269
+ The `other` argument is checked to see if it is of the same
270
+ class as self with identical domain and window. If so,
271
+ return its coefficients, otherwise return `other`.
272
+
273
+ .. versionadded:: 1.9.0
274
+
275
+ Parameters
276
+ ----------
277
+ other : anything
278
+ Object to be checked.
279
+
280
+ Returns
281
+ -------
282
+ coef
283
+ The coefficients of`other` if it is a compatible instance,
284
+ of ABCPolyBase, otherwise `other`.
285
+
286
+ Raises
287
+ ------
288
+ TypeError
289
+ When `other` is an incompatible instance of ABCPolyBase.
290
+
291
+ """
292
+ if isinstance(other, ABCPolyBase):
293
+ if not isinstance(other, self.__class__):
294
+ raise TypeError("Polynomial types differ")
295
+ elif not np.all(self.domain == other.domain):
296
+ raise TypeError("Domains differ")
297
+ elif not np.all(self.window == other.window):
298
+ raise TypeError("Windows differ")
299
+ elif self.symbol != other.symbol:
300
+ raise ValueError("Polynomial symbols differ")
301
+ return other.coef
302
+ return other
303
+
304
+ def __init__(self, coef, domain=None, window=None, symbol='x'):
305
+ [coef] = pu.as_series([coef], trim=False)
306
+ self.coef = coef
307
+
308
+ if domain is not None:
309
+ [domain] = pu.as_series([domain], trim=False)
310
+ if len(domain) != 2:
311
+ raise ValueError("Domain has wrong number of elements.")
312
+ self.domain = domain
313
+
314
+ if window is not None:
315
+ [window] = pu.as_series([window], trim=False)
316
+ if len(window) != 2:
317
+ raise ValueError("Window has wrong number of elements.")
318
+ self.window = window
319
+
320
+ # Validation for symbol
321
+ try:
322
+ if not symbol.isidentifier():
323
+ raise ValueError(
324
+ "Symbol string must be a valid Python identifier"
325
+ )
326
+ # If a user passes in something other than a string, the above
327
+ # results in an AttributeError. Catch this and raise a more
328
+ # informative exception
329
+ except AttributeError:
330
+ raise TypeError("Symbol must be a non-empty string")
331
+
332
+ self._symbol = symbol
333
+
334
+ def __repr__(self):
335
+ coef = repr(self.coef)[6:-1]
336
+ domain = repr(self.domain)[6:-1]
337
+ window = repr(self.window)[6:-1]
338
+ name = self.__class__.__name__
339
+ return (f"{name}({coef}, domain={domain}, window={window}, "
340
+ f"symbol='{self.symbol}')")
341
+
342
+ def __format__(self, fmt_str):
343
+ if fmt_str == '':
344
+ return self.__str__()
345
+ if fmt_str not in ('ascii', 'unicode'):
346
+ raise ValueError(
347
+ f"Unsupported format string '{fmt_str}' passed to "
348
+ f"{self.__class__}.__format__. Valid options are "
349
+ f"'ascii' and 'unicode'"
350
+ )
351
+ if fmt_str == 'ascii':
352
+ return self._generate_string(self._str_term_ascii)
353
+ return self._generate_string(self._str_term_unicode)
354
+
355
+ def __str__(self):
356
+ if self._use_unicode:
357
+ return self._generate_string(self._str_term_unicode)
358
+ return self._generate_string(self._str_term_ascii)
359
+
360
+ def _generate_string(self, term_method):
361
+ """
362
+ Generate the full string representation of the polynomial, using
363
+ ``term_method`` to generate each polynomial term.
364
+ """
365
+ # Get configuration for line breaks
366
+ linewidth = np.get_printoptions().get('linewidth', 75)
367
+ if linewidth < 1:
368
+ linewidth = 1
369
+ out = pu.format_float(self.coef[0])
370
+ for i, coef in enumerate(self.coef[1:]):
371
+ out += " "
372
+ power = str(i + 1)
373
+ # Polynomial coefficient
374
+ # The coefficient array can be an object array with elements that
375
+ # will raise a TypeError with >= 0 (e.g. strings or Python
376
+ # complex). In this case, represent the coefficient as-is.
377
+ try:
378
+ if coef >= 0:
379
+ next_term = f"+ " + pu.format_float(coef, parens=True)
380
+ else:
381
+ next_term = f"- " + pu.format_float(-coef, parens=True)
382
+ except TypeError:
383
+ next_term = f"+ {coef}"
384
+ # Polynomial term
385
+ next_term += term_method(power, self.symbol)
386
+ # Length of the current line with next term added
387
+ line_len = len(out.split('\n')[-1]) + len(next_term)
388
+ # If not the last term in the polynomial, it will be two
389
+ # characters longer due to the +/- with the next term
390
+ if i < len(self.coef[1:]) - 1:
391
+ line_len += 2
392
+ # Handle linebreaking
393
+ if line_len >= linewidth:
394
+ next_term = next_term.replace(" ", "\n", 1)
395
+ out += next_term
396
+ return out
397
+
398
+ @classmethod
399
+ def _str_term_unicode(cls, i, arg_str):
400
+ """
401
+ String representation of single polynomial term using unicode
402
+ characters for superscripts and subscripts.
403
+ """
404
+ if cls.basis_name is None:
405
+ raise NotImplementedError(
406
+ "Subclasses must define either a basis_name, or override "
407
+ "_str_term_unicode(cls, i, arg_str)"
408
+ )
409
+ return (f"·{cls.basis_name}{i.translate(cls._subscript_mapping)}"
410
+ f"({arg_str})")
411
+
412
+ @classmethod
413
+ def _str_term_ascii(cls, i, arg_str):
414
+ """
415
+ String representation of a single polynomial term using ** and _ to
416
+ represent superscripts and subscripts, respectively.
417
+ """
418
+ if cls.basis_name is None:
419
+ raise NotImplementedError(
420
+ "Subclasses must define either a basis_name, or override "
421
+ "_str_term_ascii(cls, i, arg_str)"
422
+ )
423
+ return f" {cls.basis_name}_{i}({arg_str})"
424
+
425
+ @classmethod
426
+ def _repr_latex_term(cls, i, arg_str, needs_parens):
427
+ if cls.basis_name is None:
428
+ raise NotImplementedError(
429
+ "Subclasses must define either a basis name, or override "
430
+ "_repr_latex_term(i, arg_str, needs_parens)")
431
+ # since we always add parens, we don't care if the expression needs them
432
+ return f"{{{cls.basis_name}}}_{{{i}}}({arg_str})"
433
+
434
+ @staticmethod
435
+ def _repr_latex_scalar(x, parens=False):
436
+ # TODO: we're stuck with disabling math formatting until we handle
437
+ # exponents in this function
438
+ return r'\text{{{}}}'.format(pu.format_float(x, parens=parens))
439
+
440
+ def _repr_latex_(self):
441
+ # get the scaled argument string to the basis functions
442
+ off, scale = self.mapparms()
443
+ if off == 0 and scale == 1:
444
+ term = self.symbol
445
+ needs_parens = False
446
+ elif scale == 1:
447
+ term = f"{self._repr_latex_scalar(off)} + {self.symbol}"
448
+ needs_parens = True
449
+ elif off == 0:
450
+ term = f"{self._repr_latex_scalar(scale)}{self.symbol}"
451
+ needs_parens = True
452
+ else:
453
+ term = (
454
+ f"{self._repr_latex_scalar(off)} + "
455
+ f"{self._repr_latex_scalar(scale)}{self.symbol}"
456
+ )
457
+ needs_parens = True
458
+
459
+ mute = r"\color{{LightGray}}{{{}}}".format
460
+
461
+ parts = []
462
+ for i, c in enumerate(self.coef):
463
+ # prevent duplication of + and - signs
464
+ if i == 0:
465
+ coef_str = f"{self._repr_latex_scalar(c)}"
466
+ elif not isinstance(c, numbers.Real):
467
+ coef_str = f" + ({self._repr_latex_scalar(c)})"
468
+ elif not np.signbit(c):
469
+ coef_str = f" + {self._repr_latex_scalar(c, parens=True)}"
470
+ else:
471
+ coef_str = f" - {self._repr_latex_scalar(-c, parens=True)}"
472
+
473
+ # produce the string for the term
474
+ term_str = self._repr_latex_term(i, term, needs_parens)
475
+ if term_str == '1':
476
+ part = coef_str
477
+ else:
478
+ part = rf"{coef_str}\,{term_str}"
479
+
480
+ if c == 0:
481
+ part = mute(part)
482
+
483
+ parts.append(part)
484
+
485
+ if parts:
486
+ body = ''.join(parts)
487
+ else:
488
+ # in case somehow there are no coefficients at all
489
+ body = '0'
490
+
491
+ return rf"${self.symbol} \mapsto {body}$"
492
+
493
+
494
+
495
+ # Pickle and copy
496
+
497
+ def __getstate__(self):
498
+ ret = self.__dict__.copy()
499
+ ret['coef'] = self.coef.copy()
500
+ ret['domain'] = self.domain.copy()
501
+ ret['window'] = self.window.copy()
502
+ ret['symbol'] = self.symbol
503
+ return ret
504
+
505
+ def __setstate__(self, dict):
506
+ self.__dict__ = dict
507
+
508
+ # Call
509
+
510
+ def __call__(self, arg):
511
+ off, scl = pu.mapparms(self.domain, self.window)
512
+ arg = off + scl*arg
513
+ return self._val(arg, self.coef)
514
+
515
+ def __iter__(self):
516
+ return iter(self.coef)
517
+
518
+ def __len__(self):
519
+ return len(self.coef)
520
+
521
+ # Numeric properties.
522
+
523
+ def __neg__(self):
524
+ return self.__class__(
525
+ -self.coef, self.domain, self.window, self.symbol
526
+ )
527
+
528
+ def __pos__(self):
529
+ return self
530
+
531
+ def __add__(self, other):
532
+ othercoef = self._get_coefficients(other)
533
+ try:
534
+ coef = self._add(self.coef, othercoef)
535
+ except Exception:
536
+ return NotImplemented
537
+ return self.__class__(coef, self.domain, self.window, self.symbol)
538
+
539
+ def __sub__(self, other):
540
+ othercoef = self._get_coefficients(other)
541
+ try:
542
+ coef = self._sub(self.coef, othercoef)
543
+ except Exception:
544
+ return NotImplemented
545
+ return self.__class__(coef, self.domain, self.window, self.symbol)
546
+
547
+ def __mul__(self, other):
548
+ othercoef = self._get_coefficients(other)
549
+ try:
550
+ coef = self._mul(self.coef, othercoef)
551
+ except Exception:
552
+ return NotImplemented
553
+ return self.__class__(coef, self.domain, self.window, self.symbol)
554
+
555
+ def __truediv__(self, other):
556
+ # there is no true divide if the rhs is not a Number, although it
557
+ # could return the first n elements of an infinite series.
558
+ # It is hard to see where n would come from, though.
559
+ if not isinstance(other, numbers.Number) or isinstance(other, bool):
560
+ raise TypeError(
561
+ f"unsupported types for true division: "
562
+ f"'{type(self)}', '{type(other)}'"
563
+ )
564
+ return self.__floordiv__(other)
565
+
566
+ def __floordiv__(self, other):
567
+ res = self.__divmod__(other)
568
+ if res is NotImplemented:
569
+ return res
570
+ return res[0]
571
+
572
+ def __mod__(self, other):
573
+ res = self.__divmod__(other)
574
+ if res is NotImplemented:
575
+ return res
576
+ return res[1]
577
+
578
+ def __divmod__(self, other):
579
+ othercoef = self._get_coefficients(other)
580
+ try:
581
+ quo, rem = self._div(self.coef, othercoef)
582
+ except ZeroDivisionError:
583
+ raise
584
+ except Exception:
585
+ return NotImplemented
586
+ quo = self.__class__(quo, self.domain, self.window, self.symbol)
587
+ rem = self.__class__(rem, self.domain, self.window, self.symbol)
588
+ return quo, rem
589
+
590
+ def __pow__(self, other):
591
+ coef = self._pow(self.coef, other, maxpower=self.maxpower)
592
+ res = self.__class__(coef, self.domain, self.window, self.symbol)
593
+ return res
594
+
595
+ def __radd__(self, other):
596
+ try:
597
+ coef = self._add(other, self.coef)
598
+ except Exception:
599
+ return NotImplemented
600
+ return self.__class__(coef, self.domain, self.window, self.symbol)
601
+
602
+ def __rsub__(self, other):
603
+ try:
604
+ coef = self._sub(other, self.coef)
605
+ except Exception:
606
+ return NotImplemented
607
+ return self.__class__(coef, self.domain, self.window, self.symbol)
608
+
609
+ def __rmul__(self, other):
610
+ try:
611
+ coef = self._mul(other, self.coef)
612
+ except Exception:
613
+ return NotImplemented
614
+ return self.__class__(coef, self.domain, self.window, self.symbol)
615
+
616
+ def __rdiv__(self, other):
617
+ # set to __floordiv__ /.
618
+ return self.__rfloordiv__(other)
619
+
620
+ def __rtruediv__(self, other):
621
+ # An instance of ABCPolyBase is not considered a
622
+ # Number.
623
+ return NotImplemented
624
+
625
+ def __rfloordiv__(self, other):
626
+ res = self.__rdivmod__(other)
627
+ if res is NotImplemented:
628
+ return res
629
+ return res[0]
630
+
631
+ def __rmod__(self, other):
632
+ res = self.__rdivmod__(other)
633
+ if res is NotImplemented:
634
+ return res
635
+ return res[1]
636
+
637
+ def __rdivmod__(self, other):
638
+ try:
639
+ quo, rem = self._div(other, self.coef)
640
+ except ZeroDivisionError:
641
+ raise
642
+ except Exception:
643
+ return NotImplemented
644
+ quo = self.__class__(quo, self.domain, self.window, self.symbol)
645
+ rem = self.__class__(rem, self.domain, self.window, self.symbol)
646
+ return quo, rem
647
+
648
+ def __eq__(self, other):
649
+ res = (isinstance(other, self.__class__) and
650
+ np.all(self.domain == other.domain) and
651
+ np.all(self.window == other.window) and
652
+ (self.coef.shape == other.coef.shape) and
653
+ np.all(self.coef == other.coef) and
654
+ (self.symbol == other.symbol))
655
+ return res
656
+
657
+ def __ne__(self, other):
658
+ return not self.__eq__(other)
659
+
660
+ #
661
+ # Extra methods.
662
+ #
663
+
664
+ def copy(self):
665
+ """Return a copy.
666
+
667
+ Returns
668
+ -------
669
+ new_series : series
670
+ Copy of self.
671
+
672
+ """
673
+ return self.__class__(self.coef, self.domain, self.window, self.symbol)
674
+
675
+ def degree(self):
676
+ """The degree of the series.
677
+
678
+ .. versionadded:: 1.5.0
679
+
680
+ Returns
681
+ -------
682
+ degree : int
683
+ Degree of the series, one less than the number of coefficients.
684
+
685
+ Examples
686
+ --------
687
+
688
+ Create a polynomial object for ``1 + 7*x + 4*x**2``:
689
+
690
+ >>> poly = np.polynomial.Polynomial([1, 7, 4])
691
+ >>> print(poly)
692
+ 1.0 + 7.0·x + 4.0·x²
693
+ >>> poly.degree()
694
+ 2
695
+
696
+ Note that this method does not check for non-zero coefficients.
697
+ You must trim the polynomial to remove any trailing zeroes:
698
+
699
+ >>> poly = np.polynomial.Polynomial([1, 7, 0])
700
+ >>> print(poly)
701
+ 1.0 + 7.0·x + 0.0·x²
702
+ >>> poly.degree()
703
+ 2
704
+ >>> poly.trim().degree()
705
+ 1
706
+
707
+ """
708
+ return len(self) - 1
709
+
710
+ def cutdeg(self, deg):
711
+ """Truncate series to the given degree.
712
+
713
+ Reduce the degree of the series to `deg` by discarding the
714
+ high order terms. If `deg` is greater than the current degree a
715
+ copy of the current series is returned. This can be useful in least
716
+ squares where the coefficients of the high degree terms may be very
717
+ small.
718
+
719
+ .. versionadded:: 1.5.0
720
+
721
+ Parameters
722
+ ----------
723
+ deg : non-negative int
724
+ The series is reduced to degree `deg` by discarding the high
725
+ order terms. The value of `deg` must be a non-negative integer.
726
+
727
+ Returns
728
+ -------
729
+ new_series : series
730
+ New instance of series with reduced degree.
731
+
732
+ """
733
+ return self.truncate(deg + 1)
734
+
735
+ def trim(self, tol=0):
736
+ """Remove trailing coefficients
737
+
738
+ Remove trailing coefficients until a coefficient is reached whose
739
+ absolute value greater than `tol` or the beginning of the series is
740
+ reached. If all the coefficients would be removed the series is set
741
+ to ``[0]``. A new series instance is returned with the new
742
+ coefficients. The current instance remains unchanged.
743
+
744
+ Parameters
745
+ ----------
746
+ tol : non-negative number.
747
+ All trailing coefficients less than `tol` will be removed.
748
+
749
+ Returns
750
+ -------
751
+ new_series : series
752
+ New instance of series with trimmed coefficients.
753
+
754
+ """
755
+ coef = pu.trimcoef(self.coef, tol)
756
+ return self.__class__(coef, self.domain, self.window, self.symbol)
757
+
758
+ def truncate(self, size):
759
+ """Truncate series to length `size`.
760
+
761
+ Reduce the series to length `size` by discarding the high
762
+ degree terms. The value of `size` must be a positive integer. This
763
+ can be useful in least squares where the coefficients of the
764
+ high degree terms may be very small.
765
+
766
+ Parameters
767
+ ----------
768
+ size : positive int
769
+ The series is reduced to length `size` by discarding the high
770
+ degree terms. The value of `size` must be a positive integer.
771
+
772
+ Returns
773
+ -------
774
+ new_series : series
775
+ New instance of series with truncated coefficients.
776
+
777
+ """
778
+ isize = int(size)
779
+ if isize != size or isize < 1:
780
+ raise ValueError("size must be a positive integer")
781
+ if isize >= len(self.coef):
782
+ coef = self.coef
783
+ else:
784
+ coef = self.coef[:isize]
785
+ return self.__class__(coef, self.domain, self.window, self.symbol)
786
+
787
+ def convert(self, domain=None, kind=None, window=None):
788
+ """Convert series to a different kind and/or domain and/or window.
789
+
790
+ Parameters
791
+ ----------
792
+ domain : array_like, optional
793
+ The domain of the converted series. If the value is None,
794
+ the default domain of `kind` is used.
795
+ kind : class, optional
796
+ The polynomial series type class to which the current instance
797
+ should be converted. If kind is None, then the class of the
798
+ current instance is used.
799
+ window : array_like, optional
800
+ The window of the converted series. If the value is None,
801
+ the default window of `kind` is used.
802
+
803
+ Returns
804
+ -------
805
+ new_series : series
806
+ The returned class can be of different type than the current
807
+ instance and/or have a different domain and/or different
808
+ window.
809
+
810
+ Notes
811
+ -----
812
+ Conversion between domains and class types can result in
813
+ numerically ill defined series.
814
+
815
+ """
816
+ if kind is None:
817
+ kind = self.__class__
818
+ if domain is None:
819
+ domain = kind.domain
820
+ if window is None:
821
+ window = kind.window
822
+ return self(kind.identity(domain, window=window, symbol=self.symbol))
823
+
824
+ def mapparms(self):
825
+ """Return the mapping parameters.
826
+
827
+ The returned values define a linear map ``off + scl*x`` that is
828
+ applied to the input arguments before the series is evaluated. The
829
+ map depends on the ``domain`` and ``window``; if the current
830
+ ``domain`` is equal to the ``window`` the resulting map is the
831
+ identity. If the coefficients of the series instance are to be
832
+ used by themselves outside this class, then the linear function
833
+ must be substituted for the ``x`` in the standard representation of
834
+ the base polynomials.
835
+
836
+ Returns
837
+ -------
838
+ off, scl : float or complex
839
+ The mapping function is defined by ``off + scl*x``.
840
+
841
+ Notes
842
+ -----
843
+ If the current domain is the interval ``[l1, r1]`` and the window
844
+ is ``[l2, r2]``, then the linear mapping function ``L`` is
845
+ defined by the equations::
846
+
847
+ L(l1) = l2
848
+ L(r1) = r2
849
+
850
+ """
851
+ return pu.mapparms(self.domain, self.window)
852
+
853
+ def integ(self, m=1, k=[], lbnd=None):
854
+ """Integrate.
855
+
856
+ Return a series instance that is the definite integral of the
857
+ current series.
858
+
859
+ Parameters
860
+ ----------
861
+ m : non-negative int
862
+ The number of integrations to perform.
863
+ k : array_like
864
+ Integration constants. The first constant is applied to the
865
+ first integration, the second to the second, and so on. The
866
+ list of values must less than or equal to `m` in length and any
867
+ missing values are set to zero.
868
+ lbnd : Scalar
869
+ The lower bound of the definite integral.
870
+
871
+ Returns
872
+ -------
873
+ new_series : series
874
+ A new series representing the integral. The domain is the same
875
+ as the domain of the integrated series.
876
+
877
+ """
878
+ off, scl = self.mapparms()
879
+ if lbnd is None:
880
+ lbnd = 0
881
+ else:
882
+ lbnd = off + scl*lbnd
883
+ coef = self._int(self.coef, m, k, lbnd, 1./scl)
884
+ return self.__class__(coef, self.domain, self.window, self.symbol)
885
+
886
+ def deriv(self, m=1):
887
+ """Differentiate.
888
+
889
+ Return a series instance of that is the derivative of the current
890
+ series.
891
+
892
+ Parameters
893
+ ----------
894
+ m : non-negative int
895
+ Find the derivative of order `m`.
896
+
897
+ Returns
898
+ -------
899
+ new_series : series
900
+ A new series representing the derivative. The domain is the same
901
+ as the domain of the differentiated series.
902
+
903
+ """
904
+ off, scl = self.mapparms()
905
+ coef = self._der(self.coef, m, scl)
906
+ return self.__class__(coef, self.domain, self.window, self.symbol)
907
+
908
+ def roots(self):
909
+ """Return the roots of the series polynomial.
910
+
911
+ Compute the roots for the series. Note that the accuracy of the
912
+ roots decreases the further outside the `domain` they lie.
913
+
914
+ Returns
915
+ -------
916
+ roots : ndarray
917
+ Array containing the roots of the series.
918
+
919
+ """
920
+ roots = self._roots(self.coef)
921
+ return pu.mapdomain(roots, self.window, self.domain)
922
+
923
+ def linspace(self, n=100, domain=None):
924
+ """Return x, y values at equally spaced points in domain.
925
+
926
+ Returns the x, y values at `n` linearly spaced points across the
927
+ domain. Here y is the value of the polynomial at the points x. By
928
+ default the domain is the same as that of the series instance.
929
+ This method is intended mostly as a plotting aid.
930
+
931
+ .. versionadded:: 1.5.0
932
+
933
+ Parameters
934
+ ----------
935
+ n : int, optional
936
+ Number of point pairs to return. The default value is 100.
937
+ domain : {None, array_like}, optional
938
+ If not None, the specified domain is used instead of that of
939
+ the calling instance. It should be of the form ``[beg,end]``.
940
+ The default is None which case the class domain is used.
941
+
942
+ Returns
943
+ -------
944
+ x, y : ndarray
945
+ x is equal to linspace(self.domain[0], self.domain[1], n) and
946
+ y is the series evaluated at element of x.
947
+
948
+ """
949
+ if domain is None:
950
+ domain = self.domain
951
+ x = np.linspace(domain[0], domain[1], n)
952
+ y = self(x)
953
+ return x, y
954
+
955
+ @classmethod
956
+ def fit(cls, x, y, deg, domain=None, rcond=None, full=False, w=None,
957
+ window=None, symbol='x'):
958
+ """Least squares fit to data.
959
+
960
+ Return a series instance that is the least squares fit to the data
961
+ `y` sampled at `x`. The domain of the returned instance can be
962
+ specified and this will often result in a superior fit with less
963
+ chance of ill conditioning.
964
+
965
+ Parameters
966
+ ----------
967
+ x : array_like, shape (M,)
968
+ x-coordinates of the M sample points ``(x[i], y[i])``.
969
+ y : array_like, shape (M,)
970
+ y-coordinates of the M sample points ``(x[i], y[i])``.
971
+ deg : int or 1-D array_like
972
+ Degree(s) of the fitting polynomials. If `deg` is a single integer
973
+ all terms up to and including the `deg`'th term are included in the
974
+ fit. For NumPy versions >= 1.11.0 a list of integers specifying the
975
+ degrees of the terms to include may be used instead.
976
+ domain : {None, [beg, end], []}, optional
977
+ Domain to use for the returned series. If ``None``,
978
+ then a minimal domain that covers the points `x` is chosen. If
979
+ ``[]`` the class domain is used. The default value was the
980
+ class domain in NumPy 1.4 and ``None`` in later versions.
981
+ The ``[]`` option was added in numpy 1.5.0.
982
+ rcond : float, optional
983
+ Relative condition number of the fit. Singular values smaller
984
+ than this relative to the largest singular value will be
985
+ ignored. The default value is len(x)*eps, where eps is the
986
+ relative precision of the float type, about 2e-16 in most
987
+ cases.
988
+ full : bool, optional
989
+ Switch determining nature of return value. When it is False
990
+ (the default) just the coefficients are returned, when True
991
+ diagnostic information from the singular value decomposition is
992
+ also returned.
993
+ w : array_like, shape (M,), optional
994
+ Weights. If not None, the weight ``w[i]`` applies to the unsquared
995
+ residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are
996
+ chosen so that the errors of the products ``w[i]*y[i]`` all have
997
+ the same variance. When using inverse-variance weighting, use
998
+ ``w[i] = 1/sigma(y[i])``. The default value is None.
999
+
1000
+ .. versionadded:: 1.5.0
1001
+ window : {[beg, end]}, optional
1002
+ Window to use for the returned series. The default
1003
+ value is the default class domain
1004
+
1005
+ .. versionadded:: 1.6.0
1006
+ symbol : str, optional
1007
+ Symbol representing the independent variable. Default is 'x'.
1008
+
1009
+ Returns
1010
+ -------
1011
+ new_series : series
1012
+ A series that represents the least squares fit to the data and
1013
+ has the domain and window specified in the call. If the
1014
+ coefficients for the unscaled and unshifted basis polynomials are
1015
+ of interest, do ``new_series.convert().coef``.
1016
+
1017
+ [resid, rank, sv, rcond] : list
1018
+ These values are only returned if ``full == True``
1019
+
1020
+ - resid -- sum of squared residuals of the least squares fit
1021
+ - rank -- the numerical rank of the scaled Vandermonde matrix
1022
+ - sv -- singular values of the scaled Vandermonde matrix
1023
+ - rcond -- value of `rcond`.
1024
+
1025
+ For more details, see `linalg.lstsq`.
1026
+
1027
+ """
1028
+ if domain is None:
1029
+ domain = pu.getdomain(x)
1030
+ elif type(domain) is list and len(domain) == 0:
1031
+ domain = cls.domain
1032
+
1033
+ if window is None:
1034
+ window = cls.window
1035
+
1036
+ xnew = pu.mapdomain(x, domain, window)
1037
+ res = cls._fit(xnew, y, deg, w=w, rcond=rcond, full=full)
1038
+ if full:
1039
+ [coef, status] = res
1040
+ return (
1041
+ cls(coef, domain=domain, window=window, symbol=symbol), status
1042
+ )
1043
+ else:
1044
+ coef = res
1045
+ return cls(coef, domain=domain, window=window, symbol=symbol)
1046
+
1047
+ @classmethod
1048
+ def fromroots(cls, roots, domain=[], window=None, symbol='x'):
1049
+ """Return series instance that has the specified roots.
1050
+
1051
+ Returns a series representing the product
1052
+ ``(x - r[0])*(x - r[1])*...*(x - r[n-1])``, where ``r`` is a
1053
+ list of roots.
1054
+
1055
+ Parameters
1056
+ ----------
1057
+ roots : array_like
1058
+ List of roots.
1059
+ domain : {[], None, array_like}, optional
1060
+ Domain for the resulting series. If None the domain is the
1061
+ interval from the smallest root to the largest. If [] the
1062
+ domain is the class domain. The default is [].
1063
+ window : {None, array_like}, optional
1064
+ Window for the returned series. If None the class window is
1065
+ used. The default is None.
1066
+ symbol : str, optional
1067
+ Symbol representing the independent variable. Default is 'x'.
1068
+
1069
+ Returns
1070
+ -------
1071
+ new_series : series
1072
+ Series with the specified roots.
1073
+
1074
+ """
1075
+ [roots] = pu.as_series([roots], trim=False)
1076
+ if domain is None:
1077
+ domain = pu.getdomain(roots)
1078
+ elif type(domain) is list and len(domain) == 0:
1079
+ domain = cls.domain
1080
+
1081
+ if window is None:
1082
+ window = cls.window
1083
+
1084
+ deg = len(roots)
1085
+ off, scl = pu.mapparms(domain, window)
1086
+ rnew = off + scl*roots
1087
+ coef = cls._fromroots(rnew) / scl**deg
1088
+ return cls(coef, domain=domain, window=window, symbol=symbol)
1089
+
1090
+ @classmethod
1091
+ def identity(cls, domain=None, window=None, symbol='x'):
1092
+ """Identity function.
1093
+
1094
+ If ``p`` is the returned series, then ``p(x) == x`` for all
1095
+ values of x.
1096
+
1097
+ Parameters
1098
+ ----------
1099
+ domain : {None, array_like}, optional
1100
+ If given, the array must be of the form ``[beg, end]``, where
1101
+ ``beg`` and ``end`` are the endpoints of the domain. If None is
1102
+ given then the class domain is used. The default is None.
1103
+ window : {None, array_like}, optional
1104
+ If given, the resulting array must be if the form
1105
+ ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of
1106
+ the window. If None is given then the class window is used. The
1107
+ default is None.
1108
+ symbol : str, optional
1109
+ Symbol representing the independent variable. Default is 'x'.
1110
+
1111
+ Returns
1112
+ -------
1113
+ new_series : series
1114
+ Series of representing the identity.
1115
+
1116
+ """
1117
+ if domain is None:
1118
+ domain = cls.domain
1119
+ if window is None:
1120
+ window = cls.window
1121
+ off, scl = pu.mapparms(window, domain)
1122
+ coef = cls._line(off, scl)
1123
+ return cls(coef, domain, window, symbol)
1124
+
1125
+ @classmethod
1126
+ def basis(cls, deg, domain=None, window=None, symbol='x'):
1127
+ """Series basis polynomial of degree `deg`.
1128
+
1129
+ Returns the series representing the basis polynomial of degree `deg`.
1130
+
1131
+ .. versionadded:: 1.7.0
1132
+
1133
+ Parameters
1134
+ ----------
1135
+ deg : int
1136
+ Degree of the basis polynomial for the series. Must be >= 0.
1137
+ domain : {None, array_like}, optional
1138
+ If given, the array must be of the form ``[beg, end]``, where
1139
+ ``beg`` and ``end`` are the endpoints of the domain. If None is
1140
+ given then the class domain is used. The default is None.
1141
+ window : {None, array_like}, optional
1142
+ If given, the resulting array must be if the form
1143
+ ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of
1144
+ the window. If None is given then the class window is used. The
1145
+ default is None.
1146
+ symbol : str, optional
1147
+ Symbol representing the independent variable. Default is 'x'.
1148
+
1149
+ Returns
1150
+ -------
1151
+ new_series : series
1152
+ A series with the coefficient of the `deg` term set to one and
1153
+ all others zero.
1154
+
1155
+ """
1156
+ if domain is None:
1157
+ domain = cls.domain
1158
+ if window is None:
1159
+ window = cls.window
1160
+ ideg = int(deg)
1161
+
1162
+ if ideg != deg or ideg < 0:
1163
+ raise ValueError("deg must be non-negative integer")
1164
+ return cls([0]*ideg + [1], domain, window, symbol)
1165
+
1166
+ @classmethod
1167
+ def cast(cls, series, domain=None, window=None):
1168
+ """Convert series to series of this class.
1169
+
1170
+ The `series` is expected to be an instance of some polynomial
1171
+ series of one of the types supported by by the numpy.polynomial
1172
+ module, but could be some other class that supports the convert
1173
+ method.
1174
+
1175
+ .. versionadded:: 1.7.0
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ series : series
1180
+ The series instance to be converted.
1181
+ domain : {None, array_like}, optional
1182
+ If given, the array must be of the form ``[beg, end]``, where
1183
+ ``beg`` and ``end`` are the endpoints of the domain. If None is
1184
+ given then the class domain is used. The default is None.
1185
+ window : {None, array_like}, optional
1186
+ If given, the resulting array must be if the form
1187
+ ``[beg, end]``, where ``beg`` and ``end`` are the endpoints of
1188
+ the window. If None is given then the class window is used. The
1189
+ default is None.
1190
+
1191
+ Returns
1192
+ -------
1193
+ new_series : series
1194
+ A series of the same kind as the calling class and equal to
1195
+ `series` when evaluated.
1196
+
1197
+ See Also
1198
+ --------
1199
+ convert : similar instance method
1200
+
1201
+ """
1202
+ if domain is None:
1203
+ domain = cls.domain
1204
+ if window is None:
1205
+ window = cls.window
1206
+ return series.convert(domain, cls, window)
.venv/lib/python3.11/site-packages/numpy/polynomial/_polybase.pyi ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, ClassVar
3
+
4
+ __all__: list[str]
5
+
6
+ class ABCPolyBase(abc.ABC):
7
+ __hash__: ClassVar[None] # type: ignore[assignment]
8
+ __array_ufunc__: ClassVar[None]
9
+ maxpower: ClassVar[int]
10
+ coef: Any
11
+ @property
12
+ def symbol(self) -> str: ...
13
+ @property
14
+ @abc.abstractmethod
15
+ def domain(self): ...
16
+ @property
17
+ @abc.abstractmethod
18
+ def window(self): ...
19
+ @property
20
+ @abc.abstractmethod
21
+ def basis_name(self): ...
22
+ def has_samecoef(self, other): ...
23
+ def has_samedomain(self, other): ...
24
+ def has_samewindow(self, other): ...
25
+ def has_sametype(self, other): ...
26
+ def __init__(self, coef, domain=..., window=..., symbol: str = ...) -> None: ...
27
+ def __format__(self, fmt_str): ...
28
+ def __call__(self, arg): ...
29
+ def __iter__(self): ...
30
+ def __len__(self): ...
31
+ def __neg__(self): ...
32
+ def __pos__(self): ...
33
+ def __add__(self, other): ...
34
+ def __sub__(self, other): ...
35
+ def __mul__(self, other): ...
36
+ def __truediv__(self, other): ...
37
+ def __floordiv__(self, other): ...
38
+ def __mod__(self, other): ...
39
+ def __divmod__(self, other): ...
40
+ def __pow__(self, other): ...
41
+ def __radd__(self, other): ...
42
+ def __rsub__(self, other): ...
43
+ def __rmul__(self, other): ...
44
+ def __rdiv__(self, other): ...
45
+ def __rtruediv__(self, other): ...
46
+ def __rfloordiv__(self, other): ...
47
+ def __rmod__(self, other): ...
48
+ def __rdivmod__(self, other): ...
49
+ def __eq__(self, other): ...
50
+ def __ne__(self, other): ...
51
+ def copy(self): ...
52
+ def degree(self): ...
53
+ def cutdeg(self, deg): ...
54
+ def trim(self, tol=...): ...
55
+ def truncate(self, size): ...
56
+ def convert(self, domain=..., kind=..., window=...): ...
57
+ def mapparms(self): ...
58
+ def integ(self, m=..., k = ..., lbnd=...): ...
59
+ def deriv(self, m=...): ...
60
+ def roots(self): ...
61
+ def linspace(self, n=..., domain=...): ...
62
+ @classmethod
63
+ def fit(cls, x, y, deg, domain=..., rcond=..., full=..., w=..., window=...): ...
64
+ @classmethod
65
+ def fromroots(cls, roots, domain = ..., window=...): ...
66
+ @classmethod
67
+ def identity(cls, domain=..., window=...): ...
68
+ @classmethod
69
+ def basis(cls, deg, domain=..., window=...): ...
70
+ @classmethod
71
+ def cast(cls, series, domain=..., window=...): ...
.venv/lib/python3.11/site-packages/numpy/polynomial/chebyshev.py ADDED
@@ -0,0 +1,2082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ====================================================
3
+ Chebyshev Series (:mod:`numpy.polynomial.chebyshev`)
4
+ ====================================================
5
+
6
+ This module provides a number of objects (mostly functions) useful for
7
+ dealing with Chebyshev series, including a `Chebyshev` class that
8
+ encapsulates the usual arithmetic operations. (General information
9
+ on how this module represents and works with such polynomials is in the
10
+ docstring for its "parent" sub-package, `numpy.polynomial`).
11
+
12
+ Classes
13
+ -------
14
+
15
+ .. autosummary::
16
+ :toctree: generated/
17
+
18
+ Chebyshev
19
+
20
+
21
+ Constants
22
+ ---------
23
+
24
+ .. autosummary::
25
+ :toctree: generated/
26
+
27
+ chebdomain
28
+ chebzero
29
+ chebone
30
+ chebx
31
+
32
+ Arithmetic
33
+ ----------
34
+
35
+ .. autosummary::
36
+ :toctree: generated/
37
+
38
+ chebadd
39
+ chebsub
40
+ chebmulx
41
+ chebmul
42
+ chebdiv
43
+ chebpow
44
+ chebval
45
+ chebval2d
46
+ chebval3d
47
+ chebgrid2d
48
+ chebgrid3d
49
+
50
+ Calculus
51
+ --------
52
+
53
+ .. autosummary::
54
+ :toctree: generated/
55
+
56
+ chebder
57
+ chebint
58
+
59
+ Misc Functions
60
+ --------------
61
+
62
+ .. autosummary::
63
+ :toctree: generated/
64
+
65
+ chebfromroots
66
+ chebroots
67
+ chebvander
68
+ chebvander2d
69
+ chebvander3d
70
+ chebgauss
71
+ chebweight
72
+ chebcompanion
73
+ chebfit
74
+ chebpts1
75
+ chebpts2
76
+ chebtrim
77
+ chebline
78
+ cheb2poly
79
+ poly2cheb
80
+ chebinterpolate
81
+
82
+ See also
83
+ --------
84
+ `numpy.polynomial`
85
+
86
+ Notes
87
+ -----
88
+ The implementations of multiplication, division, integration, and
89
+ differentiation use the algebraic identities [1]_:
90
+
91
+ .. math::
92
+ T_n(x) = \\frac{z^n + z^{-n}}{2} \\\\
93
+ z\\frac{dx}{dz} = \\frac{z - z^{-1}}{2}.
94
+
95
+ where
96
+
97
+ .. math:: x = \\frac{z + z^{-1}}{2}.
98
+
99
+ These identities allow a Chebyshev series to be expressed as a finite,
100
+ symmetric Laurent series. In this module, this sort of Laurent series
101
+ is referred to as a "z-series."
102
+
103
+ References
104
+ ----------
105
+ .. [1] A. T. Benjamin, et al., "Combinatorial Trigonometry with Chebyshev
106
+ Polynomials," *Journal of Statistical Planning and Inference 14*, 2008
107
+ (https://web.archive.org/web/20080221202153/https://www.math.hmc.edu/~benjamin/papers/CombTrig.pdf, pg. 4)
108
+
109
+ """
110
+ import numpy as np
111
+ import numpy.linalg as la
112
+ from numpy.core.multiarray import normalize_axis_index
113
+
114
+ from . import polyutils as pu
115
+ from ._polybase import ABCPolyBase
116
+
117
+ __all__ = [
118
+ 'chebzero', 'chebone', 'chebx', 'chebdomain', 'chebline', 'chebadd',
119
+ 'chebsub', 'chebmulx', 'chebmul', 'chebdiv', 'chebpow', 'chebval',
120
+ 'chebder', 'chebint', 'cheb2poly', 'poly2cheb', 'chebfromroots',
121
+ 'chebvander', 'chebfit', 'chebtrim', 'chebroots', 'chebpts1',
122
+ 'chebpts2', 'Chebyshev', 'chebval2d', 'chebval3d', 'chebgrid2d',
123
+ 'chebgrid3d', 'chebvander2d', 'chebvander3d', 'chebcompanion',
124
+ 'chebgauss', 'chebweight', 'chebinterpolate']
125
+
126
+ chebtrim = pu.trimcoef
127
+
128
+ #
129
+ # A collection of functions for manipulating z-series. These are private
130
+ # functions and do minimal error checking.
131
+ #
132
+
133
+ def _cseries_to_zseries(c):
134
+ """Convert Chebyshev series to z-series.
135
+
136
+ Convert a Chebyshev series to the equivalent z-series. The result is
137
+ never an empty array. The dtype of the return is the same as that of
138
+ the input. No checks are run on the arguments as this routine is for
139
+ internal use.
140
+
141
+ Parameters
142
+ ----------
143
+ c : 1-D ndarray
144
+ Chebyshev coefficients, ordered from low to high
145
+
146
+ Returns
147
+ -------
148
+ zs : 1-D ndarray
149
+ Odd length symmetric z-series, ordered from low to high.
150
+
151
+ """
152
+ n = c.size
153
+ zs = np.zeros(2*n-1, dtype=c.dtype)
154
+ zs[n-1:] = c/2
155
+ return zs + zs[::-1]
156
+
157
+
158
+ def _zseries_to_cseries(zs):
159
+ """Convert z-series to a Chebyshev series.
160
+
161
+ Convert a z series to the equivalent Chebyshev series. The result is
162
+ never an empty array. The dtype of the return is the same as that of
163
+ the input. No checks are run on the arguments as this routine is for
164
+ internal use.
165
+
166
+ Parameters
167
+ ----------
168
+ zs : 1-D ndarray
169
+ Odd length symmetric z-series, ordered from low to high.
170
+
171
+ Returns
172
+ -------
173
+ c : 1-D ndarray
174
+ Chebyshev coefficients, ordered from low to high.
175
+
176
+ """
177
+ n = (zs.size + 1)//2
178
+ c = zs[n-1:].copy()
179
+ c[1:n] *= 2
180
+ return c
181
+
182
+
183
+ def _zseries_mul(z1, z2):
184
+ """Multiply two z-series.
185
+
186
+ Multiply two z-series to produce a z-series.
187
+
188
+ Parameters
189
+ ----------
190
+ z1, z2 : 1-D ndarray
191
+ The arrays must be 1-D but this is not checked.
192
+
193
+ Returns
194
+ -------
195
+ product : 1-D ndarray
196
+ The product z-series.
197
+
198
+ Notes
199
+ -----
200
+ This is simply convolution. If symmetric/anti-symmetric z-series are
201
+ denoted by S/A then the following rules apply:
202
+
203
+ S*S, A*A -> S
204
+ S*A, A*S -> A
205
+
206
+ """
207
+ return np.convolve(z1, z2)
208
+
209
+
210
+ def _zseries_div(z1, z2):
211
+ """Divide the first z-series by the second.
212
+
213
+ Divide `z1` by `z2` and return the quotient and remainder as z-series.
214
+ Warning: this implementation only applies when both z1 and z2 have the
215
+ same symmetry, which is sufficient for present purposes.
216
+
217
+ Parameters
218
+ ----------
219
+ z1, z2 : 1-D ndarray
220
+ The arrays must be 1-D and have the same symmetry, but this is not
221
+ checked.
222
+
223
+ Returns
224
+ -------
225
+
226
+ (quotient, remainder) : 1-D ndarrays
227
+ Quotient and remainder as z-series.
228
+
229
+ Notes
230
+ -----
231
+ This is not the same as polynomial division on account of the desired form
232
+ of the remainder. If symmetric/anti-symmetric z-series are denoted by S/A
233
+ then the following rules apply:
234
+
235
+ S/S -> S,S
236
+ A/A -> S,A
237
+
238
+ The restriction to types of the same symmetry could be fixed but seems like
239
+ unneeded generality. There is no natural form for the remainder in the case
240
+ where there is no symmetry.
241
+
242
+ """
243
+ z1 = z1.copy()
244
+ z2 = z2.copy()
245
+ lc1 = len(z1)
246
+ lc2 = len(z2)
247
+ if lc2 == 1:
248
+ z1 /= z2
249
+ return z1, z1[:1]*0
250
+ elif lc1 < lc2:
251
+ return z1[:1]*0, z1
252
+ else:
253
+ dlen = lc1 - lc2
254
+ scl = z2[0]
255
+ z2 /= scl
256
+ quo = np.empty(dlen + 1, dtype=z1.dtype)
257
+ i = 0
258
+ j = dlen
259
+ while i < j:
260
+ r = z1[i]
261
+ quo[i] = z1[i]
262
+ quo[dlen - i] = r
263
+ tmp = r*z2
264
+ z1[i:i+lc2] -= tmp
265
+ z1[j:j+lc2] -= tmp
266
+ i += 1
267
+ j -= 1
268
+ r = z1[i]
269
+ quo[i] = r
270
+ tmp = r*z2
271
+ z1[i:i+lc2] -= tmp
272
+ quo /= scl
273
+ rem = z1[i+1:i-1+lc2].copy()
274
+ return quo, rem
275
+
276
+
277
+ def _zseries_der(zs):
278
+ """Differentiate a z-series.
279
+
280
+ The derivative is with respect to x, not z. This is achieved using the
281
+ chain rule and the value of dx/dz given in the module notes.
282
+
283
+ Parameters
284
+ ----------
285
+ zs : z-series
286
+ The z-series to differentiate.
287
+
288
+ Returns
289
+ -------
290
+ derivative : z-series
291
+ The derivative
292
+
293
+ Notes
294
+ -----
295
+ The zseries for x (ns) has been multiplied by two in order to avoid
296
+ using floats that are incompatible with Decimal and likely other
297
+ specialized scalar types. This scaling has been compensated by
298
+ multiplying the value of zs by two also so that the two cancels in the
299
+ division.
300
+
301
+ """
302
+ n = len(zs)//2
303
+ ns = np.array([-1, 0, 1], dtype=zs.dtype)
304
+ zs *= np.arange(-n, n+1)*2
305
+ d, r = _zseries_div(zs, ns)
306
+ return d
307
+
308
+
309
+ def _zseries_int(zs):
310
+ """Integrate a z-series.
311
+
312
+ The integral is with respect to x, not z. This is achieved by a change
313
+ of variable using dx/dz given in the module notes.
314
+
315
+ Parameters
316
+ ----------
317
+ zs : z-series
318
+ The z-series to integrate
319
+
320
+ Returns
321
+ -------
322
+ integral : z-series
323
+ The indefinite integral
324
+
325
+ Notes
326
+ -----
327
+ The zseries for x (ns) has been multiplied by two in order to avoid
328
+ using floats that are incompatible with Decimal and likely other
329
+ specialized scalar types. This scaling has been compensated by
330
+ dividing the resulting zs by two.
331
+
332
+ """
333
+ n = 1 + len(zs)//2
334
+ ns = np.array([-1, 0, 1], dtype=zs.dtype)
335
+ zs = _zseries_mul(zs, ns)
336
+ div = np.arange(-n, n+1)*2
337
+ zs[:n] /= div[:n]
338
+ zs[n+1:] /= div[n+1:]
339
+ zs[n] = 0
340
+ return zs
341
+
342
+ #
343
+ # Chebyshev series functions
344
+ #
345
+
346
+
347
+ def poly2cheb(pol):
348
+ """
349
+ Convert a polynomial to a Chebyshev series.
350
+
351
+ Convert an array representing the coefficients of a polynomial (relative
352
+ to the "standard" basis) ordered from lowest degree to highest, to an
353
+ array of the coefficients of the equivalent Chebyshev series, ordered
354
+ from lowest to highest degree.
355
+
356
+ Parameters
357
+ ----------
358
+ pol : array_like
359
+ 1-D array containing the polynomial coefficients
360
+
361
+ Returns
362
+ -------
363
+ c : ndarray
364
+ 1-D array containing the coefficients of the equivalent Chebyshev
365
+ series.
366
+
367
+ See Also
368
+ --------
369
+ cheb2poly
370
+
371
+ Notes
372
+ -----
373
+ The easy way to do conversions between polynomial basis sets
374
+ is to use the convert method of a class instance.
375
+
376
+ Examples
377
+ --------
378
+ >>> from numpy import polynomial as P
379
+ >>> p = P.Polynomial(range(4))
380
+ >>> p
381
+ Polynomial([0., 1., 2., 3.], domain=[-1, 1], window=[-1, 1])
382
+ >>> c = p.convert(kind=P.Chebyshev)
383
+ >>> c
384
+ Chebyshev([1. , 3.25, 1. , 0.75], domain=[-1., 1.], window=[-1., 1.])
385
+ >>> P.chebyshev.poly2cheb(range(4))
386
+ array([1. , 3.25, 1. , 0.75])
387
+
388
+ """
389
+ [pol] = pu.as_series([pol])
390
+ deg = len(pol) - 1
391
+ res = 0
392
+ for i in range(deg, -1, -1):
393
+ res = chebadd(chebmulx(res), pol[i])
394
+ return res
395
+
396
+
397
+ def cheb2poly(c):
398
+ """
399
+ Convert a Chebyshev series to a polynomial.
400
+
401
+ Convert an array representing the coefficients of a Chebyshev series,
402
+ ordered from lowest degree to highest, to an array of the coefficients
403
+ of the equivalent polynomial (relative to the "standard" basis) ordered
404
+ from lowest to highest degree.
405
+
406
+ Parameters
407
+ ----------
408
+ c : array_like
409
+ 1-D array containing the Chebyshev series coefficients, ordered
410
+ from lowest order term to highest.
411
+
412
+ Returns
413
+ -------
414
+ pol : ndarray
415
+ 1-D array containing the coefficients of the equivalent polynomial
416
+ (relative to the "standard" basis) ordered from lowest order term
417
+ to highest.
418
+
419
+ See Also
420
+ --------
421
+ poly2cheb
422
+
423
+ Notes
424
+ -----
425
+ The easy way to do conversions between polynomial basis sets
426
+ is to use the convert method of a class instance.
427
+
428
+ Examples
429
+ --------
430
+ >>> from numpy import polynomial as P
431
+ >>> c = P.Chebyshev(range(4))
432
+ >>> c
433
+ Chebyshev([0., 1., 2., 3.], domain=[-1, 1], window=[-1, 1])
434
+ >>> p = c.convert(kind=P.Polynomial)
435
+ >>> p
436
+ Polynomial([-2., -8., 4., 12.], domain=[-1., 1.], window=[-1., 1.])
437
+ >>> P.chebyshev.cheb2poly(range(4))
438
+ array([-2., -8., 4., 12.])
439
+
440
+ """
441
+ from .polynomial import polyadd, polysub, polymulx
442
+
443
+ [c] = pu.as_series([c])
444
+ n = len(c)
445
+ if n < 3:
446
+ return c
447
+ else:
448
+ c0 = c[-2]
449
+ c1 = c[-1]
450
+ # i is the current degree of c1
451
+ for i in range(n - 1, 1, -1):
452
+ tmp = c0
453
+ c0 = polysub(c[i - 2], c1)
454
+ c1 = polyadd(tmp, polymulx(c1)*2)
455
+ return polyadd(c0, polymulx(c1))
456
+
457
+
458
+ #
459
+ # These are constant arrays are of integer type so as to be compatible
460
+ # with the widest range of other types, such as Decimal.
461
+ #
462
+
463
+ # Chebyshev default domain.
464
+ chebdomain = np.array([-1, 1])
465
+
466
+ # Chebyshev coefficients representing zero.
467
+ chebzero = np.array([0])
468
+
469
+ # Chebyshev coefficients representing one.
470
+ chebone = np.array([1])
471
+
472
+ # Chebyshev coefficients representing the identity x.
473
+ chebx = np.array([0, 1])
474
+
475
+
476
+ def chebline(off, scl):
477
+ """
478
+ Chebyshev series whose graph is a straight line.
479
+
480
+ Parameters
481
+ ----------
482
+ off, scl : scalars
483
+ The specified line is given by ``off + scl*x``.
484
+
485
+ Returns
486
+ -------
487
+ y : ndarray
488
+ This module's representation of the Chebyshev series for
489
+ ``off + scl*x``.
490
+
491
+ See Also
492
+ --------
493
+ numpy.polynomial.polynomial.polyline
494
+ numpy.polynomial.legendre.legline
495
+ numpy.polynomial.laguerre.lagline
496
+ numpy.polynomial.hermite.hermline
497
+ numpy.polynomial.hermite_e.hermeline
498
+
499
+ Examples
500
+ --------
501
+ >>> import numpy.polynomial.chebyshev as C
502
+ >>> C.chebline(3,2)
503
+ array([3, 2])
504
+ >>> C.chebval(-3, C.chebline(3,2)) # should be -3
505
+ -3.0
506
+
507
+ """
508
+ if scl != 0:
509
+ return np.array([off, scl])
510
+ else:
511
+ return np.array([off])
512
+
513
+
514
+ def chebfromroots(roots):
515
+ """
516
+ Generate a Chebyshev series with given roots.
517
+
518
+ The function returns the coefficients of the polynomial
519
+
520
+ .. math:: p(x) = (x - r_0) * (x - r_1) * ... * (x - r_n),
521
+
522
+ in Chebyshev form, where the `r_n` are the roots specified in `roots`.
523
+ If a zero has multiplicity n, then it must appear in `roots` n times.
524
+ For instance, if 2 is a root of multiplicity three and 3 is a root of
525
+ multiplicity 2, then `roots` looks something like [2, 2, 2, 3, 3]. The
526
+ roots can appear in any order.
527
+
528
+ If the returned coefficients are `c`, then
529
+
530
+ .. math:: p(x) = c_0 + c_1 * T_1(x) + ... + c_n * T_n(x)
531
+
532
+ The coefficient of the last term is not generally 1 for monic
533
+ polynomials in Chebyshev form.
534
+
535
+ Parameters
536
+ ----------
537
+ roots : array_like
538
+ Sequence containing the roots.
539
+
540
+ Returns
541
+ -------
542
+ out : ndarray
543
+ 1-D array of coefficients. If all roots are real then `out` is a
544
+ real array, if some of the roots are complex, then `out` is complex
545
+ even if all the coefficients in the result are real (see Examples
546
+ below).
547
+
548
+ See Also
549
+ --------
550
+ numpy.polynomial.polynomial.polyfromroots
551
+ numpy.polynomial.legendre.legfromroots
552
+ numpy.polynomial.laguerre.lagfromroots
553
+ numpy.polynomial.hermite.hermfromroots
554
+ numpy.polynomial.hermite_e.hermefromroots
555
+
556
+ Examples
557
+ --------
558
+ >>> import numpy.polynomial.chebyshev as C
559
+ >>> C.chebfromroots((-1,0,1)) # x^3 - x relative to the standard basis
560
+ array([ 0. , -0.25, 0. , 0.25])
561
+ >>> j = complex(0,1)
562
+ >>> C.chebfromroots((-j,j)) # x^2 + 1 relative to the standard basis
563
+ array([1.5+0.j, 0. +0.j, 0.5+0.j])
564
+
565
+ """
566
+ return pu._fromroots(chebline, chebmul, roots)
567
+
568
+
569
+ def chebadd(c1, c2):
570
+ """
571
+ Add one Chebyshev series to another.
572
+
573
+ Returns the sum of two Chebyshev series `c1` + `c2`. The arguments
574
+ are sequences of coefficients ordered from lowest order term to
575
+ highest, i.e., [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``.
576
+
577
+ Parameters
578
+ ----------
579
+ c1, c2 : array_like
580
+ 1-D arrays of Chebyshev series coefficients ordered from low to
581
+ high.
582
+
583
+ Returns
584
+ -------
585
+ out : ndarray
586
+ Array representing the Chebyshev series of their sum.
587
+
588
+ See Also
589
+ --------
590
+ chebsub, chebmulx, chebmul, chebdiv, chebpow
591
+
592
+ Notes
593
+ -----
594
+ Unlike multiplication, division, etc., the sum of two Chebyshev series
595
+ is a Chebyshev series (without having to "reproject" the result onto
596
+ the basis set) so addition, just like that of "standard" polynomials,
597
+ is simply "component-wise."
598
+
599
+ Examples
600
+ --------
601
+ >>> from numpy.polynomial import chebyshev as C
602
+ >>> c1 = (1,2,3)
603
+ >>> c2 = (3,2,1)
604
+ >>> C.chebadd(c1,c2)
605
+ array([4., 4., 4.])
606
+
607
+ """
608
+ return pu._add(c1, c2)
609
+
610
+
611
+ def chebsub(c1, c2):
612
+ """
613
+ Subtract one Chebyshev series from another.
614
+
615
+ Returns the difference of two Chebyshev series `c1` - `c2`. The
616
+ sequences of coefficients are from lowest order term to highest, i.e.,
617
+ [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``.
618
+
619
+ Parameters
620
+ ----------
621
+ c1, c2 : array_like
622
+ 1-D arrays of Chebyshev series coefficients ordered from low to
623
+ high.
624
+
625
+ Returns
626
+ -------
627
+ out : ndarray
628
+ Of Chebyshev series coefficients representing their difference.
629
+
630
+ See Also
631
+ --------
632
+ chebadd, chebmulx, chebmul, chebdiv, chebpow
633
+
634
+ Notes
635
+ -----
636
+ Unlike multiplication, division, etc., the difference of two Chebyshev
637
+ series is a Chebyshev series (without having to "reproject" the result
638
+ onto the basis set) so subtraction, just like that of "standard"
639
+ polynomials, is simply "component-wise."
640
+
641
+ Examples
642
+ --------
643
+ >>> from numpy.polynomial import chebyshev as C
644
+ >>> c1 = (1,2,3)
645
+ >>> c2 = (3,2,1)
646
+ >>> C.chebsub(c1,c2)
647
+ array([-2., 0., 2.])
648
+ >>> C.chebsub(c2,c1) # -C.chebsub(c1,c2)
649
+ array([ 2., 0., -2.])
650
+
651
+ """
652
+ return pu._sub(c1, c2)
653
+
654
+
655
+ def chebmulx(c):
656
+ """Multiply a Chebyshev series by x.
657
+
658
+ Multiply the polynomial `c` by x, where x is the independent
659
+ variable.
660
+
661
+
662
+ Parameters
663
+ ----------
664
+ c : array_like
665
+ 1-D array of Chebyshev series coefficients ordered from low to
666
+ high.
667
+
668
+ Returns
669
+ -------
670
+ out : ndarray
671
+ Array representing the result of the multiplication.
672
+
673
+ Notes
674
+ -----
675
+
676
+ .. versionadded:: 1.5.0
677
+
678
+ Examples
679
+ --------
680
+ >>> from numpy.polynomial import chebyshev as C
681
+ >>> C.chebmulx([1,2,3])
682
+ array([1. , 2.5, 1. , 1.5])
683
+
684
+ """
685
+ # c is a trimmed copy
686
+ [c] = pu.as_series([c])
687
+ # The zero series needs special treatment
688
+ if len(c) == 1 and c[0] == 0:
689
+ return c
690
+
691
+ prd = np.empty(len(c) + 1, dtype=c.dtype)
692
+ prd[0] = c[0]*0
693
+ prd[1] = c[0]
694
+ if len(c) > 1:
695
+ tmp = c[1:]/2
696
+ prd[2:] = tmp
697
+ prd[0:-2] += tmp
698
+ return prd
699
+
700
+
701
+ def chebmul(c1, c2):
702
+ """
703
+ Multiply one Chebyshev series by another.
704
+
705
+ Returns the product of two Chebyshev series `c1` * `c2`. The arguments
706
+ are sequences of coefficients, from lowest order "term" to highest,
707
+ e.g., [1,2,3] represents the series ``T_0 + 2*T_1 + 3*T_2``.
708
+
709
+ Parameters
710
+ ----------
711
+ c1, c2 : array_like
712
+ 1-D arrays of Chebyshev series coefficients ordered from low to
713
+ high.
714
+
715
+ Returns
716
+ -------
717
+ out : ndarray
718
+ Of Chebyshev series coefficients representing their product.
719
+
720
+ See Also
721
+ --------
722
+ chebadd, chebsub, chebmulx, chebdiv, chebpow
723
+
724
+ Notes
725
+ -----
726
+ In general, the (polynomial) product of two C-series results in terms
727
+ that are not in the Chebyshev polynomial basis set. Thus, to express
728
+ the product as a C-series, it is typically necessary to "reproject"
729
+ the product onto said basis set, which typically produces
730
+ "unintuitive live" (but correct) results; see Examples section below.
731
+
732
+ Examples
733
+ --------
734
+ >>> from numpy.polynomial import chebyshev as C
735
+ >>> c1 = (1,2,3)
736
+ >>> c2 = (3,2,1)
737
+ >>> C.chebmul(c1,c2) # multiplication requires "reprojection"
738
+ array([ 6.5, 12. , 12. , 4. , 1.5])
739
+
740
+ """
741
+ # c1, c2 are trimmed copies
742
+ [c1, c2] = pu.as_series([c1, c2])
743
+ z1 = _cseries_to_zseries(c1)
744
+ z2 = _cseries_to_zseries(c2)
745
+ prd = _zseries_mul(z1, z2)
746
+ ret = _zseries_to_cseries(prd)
747
+ return pu.trimseq(ret)
748
+
749
+
750
+ def chebdiv(c1, c2):
751
+ """
752
+ Divide one Chebyshev series by another.
753
+
754
+ Returns the quotient-with-remainder of two Chebyshev series
755
+ `c1` / `c2`. The arguments are sequences of coefficients from lowest
756
+ order "term" to highest, e.g., [1,2,3] represents the series
757
+ ``T_0 + 2*T_1 + 3*T_2``.
758
+
759
+ Parameters
760
+ ----------
761
+ c1, c2 : array_like
762
+ 1-D arrays of Chebyshev series coefficients ordered from low to
763
+ high.
764
+
765
+ Returns
766
+ -------
767
+ [quo, rem] : ndarrays
768
+ Of Chebyshev series coefficients representing the quotient and
769
+ remainder.
770
+
771
+ See Also
772
+ --------
773
+ chebadd, chebsub, chebmulx, chebmul, chebpow
774
+
775
+ Notes
776
+ -----
777
+ In general, the (polynomial) division of one C-series by another
778
+ results in quotient and remainder terms that are not in the Chebyshev
779
+ polynomial basis set. Thus, to express these results as C-series, it
780
+ is typically necessary to "reproject" the results onto said basis
781
+ set, which typically produces "unintuitive" (but correct) results;
782
+ see Examples section below.
783
+
784
+ Examples
785
+ --------
786
+ >>> from numpy.polynomial import chebyshev as C
787
+ >>> c1 = (1,2,3)
788
+ >>> c2 = (3,2,1)
789
+ >>> C.chebdiv(c1,c2) # quotient "intuitive," remainder not
790
+ (array([3.]), array([-8., -4.]))
791
+ >>> c2 = (0,1,2,3)
792
+ >>> C.chebdiv(c2,c1) # neither "intuitive"
793
+ (array([0., 2.]), array([-2., -4.]))
794
+
795
+ """
796
+ # c1, c2 are trimmed copies
797
+ [c1, c2] = pu.as_series([c1, c2])
798
+ if c2[-1] == 0:
799
+ raise ZeroDivisionError()
800
+
801
+ # note: this is more efficient than `pu._div(chebmul, c1, c2)`
802
+ lc1 = len(c1)
803
+ lc2 = len(c2)
804
+ if lc1 < lc2:
805
+ return c1[:1]*0, c1
806
+ elif lc2 == 1:
807
+ return c1/c2[-1], c1[:1]*0
808
+ else:
809
+ z1 = _cseries_to_zseries(c1)
810
+ z2 = _cseries_to_zseries(c2)
811
+ quo, rem = _zseries_div(z1, z2)
812
+ quo = pu.trimseq(_zseries_to_cseries(quo))
813
+ rem = pu.trimseq(_zseries_to_cseries(rem))
814
+ return quo, rem
815
+
816
+
817
+ def chebpow(c, pow, maxpower=16):
818
+ """Raise a Chebyshev series to a power.
819
+
820
+ Returns the Chebyshev series `c` raised to the power `pow`. The
821
+ argument `c` is a sequence of coefficients ordered from low to high.
822
+ i.e., [1,2,3] is the series ``T_0 + 2*T_1 + 3*T_2.``
823
+
824
+ Parameters
825
+ ----------
826
+ c : array_like
827
+ 1-D array of Chebyshev series coefficients ordered from low to
828
+ high.
829
+ pow : integer
830
+ Power to which the series will be raised
831
+ maxpower : integer, optional
832
+ Maximum power allowed. This is mainly to limit growth of the series
833
+ to unmanageable size. Default is 16
834
+
835
+ Returns
836
+ -------
837
+ coef : ndarray
838
+ Chebyshev series of power.
839
+
840
+ See Also
841
+ --------
842
+ chebadd, chebsub, chebmulx, chebmul, chebdiv
843
+
844
+ Examples
845
+ --------
846
+ >>> from numpy.polynomial import chebyshev as C
847
+ >>> C.chebpow([1, 2, 3, 4], 2)
848
+ array([15.5, 22. , 16. , ..., 12.5, 12. , 8. ])
849
+
850
+ """
851
+ # note: this is more efficient than `pu._pow(chebmul, c1, c2)`, as it
852
+ # avoids converting between z and c series repeatedly
853
+
854
+ # c is a trimmed copy
855
+ [c] = pu.as_series([c])
856
+ power = int(pow)
857
+ if power != pow or power < 0:
858
+ raise ValueError("Power must be a non-negative integer.")
859
+ elif maxpower is not None and power > maxpower:
860
+ raise ValueError("Power is too large")
861
+ elif power == 0:
862
+ return np.array([1], dtype=c.dtype)
863
+ elif power == 1:
864
+ return c
865
+ else:
866
+ # This can be made more efficient by using powers of two
867
+ # in the usual way.
868
+ zs = _cseries_to_zseries(c)
869
+ prd = zs
870
+ for i in range(2, power + 1):
871
+ prd = np.convolve(prd, zs)
872
+ return _zseries_to_cseries(prd)
873
+
874
+
875
+ def chebder(c, m=1, scl=1, axis=0):
876
+ """
877
+ Differentiate a Chebyshev series.
878
+
879
+ Returns the Chebyshev series coefficients `c` differentiated `m` times
880
+ along `axis`. At each iteration the result is multiplied by `scl` (the
881
+ scaling factor is for use in a linear change of variable). The argument
882
+ `c` is an array of coefficients from low to high degree along each
883
+ axis, e.g., [1,2,3] represents the series ``1*T_0 + 2*T_1 + 3*T_2``
884
+ while [[1,2],[1,2]] represents ``1*T_0(x)*T_0(y) + 1*T_1(x)*T_0(y) +
885
+ 2*T_0(x)*T_1(y) + 2*T_1(x)*T_1(y)`` if axis=0 is ``x`` and axis=1 is
886
+ ``y``.
887
+
888
+ Parameters
889
+ ----------
890
+ c : array_like
891
+ Array of Chebyshev series coefficients. If c is multidimensional
892
+ the different axis correspond to different variables with the
893
+ degree in each axis given by the corresponding index.
894
+ m : int, optional
895
+ Number of derivatives taken, must be non-negative. (Default: 1)
896
+ scl : scalar, optional
897
+ Each differentiation is multiplied by `scl`. The end result is
898
+ multiplication by ``scl**m``. This is for use in a linear change of
899
+ variable. (Default: 1)
900
+ axis : int, optional
901
+ Axis over which the derivative is taken. (Default: 0).
902
+
903
+ .. versionadded:: 1.7.0
904
+
905
+ Returns
906
+ -------
907
+ der : ndarray
908
+ Chebyshev series of the derivative.
909
+
910
+ See Also
911
+ --------
912
+ chebint
913
+
914
+ Notes
915
+ -----
916
+ In general, the result of differentiating a C-series needs to be
917
+ "reprojected" onto the C-series basis set. Thus, typically, the
918
+ result of this function is "unintuitive," albeit correct; see Examples
919
+ section below.
920
+
921
+ Examples
922
+ --------
923
+ >>> from numpy.polynomial import chebyshev as C
924
+ >>> c = (1,2,3,4)
925
+ >>> C.chebder(c)
926
+ array([14., 12., 24.])
927
+ >>> C.chebder(c,3)
928
+ array([96.])
929
+ >>> C.chebder(c,scl=-1)
930
+ array([-14., -12., -24.])
931
+ >>> C.chebder(c,2,-1)
932
+ array([12., 96.])
933
+
934
+ """
935
+ c = np.array(c, ndmin=1, copy=True)
936
+ if c.dtype.char in '?bBhHiIlLqQpP':
937
+ c = c.astype(np.double)
938
+ cnt = pu._deprecate_as_int(m, "the order of derivation")
939
+ iaxis = pu._deprecate_as_int(axis, "the axis")
940
+ if cnt < 0:
941
+ raise ValueError("The order of derivation must be non-negative")
942
+ iaxis = normalize_axis_index(iaxis, c.ndim)
943
+
944
+ if cnt == 0:
945
+ return c
946
+
947
+ c = np.moveaxis(c, iaxis, 0)
948
+ n = len(c)
949
+ if cnt >= n:
950
+ c = c[:1]*0
951
+ else:
952
+ for i in range(cnt):
953
+ n = n - 1
954
+ c *= scl
955
+ der = np.empty((n,) + c.shape[1:], dtype=c.dtype)
956
+ for j in range(n, 2, -1):
957
+ der[j - 1] = (2*j)*c[j]
958
+ c[j - 2] += (j*c[j])/(j - 2)
959
+ if n > 1:
960
+ der[1] = 4*c[2]
961
+ der[0] = c[1]
962
+ c = der
963
+ c = np.moveaxis(c, 0, iaxis)
964
+ return c
965
+
966
+
967
+ def chebint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
968
+ """
969
+ Integrate a Chebyshev series.
970
+
971
+ Returns the Chebyshev series coefficients `c` integrated `m` times from
972
+ `lbnd` along `axis`. At each iteration the resulting series is
973
+ **multiplied** by `scl` and an integration constant, `k`, is added.
974
+ The scaling factor is for use in a linear change of variable. ("Buyer
975
+ beware": note that, depending on what one is doing, one may want `scl`
976
+ to be the reciprocal of what one might expect; for more information,
977
+ see the Notes section below.) The argument `c` is an array of
978
+ coefficients from low to high degree along each axis, e.g., [1,2,3]
979
+ represents the series ``T_0 + 2*T_1 + 3*T_2`` while [[1,2],[1,2]]
980
+ represents ``1*T_0(x)*T_0(y) + 1*T_1(x)*T_0(y) + 2*T_0(x)*T_1(y) +
981
+ 2*T_1(x)*T_1(y)`` if axis=0 is ``x`` and axis=1 is ``y``.
982
+
983
+ Parameters
984
+ ----------
985
+ c : array_like
986
+ Array of Chebyshev series coefficients. If c is multidimensional
987
+ the different axis correspond to different variables with the
988
+ degree in each axis given by the corresponding index.
989
+ m : int, optional
990
+ Order of integration, must be positive. (Default: 1)
991
+ k : {[], list, scalar}, optional
992
+ Integration constant(s). The value of the first integral at zero
993
+ is the first value in the list, the value of the second integral
994
+ at zero is the second value, etc. If ``k == []`` (the default),
995
+ all constants are set to zero. If ``m == 1``, a single scalar can
996
+ be given instead of a list.
997
+ lbnd : scalar, optional
998
+ The lower bound of the integral. (Default: 0)
999
+ scl : scalar, optional
1000
+ Following each integration the result is *multiplied* by `scl`
1001
+ before the integration constant is added. (Default: 1)
1002
+ axis : int, optional
1003
+ Axis over which the integral is taken. (Default: 0).
1004
+
1005
+ .. versionadded:: 1.7.0
1006
+
1007
+ Returns
1008
+ -------
1009
+ S : ndarray
1010
+ C-series coefficients of the integral.
1011
+
1012
+ Raises
1013
+ ------
1014
+ ValueError
1015
+ If ``m < 1``, ``len(k) > m``, ``np.ndim(lbnd) != 0``, or
1016
+ ``np.ndim(scl) != 0``.
1017
+
1018
+ See Also
1019
+ --------
1020
+ chebder
1021
+
1022
+ Notes
1023
+ -----
1024
+ Note that the result of each integration is *multiplied* by `scl`.
1025
+ Why is this important to note? Say one is making a linear change of
1026
+ variable :math:`u = ax + b` in an integral relative to `x`. Then
1027
+ :math:`dx = du/a`, so one will need to set `scl` equal to
1028
+ :math:`1/a`- perhaps not what one would have first thought.
1029
+
1030
+ Also note that, in general, the result of integrating a C-series needs
1031
+ to be "reprojected" onto the C-series basis set. Thus, typically,
1032
+ the result of this function is "unintuitive," albeit correct; see
1033
+ Examples section below.
1034
+
1035
+ Examples
1036
+ --------
1037
+ >>> from numpy.polynomial import chebyshev as C
1038
+ >>> c = (1,2,3)
1039
+ >>> C.chebint(c)
1040
+ array([ 0.5, -0.5, 0.5, 0.5])
1041
+ >>> C.chebint(c,3)
1042
+ array([ 0.03125 , -0.1875 , 0.04166667, -0.05208333, 0.01041667, # may vary
1043
+ 0.00625 ])
1044
+ >>> C.chebint(c, k=3)
1045
+ array([ 3.5, -0.5, 0.5, 0.5])
1046
+ >>> C.chebint(c,lbnd=-2)
1047
+ array([ 8.5, -0.5, 0.5, 0.5])
1048
+ >>> C.chebint(c,scl=-2)
1049
+ array([-1., 1., -1., -1.])
1050
+
1051
+ """
1052
+ c = np.array(c, ndmin=1, copy=True)
1053
+ if c.dtype.char in '?bBhHiIlLqQpP':
1054
+ c = c.astype(np.double)
1055
+ if not np.iterable(k):
1056
+ k = [k]
1057
+ cnt = pu._deprecate_as_int(m, "the order of integration")
1058
+ iaxis = pu._deprecate_as_int(axis, "the axis")
1059
+ if cnt < 0:
1060
+ raise ValueError("The order of integration must be non-negative")
1061
+ if len(k) > cnt:
1062
+ raise ValueError("Too many integration constants")
1063
+ if np.ndim(lbnd) != 0:
1064
+ raise ValueError("lbnd must be a scalar.")
1065
+ if np.ndim(scl) != 0:
1066
+ raise ValueError("scl must be a scalar.")
1067
+ iaxis = normalize_axis_index(iaxis, c.ndim)
1068
+
1069
+ if cnt == 0:
1070
+ return c
1071
+
1072
+ c = np.moveaxis(c, iaxis, 0)
1073
+ k = list(k) + [0]*(cnt - len(k))
1074
+ for i in range(cnt):
1075
+ n = len(c)
1076
+ c *= scl
1077
+ if n == 1 and np.all(c[0] == 0):
1078
+ c[0] += k[i]
1079
+ else:
1080
+ tmp = np.empty((n + 1,) + c.shape[1:], dtype=c.dtype)
1081
+ tmp[0] = c[0]*0
1082
+ tmp[1] = c[0]
1083
+ if n > 1:
1084
+ tmp[2] = c[1]/4
1085
+ for j in range(2, n):
1086
+ tmp[j + 1] = c[j]/(2*(j + 1))
1087
+ tmp[j - 1] -= c[j]/(2*(j - 1))
1088
+ tmp[0] += k[i] - chebval(lbnd, tmp)
1089
+ c = tmp
1090
+ c = np.moveaxis(c, 0, iaxis)
1091
+ return c
1092
+
1093
+
1094
+ def chebval(x, c, tensor=True):
1095
+ """
1096
+ Evaluate a Chebyshev series at points x.
1097
+
1098
+ If `c` is of length `n + 1`, this function returns the value:
1099
+
1100
+ .. math:: p(x) = c_0 * T_0(x) + c_1 * T_1(x) + ... + c_n * T_n(x)
1101
+
1102
+ The parameter `x` is converted to an array only if it is a tuple or a
1103
+ list, otherwise it is treated as a scalar. In either case, either `x`
1104
+ or its elements must support multiplication and addition both with
1105
+ themselves and with the elements of `c`.
1106
+
1107
+ If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
1108
+ `c` is multidimensional, then the shape of the result depends on the
1109
+ value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
1110
+ x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
1111
+ scalars have shape (,).
1112
+
1113
+ Trailing zeros in the coefficients will be used in the evaluation, so
1114
+ they should be avoided if efficiency is a concern.
1115
+
1116
+ Parameters
1117
+ ----------
1118
+ x : array_like, compatible object
1119
+ If `x` is a list or tuple, it is converted to an ndarray, otherwise
1120
+ it is left unchanged and treated as a scalar. In either case, `x`
1121
+ or its elements must support addition and multiplication with
1122
+ themselves and with the elements of `c`.
1123
+ c : array_like
1124
+ Array of coefficients ordered so that the coefficients for terms of
1125
+ degree n are contained in c[n]. If `c` is multidimensional the
1126
+ remaining indices enumerate multiple polynomials. In the two
1127
+ dimensional case the coefficients may be thought of as stored in
1128
+ the columns of `c`.
1129
+ tensor : boolean, optional
1130
+ If True, the shape of the coefficient array is extended with ones
1131
+ on the right, one for each dimension of `x`. Scalars have dimension 0
1132
+ for this action. The result is that every column of coefficients in
1133
+ `c` is evaluated for every element of `x`. If False, `x` is broadcast
1134
+ over the columns of `c` for the evaluation. This keyword is useful
1135
+ when `c` is multidimensional. The default value is True.
1136
+
1137
+ .. versionadded:: 1.7.0
1138
+
1139
+ Returns
1140
+ -------
1141
+ values : ndarray, algebra_like
1142
+ The shape of the return value is described above.
1143
+
1144
+ See Also
1145
+ --------
1146
+ chebval2d, chebgrid2d, chebval3d, chebgrid3d
1147
+
1148
+ Notes
1149
+ -----
1150
+ The evaluation uses Clenshaw recursion, aka synthetic division.
1151
+
1152
+ """
1153
+ c = np.array(c, ndmin=1, copy=True)
1154
+ if c.dtype.char in '?bBhHiIlLqQpP':
1155
+ c = c.astype(np.double)
1156
+ if isinstance(x, (tuple, list)):
1157
+ x = np.asarray(x)
1158
+ if isinstance(x, np.ndarray) and tensor:
1159
+ c = c.reshape(c.shape + (1,)*x.ndim)
1160
+
1161
+ if len(c) == 1:
1162
+ c0 = c[0]
1163
+ c1 = 0
1164
+ elif len(c) == 2:
1165
+ c0 = c[0]
1166
+ c1 = c[1]
1167
+ else:
1168
+ x2 = 2*x
1169
+ c0 = c[-2]
1170
+ c1 = c[-1]
1171
+ for i in range(3, len(c) + 1):
1172
+ tmp = c0
1173
+ c0 = c[-i] - c1
1174
+ c1 = tmp + c1*x2
1175
+ return c0 + c1*x
1176
+
1177
+
1178
+ def chebval2d(x, y, c):
1179
+ """
1180
+ Evaluate a 2-D Chebyshev series at points (x, y).
1181
+
1182
+ This function returns the values:
1183
+
1184
+ .. math:: p(x,y) = \\sum_{i,j} c_{i,j} * T_i(x) * T_j(y)
1185
+
1186
+ The parameters `x` and `y` are converted to arrays only if they are
1187
+ tuples or a lists, otherwise they are treated as a scalars and they
1188
+ must have the same shape after conversion. In either case, either `x`
1189
+ and `y` or their elements must support multiplication and addition both
1190
+ with themselves and with the elements of `c`.
1191
+
1192
+ If `c` is a 1-D array a one is implicitly appended to its shape to make
1193
+ it 2-D. The shape of the result will be c.shape[2:] + x.shape.
1194
+
1195
+ Parameters
1196
+ ----------
1197
+ x, y : array_like, compatible objects
1198
+ The two dimensional series is evaluated at the points `(x, y)`,
1199
+ where `x` and `y` must have the same shape. If `x` or `y` is a list
1200
+ or tuple, it is first converted to an ndarray, otherwise it is left
1201
+ unchanged and if it isn't an ndarray it is treated as a scalar.
1202
+ c : array_like
1203
+ Array of coefficients ordered so that the coefficient of the term
1204
+ of multi-degree i,j is contained in ``c[i,j]``. If `c` has
1205
+ dimension greater than 2 the remaining indices enumerate multiple
1206
+ sets of coefficients.
1207
+
1208
+ Returns
1209
+ -------
1210
+ values : ndarray, compatible object
1211
+ The values of the two dimensional Chebyshev series at points formed
1212
+ from pairs of corresponding values from `x` and `y`.
1213
+
1214
+ See Also
1215
+ --------
1216
+ chebval, chebgrid2d, chebval3d, chebgrid3d
1217
+
1218
+ Notes
1219
+ -----
1220
+
1221
+ .. versionadded:: 1.7.0
1222
+
1223
+ """
1224
+ return pu._valnd(chebval, c, x, y)
1225
+
1226
+
1227
+ def chebgrid2d(x, y, c):
1228
+ """
1229
+ Evaluate a 2-D Chebyshev series on the Cartesian product of x and y.
1230
+
1231
+ This function returns the values:
1232
+
1233
+ .. math:: p(a,b) = \\sum_{i,j} c_{i,j} * T_i(a) * T_j(b),
1234
+
1235
+ where the points `(a, b)` consist of all pairs formed by taking
1236
+ `a` from `x` and `b` from `y`. The resulting points form a grid with
1237
+ `x` in the first dimension and `y` in the second.
1238
+
1239
+ The parameters `x` and `y` are converted to arrays only if they are
1240
+ tuples or a lists, otherwise they are treated as a scalars. In either
1241
+ case, either `x` and `y` or their elements must support multiplication
1242
+ and addition both with themselves and with the elements of `c`.
1243
+
1244
+ If `c` has fewer than two dimensions, ones are implicitly appended to
1245
+ its shape to make it 2-D. The shape of the result will be c.shape[2:] +
1246
+ x.shape + y.shape.
1247
+
1248
+ Parameters
1249
+ ----------
1250
+ x, y : array_like, compatible objects
1251
+ The two dimensional series is evaluated at the points in the
1252
+ Cartesian product of `x` and `y`. If `x` or `y` is a list or
1253
+ tuple, it is first converted to an ndarray, otherwise it is left
1254
+ unchanged and, if it isn't an ndarray, it is treated as a scalar.
1255
+ c : array_like
1256
+ Array of coefficients ordered so that the coefficient of the term of
1257
+ multi-degree i,j is contained in `c[i,j]`. If `c` has dimension
1258
+ greater than two the remaining indices enumerate multiple sets of
1259
+ coefficients.
1260
+
1261
+ Returns
1262
+ -------
1263
+ values : ndarray, compatible object
1264
+ The values of the two dimensional Chebyshev series at points in the
1265
+ Cartesian product of `x` and `y`.
1266
+
1267
+ See Also
1268
+ --------
1269
+ chebval, chebval2d, chebval3d, chebgrid3d
1270
+
1271
+ Notes
1272
+ -----
1273
+
1274
+ .. versionadded:: 1.7.0
1275
+
1276
+ """
1277
+ return pu._gridnd(chebval, c, x, y)
1278
+
1279
+
1280
+ def chebval3d(x, y, z, c):
1281
+ """
1282
+ Evaluate a 3-D Chebyshev series at points (x, y, z).
1283
+
1284
+ This function returns the values:
1285
+
1286
+ .. math:: p(x,y,z) = \\sum_{i,j,k} c_{i,j,k} * T_i(x) * T_j(y) * T_k(z)
1287
+
1288
+ The parameters `x`, `y`, and `z` are converted to arrays only if
1289
+ they are tuples or a lists, otherwise they are treated as a scalars and
1290
+ they must have the same shape after conversion. In either case, either
1291
+ `x`, `y`, and `z` or their elements must support multiplication and
1292
+ addition both with themselves and with the elements of `c`.
1293
+
1294
+ If `c` has fewer than 3 dimensions, ones are implicitly appended to its
1295
+ shape to make it 3-D. The shape of the result will be c.shape[3:] +
1296
+ x.shape.
1297
+
1298
+ Parameters
1299
+ ----------
1300
+ x, y, z : array_like, compatible object
1301
+ The three dimensional series is evaluated at the points
1302
+ `(x, y, z)`, where `x`, `y`, and `z` must have the same shape. If
1303
+ any of `x`, `y`, or `z` is a list or tuple, it is first converted
1304
+ to an ndarray, otherwise it is left unchanged and if it isn't an
1305
+ ndarray it is treated as a scalar.
1306
+ c : array_like
1307
+ Array of coefficients ordered so that the coefficient of the term of
1308
+ multi-degree i,j,k is contained in ``c[i,j,k]``. If `c` has dimension
1309
+ greater than 3 the remaining indices enumerate multiple sets of
1310
+ coefficients.
1311
+
1312
+ Returns
1313
+ -------
1314
+ values : ndarray, compatible object
1315
+ The values of the multidimensional polynomial on points formed with
1316
+ triples of corresponding values from `x`, `y`, and `z`.
1317
+
1318
+ See Also
1319
+ --------
1320
+ chebval, chebval2d, chebgrid2d, chebgrid3d
1321
+
1322
+ Notes
1323
+ -----
1324
+
1325
+ .. versionadded:: 1.7.0
1326
+
1327
+ """
1328
+ return pu._valnd(chebval, c, x, y, z)
1329
+
1330
+
1331
+ def chebgrid3d(x, y, z, c):
1332
+ """
1333
+ Evaluate a 3-D Chebyshev series on the Cartesian product of x, y, and z.
1334
+
1335
+ This function returns the values:
1336
+
1337
+ .. math:: p(a,b,c) = \\sum_{i,j,k} c_{i,j,k} * T_i(a) * T_j(b) * T_k(c)
1338
+
1339
+ where the points `(a, b, c)` consist of all triples formed by taking
1340
+ `a` from `x`, `b` from `y`, and `c` from `z`. The resulting points form
1341
+ a grid with `x` in the first dimension, `y` in the second, and `z` in
1342
+ the third.
1343
+
1344
+ The parameters `x`, `y`, and `z` are converted to arrays only if they
1345
+ are tuples or a lists, otherwise they are treated as a scalars. In
1346
+ either case, either `x`, `y`, and `z` or their elements must support
1347
+ multiplication and addition both with themselves and with the elements
1348
+ of `c`.
1349
+
1350
+ If `c` has fewer than three dimensions, ones are implicitly appended to
1351
+ its shape to make it 3-D. The shape of the result will be c.shape[3:] +
1352
+ x.shape + y.shape + z.shape.
1353
+
1354
+ Parameters
1355
+ ----------
1356
+ x, y, z : array_like, compatible objects
1357
+ The three dimensional series is evaluated at the points in the
1358
+ Cartesian product of `x`, `y`, and `z`. If `x`,`y`, or `z` is a
1359
+ list or tuple, it is first converted to an ndarray, otherwise it is
1360
+ left unchanged and, if it isn't an ndarray, it is treated as a
1361
+ scalar.
1362
+ c : array_like
1363
+ Array of coefficients ordered so that the coefficients for terms of
1364
+ degree i,j are contained in ``c[i,j]``. If `c` has dimension
1365
+ greater than two the remaining indices enumerate multiple sets of
1366
+ coefficients.
1367
+
1368
+ Returns
1369
+ -------
1370
+ values : ndarray, compatible object
1371
+ The values of the two dimensional polynomial at points in the Cartesian
1372
+ product of `x` and `y`.
1373
+
1374
+ See Also
1375
+ --------
1376
+ chebval, chebval2d, chebgrid2d, chebval3d
1377
+
1378
+ Notes
1379
+ -----
1380
+
1381
+ .. versionadded:: 1.7.0
1382
+
1383
+ """
1384
+ return pu._gridnd(chebval, c, x, y, z)
1385
+
1386
+
1387
+ def chebvander(x, deg):
1388
+ """Pseudo-Vandermonde matrix of given degree.
1389
+
1390
+ Returns the pseudo-Vandermonde matrix of degree `deg` and sample points
1391
+ `x`. The pseudo-Vandermonde matrix is defined by
1392
+
1393
+ .. math:: V[..., i] = T_i(x),
1394
+
1395
+ where `0 <= i <= deg`. The leading indices of `V` index the elements of
1396
+ `x` and the last index is the degree of the Chebyshev polynomial.
1397
+
1398
+ If `c` is a 1-D array of coefficients of length `n + 1` and `V` is the
1399
+ matrix ``V = chebvander(x, n)``, then ``np.dot(V, c)`` and
1400
+ ``chebval(x, c)`` are the same up to roundoff. This equivalence is
1401
+ useful both for least squares fitting and for the evaluation of a large
1402
+ number of Chebyshev series of the same degree and sample points.
1403
+
1404
+ Parameters
1405
+ ----------
1406
+ x : array_like
1407
+ Array of points. The dtype is converted to float64 or complex128
1408
+ depending on whether any of the elements are complex. If `x` is
1409
+ scalar it is converted to a 1-D array.
1410
+ deg : int
1411
+ Degree of the resulting matrix.
1412
+
1413
+ Returns
1414
+ -------
1415
+ vander : ndarray
1416
+ The pseudo Vandermonde matrix. The shape of the returned matrix is
1417
+ ``x.shape + (deg + 1,)``, where The last index is the degree of the
1418
+ corresponding Chebyshev polynomial. The dtype will be the same as
1419
+ the converted `x`.
1420
+
1421
+ """
1422
+ ideg = pu._deprecate_as_int(deg, "deg")
1423
+ if ideg < 0:
1424
+ raise ValueError("deg must be non-negative")
1425
+
1426
+ x = np.array(x, copy=False, ndmin=1) + 0.0
1427
+ dims = (ideg + 1,) + x.shape
1428
+ dtyp = x.dtype
1429
+ v = np.empty(dims, dtype=dtyp)
1430
+ # Use forward recursion to generate the entries.
1431
+ v[0] = x*0 + 1
1432
+ if ideg > 0:
1433
+ x2 = 2*x
1434
+ v[1] = x
1435
+ for i in range(2, ideg + 1):
1436
+ v[i] = v[i-1]*x2 - v[i-2]
1437
+ return np.moveaxis(v, 0, -1)
1438
+
1439
+
1440
+ def chebvander2d(x, y, deg):
1441
+ """Pseudo-Vandermonde matrix of given degrees.
1442
+
1443
+ Returns the pseudo-Vandermonde matrix of degrees `deg` and sample
1444
+ points `(x, y)`. The pseudo-Vandermonde matrix is defined by
1445
+
1446
+ .. math:: V[..., (deg[1] + 1)*i + j] = T_i(x) * T_j(y),
1447
+
1448
+ where `0 <= i <= deg[0]` and `0 <= j <= deg[1]`. The leading indices of
1449
+ `V` index the points `(x, y)` and the last index encodes the degrees of
1450
+ the Chebyshev polynomials.
1451
+
1452
+ If ``V = chebvander2d(x, y, [xdeg, ydeg])``, then the columns of `V`
1453
+ correspond to the elements of a 2-D coefficient array `c` of shape
1454
+ (xdeg + 1, ydeg + 1) in the order
1455
+
1456
+ .. math:: c_{00}, c_{01}, c_{02} ... , c_{10}, c_{11}, c_{12} ...
1457
+
1458
+ and ``np.dot(V, c.flat)`` and ``chebval2d(x, y, c)`` will be the same
1459
+ up to roundoff. This equivalence is useful both for least squares
1460
+ fitting and for the evaluation of a large number of 2-D Chebyshev
1461
+ series of the same degrees and sample points.
1462
+
1463
+ Parameters
1464
+ ----------
1465
+ x, y : array_like
1466
+ Arrays of point coordinates, all of the same shape. The dtypes
1467
+ will be converted to either float64 or complex128 depending on
1468
+ whether any of the elements are complex. Scalars are converted to
1469
+ 1-D arrays.
1470
+ deg : list of ints
1471
+ List of maximum degrees of the form [x_deg, y_deg].
1472
+
1473
+ Returns
1474
+ -------
1475
+ vander2d : ndarray
1476
+ The shape of the returned matrix is ``x.shape + (order,)``, where
1477
+ :math:`order = (deg[0]+1)*(deg[1]+1)`. The dtype will be the same
1478
+ as the converted `x` and `y`.
1479
+
1480
+ See Also
1481
+ --------
1482
+ chebvander, chebvander3d, chebval2d, chebval3d
1483
+
1484
+ Notes
1485
+ -----
1486
+
1487
+ .. versionadded:: 1.7.0
1488
+
1489
+ """
1490
+ return pu._vander_nd_flat((chebvander, chebvander), (x, y), deg)
1491
+
1492
+
1493
+ def chebvander3d(x, y, z, deg):
1494
+ """Pseudo-Vandermonde matrix of given degrees.
1495
+
1496
+ Returns the pseudo-Vandermonde matrix of degrees `deg` and sample
1497
+ points `(x, y, z)`. If `l, m, n` are the given degrees in `x, y, z`,
1498
+ then The pseudo-Vandermonde matrix is defined by
1499
+
1500
+ .. math:: V[..., (m+1)(n+1)i + (n+1)j + k] = T_i(x)*T_j(y)*T_k(z),
1501
+
1502
+ where `0 <= i <= l`, `0 <= j <= m`, and `0 <= j <= n`. The leading
1503
+ indices of `V` index the points `(x, y, z)` and the last index encodes
1504
+ the degrees of the Chebyshev polynomials.
1505
+
1506
+ If ``V = chebvander3d(x, y, z, [xdeg, ydeg, zdeg])``, then the columns
1507
+ of `V` correspond to the elements of a 3-D coefficient array `c` of
1508
+ shape (xdeg + 1, ydeg + 1, zdeg + 1) in the order
1509
+
1510
+ .. math:: c_{000}, c_{001}, c_{002},... , c_{010}, c_{011}, c_{012},...
1511
+
1512
+ and ``np.dot(V, c.flat)`` and ``chebval3d(x, y, z, c)`` will be the
1513
+ same up to roundoff. This equivalence is useful both for least squares
1514
+ fitting and for the evaluation of a large number of 3-D Chebyshev
1515
+ series of the same degrees and sample points.
1516
+
1517
+ Parameters
1518
+ ----------
1519
+ x, y, z : array_like
1520
+ Arrays of point coordinates, all of the same shape. The dtypes will
1521
+ be converted to either float64 or complex128 depending on whether
1522
+ any of the elements are complex. Scalars are converted to 1-D
1523
+ arrays.
1524
+ deg : list of ints
1525
+ List of maximum degrees of the form [x_deg, y_deg, z_deg].
1526
+
1527
+ Returns
1528
+ -------
1529
+ vander3d : ndarray
1530
+ The shape of the returned matrix is ``x.shape + (order,)``, where
1531
+ :math:`order = (deg[0]+1)*(deg[1]+1)*(deg[2]+1)`. The dtype will
1532
+ be the same as the converted `x`, `y`, and `z`.
1533
+
1534
+ See Also
1535
+ --------
1536
+ chebvander, chebvander3d, chebval2d, chebval3d
1537
+
1538
+ Notes
1539
+ -----
1540
+
1541
+ .. versionadded:: 1.7.0
1542
+
1543
+ """
1544
+ return pu._vander_nd_flat((chebvander, chebvander, chebvander), (x, y, z), deg)
1545
+
1546
+
1547
+ def chebfit(x, y, deg, rcond=None, full=False, w=None):
1548
+ """
1549
+ Least squares fit of Chebyshev series to data.
1550
+
1551
+ Return the coefficients of a Chebyshev series of degree `deg` that is the
1552
+ least squares fit to the data values `y` given at points `x`. If `y` is
1553
+ 1-D the returned coefficients will also be 1-D. If `y` is 2-D multiple
1554
+ fits are done, one for each column of `y`, and the resulting
1555
+ coefficients are stored in the corresponding columns of a 2-D return.
1556
+ The fitted polynomial(s) are in the form
1557
+
1558
+ .. math:: p(x) = c_0 + c_1 * T_1(x) + ... + c_n * T_n(x),
1559
+
1560
+ where `n` is `deg`.
1561
+
1562
+ Parameters
1563
+ ----------
1564
+ x : array_like, shape (M,)
1565
+ x-coordinates of the M sample points ``(x[i], y[i])``.
1566
+ y : array_like, shape (M,) or (M, K)
1567
+ y-coordinates of the sample points. Several data sets of sample
1568
+ points sharing the same x-coordinates can be fitted at once by
1569
+ passing in a 2D-array that contains one dataset per column.
1570
+ deg : int or 1-D array_like
1571
+ Degree(s) of the fitting polynomials. If `deg` is a single integer,
1572
+ all terms up to and including the `deg`'th term are included in the
1573
+ fit. For NumPy versions >= 1.11.0 a list of integers specifying the
1574
+ degrees of the terms to include may be used instead.
1575
+ rcond : float, optional
1576
+ Relative condition number of the fit. Singular values smaller than
1577
+ this relative to the largest singular value will be ignored. The
1578
+ default value is len(x)*eps, where eps is the relative precision of
1579
+ the float type, about 2e-16 in most cases.
1580
+ full : bool, optional
1581
+ Switch determining nature of return value. When it is False (the
1582
+ default) just the coefficients are returned, when True diagnostic
1583
+ information from the singular value decomposition is also returned.
1584
+ w : array_like, shape (`M`,), optional
1585
+ Weights. If not None, the weight ``w[i]`` applies to the unsquared
1586
+ residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are
1587
+ chosen so that the errors of the products ``w[i]*y[i]`` all have the
1588
+ same variance. When using inverse-variance weighting, use
1589
+ ``w[i] = 1/sigma(y[i])``. The default value is None.
1590
+
1591
+ .. versionadded:: 1.5.0
1592
+
1593
+ Returns
1594
+ -------
1595
+ coef : ndarray, shape (M,) or (M, K)
1596
+ Chebyshev coefficients ordered from low to high. If `y` was 2-D,
1597
+ the coefficients for the data in column k of `y` are in column
1598
+ `k`.
1599
+
1600
+ [residuals, rank, singular_values, rcond] : list
1601
+ These values are only returned if ``full == True``
1602
+
1603
+ - residuals -- sum of squared residuals of the least squares fit
1604
+ - rank -- the numerical rank of the scaled Vandermonde matrix
1605
+ - singular_values -- singular values of the scaled Vandermonde matrix
1606
+ - rcond -- value of `rcond`.
1607
+
1608
+ For more details, see `numpy.linalg.lstsq`.
1609
+
1610
+ Warns
1611
+ -----
1612
+ RankWarning
1613
+ The rank of the coefficient matrix in the least-squares fit is
1614
+ deficient. The warning is only raised if ``full == False``. The
1615
+ warnings can be turned off by
1616
+
1617
+ >>> import warnings
1618
+ >>> warnings.simplefilter('ignore', np.RankWarning)
1619
+
1620
+ See Also
1621
+ --------
1622
+ numpy.polynomial.polynomial.polyfit
1623
+ numpy.polynomial.legendre.legfit
1624
+ numpy.polynomial.laguerre.lagfit
1625
+ numpy.polynomial.hermite.hermfit
1626
+ numpy.polynomial.hermite_e.hermefit
1627
+ chebval : Evaluates a Chebyshev series.
1628
+ chebvander : Vandermonde matrix of Chebyshev series.
1629
+ chebweight : Chebyshev weight function.
1630
+ numpy.linalg.lstsq : Computes a least-squares fit from the matrix.
1631
+ scipy.interpolate.UnivariateSpline : Computes spline fits.
1632
+
1633
+ Notes
1634
+ -----
1635
+ The solution is the coefficients of the Chebyshev series `p` that
1636
+ minimizes the sum of the weighted squared errors
1637
+
1638
+ .. math:: E = \\sum_j w_j^2 * |y_j - p(x_j)|^2,
1639
+
1640
+ where :math:`w_j` are the weights. This problem is solved by setting up
1641
+ as the (typically) overdetermined matrix equation
1642
+
1643
+ .. math:: V(x) * c = w * y,
1644
+
1645
+ where `V` is the weighted pseudo Vandermonde matrix of `x`, `c` are the
1646
+ coefficients to be solved for, `w` are the weights, and `y` are the
1647
+ observed values. This equation is then solved using the singular value
1648
+ decomposition of `V`.
1649
+
1650
+ If some of the singular values of `V` are so small that they are
1651
+ neglected, then a `RankWarning` will be issued. This means that the
1652
+ coefficient values may be poorly determined. Using a lower order fit
1653
+ will usually get rid of the warning. The `rcond` parameter can also be
1654
+ set to a value smaller than its default, but the resulting fit may be
1655
+ spurious and have large contributions from roundoff error.
1656
+
1657
+ Fits using Chebyshev series are usually better conditioned than fits
1658
+ using power series, but much can depend on the distribution of the
1659
+ sample points and the smoothness of the data. If the quality of the fit
1660
+ is inadequate splines may be a good alternative.
1661
+
1662
+ References
1663
+ ----------
1664
+ .. [1] Wikipedia, "Curve fitting",
1665
+ https://en.wikipedia.org/wiki/Curve_fitting
1666
+
1667
+ Examples
1668
+ --------
1669
+
1670
+ """
1671
+ return pu._fit(chebvander, x, y, deg, rcond, full, w)
1672
+
1673
+
1674
+ def chebcompanion(c):
1675
+ """Return the scaled companion matrix of c.
1676
+
1677
+ The basis polynomials are scaled so that the companion matrix is
1678
+ symmetric when `c` is a Chebyshev basis polynomial. This provides
1679
+ better eigenvalue estimates than the unscaled case and for basis
1680
+ polynomials the eigenvalues are guaranteed to be real if
1681
+ `numpy.linalg.eigvalsh` is used to obtain them.
1682
+
1683
+ Parameters
1684
+ ----------
1685
+ c : array_like
1686
+ 1-D array of Chebyshev series coefficients ordered from low to high
1687
+ degree.
1688
+
1689
+ Returns
1690
+ -------
1691
+ mat : ndarray
1692
+ Scaled companion matrix of dimensions (deg, deg).
1693
+
1694
+ Notes
1695
+ -----
1696
+
1697
+ .. versionadded:: 1.7.0
1698
+
1699
+ """
1700
+ # c is a trimmed copy
1701
+ [c] = pu.as_series([c])
1702
+ if len(c) < 2:
1703
+ raise ValueError('Series must have maximum degree of at least 1.')
1704
+ if len(c) == 2:
1705
+ return np.array([[-c[0]/c[1]]])
1706
+
1707
+ n = len(c) - 1
1708
+ mat = np.zeros((n, n), dtype=c.dtype)
1709
+ scl = np.array([1.] + [np.sqrt(.5)]*(n-1))
1710
+ top = mat.reshape(-1)[1::n+1]
1711
+ bot = mat.reshape(-1)[n::n+1]
1712
+ top[0] = np.sqrt(.5)
1713
+ top[1:] = 1/2
1714
+ bot[...] = top
1715
+ mat[:, -1] -= (c[:-1]/c[-1])*(scl/scl[-1])*.5
1716
+ return mat
1717
+
1718
+
1719
+ def chebroots(c):
1720
+ """
1721
+ Compute the roots of a Chebyshev series.
1722
+
1723
+ Return the roots (a.k.a. "zeros") of the polynomial
1724
+
1725
+ .. math:: p(x) = \\sum_i c[i] * T_i(x).
1726
+
1727
+ Parameters
1728
+ ----------
1729
+ c : 1-D array_like
1730
+ 1-D array of coefficients.
1731
+
1732
+ Returns
1733
+ -------
1734
+ out : ndarray
1735
+ Array of the roots of the series. If all the roots are real,
1736
+ then `out` is also real, otherwise it is complex.
1737
+
1738
+ See Also
1739
+ --------
1740
+ numpy.polynomial.polynomial.polyroots
1741
+ numpy.polynomial.legendre.legroots
1742
+ numpy.polynomial.laguerre.lagroots
1743
+ numpy.polynomial.hermite.hermroots
1744
+ numpy.polynomial.hermite_e.hermeroots
1745
+
1746
+ Notes
1747
+ -----
1748
+ The root estimates are obtained as the eigenvalues of the companion
1749
+ matrix, Roots far from the origin of the complex plane may have large
1750
+ errors due to the numerical instability of the series for such
1751
+ values. Roots with multiplicity greater than 1 will also show larger
1752
+ errors as the value of the series near such points is relatively
1753
+ insensitive to errors in the roots. Isolated roots near the origin can
1754
+ be improved by a few iterations of Newton's method.
1755
+
1756
+ The Chebyshev series basis polynomials aren't powers of `x` so the
1757
+ results of this function may seem unintuitive.
1758
+
1759
+ Examples
1760
+ --------
1761
+ >>> import numpy.polynomial.chebyshev as cheb
1762
+ >>> cheb.chebroots((-1, 1,-1, 1)) # T3 - T2 + T1 - T0 has real roots
1763
+ array([ -5.00000000e-01, 2.60860684e-17, 1.00000000e+00]) # may vary
1764
+
1765
+ """
1766
+ # c is a trimmed copy
1767
+ [c] = pu.as_series([c])
1768
+ if len(c) < 2:
1769
+ return np.array([], dtype=c.dtype)
1770
+ if len(c) == 2:
1771
+ return np.array([-c[0]/c[1]])
1772
+
1773
+ # rotated companion matrix reduces error
1774
+ m = chebcompanion(c)[::-1,::-1]
1775
+ r = la.eigvals(m)
1776
+ r.sort()
1777
+ return r
1778
+
1779
+
1780
+ def chebinterpolate(func, deg, args=()):
1781
+ """Interpolate a function at the Chebyshev points of the first kind.
1782
+
1783
+ Returns the Chebyshev series that interpolates `func` at the Chebyshev
1784
+ points of the first kind in the interval [-1, 1]. The interpolating
1785
+ series tends to a minmax approximation to `func` with increasing `deg`
1786
+ if the function is continuous in the interval.
1787
+
1788
+ .. versionadded:: 1.14.0
1789
+
1790
+ Parameters
1791
+ ----------
1792
+ func : function
1793
+ The function to be approximated. It must be a function of a single
1794
+ variable of the form ``f(x, a, b, c...)``, where ``a, b, c...`` are
1795
+ extra arguments passed in the `args` parameter.
1796
+ deg : int
1797
+ Degree of the interpolating polynomial
1798
+ args : tuple, optional
1799
+ Extra arguments to be used in the function call. Default is no extra
1800
+ arguments.
1801
+
1802
+ Returns
1803
+ -------
1804
+ coef : ndarray, shape (deg + 1,)
1805
+ Chebyshev coefficients of the interpolating series ordered from low to
1806
+ high.
1807
+
1808
+ Examples
1809
+ --------
1810
+ >>> import numpy.polynomial.chebyshev as C
1811
+ >>> C.chebfromfunction(lambda x: np.tanh(x) + 0.5, 8)
1812
+ array([ 5.00000000e-01, 8.11675684e-01, -9.86864911e-17,
1813
+ -5.42457905e-02, -2.71387850e-16, 4.51658839e-03,
1814
+ 2.46716228e-17, -3.79694221e-04, -3.26899002e-16])
1815
+
1816
+ Notes
1817
+ -----
1818
+
1819
+ The Chebyshev polynomials used in the interpolation are orthogonal when
1820
+ sampled at the Chebyshev points of the first kind. If it is desired to
1821
+ constrain some of the coefficients they can simply be set to the desired
1822
+ value after the interpolation, no new interpolation or fit is needed. This
1823
+ is especially useful if it is known apriori that some of coefficients are
1824
+ zero. For instance, if the function is even then the coefficients of the
1825
+ terms of odd degree in the result can be set to zero.
1826
+
1827
+ """
1828
+ deg = np.asarray(deg)
1829
+
1830
+ # check arguments.
1831
+ if deg.ndim > 0 or deg.dtype.kind not in 'iu' or deg.size == 0:
1832
+ raise TypeError("deg must be an int")
1833
+ if deg < 0:
1834
+ raise ValueError("expected deg >= 0")
1835
+
1836
+ order = deg + 1
1837
+ xcheb = chebpts1(order)
1838
+ yfunc = func(xcheb, *args)
1839
+ m = chebvander(xcheb, deg)
1840
+ c = np.dot(m.T, yfunc)
1841
+ c[0] /= order
1842
+ c[1:] /= 0.5*order
1843
+
1844
+ return c
1845
+
1846
+
1847
+ def chebgauss(deg):
1848
+ """
1849
+ Gauss-Chebyshev quadrature.
1850
+
1851
+ Computes the sample points and weights for Gauss-Chebyshev quadrature.
1852
+ These sample points and weights will correctly integrate polynomials of
1853
+ degree :math:`2*deg - 1` or less over the interval :math:`[-1, 1]` with
1854
+ the weight function :math:`f(x) = 1/\\sqrt{1 - x^2}`.
1855
+
1856
+ Parameters
1857
+ ----------
1858
+ deg : int
1859
+ Number of sample points and weights. It must be >= 1.
1860
+
1861
+ Returns
1862
+ -------
1863
+ x : ndarray
1864
+ 1-D ndarray containing the sample points.
1865
+ y : ndarray
1866
+ 1-D ndarray containing the weights.
1867
+
1868
+ Notes
1869
+ -----
1870
+
1871
+ .. versionadded:: 1.7.0
1872
+
1873
+ The results have only been tested up to degree 100, higher degrees may
1874
+ be problematic. For Gauss-Chebyshev there are closed form solutions for
1875
+ the sample points and weights. If n = `deg`, then
1876
+
1877
+ .. math:: x_i = \\cos(\\pi (2 i - 1) / (2 n))
1878
+
1879
+ .. math:: w_i = \\pi / n
1880
+
1881
+ """
1882
+ ideg = pu._deprecate_as_int(deg, "deg")
1883
+ if ideg <= 0:
1884
+ raise ValueError("deg must be a positive integer")
1885
+
1886
+ x = np.cos(np.pi * np.arange(1, 2*ideg, 2) / (2.0*ideg))
1887
+ w = np.ones(ideg)*(np.pi/ideg)
1888
+
1889
+ return x, w
1890
+
1891
+
1892
+ def chebweight(x):
1893
+ """
1894
+ The weight function of the Chebyshev polynomials.
1895
+
1896
+ The weight function is :math:`1/\\sqrt{1 - x^2}` and the interval of
1897
+ integration is :math:`[-1, 1]`. The Chebyshev polynomials are
1898
+ orthogonal, but not normalized, with respect to this weight function.
1899
+
1900
+ Parameters
1901
+ ----------
1902
+ x : array_like
1903
+ Values at which the weight function will be computed.
1904
+
1905
+ Returns
1906
+ -------
1907
+ w : ndarray
1908
+ The weight function at `x`.
1909
+
1910
+ Notes
1911
+ -----
1912
+
1913
+ .. versionadded:: 1.7.0
1914
+
1915
+ """
1916
+ w = 1./(np.sqrt(1. + x) * np.sqrt(1. - x))
1917
+ return w
1918
+
1919
+
1920
+ def chebpts1(npts):
1921
+ """
1922
+ Chebyshev points of the first kind.
1923
+
1924
+ The Chebyshev points of the first kind are the points ``cos(x)``,
1925
+ where ``x = [pi*(k + .5)/npts for k in range(npts)]``.
1926
+
1927
+ Parameters
1928
+ ----------
1929
+ npts : int
1930
+ Number of sample points desired.
1931
+
1932
+ Returns
1933
+ -------
1934
+ pts : ndarray
1935
+ The Chebyshev points of the first kind.
1936
+
1937
+ See Also
1938
+ --------
1939
+ chebpts2
1940
+
1941
+ Notes
1942
+ -----
1943
+
1944
+ .. versionadded:: 1.5.0
1945
+
1946
+ """
1947
+ _npts = int(npts)
1948
+ if _npts != npts:
1949
+ raise ValueError("npts must be integer")
1950
+ if _npts < 1:
1951
+ raise ValueError("npts must be >= 1")
1952
+
1953
+ x = 0.5 * np.pi / _npts * np.arange(-_npts+1, _npts+1, 2)
1954
+ return np.sin(x)
1955
+
1956
+
1957
+ def chebpts2(npts):
1958
+ """
1959
+ Chebyshev points of the second kind.
1960
+
1961
+ The Chebyshev points of the second kind are the points ``cos(x)``,
1962
+ where ``x = [pi*k/(npts - 1) for k in range(npts)]`` sorted in ascending
1963
+ order.
1964
+
1965
+ Parameters
1966
+ ----------
1967
+ npts : int
1968
+ Number of sample points desired.
1969
+
1970
+ Returns
1971
+ -------
1972
+ pts : ndarray
1973
+ The Chebyshev points of the second kind.
1974
+
1975
+ Notes
1976
+ -----
1977
+
1978
+ .. versionadded:: 1.5.0
1979
+
1980
+ """
1981
+ _npts = int(npts)
1982
+ if _npts != npts:
1983
+ raise ValueError("npts must be integer")
1984
+ if _npts < 2:
1985
+ raise ValueError("npts must be >= 2")
1986
+
1987
+ x = np.linspace(-np.pi, 0, _npts)
1988
+ return np.cos(x)
1989
+
1990
+
1991
+ #
1992
+ # Chebyshev series class
1993
+ #
1994
+
1995
+ class Chebyshev(ABCPolyBase):
1996
+ """A Chebyshev series class.
1997
+
1998
+ The Chebyshev class provides the standard Python numerical methods
1999
+ '+', '-', '*', '//', '%', 'divmod', '**', and '()' as well as the
2000
+ methods listed below.
2001
+
2002
+ Parameters
2003
+ ----------
2004
+ coef : array_like
2005
+ Chebyshev coefficients in order of increasing degree, i.e.,
2006
+ ``(1, 2, 3)`` gives ``1*T_0(x) + 2*T_1(x) + 3*T_2(x)``.
2007
+ domain : (2,) array_like, optional
2008
+ Domain to use. The interval ``[domain[0], domain[1]]`` is mapped
2009
+ to the interval ``[window[0], window[1]]`` by shifting and scaling.
2010
+ The default value is [-1, 1].
2011
+ window : (2,) array_like, optional
2012
+ Window, see `domain` for its use. The default value is [-1, 1].
2013
+
2014
+ .. versionadded:: 1.6.0
2015
+ symbol : str, optional
2016
+ Symbol used to represent the independent variable in string
2017
+ representations of the polynomial expression, e.g. for printing.
2018
+ The symbol must be a valid Python identifier. Default value is 'x'.
2019
+
2020
+ .. versionadded:: 1.24
2021
+
2022
+ """
2023
+ # Virtual Functions
2024
+ _add = staticmethod(chebadd)
2025
+ _sub = staticmethod(chebsub)
2026
+ _mul = staticmethod(chebmul)
2027
+ _div = staticmethod(chebdiv)
2028
+ _pow = staticmethod(chebpow)
2029
+ _val = staticmethod(chebval)
2030
+ _int = staticmethod(chebint)
2031
+ _der = staticmethod(chebder)
2032
+ _fit = staticmethod(chebfit)
2033
+ _line = staticmethod(chebline)
2034
+ _roots = staticmethod(chebroots)
2035
+ _fromroots = staticmethod(chebfromroots)
2036
+
2037
+ @classmethod
2038
+ def interpolate(cls, func, deg, domain=None, args=()):
2039
+ """Interpolate a function at the Chebyshev points of the first kind.
2040
+
2041
+ Returns the series that interpolates `func` at the Chebyshev points of
2042
+ the first kind scaled and shifted to the `domain`. The resulting series
2043
+ tends to a minmax approximation of `func` when the function is
2044
+ continuous in the domain.
2045
+
2046
+ .. versionadded:: 1.14.0
2047
+
2048
+ Parameters
2049
+ ----------
2050
+ func : function
2051
+ The function to be interpolated. It must be a function of a single
2052
+ variable of the form ``f(x, a, b, c...)``, where ``a, b, c...`` are
2053
+ extra arguments passed in the `args` parameter.
2054
+ deg : int
2055
+ Degree of the interpolating polynomial.
2056
+ domain : {None, [beg, end]}, optional
2057
+ Domain over which `func` is interpolated. The default is None, in
2058
+ which case the domain is [-1, 1].
2059
+ args : tuple, optional
2060
+ Extra arguments to be used in the function call. Default is no
2061
+ extra arguments.
2062
+
2063
+ Returns
2064
+ -------
2065
+ polynomial : Chebyshev instance
2066
+ Interpolating Chebyshev instance.
2067
+
2068
+ Notes
2069
+ -----
2070
+ See `numpy.polynomial.chebfromfunction` for more details.
2071
+
2072
+ """
2073
+ if domain is None:
2074
+ domain = cls.domain
2075
+ xfunc = lambda x: func(pu.mapdomain(x, cls.window, domain), *args)
2076
+ coef = chebinterpolate(xfunc, deg)
2077
+ return cls(coef, domain=domain)
2078
+
2079
+ # Virtual properties
2080
+ domain = np.array(chebdomain)
2081
+ window = np.array(chebdomain)
2082
+ basis_name = 'T'
.venv/lib/python3.11/site-packages/numpy/polynomial/hermite.pyi ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from numpy import ndarray, dtype, int_, float_
4
+ from numpy.polynomial._polybase import ABCPolyBase
5
+ from numpy.polynomial.polyutils import trimcoef
6
+
7
+ __all__: list[str]
8
+
9
+ hermtrim = trimcoef
10
+
11
+ def poly2herm(pol): ...
12
+ def herm2poly(c): ...
13
+
14
+ hermdomain: ndarray[Any, dtype[int_]]
15
+ hermzero: ndarray[Any, dtype[int_]]
16
+ hermone: ndarray[Any, dtype[int_]]
17
+ hermx: ndarray[Any, dtype[float_]]
18
+
19
+ def hermline(off, scl): ...
20
+ def hermfromroots(roots): ...
21
+ def hermadd(c1, c2): ...
22
+ def hermsub(c1, c2): ...
23
+ def hermmulx(c): ...
24
+ def hermmul(c1, c2): ...
25
+ def hermdiv(c1, c2): ...
26
+ def hermpow(c, pow, maxpower=...): ...
27
+ def hermder(c, m=..., scl=..., axis=...): ...
28
+ def hermint(c, m=..., k = ..., lbnd=..., scl=..., axis=...): ...
29
+ def hermval(x, c, tensor=...): ...
30
+ def hermval2d(x, y, c): ...
31
+ def hermgrid2d(x, y, c): ...
32
+ def hermval3d(x, y, z, c): ...
33
+ def hermgrid3d(x, y, z, c): ...
34
+ def hermvander(x, deg): ...
35
+ def hermvander2d(x, y, deg): ...
36
+ def hermvander3d(x, y, z, deg): ...
37
+ def hermfit(x, y, deg, rcond=..., full=..., w=...): ...
38
+ def hermcompanion(c): ...
39
+ def hermroots(c): ...
40
+ def hermgauss(deg): ...
41
+ def hermweight(x): ...
42
+
43
+ class Hermite(ABCPolyBase):
44
+ domain: Any
45
+ window: Any
46
+ basis_name: Any
.venv/lib/python3.11/site-packages/numpy/polynomial/polynomial.py ADDED
@@ -0,0 +1,1542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ =================================================
3
+ Power Series (:mod:`numpy.polynomial.polynomial`)
4
+ =================================================
5
+
6
+ This module provides a number of objects (mostly functions) useful for
7
+ dealing with polynomials, including a `Polynomial` class that
8
+ encapsulates the usual arithmetic operations. (General information
9
+ on how this module represents and works with polynomial objects is in
10
+ the docstring for its "parent" sub-package, `numpy.polynomial`).
11
+
12
+ Classes
13
+ -------
14
+ .. autosummary::
15
+ :toctree: generated/
16
+
17
+ Polynomial
18
+
19
+ Constants
20
+ ---------
21
+ .. autosummary::
22
+ :toctree: generated/
23
+
24
+ polydomain
25
+ polyzero
26
+ polyone
27
+ polyx
28
+
29
+ Arithmetic
30
+ ----------
31
+ .. autosummary::
32
+ :toctree: generated/
33
+
34
+ polyadd
35
+ polysub
36
+ polymulx
37
+ polymul
38
+ polydiv
39
+ polypow
40
+ polyval
41
+ polyval2d
42
+ polyval3d
43
+ polygrid2d
44
+ polygrid3d
45
+
46
+ Calculus
47
+ --------
48
+ .. autosummary::
49
+ :toctree: generated/
50
+
51
+ polyder
52
+ polyint
53
+
54
+ Misc Functions
55
+ --------------
56
+ .. autosummary::
57
+ :toctree: generated/
58
+
59
+ polyfromroots
60
+ polyroots
61
+ polyvalfromroots
62
+ polyvander
63
+ polyvander2d
64
+ polyvander3d
65
+ polycompanion
66
+ polyfit
67
+ polytrim
68
+ polyline
69
+
70
+ See Also
71
+ --------
72
+ `numpy.polynomial`
73
+
74
+ """
75
+ __all__ = [
76
+ 'polyzero', 'polyone', 'polyx', 'polydomain', 'polyline', 'polyadd',
77
+ 'polysub', 'polymulx', 'polymul', 'polydiv', 'polypow', 'polyval',
78
+ 'polyvalfromroots', 'polyder', 'polyint', 'polyfromroots', 'polyvander',
79
+ 'polyfit', 'polytrim', 'polyroots', 'Polynomial', 'polyval2d', 'polyval3d',
80
+ 'polygrid2d', 'polygrid3d', 'polyvander2d', 'polyvander3d']
81
+
82
+ import numpy as np
83
+ import numpy.linalg as la
84
+ from numpy.core.multiarray import normalize_axis_index
85
+
86
+ from . import polyutils as pu
87
+ from ._polybase import ABCPolyBase
88
+
89
+ polytrim = pu.trimcoef
90
+
91
+ #
92
+ # These are constant arrays are of integer type so as to be compatible
93
+ # with the widest range of other types, such as Decimal.
94
+ #
95
+
96
+ # Polynomial default domain.
97
+ polydomain = np.array([-1, 1])
98
+
99
+ # Polynomial coefficients representing zero.
100
+ polyzero = np.array([0])
101
+
102
+ # Polynomial coefficients representing one.
103
+ polyone = np.array([1])
104
+
105
+ # Polynomial coefficients representing the identity x.
106
+ polyx = np.array([0, 1])
107
+
108
+ #
109
+ # Polynomial series functions
110
+ #
111
+
112
+
113
+ def polyline(off, scl):
114
+ """
115
+ Returns an array representing a linear polynomial.
116
+
117
+ Parameters
118
+ ----------
119
+ off, scl : scalars
120
+ The "y-intercept" and "slope" of the line, respectively.
121
+
122
+ Returns
123
+ -------
124
+ y : ndarray
125
+ This module's representation of the linear polynomial ``off +
126
+ scl*x``.
127
+
128
+ See Also
129
+ --------
130
+ numpy.polynomial.chebyshev.chebline
131
+ numpy.polynomial.legendre.legline
132
+ numpy.polynomial.laguerre.lagline
133
+ numpy.polynomial.hermite.hermline
134
+ numpy.polynomial.hermite_e.hermeline
135
+
136
+ Examples
137
+ --------
138
+ >>> from numpy.polynomial import polynomial as P
139
+ >>> P.polyline(1,-1)
140
+ array([ 1, -1])
141
+ >>> P.polyval(1, P.polyline(1,-1)) # should be 0
142
+ 0.0
143
+
144
+ """
145
+ if scl != 0:
146
+ return np.array([off, scl])
147
+ else:
148
+ return np.array([off])
149
+
150
+
151
+ def polyfromroots(roots):
152
+ """
153
+ Generate a monic polynomial with given roots.
154
+
155
+ Return the coefficients of the polynomial
156
+
157
+ .. math:: p(x) = (x - r_0) * (x - r_1) * ... * (x - r_n),
158
+
159
+ where the ``r_n`` are the roots specified in `roots`. If a zero has
160
+ multiplicity n, then it must appear in `roots` n times. For instance,
161
+ if 2 is a root of multiplicity three and 3 is a root of multiplicity 2,
162
+ then `roots` looks something like [2, 2, 2, 3, 3]. The roots can appear
163
+ in any order.
164
+
165
+ If the returned coefficients are `c`, then
166
+
167
+ .. math:: p(x) = c_0 + c_1 * x + ... + x^n
168
+
169
+ The coefficient of the last term is 1 for monic polynomials in this
170
+ form.
171
+
172
+ Parameters
173
+ ----------
174
+ roots : array_like
175
+ Sequence containing the roots.
176
+
177
+ Returns
178
+ -------
179
+ out : ndarray
180
+ 1-D array of the polynomial's coefficients If all the roots are
181
+ real, then `out` is also real, otherwise it is complex. (see
182
+ Examples below).
183
+
184
+ See Also
185
+ --------
186
+ numpy.polynomial.chebyshev.chebfromroots
187
+ numpy.polynomial.legendre.legfromroots
188
+ numpy.polynomial.laguerre.lagfromroots
189
+ numpy.polynomial.hermite.hermfromroots
190
+ numpy.polynomial.hermite_e.hermefromroots
191
+
192
+ Notes
193
+ -----
194
+ The coefficients are determined by multiplying together linear factors
195
+ of the form ``(x - r_i)``, i.e.
196
+
197
+ .. math:: p(x) = (x - r_0) (x - r_1) ... (x - r_n)
198
+
199
+ where ``n == len(roots) - 1``; note that this implies that ``1`` is always
200
+ returned for :math:`a_n`.
201
+
202
+ Examples
203
+ --------
204
+ >>> from numpy.polynomial import polynomial as P
205
+ >>> P.polyfromroots((-1,0,1)) # x(x - 1)(x + 1) = x^3 - x
206
+ array([ 0., -1., 0., 1.])
207
+ >>> j = complex(0,1)
208
+ >>> P.polyfromroots((-j,j)) # complex returned, though values are real
209
+ array([1.+0.j, 0.+0.j, 1.+0.j])
210
+
211
+ """
212
+ return pu._fromroots(polyline, polymul, roots)
213
+
214
+
215
+ def polyadd(c1, c2):
216
+ """
217
+ Add one polynomial to another.
218
+
219
+ Returns the sum of two polynomials `c1` + `c2`. The arguments are
220
+ sequences of coefficients from lowest order term to highest, i.e.,
221
+ [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2``.
222
+
223
+ Parameters
224
+ ----------
225
+ c1, c2 : array_like
226
+ 1-D arrays of polynomial coefficients ordered from low to high.
227
+
228
+ Returns
229
+ -------
230
+ out : ndarray
231
+ The coefficient array representing their sum.
232
+
233
+ See Also
234
+ --------
235
+ polysub, polymulx, polymul, polydiv, polypow
236
+
237
+ Examples
238
+ --------
239
+ >>> from numpy.polynomial import polynomial as P
240
+ >>> c1 = (1,2,3)
241
+ >>> c2 = (3,2,1)
242
+ >>> sum = P.polyadd(c1,c2); sum
243
+ array([4., 4., 4.])
244
+ >>> P.polyval(2, sum) # 4 + 4(2) + 4(2**2)
245
+ 28.0
246
+
247
+ """
248
+ return pu._add(c1, c2)
249
+
250
+
251
+ def polysub(c1, c2):
252
+ """
253
+ Subtract one polynomial from another.
254
+
255
+ Returns the difference of two polynomials `c1` - `c2`. The arguments
256
+ are sequences of coefficients from lowest order term to highest, i.e.,
257
+ [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2``.
258
+
259
+ Parameters
260
+ ----------
261
+ c1, c2 : array_like
262
+ 1-D arrays of polynomial coefficients ordered from low to
263
+ high.
264
+
265
+ Returns
266
+ -------
267
+ out : ndarray
268
+ Of coefficients representing their difference.
269
+
270
+ See Also
271
+ --------
272
+ polyadd, polymulx, polymul, polydiv, polypow
273
+
274
+ Examples
275
+ --------
276
+ >>> from numpy.polynomial import polynomial as P
277
+ >>> c1 = (1,2,3)
278
+ >>> c2 = (3,2,1)
279
+ >>> P.polysub(c1,c2)
280
+ array([-2., 0., 2.])
281
+ >>> P.polysub(c2,c1) # -P.polysub(c1,c2)
282
+ array([ 2., 0., -2.])
283
+
284
+ """
285
+ return pu._sub(c1, c2)
286
+
287
+
288
+ def polymulx(c):
289
+ """Multiply a polynomial by x.
290
+
291
+ Multiply the polynomial `c` by x, where x is the independent
292
+ variable.
293
+
294
+
295
+ Parameters
296
+ ----------
297
+ c : array_like
298
+ 1-D array of polynomial coefficients ordered from low to
299
+ high.
300
+
301
+ Returns
302
+ -------
303
+ out : ndarray
304
+ Array representing the result of the multiplication.
305
+
306
+ See Also
307
+ --------
308
+ polyadd, polysub, polymul, polydiv, polypow
309
+
310
+ Notes
311
+ -----
312
+
313
+ .. versionadded:: 1.5.0
314
+
315
+ """
316
+ # c is a trimmed copy
317
+ [c] = pu.as_series([c])
318
+ # The zero series needs special treatment
319
+ if len(c) == 1 and c[0] == 0:
320
+ return c
321
+
322
+ prd = np.empty(len(c) + 1, dtype=c.dtype)
323
+ prd[0] = c[0]*0
324
+ prd[1:] = c
325
+ return prd
326
+
327
+
328
+ def polymul(c1, c2):
329
+ """
330
+ Multiply one polynomial by another.
331
+
332
+ Returns the product of two polynomials `c1` * `c2`. The arguments are
333
+ sequences of coefficients, from lowest order term to highest, e.g.,
334
+ [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2.``
335
+
336
+ Parameters
337
+ ----------
338
+ c1, c2 : array_like
339
+ 1-D arrays of coefficients representing a polynomial, relative to the
340
+ "standard" basis, and ordered from lowest order term to highest.
341
+
342
+ Returns
343
+ -------
344
+ out : ndarray
345
+ Of the coefficients of their product.
346
+
347
+ See Also
348
+ --------
349
+ polyadd, polysub, polymulx, polydiv, polypow
350
+
351
+ Examples
352
+ --------
353
+ >>> from numpy.polynomial import polynomial as P
354
+ >>> c1 = (1,2,3)
355
+ >>> c2 = (3,2,1)
356
+ >>> P.polymul(c1,c2)
357
+ array([ 3., 8., 14., 8., 3.])
358
+
359
+ """
360
+ # c1, c2 are trimmed copies
361
+ [c1, c2] = pu.as_series([c1, c2])
362
+ ret = np.convolve(c1, c2)
363
+ return pu.trimseq(ret)
364
+
365
+
366
+ def polydiv(c1, c2):
367
+ """
368
+ Divide one polynomial by another.
369
+
370
+ Returns the quotient-with-remainder of two polynomials `c1` / `c2`.
371
+ The arguments are sequences of coefficients, from lowest order term
372
+ to highest, e.g., [1,2,3] represents ``1 + 2*x + 3*x**2``.
373
+
374
+ Parameters
375
+ ----------
376
+ c1, c2 : array_like
377
+ 1-D arrays of polynomial coefficients ordered from low to high.
378
+
379
+ Returns
380
+ -------
381
+ [quo, rem] : ndarrays
382
+ Of coefficient series representing the quotient and remainder.
383
+
384
+ See Also
385
+ --------
386
+ polyadd, polysub, polymulx, polymul, polypow
387
+
388
+ Examples
389
+ --------
390
+ >>> from numpy.polynomial import polynomial as P
391
+ >>> c1 = (1,2,3)
392
+ >>> c2 = (3,2,1)
393
+ >>> P.polydiv(c1,c2)
394
+ (array([3.]), array([-8., -4.]))
395
+ >>> P.polydiv(c2,c1)
396
+ (array([ 0.33333333]), array([ 2.66666667, 1.33333333])) # may vary
397
+
398
+ """
399
+ # c1, c2 are trimmed copies
400
+ [c1, c2] = pu.as_series([c1, c2])
401
+ if c2[-1] == 0:
402
+ raise ZeroDivisionError()
403
+
404
+ # note: this is more efficient than `pu._div(polymul, c1, c2)`
405
+ lc1 = len(c1)
406
+ lc2 = len(c2)
407
+ if lc1 < lc2:
408
+ return c1[:1]*0, c1
409
+ elif lc2 == 1:
410
+ return c1/c2[-1], c1[:1]*0
411
+ else:
412
+ dlen = lc1 - lc2
413
+ scl = c2[-1]
414
+ c2 = c2[:-1]/scl
415
+ i = dlen
416
+ j = lc1 - 1
417
+ while i >= 0:
418
+ c1[i:j] -= c2*c1[j]
419
+ i -= 1
420
+ j -= 1
421
+ return c1[j+1:]/scl, pu.trimseq(c1[:j+1])
422
+
423
+
424
+ def polypow(c, pow, maxpower=None):
425
+ """Raise a polynomial to a power.
426
+
427
+ Returns the polynomial `c` raised to the power `pow`. The argument
428
+ `c` is a sequence of coefficients ordered from low to high. i.e.,
429
+ [1,2,3] is the series ``1 + 2*x + 3*x**2.``
430
+
431
+ Parameters
432
+ ----------
433
+ c : array_like
434
+ 1-D array of array of series coefficients ordered from low to
435
+ high degree.
436
+ pow : integer
437
+ Power to which the series will be raised
438
+ maxpower : integer, optional
439
+ Maximum power allowed. This is mainly to limit growth of the series
440
+ to unmanageable size. Default is 16
441
+
442
+ Returns
443
+ -------
444
+ coef : ndarray
445
+ Power series of power.
446
+
447
+ See Also
448
+ --------
449
+ polyadd, polysub, polymulx, polymul, polydiv
450
+
451
+ Examples
452
+ --------
453
+ >>> from numpy.polynomial import polynomial as P
454
+ >>> P.polypow([1,2,3], 2)
455
+ array([ 1., 4., 10., 12., 9.])
456
+
457
+ """
458
+ # note: this is more efficient than `pu._pow(polymul, c1, c2)`, as it
459
+ # avoids calling `as_series` repeatedly
460
+ return pu._pow(np.convolve, c, pow, maxpower)
461
+
462
+
463
+ def polyder(c, m=1, scl=1, axis=0):
464
+ """
465
+ Differentiate a polynomial.
466
+
467
+ Returns the polynomial coefficients `c` differentiated `m` times along
468
+ `axis`. At each iteration the result is multiplied by `scl` (the
469
+ scaling factor is for use in a linear change of variable). The
470
+ argument `c` is an array of coefficients from low to high degree along
471
+ each axis, e.g., [1,2,3] represents the polynomial ``1 + 2*x + 3*x**2``
472
+ while [[1,2],[1,2]] represents ``1 + 1*x + 2*y + 2*x*y`` if axis=0 is
473
+ ``x`` and axis=1 is ``y``.
474
+
475
+ Parameters
476
+ ----------
477
+ c : array_like
478
+ Array of polynomial coefficients. If c is multidimensional the
479
+ different axis correspond to different variables with the degree
480
+ in each axis given by the corresponding index.
481
+ m : int, optional
482
+ Number of derivatives taken, must be non-negative. (Default: 1)
483
+ scl : scalar, optional
484
+ Each differentiation is multiplied by `scl`. The end result is
485
+ multiplication by ``scl**m``. This is for use in a linear change
486
+ of variable. (Default: 1)
487
+ axis : int, optional
488
+ Axis over which the derivative is taken. (Default: 0).
489
+
490
+ .. versionadded:: 1.7.0
491
+
492
+ Returns
493
+ -------
494
+ der : ndarray
495
+ Polynomial coefficients of the derivative.
496
+
497
+ See Also
498
+ --------
499
+ polyint
500
+
501
+ Examples
502
+ --------
503
+ >>> from numpy.polynomial import polynomial as P
504
+ >>> c = (1,2,3,4) # 1 + 2x + 3x**2 + 4x**3
505
+ >>> P.polyder(c) # (d/dx)(c) = 2 + 6x + 12x**2
506
+ array([ 2., 6., 12.])
507
+ >>> P.polyder(c,3) # (d**3/dx**3)(c) = 24
508
+ array([24.])
509
+ >>> P.polyder(c,scl=-1) # (d/d(-x))(c) = -2 - 6x - 12x**2
510
+ array([ -2., -6., -12.])
511
+ >>> P.polyder(c,2,-1) # (d**2/d(-x)**2)(c) = 6 + 24x
512
+ array([ 6., 24.])
513
+
514
+ """
515
+ c = np.array(c, ndmin=1, copy=True)
516
+ if c.dtype.char in '?bBhHiIlLqQpP':
517
+ # astype fails with NA
518
+ c = c + 0.0
519
+ cdt = c.dtype
520
+ cnt = pu._deprecate_as_int(m, "the order of derivation")
521
+ iaxis = pu._deprecate_as_int(axis, "the axis")
522
+ if cnt < 0:
523
+ raise ValueError("The order of derivation must be non-negative")
524
+ iaxis = normalize_axis_index(iaxis, c.ndim)
525
+
526
+ if cnt == 0:
527
+ return c
528
+
529
+ c = np.moveaxis(c, iaxis, 0)
530
+ n = len(c)
531
+ if cnt >= n:
532
+ c = c[:1]*0
533
+ else:
534
+ for i in range(cnt):
535
+ n = n - 1
536
+ c *= scl
537
+ der = np.empty((n,) + c.shape[1:], dtype=cdt)
538
+ for j in range(n, 0, -1):
539
+ der[j - 1] = j*c[j]
540
+ c = der
541
+ c = np.moveaxis(c, 0, iaxis)
542
+ return c
543
+
544
+
545
+ def polyint(c, m=1, k=[], lbnd=0, scl=1, axis=0):
546
+ """
547
+ Integrate a polynomial.
548
+
549
+ Returns the polynomial coefficients `c` integrated `m` times from
550
+ `lbnd` along `axis`. At each iteration the resulting series is
551
+ **multiplied** by `scl` and an integration constant, `k`, is added.
552
+ The scaling factor is for use in a linear change of variable. ("Buyer
553
+ beware": note that, depending on what one is doing, one may want `scl`
554
+ to be the reciprocal of what one might expect; for more information,
555
+ see the Notes section below.) The argument `c` is an array of
556
+ coefficients, from low to high degree along each axis, e.g., [1,2,3]
557
+ represents the polynomial ``1 + 2*x + 3*x**2`` while [[1,2],[1,2]]
558
+ represents ``1 + 1*x + 2*y + 2*x*y`` if axis=0 is ``x`` and axis=1 is
559
+ ``y``.
560
+
561
+ Parameters
562
+ ----------
563
+ c : array_like
564
+ 1-D array of polynomial coefficients, ordered from low to high.
565
+ m : int, optional
566
+ Order of integration, must be positive. (Default: 1)
567
+ k : {[], list, scalar}, optional
568
+ Integration constant(s). The value of the first integral at zero
569
+ is the first value in the list, the value of the second integral
570
+ at zero is the second value, etc. If ``k == []`` (the default),
571
+ all constants are set to zero. If ``m == 1``, a single scalar can
572
+ be given instead of a list.
573
+ lbnd : scalar, optional
574
+ The lower bound of the integral. (Default: 0)
575
+ scl : scalar, optional
576
+ Following each integration the result is *multiplied* by `scl`
577
+ before the integration constant is added. (Default: 1)
578
+ axis : int, optional
579
+ Axis over which the integral is taken. (Default: 0).
580
+
581
+ .. versionadded:: 1.7.0
582
+
583
+ Returns
584
+ -------
585
+ S : ndarray
586
+ Coefficient array of the integral.
587
+
588
+ Raises
589
+ ------
590
+ ValueError
591
+ If ``m < 1``, ``len(k) > m``, ``np.ndim(lbnd) != 0``, or
592
+ ``np.ndim(scl) != 0``.
593
+
594
+ See Also
595
+ --------
596
+ polyder
597
+
598
+ Notes
599
+ -----
600
+ Note that the result of each integration is *multiplied* by `scl`. Why
601
+ is this important to note? Say one is making a linear change of
602
+ variable :math:`u = ax + b` in an integral relative to `x`. Then
603
+ :math:`dx = du/a`, so one will need to set `scl` equal to
604
+ :math:`1/a` - perhaps not what one would have first thought.
605
+
606
+ Examples
607
+ --------
608
+ >>> from numpy.polynomial import polynomial as P
609
+ >>> c = (1,2,3)
610
+ >>> P.polyint(c) # should return array([0, 1, 1, 1])
611
+ array([0., 1., 1., 1.])
612
+ >>> P.polyint(c,3) # should return array([0, 0, 0, 1/6, 1/12, 1/20])
613
+ array([ 0. , 0. , 0. , 0.16666667, 0.08333333, # may vary
614
+ 0.05 ])
615
+ >>> P.polyint(c,k=3) # should return array([3, 1, 1, 1])
616
+ array([3., 1., 1., 1.])
617
+ >>> P.polyint(c,lbnd=-2) # should return array([6, 1, 1, 1])
618
+ array([6., 1., 1., 1.])
619
+ >>> P.polyint(c,scl=-2) # should return array([0, -2, -2, -2])
620
+ array([ 0., -2., -2., -2.])
621
+
622
+ """
623
+ c = np.array(c, ndmin=1, copy=True)
624
+ if c.dtype.char in '?bBhHiIlLqQpP':
625
+ # astype doesn't preserve mask attribute.
626
+ c = c + 0.0
627
+ cdt = c.dtype
628
+ if not np.iterable(k):
629
+ k = [k]
630
+ cnt = pu._deprecate_as_int(m, "the order of integration")
631
+ iaxis = pu._deprecate_as_int(axis, "the axis")
632
+ if cnt < 0:
633
+ raise ValueError("The order of integration must be non-negative")
634
+ if len(k) > cnt:
635
+ raise ValueError("Too many integration constants")
636
+ if np.ndim(lbnd) != 0:
637
+ raise ValueError("lbnd must be a scalar.")
638
+ if np.ndim(scl) != 0:
639
+ raise ValueError("scl must be a scalar.")
640
+ iaxis = normalize_axis_index(iaxis, c.ndim)
641
+
642
+ if cnt == 0:
643
+ return c
644
+
645
+ k = list(k) + [0]*(cnt - len(k))
646
+ c = np.moveaxis(c, iaxis, 0)
647
+ for i in range(cnt):
648
+ n = len(c)
649
+ c *= scl
650
+ if n == 1 and np.all(c[0] == 0):
651
+ c[0] += k[i]
652
+ else:
653
+ tmp = np.empty((n + 1,) + c.shape[1:], dtype=cdt)
654
+ tmp[0] = c[0]*0
655
+ tmp[1] = c[0]
656
+ for j in range(1, n):
657
+ tmp[j + 1] = c[j]/(j + 1)
658
+ tmp[0] += k[i] - polyval(lbnd, tmp)
659
+ c = tmp
660
+ c = np.moveaxis(c, 0, iaxis)
661
+ return c
662
+
663
+
664
+ def polyval(x, c, tensor=True):
665
+ """
666
+ Evaluate a polynomial at points x.
667
+
668
+ If `c` is of length `n + 1`, this function returns the value
669
+
670
+ .. math:: p(x) = c_0 + c_1 * x + ... + c_n * x^n
671
+
672
+ The parameter `x` is converted to an array only if it is a tuple or a
673
+ list, otherwise it is treated as a scalar. In either case, either `x`
674
+ or its elements must support multiplication and addition both with
675
+ themselves and with the elements of `c`.
676
+
677
+ If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
678
+ `c` is multidimensional, then the shape of the result depends on the
679
+ value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
680
+ x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
681
+ scalars have shape (,).
682
+
683
+ Trailing zeros in the coefficients will be used in the evaluation, so
684
+ they should be avoided if efficiency is a concern.
685
+
686
+ Parameters
687
+ ----------
688
+ x : array_like, compatible object
689
+ If `x` is a list or tuple, it is converted to an ndarray, otherwise
690
+ it is left unchanged and treated as a scalar. In either case, `x`
691
+ or its elements must support addition and multiplication with
692
+ with themselves and with the elements of `c`.
693
+ c : array_like
694
+ Array of coefficients ordered so that the coefficients for terms of
695
+ degree n are contained in c[n]. If `c` is multidimensional the
696
+ remaining indices enumerate multiple polynomials. In the two
697
+ dimensional case the coefficients may be thought of as stored in
698
+ the columns of `c`.
699
+ tensor : boolean, optional
700
+ If True, the shape of the coefficient array is extended with ones
701
+ on the right, one for each dimension of `x`. Scalars have dimension 0
702
+ for this action. The result is that every column of coefficients in
703
+ `c` is evaluated for every element of `x`. If False, `x` is broadcast
704
+ over the columns of `c` for the evaluation. This keyword is useful
705
+ when `c` is multidimensional. The default value is True.
706
+
707
+ .. versionadded:: 1.7.0
708
+
709
+ Returns
710
+ -------
711
+ values : ndarray, compatible object
712
+ The shape of the returned array is described above.
713
+
714
+ See Also
715
+ --------
716
+ polyval2d, polygrid2d, polyval3d, polygrid3d
717
+
718
+ Notes
719
+ -----
720
+ The evaluation uses Horner's method.
721
+
722
+ Examples
723
+ --------
724
+ >>> from numpy.polynomial.polynomial import polyval
725
+ >>> polyval(1, [1,2,3])
726
+ 6.0
727
+ >>> a = np.arange(4).reshape(2,2)
728
+ >>> a
729
+ array([[0, 1],
730
+ [2, 3]])
731
+ >>> polyval(a, [1,2,3])
732
+ array([[ 1., 6.],
733
+ [17., 34.]])
734
+ >>> coef = np.arange(4).reshape(2,2) # multidimensional coefficients
735
+ >>> coef
736
+ array([[0, 1],
737
+ [2, 3]])
738
+ >>> polyval([1,2], coef, tensor=True)
739
+ array([[2., 4.],
740
+ [4., 7.]])
741
+ >>> polyval([1,2], coef, tensor=False)
742
+ array([2., 7.])
743
+
744
+ """
745
+ c = np.array(c, ndmin=1, copy=False)
746
+ if c.dtype.char in '?bBhHiIlLqQpP':
747
+ # astype fails with NA
748
+ c = c + 0.0
749
+ if isinstance(x, (tuple, list)):
750
+ x = np.asarray(x)
751
+ if isinstance(x, np.ndarray) and tensor:
752
+ c = c.reshape(c.shape + (1,)*x.ndim)
753
+
754
+ c0 = c[-1] + x*0
755
+ for i in range(2, len(c) + 1):
756
+ c0 = c[-i] + c0*x
757
+ return c0
758
+
759
+
760
+ def polyvalfromroots(x, r, tensor=True):
761
+ """
762
+ Evaluate a polynomial specified by its roots at points x.
763
+
764
+ If `r` is of length `N`, this function returns the value
765
+
766
+ .. math:: p(x) = \\prod_{n=1}^{N} (x - r_n)
767
+
768
+ The parameter `x` is converted to an array only if it is a tuple or a
769
+ list, otherwise it is treated as a scalar. In either case, either `x`
770
+ or its elements must support multiplication and addition both with
771
+ themselves and with the elements of `r`.
772
+
773
+ If `r` is a 1-D array, then `p(x)` will have the same shape as `x`. If `r`
774
+ is multidimensional, then the shape of the result depends on the value of
775
+ `tensor`. If `tensor` is ``True`` the shape will be r.shape[1:] + x.shape;
776
+ that is, each polynomial is evaluated at every value of `x`. If `tensor` is
777
+ ``False``, the shape will be r.shape[1:]; that is, each polynomial is
778
+ evaluated only for the corresponding broadcast value of `x`. Note that
779
+ scalars have shape (,).
780
+
781
+ .. versionadded:: 1.12
782
+
783
+ Parameters
784
+ ----------
785
+ x : array_like, compatible object
786
+ If `x` is a list or tuple, it is converted to an ndarray, otherwise
787
+ it is left unchanged and treated as a scalar. In either case, `x`
788
+ or its elements must support addition and multiplication with
789
+ with themselves and with the elements of `r`.
790
+ r : array_like
791
+ Array of roots. If `r` is multidimensional the first index is the
792
+ root index, while the remaining indices enumerate multiple
793
+ polynomials. For instance, in the two dimensional case the roots
794
+ of each polynomial may be thought of as stored in the columns of `r`.
795
+ tensor : boolean, optional
796
+ If True, the shape of the roots array is extended with ones on the
797
+ right, one for each dimension of `x`. Scalars have dimension 0 for this
798
+ action. The result is that every column of coefficients in `r` is
799
+ evaluated for every element of `x`. If False, `x` is broadcast over the
800
+ columns of `r` for the evaluation. This keyword is useful when `r` is
801
+ multidimensional. The default value is True.
802
+
803
+ Returns
804
+ -------
805
+ values : ndarray, compatible object
806
+ The shape of the returned array is described above.
807
+
808
+ See Also
809
+ --------
810
+ polyroots, polyfromroots, polyval
811
+
812
+ Examples
813
+ --------
814
+ >>> from numpy.polynomial.polynomial import polyvalfromroots
815
+ >>> polyvalfromroots(1, [1,2,3])
816
+ 0.0
817
+ >>> a = np.arange(4).reshape(2,2)
818
+ >>> a
819
+ array([[0, 1],
820
+ [2, 3]])
821
+ >>> polyvalfromroots(a, [-1, 0, 1])
822
+ array([[-0., 0.],
823
+ [ 6., 24.]])
824
+ >>> r = np.arange(-2, 2).reshape(2,2) # multidimensional coefficients
825
+ >>> r # each column of r defines one polynomial
826
+ array([[-2, -1],
827
+ [ 0, 1]])
828
+ >>> b = [-2, 1]
829
+ >>> polyvalfromroots(b, r, tensor=True)
830
+ array([[-0., 3.],
831
+ [ 3., 0.]])
832
+ >>> polyvalfromroots(b, r, tensor=False)
833
+ array([-0., 0.])
834
+ """
835
+ r = np.array(r, ndmin=1, copy=False)
836
+ if r.dtype.char in '?bBhHiIlLqQpP':
837
+ r = r.astype(np.double)
838
+ if isinstance(x, (tuple, list)):
839
+ x = np.asarray(x)
840
+ if isinstance(x, np.ndarray):
841
+ if tensor:
842
+ r = r.reshape(r.shape + (1,)*x.ndim)
843
+ elif x.ndim >= r.ndim:
844
+ raise ValueError("x.ndim must be < r.ndim when tensor == False")
845
+ return np.prod(x - r, axis=0)
846
+
847
+
848
+ def polyval2d(x, y, c):
849
+ """
850
+ Evaluate a 2-D polynomial at points (x, y).
851
+
852
+ This function returns the value
853
+
854
+ .. math:: p(x,y) = \\sum_{i,j} c_{i,j} * x^i * y^j
855
+
856
+ The parameters `x` and `y` are converted to arrays only if they are
857
+ tuples or a lists, otherwise they are treated as a scalars and they
858
+ must have the same shape after conversion. In either case, either `x`
859
+ and `y` or their elements must support multiplication and addition both
860
+ with themselves and with the elements of `c`.
861
+
862
+ If `c` has fewer than two dimensions, ones are implicitly appended to
863
+ its shape to make it 2-D. The shape of the result will be c.shape[2:] +
864
+ x.shape.
865
+
866
+ Parameters
867
+ ----------
868
+ x, y : array_like, compatible objects
869
+ The two dimensional series is evaluated at the points `(x, y)`,
870
+ where `x` and `y` must have the same shape. If `x` or `y` is a list
871
+ or tuple, it is first converted to an ndarray, otherwise it is left
872
+ unchanged and, if it isn't an ndarray, it is treated as a scalar.
873
+ c : array_like
874
+ Array of coefficients ordered so that the coefficient of the term
875
+ of multi-degree i,j is contained in `c[i,j]`. If `c` has
876
+ dimension greater than two the remaining indices enumerate multiple
877
+ sets of coefficients.
878
+
879
+ Returns
880
+ -------
881
+ values : ndarray, compatible object
882
+ The values of the two dimensional polynomial at points formed with
883
+ pairs of corresponding values from `x` and `y`.
884
+
885
+ See Also
886
+ --------
887
+ polyval, polygrid2d, polyval3d, polygrid3d
888
+
889
+ Notes
890
+ -----
891
+
892
+ .. versionadded:: 1.7.0
893
+
894
+ """
895
+ return pu._valnd(polyval, c, x, y)
896
+
897
+
898
+ def polygrid2d(x, y, c):
899
+ """
900
+ Evaluate a 2-D polynomial on the Cartesian product of x and y.
901
+
902
+ This function returns the values:
903
+
904
+ .. math:: p(a,b) = \\sum_{i,j} c_{i,j} * a^i * b^j
905
+
906
+ where the points `(a, b)` consist of all pairs formed by taking
907
+ `a` from `x` and `b` from `y`. The resulting points form a grid with
908
+ `x` in the first dimension and `y` in the second.
909
+
910
+ The parameters `x` and `y` are converted to arrays only if they are
911
+ tuples or a lists, otherwise they are treated as a scalars. In either
912
+ case, either `x` and `y` or their elements must support multiplication
913
+ and addition both with themselves and with the elements of `c`.
914
+
915
+ If `c` has fewer than two dimensions, ones are implicitly appended to
916
+ its shape to make it 2-D. The shape of the result will be c.shape[2:] +
917
+ x.shape + y.shape.
918
+
919
+ Parameters
920
+ ----------
921
+ x, y : array_like, compatible objects
922
+ The two dimensional series is evaluated at the points in the
923
+ Cartesian product of `x` and `y`. If `x` or `y` is a list or
924
+ tuple, it is first converted to an ndarray, otherwise it is left
925
+ unchanged and, if it isn't an ndarray, it is treated as a scalar.
926
+ c : array_like
927
+ Array of coefficients ordered so that the coefficients for terms of
928
+ degree i,j are contained in ``c[i,j]``. If `c` has dimension
929
+ greater than two the remaining indices enumerate multiple sets of
930
+ coefficients.
931
+
932
+ Returns
933
+ -------
934
+ values : ndarray, compatible object
935
+ The values of the two dimensional polynomial at points in the Cartesian
936
+ product of `x` and `y`.
937
+
938
+ See Also
939
+ --------
940
+ polyval, polyval2d, polyval3d, polygrid3d
941
+
942
+ Notes
943
+ -----
944
+
945
+ .. versionadded:: 1.7.0
946
+
947
+ """
948
+ return pu._gridnd(polyval, c, x, y)
949
+
950
+
951
+ def polyval3d(x, y, z, c):
952
+ """
953
+ Evaluate a 3-D polynomial at points (x, y, z).
954
+
955
+ This function returns the values:
956
+
957
+ .. math:: p(x,y,z) = \\sum_{i,j,k} c_{i,j,k} * x^i * y^j * z^k
958
+
959
+ The parameters `x`, `y`, and `z` are converted to arrays only if
960
+ they are tuples or a lists, otherwise they are treated as a scalars and
961
+ they must have the same shape after conversion. In either case, either
962
+ `x`, `y`, and `z` or their elements must support multiplication and
963
+ addition both with themselves and with the elements of `c`.
964
+
965
+ If `c` has fewer than 3 dimensions, ones are implicitly appended to its
966
+ shape to make it 3-D. The shape of the result will be c.shape[3:] +
967
+ x.shape.
968
+
969
+ Parameters
970
+ ----------
971
+ x, y, z : array_like, compatible object
972
+ The three dimensional series is evaluated at the points
973
+ `(x, y, z)`, where `x`, `y`, and `z` must have the same shape. If
974
+ any of `x`, `y`, or `z` is a list or tuple, it is first converted
975
+ to an ndarray, otherwise it is left unchanged and if it isn't an
976
+ ndarray it is treated as a scalar.
977
+ c : array_like
978
+ Array of coefficients ordered so that the coefficient of the term of
979
+ multi-degree i,j,k is contained in ``c[i,j,k]``. If `c` has dimension
980
+ greater than 3 the remaining indices enumerate multiple sets of
981
+ coefficients.
982
+
983
+ Returns
984
+ -------
985
+ values : ndarray, compatible object
986
+ The values of the multidimensional polynomial on points formed with
987
+ triples of corresponding values from `x`, `y`, and `z`.
988
+
989
+ See Also
990
+ --------
991
+ polyval, polyval2d, polygrid2d, polygrid3d
992
+
993
+ Notes
994
+ -----
995
+
996
+ .. versionadded:: 1.7.0
997
+
998
+ """
999
+ return pu._valnd(polyval, c, x, y, z)
1000
+
1001
+
1002
+ def polygrid3d(x, y, z, c):
1003
+ """
1004
+ Evaluate a 3-D polynomial on the Cartesian product of x, y and z.
1005
+
1006
+ This function returns the values:
1007
+
1008
+ .. math:: p(a,b,c) = \\sum_{i,j,k} c_{i,j,k} * a^i * b^j * c^k
1009
+
1010
+ where the points `(a, b, c)` consist of all triples formed by taking
1011
+ `a` from `x`, `b` from `y`, and `c` from `z`. The resulting points form
1012
+ a grid with `x` in the first dimension, `y` in the second, and `z` in
1013
+ the third.
1014
+
1015
+ The parameters `x`, `y`, and `z` are converted to arrays only if they
1016
+ are tuples or a lists, otherwise they are treated as a scalars. In
1017
+ either case, either `x`, `y`, and `z` or their elements must support
1018
+ multiplication and addition both with themselves and with the elements
1019
+ of `c`.
1020
+
1021
+ If `c` has fewer than three dimensions, ones are implicitly appended to
1022
+ its shape to make it 3-D. The shape of the result will be c.shape[3:] +
1023
+ x.shape + y.shape + z.shape.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ x, y, z : array_like, compatible objects
1028
+ The three dimensional series is evaluated at the points in the
1029
+ Cartesian product of `x`, `y`, and `z`. If `x`,`y`, or `z` is a
1030
+ list or tuple, it is first converted to an ndarray, otherwise it is
1031
+ left unchanged and, if it isn't an ndarray, it is treated as a
1032
+ scalar.
1033
+ c : array_like
1034
+ Array of coefficients ordered so that the coefficients for terms of
1035
+ degree i,j are contained in ``c[i,j]``. If `c` has dimension
1036
+ greater than two the remaining indices enumerate multiple sets of
1037
+ coefficients.
1038
+
1039
+ Returns
1040
+ -------
1041
+ values : ndarray, compatible object
1042
+ The values of the two dimensional polynomial at points in the Cartesian
1043
+ product of `x` and `y`.
1044
+
1045
+ See Also
1046
+ --------
1047
+ polyval, polyval2d, polygrid2d, polyval3d
1048
+
1049
+ Notes
1050
+ -----
1051
+
1052
+ .. versionadded:: 1.7.0
1053
+
1054
+ """
1055
+ return pu._gridnd(polyval, c, x, y, z)
1056
+
1057
+
1058
+ def polyvander(x, deg):
1059
+ """Vandermonde matrix of given degree.
1060
+
1061
+ Returns the Vandermonde matrix of degree `deg` and sample points
1062
+ `x`. The Vandermonde matrix is defined by
1063
+
1064
+ .. math:: V[..., i] = x^i,
1065
+
1066
+ where `0 <= i <= deg`. The leading indices of `V` index the elements of
1067
+ `x` and the last index is the power of `x`.
1068
+
1069
+ If `c` is a 1-D array of coefficients of length `n + 1` and `V` is the
1070
+ matrix ``V = polyvander(x, n)``, then ``np.dot(V, c)`` and
1071
+ ``polyval(x, c)`` are the same up to roundoff. This equivalence is
1072
+ useful both for least squares fitting and for the evaluation of a large
1073
+ number of polynomials of the same degree and sample points.
1074
+
1075
+ Parameters
1076
+ ----------
1077
+ x : array_like
1078
+ Array of points. The dtype is converted to float64 or complex128
1079
+ depending on whether any of the elements are complex. If `x` is
1080
+ scalar it is converted to a 1-D array.
1081
+ deg : int
1082
+ Degree of the resulting matrix.
1083
+
1084
+ Returns
1085
+ -------
1086
+ vander : ndarray.
1087
+ The Vandermonde matrix. The shape of the returned matrix is
1088
+ ``x.shape + (deg + 1,)``, where the last index is the power of `x`.
1089
+ The dtype will be the same as the converted `x`.
1090
+
1091
+ See Also
1092
+ --------
1093
+ polyvander2d, polyvander3d
1094
+
1095
+ """
1096
+ ideg = pu._deprecate_as_int(deg, "deg")
1097
+ if ideg < 0:
1098
+ raise ValueError("deg must be non-negative")
1099
+
1100
+ x = np.array(x, copy=False, ndmin=1) + 0.0
1101
+ dims = (ideg + 1,) + x.shape
1102
+ dtyp = x.dtype
1103
+ v = np.empty(dims, dtype=dtyp)
1104
+ v[0] = x*0 + 1
1105
+ if ideg > 0:
1106
+ v[1] = x
1107
+ for i in range(2, ideg + 1):
1108
+ v[i] = v[i-1]*x
1109
+ return np.moveaxis(v, 0, -1)
1110
+
1111
+
1112
+ def polyvander2d(x, y, deg):
1113
+ """Pseudo-Vandermonde matrix of given degrees.
1114
+
1115
+ Returns the pseudo-Vandermonde matrix of degrees `deg` and sample
1116
+ points `(x, y)`. The pseudo-Vandermonde matrix is defined by
1117
+
1118
+ .. math:: V[..., (deg[1] + 1)*i + j] = x^i * y^j,
1119
+
1120
+ where `0 <= i <= deg[0]` and `0 <= j <= deg[1]`. The leading indices of
1121
+ `V` index the points `(x, y)` and the last index encodes the powers of
1122
+ `x` and `y`.
1123
+
1124
+ If ``V = polyvander2d(x, y, [xdeg, ydeg])``, then the columns of `V`
1125
+ correspond to the elements of a 2-D coefficient array `c` of shape
1126
+ (xdeg + 1, ydeg + 1) in the order
1127
+
1128
+ .. math:: c_{00}, c_{01}, c_{02} ... , c_{10}, c_{11}, c_{12} ...
1129
+
1130
+ and ``np.dot(V, c.flat)`` and ``polyval2d(x, y, c)`` will be the same
1131
+ up to roundoff. This equivalence is useful both for least squares
1132
+ fitting and for the evaluation of a large number of 2-D polynomials
1133
+ of the same degrees and sample points.
1134
+
1135
+ Parameters
1136
+ ----------
1137
+ x, y : array_like
1138
+ Arrays of point coordinates, all of the same shape. The dtypes
1139
+ will be converted to either float64 or complex128 depending on
1140
+ whether any of the elements are complex. Scalars are converted to
1141
+ 1-D arrays.
1142
+ deg : list of ints
1143
+ List of maximum degrees of the form [x_deg, y_deg].
1144
+
1145
+ Returns
1146
+ -------
1147
+ vander2d : ndarray
1148
+ The shape of the returned matrix is ``x.shape + (order,)``, where
1149
+ :math:`order = (deg[0]+1)*(deg([1]+1)`. The dtype will be the same
1150
+ as the converted `x` and `y`.
1151
+
1152
+ See Also
1153
+ --------
1154
+ polyvander, polyvander3d, polyval2d, polyval3d
1155
+
1156
+ """
1157
+ return pu._vander_nd_flat((polyvander, polyvander), (x, y), deg)
1158
+
1159
+
1160
+ def polyvander3d(x, y, z, deg):
1161
+ """Pseudo-Vandermonde matrix of given degrees.
1162
+
1163
+ Returns the pseudo-Vandermonde matrix of degrees `deg` and sample
1164
+ points `(x, y, z)`. If `l, m, n` are the given degrees in `x, y, z`,
1165
+ then The pseudo-Vandermonde matrix is defined by
1166
+
1167
+ .. math:: V[..., (m+1)(n+1)i + (n+1)j + k] = x^i * y^j * z^k,
1168
+
1169
+ where `0 <= i <= l`, `0 <= j <= m`, and `0 <= j <= n`. The leading
1170
+ indices of `V` index the points `(x, y, z)` and the last index encodes
1171
+ the powers of `x`, `y`, and `z`.
1172
+
1173
+ If ``V = polyvander3d(x, y, z, [xdeg, ydeg, zdeg])``, then the columns
1174
+ of `V` correspond to the elements of a 3-D coefficient array `c` of
1175
+ shape (xdeg + 1, ydeg + 1, zdeg + 1) in the order
1176
+
1177
+ .. math:: c_{000}, c_{001}, c_{002},... , c_{010}, c_{011}, c_{012},...
1178
+
1179
+ and ``np.dot(V, c.flat)`` and ``polyval3d(x, y, z, c)`` will be the
1180
+ same up to roundoff. This equivalence is useful both for least squares
1181
+ fitting and for the evaluation of a large number of 3-D polynomials
1182
+ of the same degrees and sample points.
1183
+
1184
+ Parameters
1185
+ ----------
1186
+ x, y, z : array_like
1187
+ Arrays of point coordinates, all of the same shape. The dtypes will
1188
+ be converted to either float64 or complex128 depending on whether
1189
+ any of the elements are complex. Scalars are converted to 1-D
1190
+ arrays.
1191
+ deg : list of ints
1192
+ List of maximum degrees of the form [x_deg, y_deg, z_deg].
1193
+
1194
+ Returns
1195
+ -------
1196
+ vander3d : ndarray
1197
+ The shape of the returned matrix is ``x.shape + (order,)``, where
1198
+ :math:`order = (deg[0]+1)*(deg([1]+1)*(deg[2]+1)`. The dtype will
1199
+ be the same as the converted `x`, `y`, and `z`.
1200
+
1201
+ See Also
1202
+ --------
1203
+ polyvander, polyvander3d, polyval2d, polyval3d
1204
+
1205
+ Notes
1206
+ -----
1207
+
1208
+ .. versionadded:: 1.7.0
1209
+
1210
+ """
1211
+ return pu._vander_nd_flat((polyvander, polyvander, polyvander), (x, y, z), deg)
1212
+
1213
+
1214
+ def polyfit(x, y, deg, rcond=None, full=False, w=None):
1215
+ """
1216
+ Least-squares fit of a polynomial to data.
1217
+
1218
+ Return the coefficients of a polynomial of degree `deg` that is the
1219
+ least squares fit to the data values `y` given at points `x`. If `y` is
1220
+ 1-D the returned coefficients will also be 1-D. If `y` is 2-D multiple
1221
+ fits are done, one for each column of `y`, and the resulting
1222
+ coefficients are stored in the corresponding columns of a 2-D return.
1223
+ The fitted polynomial(s) are in the form
1224
+
1225
+ .. math:: p(x) = c_0 + c_1 * x + ... + c_n * x^n,
1226
+
1227
+ where `n` is `deg`.
1228
+
1229
+ Parameters
1230
+ ----------
1231
+ x : array_like, shape (`M`,)
1232
+ x-coordinates of the `M` sample (data) points ``(x[i], y[i])``.
1233
+ y : array_like, shape (`M`,) or (`M`, `K`)
1234
+ y-coordinates of the sample points. Several sets of sample points
1235
+ sharing the same x-coordinates can be (independently) fit with one
1236
+ call to `polyfit` by passing in for `y` a 2-D array that contains
1237
+ one data set per column.
1238
+ deg : int or 1-D array_like
1239
+ Degree(s) of the fitting polynomials. If `deg` is a single integer
1240
+ all terms up to and including the `deg`'th term are included in the
1241
+ fit. For NumPy versions >= 1.11.0 a list of integers specifying the
1242
+ degrees of the terms to include may be used instead.
1243
+ rcond : float, optional
1244
+ Relative condition number of the fit. Singular values smaller
1245
+ than `rcond`, relative to the largest singular value, will be
1246
+ ignored. The default value is ``len(x)*eps``, where `eps` is the
1247
+ relative precision of the platform's float type, about 2e-16 in
1248
+ most cases.
1249
+ full : bool, optional
1250
+ Switch determining the nature of the return value. When ``False``
1251
+ (the default) just the coefficients are returned; when ``True``,
1252
+ diagnostic information from the singular value decomposition (used
1253
+ to solve the fit's matrix equation) is also returned.
1254
+ w : array_like, shape (`M`,), optional
1255
+ Weights. If not None, the weight ``w[i]`` applies to the unsquared
1256
+ residual ``y[i] - y_hat[i]`` at ``x[i]``. Ideally the weights are
1257
+ chosen so that the errors of the products ``w[i]*y[i]`` all have the
1258
+ same variance. When using inverse-variance weighting, use
1259
+ ``w[i] = 1/sigma(y[i])``. The default value is None.
1260
+
1261
+ .. versionadded:: 1.5.0
1262
+
1263
+ Returns
1264
+ -------
1265
+ coef : ndarray, shape (`deg` + 1,) or (`deg` + 1, `K`)
1266
+ Polynomial coefficients ordered from low to high. If `y` was 2-D,
1267
+ the coefficients in column `k` of `coef` represent the polynomial
1268
+ fit to the data in `y`'s `k`-th column.
1269
+
1270
+ [residuals, rank, singular_values, rcond] : list
1271
+ These values are only returned if ``full == True``
1272
+
1273
+ - residuals -- sum of squared residuals of the least squares fit
1274
+ - rank -- the numerical rank of the scaled Vandermonde matrix
1275
+ - singular_values -- singular values of the scaled Vandermonde matrix
1276
+ - rcond -- value of `rcond`.
1277
+
1278
+ For more details, see `numpy.linalg.lstsq`.
1279
+
1280
+ Raises
1281
+ ------
1282
+ RankWarning
1283
+ Raised if the matrix in the least-squares fit is rank deficient.
1284
+ The warning is only raised if ``full == False``. The warnings can
1285
+ be turned off by:
1286
+
1287
+ >>> import warnings
1288
+ >>> warnings.simplefilter('ignore', np.RankWarning)
1289
+
1290
+ See Also
1291
+ --------
1292
+ numpy.polynomial.chebyshev.chebfit
1293
+ numpy.polynomial.legendre.legfit
1294
+ numpy.polynomial.laguerre.lagfit
1295
+ numpy.polynomial.hermite.hermfit
1296
+ numpy.polynomial.hermite_e.hermefit
1297
+ polyval : Evaluates a polynomial.
1298
+ polyvander : Vandermonde matrix for powers.
1299
+ numpy.linalg.lstsq : Computes a least-squares fit from the matrix.
1300
+ scipy.interpolate.UnivariateSpline : Computes spline fits.
1301
+
1302
+ Notes
1303
+ -----
1304
+ The solution is the coefficients of the polynomial `p` that minimizes
1305
+ the sum of the weighted squared errors
1306
+
1307
+ .. math:: E = \\sum_j w_j^2 * |y_j - p(x_j)|^2,
1308
+
1309
+ where the :math:`w_j` are the weights. This problem is solved by
1310
+ setting up the (typically) over-determined matrix equation:
1311
+
1312
+ .. math:: V(x) * c = w * y,
1313
+
1314
+ where `V` is the weighted pseudo Vandermonde matrix of `x`, `c` are the
1315
+ coefficients to be solved for, `w` are the weights, and `y` are the
1316
+ observed values. This equation is then solved using the singular value
1317
+ decomposition of `V`.
1318
+
1319
+ If some of the singular values of `V` are so small that they are
1320
+ neglected (and `full` == ``False``), a `RankWarning` will be raised.
1321
+ This means that the coefficient values may be poorly determined.
1322
+ Fitting to a lower order polynomial will usually get rid of the warning
1323
+ (but may not be what you want, of course; if you have independent
1324
+ reason(s) for choosing the degree which isn't working, you may have to:
1325
+ a) reconsider those reasons, and/or b) reconsider the quality of your
1326
+ data). The `rcond` parameter can also be set to a value smaller than
1327
+ its default, but the resulting fit may be spurious and have large
1328
+ contributions from roundoff error.
1329
+
1330
+ Polynomial fits using double precision tend to "fail" at about
1331
+ (polynomial) degree 20. Fits using Chebyshev or Legendre series are
1332
+ generally better conditioned, but much can still depend on the
1333
+ distribution of the sample points and the smoothness of the data. If
1334
+ the quality of the fit is inadequate, splines may be a good
1335
+ alternative.
1336
+
1337
+ Examples
1338
+ --------
1339
+ >>> np.random.seed(123)
1340
+ >>> from numpy.polynomial import polynomial as P
1341
+ >>> x = np.linspace(-1,1,51) # x "data": [-1, -0.96, ..., 0.96, 1]
1342
+ >>> y = x**3 - x + np.random.randn(len(x)) # x^3 - x + Gaussian noise
1343
+ >>> c, stats = P.polyfit(x,y,3,full=True)
1344
+ >>> np.random.seed(123)
1345
+ >>> c # c[0], c[2] should be approx. 0, c[1] approx. -1, c[3] approx. 1
1346
+ array([ 0.01909725, -1.30598256, -0.00577963, 1.02644286]) # may vary
1347
+ >>> stats # note the large SSR, explaining the rather poor results
1348
+ [array([ 38.06116253]), 4, array([ 1.38446749, 1.32119158, 0.50443316, # may vary
1349
+ 0.28853036]), 1.1324274851176597e-014]
1350
+
1351
+ Same thing without the added noise
1352
+
1353
+ >>> y = x**3 - x
1354
+ >>> c, stats = P.polyfit(x,y,3,full=True)
1355
+ >>> c # c[0], c[2] should be "very close to 0", c[1] ~= -1, c[3] ~= 1
1356
+ array([-6.36925336e-18, -1.00000000e+00, -4.08053781e-16, 1.00000000e+00])
1357
+ >>> stats # note the minuscule SSR
1358
+ [array([ 7.46346754e-31]), 4, array([ 1.38446749, 1.32119158, # may vary
1359
+ 0.50443316, 0.28853036]), 1.1324274851176597e-014]
1360
+
1361
+ """
1362
+ return pu._fit(polyvander, x, y, deg, rcond, full, w)
1363
+
1364
+
1365
+ def polycompanion(c):
1366
+ """
1367
+ Return the companion matrix of c.
1368
+
1369
+ The companion matrix for power series cannot be made symmetric by
1370
+ scaling the basis, so this function differs from those for the
1371
+ orthogonal polynomials.
1372
+
1373
+ Parameters
1374
+ ----------
1375
+ c : array_like
1376
+ 1-D array of polynomial coefficients ordered from low to high
1377
+ degree.
1378
+
1379
+ Returns
1380
+ -------
1381
+ mat : ndarray
1382
+ Companion matrix of dimensions (deg, deg).
1383
+
1384
+ Notes
1385
+ -----
1386
+
1387
+ .. versionadded:: 1.7.0
1388
+
1389
+ """
1390
+ # c is a trimmed copy
1391
+ [c] = pu.as_series([c])
1392
+ if len(c) < 2:
1393
+ raise ValueError('Series must have maximum degree of at least 1.')
1394
+ if len(c) == 2:
1395
+ return np.array([[-c[0]/c[1]]])
1396
+
1397
+ n = len(c) - 1
1398
+ mat = np.zeros((n, n), dtype=c.dtype)
1399
+ bot = mat.reshape(-1)[n::n+1]
1400
+ bot[...] = 1
1401
+ mat[:, -1] -= c[:-1]/c[-1]
1402
+ return mat
1403
+
1404
+
1405
+ def polyroots(c):
1406
+ """
1407
+ Compute the roots of a polynomial.
1408
+
1409
+ Return the roots (a.k.a. "zeros") of the polynomial
1410
+
1411
+ .. math:: p(x) = \\sum_i c[i] * x^i.
1412
+
1413
+ Parameters
1414
+ ----------
1415
+ c : 1-D array_like
1416
+ 1-D array of polynomial coefficients.
1417
+
1418
+ Returns
1419
+ -------
1420
+ out : ndarray
1421
+ Array of the roots of the polynomial. If all the roots are real,
1422
+ then `out` is also real, otherwise it is complex.
1423
+
1424
+ See Also
1425
+ --------
1426
+ numpy.polynomial.chebyshev.chebroots
1427
+ numpy.polynomial.legendre.legroots
1428
+ numpy.polynomial.laguerre.lagroots
1429
+ numpy.polynomial.hermite.hermroots
1430
+ numpy.polynomial.hermite_e.hermeroots
1431
+
1432
+ Notes
1433
+ -----
1434
+ The root estimates are obtained as the eigenvalues of the companion
1435
+ matrix, Roots far from the origin of the complex plane may have large
1436
+ errors due to the numerical instability of the power series for such
1437
+ values. Roots with multiplicity greater than 1 will also show larger
1438
+ errors as the value of the series near such points is relatively
1439
+ insensitive to errors in the roots. Isolated roots near the origin can
1440
+ be improved by a few iterations of Newton's method.
1441
+
1442
+ Examples
1443
+ --------
1444
+ >>> import numpy.polynomial.polynomial as poly
1445
+ >>> poly.polyroots(poly.polyfromroots((-1,0,1)))
1446
+ array([-1., 0., 1.])
1447
+ >>> poly.polyroots(poly.polyfromroots((-1,0,1))).dtype
1448
+ dtype('float64')
1449
+ >>> j = complex(0,1)
1450
+ >>> poly.polyroots(poly.polyfromroots((-j,0,j)))
1451
+ array([ 0.00000000e+00+0.j, 0.00000000e+00+1.j, 2.77555756e-17-1.j]) # may vary
1452
+
1453
+ """
1454
+ # c is a trimmed copy
1455
+ [c] = pu.as_series([c])
1456
+ if len(c) < 2:
1457
+ return np.array([], dtype=c.dtype)
1458
+ if len(c) == 2:
1459
+ return np.array([-c[0]/c[1]])
1460
+
1461
+ # rotated companion matrix reduces error
1462
+ m = polycompanion(c)[::-1,::-1]
1463
+ r = la.eigvals(m)
1464
+ r.sort()
1465
+ return r
1466
+
1467
+
1468
+ #
1469
+ # polynomial class
1470
+ #
1471
+
1472
+ class Polynomial(ABCPolyBase):
1473
+ """A power series class.
1474
+
1475
+ The Polynomial class provides the standard Python numerical methods
1476
+ '+', '-', '*', '//', '%', 'divmod', '**', and '()' as well as the
1477
+ attributes and methods listed in the `ABCPolyBase` documentation.
1478
+
1479
+ Parameters
1480
+ ----------
1481
+ coef : array_like
1482
+ Polynomial coefficients in order of increasing degree, i.e.,
1483
+ ``(1, 2, 3)`` give ``1 + 2*x + 3*x**2``.
1484
+ domain : (2,) array_like, optional
1485
+ Domain to use. The interval ``[domain[0], domain[1]]`` is mapped
1486
+ to the interval ``[window[0], window[1]]`` by shifting and scaling.
1487
+ The default value is [-1, 1].
1488
+ window : (2,) array_like, optional
1489
+ Window, see `domain` for its use. The default value is [-1, 1].
1490
+
1491
+ .. versionadded:: 1.6.0
1492
+ symbol : str, optional
1493
+ Symbol used to represent the independent variable in string
1494
+ representations of the polynomial expression, e.g. for printing.
1495
+ The symbol must be a valid Python identifier. Default value is 'x'.
1496
+
1497
+ .. versionadded:: 1.24
1498
+
1499
+ """
1500
+ # Virtual Functions
1501
+ _add = staticmethod(polyadd)
1502
+ _sub = staticmethod(polysub)
1503
+ _mul = staticmethod(polymul)
1504
+ _div = staticmethod(polydiv)
1505
+ _pow = staticmethod(polypow)
1506
+ _val = staticmethod(polyval)
1507
+ _int = staticmethod(polyint)
1508
+ _der = staticmethod(polyder)
1509
+ _fit = staticmethod(polyfit)
1510
+ _line = staticmethod(polyline)
1511
+ _roots = staticmethod(polyroots)
1512
+ _fromroots = staticmethod(polyfromroots)
1513
+
1514
+ # Virtual properties
1515
+ domain = np.array(polydomain)
1516
+ window = np.array(polydomain)
1517
+ basis_name = None
1518
+
1519
+ @classmethod
1520
+ def _str_term_unicode(cls, i, arg_str):
1521
+ if i == '1':
1522
+ return f"·{arg_str}"
1523
+ else:
1524
+ return f"·{arg_str}{i.translate(cls._superscript_mapping)}"
1525
+
1526
+ @staticmethod
1527
+ def _str_term_ascii(i, arg_str):
1528
+ if i == '1':
1529
+ return f" {arg_str}"
1530
+ else:
1531
+ return f" {arg_str}**{i}"
1532
+
1533
+ @staticmethod
1534
+ def _repr_latex_term(i, arg_str, needs_parens):
1535
+ if needs_parens:
1536
+ arg_str = rf"\left({arg_str}\right)"
1537
+ if i == 0:
1538
+ return '1'
1539
+ elif i == 1:
1540
+ return arg_str
1541
+ else:
1542
+ return f"{arg_str}^{{{i}}}"
.venv/lib/python3.11/site-packages/numpy/polynomial/polyutils.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility classes and functions for the polynomial modules.
3
+
4
+ This module provides: error and warning objects; a polynomial base class;
5
+ and some routines used in both the `polynomial` and `chebyshev` modules.
6
+
7
+ Warning objects
8
+ ---------------
9
+
10
+ .. autosummary::
11
+ :toctree: generated/
12
+
13
+ RankWarning raised in least-squares fit for rank-deficient matrix.
14
+
15
+ Functions
16
+ ---------
17
+
18
+ .. autosummary::
19
+ :toctree: generated/
20
+
21
+ as_series convert list of array_likes into 1-D arrays of common type.
22
+ trimseq remove trailing zeros.
23
+ trimcoef remove small trailing coefficients.
24
+ getdomain return the domain appropriate for a given set of abscissae.
25
+ mapdomain maps points between domains.
26
+ mapparms parameters of the linear map between domains.
27
+
28
+ """
29
+ import operator
30
+ import functools
31
+ import warnings
32
+
33
+ import numpy as np
34
+
35
+ from numpy.core.multiarray import dragon4_positional, dragon4_scientific
36
+ from numpy.core.umath import absolute
37
+
38
+ __all__ = [
39
+ 'RankWarning', 'as_series', 'trimseq',
40
+ 'trimcoef', 'getdomain', 'mapdomain', 'mapparms',
41
+ 'format_float']
42
+
43
+ #
44
+ # Warnings and Exceptions
45
+ #
46
+
47
+ class RankWarning(UserWarning):
48
+ """Issued by chebfit when the design matrix is rank deficient."""
49
+ pass
50
+
51
+ #
52
+ # Helper functions to convert inputs to 1-D arrays
53
+ #
54
+ def trimseq(seq):
55
+ """Remove small Poly series coefficients.
56
+
57
+ Parameters
58
+ ----------
59
+ seq : sequence
60
+ Sequence of Poly series coefficients. This routine fails for
61
+ empty sequences.
62
+
63
+ Returns
64
+ -------
65
+ series : sequence
66
+ Subsequence with trailing zeros removed. If the resulting sequence
67
+ would be empty, return the first element. The returned sequence may
68
+ or may not be a view.
69
+
70
+ Notes
71
+ -----
72
+ Do not lose the type info if the sequence contains unknown objects.
73
+
74
+ """
75
+ if len(seq) == 0:
76
+ return seq
77
+ else:
78
+ for i in range(len(seq) - 1, -1, -1):
79
+ if seq[i] != 0:
80
+ break
81
+ return seq[:i+1]
82
+
83
+
84
+ def as_series(alist, trim=True):
85
+ """
86
+ Return argument as a list of 1-d arrays.
87
+
88
+ The returned list contains array(s) of dtype double, complex double, or
89
+ object. A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of
90
+ size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays
91
+ of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array
92
+ raises a Value Error if it is not first reshaped into either a 1-d or 2-d
93
+ array.
94
+
95
+ Parameters
96
+ ----------
97
+ alist : array_like
98
+ A 1- or 2-d array_like
99
+ trim : boolean, optional
100
+ When True, trailing zeros are removed from the inputs.
101
+ When False, the inputs are passed through intact.
102
+
103
+ Returns
104
+ -------
105
+ [a1, a2,...] : list of 1-D arrays
106
+ A copy of the input data as a list of 1-d arrays.
107
+
108
+ Raises
109
+ ------
110
+ ValueError
111
+ Raised when `as_series` cannot convert its input to 1-d arrays, or at
112
+ least one of the resulting arrays is empty.
113
+
114
+ Examples
115
+ --------
116
+ >>> from numpy.polynomial import polyutils as pu
117
+ >>> a = np.arange(4)
118
+ >>> pu.as_series(a)
119
+ [array([0.]), array([1.]), array([2.]), array([3.])]
120
+ >>> b = np.arange(6).reshape((2,3))
121
+ >>> pu.as_series(b)
122
+ [array([0., 1., 2.]), array([3., 4., 5.])]
123
+
124
+ >>> pu.as_series((1, np.arange(3), np.arange(2, dtype=np.float16)))
125
+ [array([1.]), array([0., 1., 2.]), array([0., 1.])]
126
+
127
+ >>> pu.as_series([2, [1.1, 0.]])
128
+ [array([2.]), array([1.1])]
129
+
130
+ >>> pu.as_series([2, [1.1, 0.]], trim=False)
131
+ [array([2.]), array([1.1, 0. ])]
132
+
133
+ """
134
+ arrays = [np.array(a, ndmin=1, copy=False) for a in alist]
135
+ if min([a.size for a in arrays]) == 0:
136
+ raise ValueError("Coefficient array is empty")
137
+ if any(a.ndim != 1 for a in arrays):
138
+ raise ValueError("Coefficient array is not 1-d")
139
+ if trim:
140
+ arrays = [trimseq(a) for a in arrays]
141
+
142
+ if any(a.dtype == np.dtype(object) for a in arrays):
143
+ ret = []
144
+ for a in arrays:
145
+ if a.dtype != np.dtype(object):
146
+ tmp = np.empty(len(a), dtype=np.dtype(object))
147
+ tmp[:] = a[:]
148
+ ret.append(tmp)
149
+ else:
150
+ ret.append(a.copy())
151
+ else:
152
+ try:
153
+ dtype = np.common_type(*arrays)
154
+ except Exception as e:
155
+ raise ValueError("Coefficient arrays have no common type") from e
156
+ ret = [np.array(a, copy=True, dtype=dtype) for a in arrays]
157
+ return ret
158
+
159
+
160
+ def trimcoef(c, tol=0):
161
+ """
162
+ Remove "small" "trailing" coefficients from a polynomial.
163
+
164
+ "Small" means "small in absolute value" and is controlled by the
165
+ parameter `tol`; "trailing" means highest order coefficient(s), e.g., in
166
+ ``[0, 1, 1, 0, 0]`` (which represents ``0 + x + x**2 + 0*x**3 + 0*x**4``)
167
+ both the 3-rd and 4-th order coefficients would be "trimmed."
168
+
169
+ Parameters
170
+ ----------
171
+ c : array_like
172
+ 1-d array of coefficients, ordered from lowest order to highest.
173
+ tol : number, optional
174
+ Trailing (i.e., highest order) elements with absolute value less
175
+ than or equal to `tol` (default value is zero) are removed.
176
+
177
+ Returns
178
+ -------
179
+ trimmed : ndarray
180
+ 1-d array with trailing zeros removed. If the resulting series
181
+ would be empty, a series containing a single zero is returned.
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If `tol` < 0
187
+
188
+ See Also
189
+ --------
190
+ trimseq
191
+
192
+ Examples
193
+ --------
194
+ >>> from numpy.polynomial import polyutils as pu
195
+ >>> pu.trimcoef((0,0,3,0,5,0,0))
196
+ array([0., 0., 3., 0., 5.])
197
+ >>> pu.trimcoef((0,0,1e-3,0,1e-5,0,0),1e-3) # item == tol is trimmed
198
+ array([0.])
199
+ >>> i = complex(0,1) # works for complex
200
+ >>> pu.trimcoef((3e-4,1e-3*(1-i),5e-4,2e-5*(1+i)), 1e-3)
201
+ array([0.0003+0.j , 0.001 -0.001j])
202
+
203
+ """
204
+ if tol < 0:
205
+ raise ValueError("tol must be non-negative")
206
+
207
+ [c] = as_series([c])
208
+ [ind] = np.nonzero(np.abs(c) > tol)
209
+ if len(ind) == 0:
210
+ return c[:1]*0
211
+ else:
212
+ return c[:ind[-1] + 1].copy()
213
+
214
+ def getdomain(x):
215
+ """
216
+ Return a domain suitable for given abscissae.
217
+
218
+ Find a domain suitable for a polynomial or Chebyshev series
219
+ defined at the values supplied.
220
+
221
+ Parameters
222
+ ----------
223
+ x : array_like
224
+ 1-d array of abscissae whose domain will be determined.
225
+
226
+ Returns
227
+ -------
228
+ domain : ndarray
229
+ 1-d array containing two values. If the inputs are complex, then
230
+ the two returned points are the lower left and upper right corners
231
+ of the smallest rectangle (aligned with the axes) in the complex
232
+ plane containing the points `x`. If the inputs are real, then the
233
+ two points are the ends of the smallest interval containing the
234
+ points `x`.
235
+
236
+ See Also
237
+ --------
238
+ mapparms, mapdomain
239
+
240
+ Examples
241
+ --------
242
+ >>> from numpy.polynomial import polyutils as pu
243
+ >>> points = np.arange(4)**2 - 5; points
244
+ array([-5, -4, -1, 4])
245
+ >>> pu.getdomain(points)
246
+ array([-5., 4.])
247
+ >>> c = np.exp(complex(0,1)*np.pi*np.arange(12)/6) # unit circle
248
+ >>> pu.getdomain(c)
249
+ array([-1.-1.j, 1.+1.j])
250
+
251
+ """
252
+ [x] = as_series([x], trim=False)
253
+ if x.dtype.char in np.typecodes['Complex']:
254
+ rmin, rmax = x.real.min(), x.real.max()
255
+ imin, imax = x.imag.min(), x.imag.max()
256
+ return np.array((complex(rmin, imin), complex(rmax, imax)))
257
+ else:
258
+ return np.array((x.min(), x.max()))
259
+
260
+ def mapparms(old, new):
261
+ """
262
+ Linear map parameters between domains.
263
+
264
+ Return the parameters of the linear map ``offset + scale*x`` that maps
265
+ `old` to `new` such that ``old[i] -> new[i]``, ``i = 0, 1``.
266
+
267
+ Parameters
268
+ ----------
269
+ old, new : array_like
270
+ Domains. Each domain must (successfully) convert to a 1-d array
271
+ containing precisely two values.
272
+
273
+ Returns
274
+ -------
275
+ offset, scale : scalars
276
+ The map ``L(x) = offset + scale*x`` maps the first domain to the
277
+ second.
278
+
279
+ See Also
280
+ --------
281
+ getdomain, mapdomain
282
+
283
+ Notes
284
+ -----
285
+ Also works for complex numbers, and thus can be used to calculate the
286
+ parameters required to map any line in the complex plane to any other
287
+ line therein.
288
+
289
+ Examples
290
+ --------
291
+ >>> from numpy.polynomial import polyutils as pu
292
+ >>> pu.mapparms((-1,1),(-1,1))
293
+ (0.0, 1.0)
294
+ >>> pu.mapparms((1,-1),(-1,1))
295
+ (-0.0, -1.0)
296
+ >>> i = complex(0,1)
297
+ >>> pu.mapparms((-i,-1),(1,i))
298
+ ((1+1j), (1-0j))
299
+
300
+ """
301
+ oldlen = old[1] - old[0]
302
+ newlen = new[1] - new[0]
303
+ off = (old[1]*new[0] - old[0]*new[1])/oldlen
304
+ scl = newlen/oldlen
305
+ return off, scl
306
+
307
+ def mapdomain(x, old, new):
308
+ """
309
+ Apply linear map to input points.
310
+
311
+ The linear map ``offset + scale*x`` that maps the domain `old` to
312
+ the domain `new` is applied to the points `x`.
313
+
314
+ Parameters
315
+ ----------
316
+ x : array_like
317
+ Points to be mapped. If `x` is a subtype of ndarray the subtype
318
+ will be preserved.
319
+ old, new : array_like
320
+ The two domains that determine the map. Each must (successfully)
321
+ convert to 1-d arrays containing precisely two values.
322
+
323
+ Returns
324
+ -------
325
+ x_out : ndarray
326
+ Array of points of the same shape as `x`, after application of the
327
+ linear map between the two domains.
328
+
329
+ See Also
330
+ --------
331
+ getdomain, mapparms
332
+
333
+ Notes
334
+ -----
335
+ Effectively, this implements:
336
+
337
+ .. math::
338
+ x\\_out = new[0] + m(x - old[0])
339
+
340
+ where
341
+
342
+ .. math::
343
+ m = \\frac{new[1]-new[0]}{old[1]-old[0]}
344
+
345
+ Examples
346
+ --------
347
+ >>> from numpy.polynomial import polyutils as pu
348
+ >>> old_domain = (-1,1)
349
+ >>> new_domain = (0,2*np.pi)
350
+ >>> x = np.linspace(-1,1,6); x
351
+ array([-1. , -0.6, -0.2, 0.2, 0.6, 1. ])
352
+ >>> x_out = pu.mapdomain(x, old_domain, new_domain); x_out
353
+ array([ 0. , 1.25663706, 2.51327412, 3.76991118, 5.02654825, # may vary
354
+ 6.28318531])
355
+ >>> x - pu.mapdomain(x_out, new_domain, old_domain)
356
+ array([0., 0., 0., 0., 0., 0.])
357
+
358
+ Also works for complex numbers (and thus can be used to map any line in
359
+ the complex plane to any other line therein).
360
+
361
+ >>> i = complex(0,1)
362
+ >>> old = (-1 - i, 1 + i)
363
+ >>> new = (-1 + i, 1 - i)
364
+ >>> z = np.linspace(old[0], old[1], 6); z
365
+ array([-1. -1.j , -0.6-0.6j, -0.2-0.2j, 0.2+0.2j, 0.6+0.6j, 1. +1.j ])
366
+ >>> new_z = pu.mapdomain(z, old, new); new_z
367
+ array([-1.0+1.j , -0.6+0.6j, -0.2+0.2j, 0.2-0.2j, 0.6-0.6j, 1.0-1.j ]) # may vary
368
+
369
+ """
370
+ x = np.asanyarray(x)
371
+ off, scl = mapparms(old, new)
372
+ return off + scl*x
373
+
374
+
375
+ def _nth_slice(i, ndim):
376
+ sl = [np.newaxis] * ndim
377
+ sl[i] = slice(None)
378
+ return tuple(sl)
379
+
380
+
381
+ def _vander_nd(vander_fs, points, degrees):
382
+ r"""
383
+ A generalization of the Vandermonde matrix for N dimensions
384
+
385
+ The result is built by combining the results of 1d Vandermonde matrices,
386
+
387
+ .. math::
388
+ W[i_0, \ldots, i_M, j_0, \ldots, j_N] = \prod_{k=0}^N{V_k(x_k)[i_0, \ldots, i_M, j_k]}
389
+
390
+ where
391
+
392
+ .. math::
393
+ N &= \texttt{len(points)} = \texttt{len(degrees)} = \texttt{len(vander\_fs)} \\
394
+ M &= \texttt{points[k].ndim} \\
395
+ V_k &= \texttt{vander\_fs[k]} \\
396
+ x_k &= \texttt{points[k]} \\
397
+ 0 \le j_k &\le \texttt{degrees[k]}
398
+
399
+ Expanding the one-dimensional :math:`V_k` functions gives:
400
+
401
+ .. math::
402
+ W[i_0, \ldots, i_M, j_0, \ldots, j_N] = \prod_{k=0}^N{B_{k, j_k}(x_k[i_0, \ldots, i_M])}
403
+
404
+ where :math:`B_{k,m}` is the m'th basis of the polynomial construction used along
405
+ dimension :math:`k`. For a regular polynomial, :math:`B_{k, m}(x) = P_m(x) = x^m`.
406
+
407
+ Parameters
408
+ ----------
409
+ vander_fs : Sequence[function(array_like, int) -> ndarray]
410
+ The 1d vander function to use for each axis, such as ``polyvander``
411
+ points : Sequence[array_like]
412
+ Arrays of point coordinates, all of the same shape. The dtypes
413
+ will be converted to either float64 or complex128 depending on
414
+ whether any of the elements are complex. Scalars are converted to
415
+ 1-D arrays.
416
+ This must be the same length as `vander_fs`.
417
+ degrees : Sequence[int]
418
+ The maximum degree (inclusive) to use for each axis.
419
+ This must be the same length as `vander_fs`.
420
+
421
+ Returns
422
+ -------
423
+ vander_nd : ndarray
424
+ An array of shape ``points[0].shape + tuple(d + 1 for d in degrees)``.
425
+ """
426
+ n_dims = len(vander_fs)
427
+ if n_dims != len(points):
428
+ raise ValueError(
429
+ f"Expected {n_dims} dimensions of sample points, got {len(points)}")
430
+ if n_dims != len(degrees):
431
+ raise ValueError(
432
+ f"Expected {n_dims} dimensions of degrees, got {len(degrees)}")
433
+ if n_dims == 0:
434
+ raise ValueError("Unable to guess a dtype or shape when no points are given")
435
+
436
+ # convert to the same shape and type
437
+ points = tuple(np.array(tuple(points), copy=False) + 0.0)
438
+
439
+ # produce the vandermonde matrix for each dimension, placing the last
440
+ # axis of each in an independent trailing axis of the output
441
+ vander_arrays = (
442
+ vander_fs[i](points[i], degrees[i])[(...,) + _nth_slice(i, n_dims)]
443
+ for i in range(n_dims)
444
+ )
445
+
446
+ # we checked this wasn't empty already, so no `initial` needed
447
+ return functools.reduce(operator.mul, vander_arrays)
448
+
449
+
450
+ def _vander_nd_flat(vander_fs, points, degrees):
451
+ """
452
+ Like `_vander_nd`, but flattens the last ``len(degrees)`` axes into a single axis
453
+
454
+ Used to implement the public ``<type>vander<n>d`` functions.
455
+ """
456
+ v = _vander_nd(vander_fs, points, degrees)
457
+ return v.reshape(v.shape[:-len(degrees)] + (-1,))
458
+
459
+
460
+ def _fromroots(line_f, mul_f, roots):
461
+ """
462
+ Helper function used to implement the ``<type>fromroots`` functions.
463
+
464
+ Parameters
465
+ ----------
466
+ line_f : function(float, float) -> ndarray
467
+ The ``<type>line`` function, such as ``polyline``
468
+ mul_f : function(array_like, array_like) -> ndarray
469
+ The ``<type>mul`` function, such as ``polymul``
470
+ roots
471
+ See the ``<type>fromroots`` functions for more detail
472
+ """
473
+ if len(roots) == 0:
474
+ return np.ones(1)
475
+ else:
476
+ [roots] = as_series([roots], trim=False)
477
+ roots.sort()
478
+ p = [line_f(-r, 1) for r in roots]
479
+ n = len(p)
480
+ while n > 1:
481
+ m, r = divmod(n, 2)
482
+ tmp = [mul_f(p[i], p[i+m]) for i in range(m)]
483
+ if r:
484
+ tmp[0] = mul_f(tmp[0], p[-1])
485
+ p = tmp
486
+ n = m
487
+ return p[0]
488
+
489
+
490
+ def _valnd(val_f, c, *args):
491
+ """
492
+ Helper function used to implement the ``<type>val<n>d`` functions.
493
+
494
+ Parameters
495
+ ----------
496
+ val_f : function(array_like, array_like, tensor: bool) -> array_like
497
+ The ``<type>val`` function, such as ``polyval``
498
+ c, args
499
+ See the ``<type>val<n>d`` functions for more detail
500
+ """
501
+ args = [np.asanyarray(a) for a in args]
502
+ shape0 = args[0].shape
503
+ if not all((a.shape == shape0 for a in args[1:])):
504
+ if len(args) == 3:
505
+ raise ValueError('x, y, z are incompatible')
506
+ elif len(args) == 2:
507
+ raise ValueError('x, y are incompatible')
508
+ else:
509
+ raise ValueError('ordinates are incompatible')
510
+ it = iter(args)
511
+ x0 = next(it)
512
+
513
+ # use tensor on only the first
514
+ c = val_f(x0, c)
515
+ for xi in it:
516
+ c = val_f(xi, c, tensor=False)
517
+ return c
518
+
519
+
520
+ def _gridnd(val_f, c, *args):
521
+ """
522
+ Helper function used to implement the ``<type>grid<n>d`` functions.
523
+
524
+ Parameters
525
+ ----------
526
+ val_f : function(array_like, array_like, tensor: bool) -> array_like
527
+ The ``<type>val`` function, such as ``polyval``
528
+ c, args
529
+ See the ``<type>grid<n>d`` functions for more detail
530
+ """
531
+ for xi in args:
532
+ c = val_f(xi, c)
533
+ return c
534
+
535
+
536
+ def _div(mul_f, c1, c2):
537
+ """
538
+ Helper function used to implement the ``<type>div`` functions.
539
+
540
+ Implementation uses repeated subtraction of c2 multiplied by the nth basis.
541
+ For some polynomial types, a more efficient approach may be possible.
542
+
543
+ Parameters
544
+ ----------
545
+ mul_f : function(array_like, array_like) -> array_like
546
+ The ``<type>mul`` function, such as ``polymul``
547
+ c1, c2
548
+ See the ``<type>div`` functions for more detail
549
+ """
550
+ # c1, c2 are trimmed copies
551
+ [c1, c2] = as_series([c1, c2])
552
+ if c2[-1] == 0:
553
+ raise ZeroDivisionError()
554
+
555
+ lc1 = len(c1)
556
+ lc2 = len(c2)
557
+ if lc1 < lc2:
558
+ return c1[:1]*0, c1
559
+ elif lc2 == 1:
560
+ return c1/c2[-1], c1[:1]*0
561
+ else:
562
+ quo = np.empty(lc1 - lc2 + 1, dtype=c1.dtype)
563
+ rem = c1
564
+ for i in range(lc1 - lc2, - 1, -1):
565
+ p = mul_f([0]*i + [1], c2)
566
+ q = rem[-1]/p[-1]
567
+ rem = rem[:-1] - q*p[:-1]
568
+ quo[i] = q
569
+ return quo, trimseq(rem)
570
+
571
+
572
+ def _add(c1, c2):
573
+ """ Helper function used to implement the ``<type>add`` functions. """
574
+ # c1, c2 are trimmed copies
575
+ [c1, c2] = as_series([c1, c2])
576
+ if len(c1) > len(c2):
577
+ c1[:c2.size] += c2
578
+ ret = c1
579
+ else:
580
+ c2[:c1.size] += c1
581
+ ret = c2
582
+ return trimseq(ret)
583
+
584
+
585
+ def _sub(c1, c2):
586
+ """ Helper function used to implement the ``<type>sub`` functions. """
587
+ # c1, c2 are trimmed copies
588
+ [c1, c2] = as_series([c1, c2])
589
+ if len(c1) > len(c2):
590
+ c1[:c2.size] -= c2
591
+ ret = c1
592
+ else:
593
+ c2 = -c2
594
+ c2[:c1.size] += c1
595
+ ret = c2
596
+ return trimseq(ret)
597
+
598
+
599
+ def _fit(vander_f, x, y, deg, rcond=None, full=False, w=None):
600
+ """
601
+ Helper function used to implement the ``<type>fit`` functions.
602
+
603
+ Parameters
604
+ ----------
605
+ vander_f : function(array_like, int) -> ndarray
606
+ The 1d vander function, such as ``polyvander``
607
+ c1, c2
608
+ See the ``<type>fit`` functions for more detail
609
+ """
610
+ x = np.asarray(x) + 0.0
611
+ y = np.asarray(y) + 0.0
612
+ deg = np.asarray(deg)
613
+
614
+ # check arguments.
615
+ if deg.ndim > 1 or deg.dtype.kind not in 'iu' or deg.size == 0:
616
+ raise TypeError("deg must be an int or non-empty 1-D array of int")
617
+ if deg.min() < 0:
618
+ raise ValueError("expected deg >= 0")
619
+ if x.ndim != 1:
620
+ raise TypeError("expected 1D vector for x")
621
+ if x.size == 0:
622
+ raise TypeError("expected non-empty vector for x")
623
+ if y.ndim < 1 or y.ndim > 2:
624
+ raise TypeError("expected 1D or 2D array for y")
625
+ if len(x) != len(y):
626
+ raise TypeError("expected x and y to have same length")
627
+
628
+ if deg.ndim == 0:
629
+ lmax = deg
630
+ order = lmax + 1
631
+ van = vander_f(x, lmax)
632
+ else:
633
+ deg = np.sort(deg)
634
+ lmax = deg[-1]
635
+ order = len(deg)
636
+ van = vander_f(x, lmax)[:, deg]
637
+
638
+ # set up the least squares matrices in transposed form
639
+ lhs = van.T
640
+ rhs = y.T
641
+ if w is not None:
642
+ w = np.asarray(w) + 0.0
643
+ if w.ndim != 1:
644
+ raise TypeError("expected 1D vector for w")
645
+ if len(x) != len(w):
646
+ raise TypeError("expected x and w to have same length")
647
+ # apply weights. Don't use inplace operations as they
648
+ # can cause problems with NA.
649
+ lhs = lhs * w
650
+ rhs = rhs * w
651
+
652
+ # set rcond
653
+ if rcond is None:
654
+ rcond = len(x)*np.finfo(x.dtype).eps
655
+
656
+ # Determine the norms of the design matrix columns.
657
+ if issubclass(lhs.dtype.type, np.complexfloating):
658
+ scl = np.sqrt((np.square(lhs.real) + np.square(lhs.imag)).sum(1))
659
+ else:
660
+ scl = np.sqrt(np.square(lhs).sum(1))
661
+ scl[scl == 0] = 1
662
+
663
+ # Solve the least squares problem.
664
+ c, resids, rank, s = np.linalg.lstsq(lhs.T/scl, rhs.T, rcond)
665
+ c = (c.T/scl).T
666
+
667
+ # Expand c to include non-fitted coefficients which are set to zero
668
+ if deg.ndim > 0:
669
+ if c.ndim == 2:
670
+ cc = np.zeros((lmax+1, c.shape[1]), dtype=c.dtype)
671
+ else:
672
+ cc = np.zeros(lmax+1, dtype=c.dtype)
673
+ cc[deg] = c
674
+ c = cc
675
+
676
+ # warn on rank reduction
677
+ if rank != order and not full:
678
+ msg = "The fit may be poorly conditioned"
679
+ warnings.warn(msg, RankWarning, stacklevel=2)
680
+
681
+ if full:
682
+ return c, [resids, rank, s, rcond]
683
+ else:
684
+ return c
685
+
686
+
687
+ def _pow(mul_f, c, pow, maxpower):
688
+ """
689
+ Helper function used to implement the ``<type>pow`` functions.
690
+
691
+ Parameters
692
+ ----------
693
+ mul_f : function(array_like, array_like) -> ndarray
694
+ The ``<type>mul`` function, such as ``polymul``
695
+ c : array_like
696
+ 1-D array of array of series coefficients
697
+ pow, maxpower
698
+ See the ``<type>pow`` functions for more detail
699
+ """
700
+ # c is a trimmed copy
701
+ [c] = as_series([c])
702
+ power = int(pow)
703
+ if power != pow or power < 0:
704
+ raise ValueError("Power must be a non-negative integer.")
705
+ elif maxpower is not None and power > maxpower:
706
+ raise ValueError("Power is too large")
707
+ elif power == 0:
708
+ return np.array([1], dtype=c.dtype)
709
+ elif power == 1:
710
+ return c
711
+ else:
712
+ # This can be made more efficient by using powers of two
713
+ # in the usual way.
714
+ prd = c
715
+ for i in range(2, power + 1):
716
+ prd = mul_f(prd, c)
717
+ return prd
718
+
719
+
720
+ def _deprecate_as_int(x, desc):
721
+ """
722
+ Like `operator.index`, but emits a deprecation warning when passed a float
723
+
724
+ Parameters
725
+ ----------
726
+ x : int-like, or float with integral value
727
+ Value to interpret as an integer
728
+ desc : str
729
+ description to include in any error message
730
+
731
+ Raises
732
+ ------
733
+ TypeError : if x is a non-integral float or non-numeric
734
+ DeprecationWarning : if x is an integral float
735
+ """
736
+ try:
737
+ return operator.index(x)
738
+ except TypeError as e:
739
+ # Numpy 1.17.0, 2019-03-11
740
+ try:
741
+ ix = int(x)
742
+ except TypeError:
743
+ pass
744
+ else:
745
+ if ix == x:
746
+ warnings.warn(
747
+ f"In future, this will raise TypeError, as {desc} will "
748
+ "need to be an integer not just an integral float.",
749
+ DeprecationWarning,
750
+ stacklevel=3
751
+ )
752
+ return ix
753
+
754
+ raise TypeError(f"{desc} must be an integer") from e
755
+
756
+
757
+ def format_float(x, parens=False):
758
+ if not np.issubdtype(type(x), np.floating):
759
+ return str(x)
760
+
761
+ opts = np.get_printoptions()
762
+
763
+ if np.isnan(x):
764
+ return opts['nanstr']
765
+ elif np.isinf(x):
766
+ return opts['infstr']
767
+
768
+ exp_format = False
769
+ if x != 0:
770
+ a = absolute(x)
771
+ if a >= 1.e8 or a < 10**min(0, -(opts['precision']-1)//2):
772
+ exp_format = True
773
+
774
+ trim, unique = '0', True
775
+ if opts['floatmode'] == 'fixed':
776
+ trim, unique = 'k', False
777
+
778
+ if exp_format:
779
+ s = dragon4_scientific(x, precision=opts['precision'],
780
+ unique=unique, trim=trim,
781
+ sign=opts['sign'] == '+')
782
+ if parens:
783
+ s = '(' + s + ')'
784
+ else:
785
+ s = dragon4_positional(x, precision=opts['precision'],
786
+ fractional=True,
787
+ unique=unique, trim=trim,
788
+ sign=opts['sign'] == '+')
789
+ return s
.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_chebyshev.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for chebyshev module.
2
+
3
+ """
4
+ from functools import reduce
5
+
6
+ import numpy as np
7
+ import numpy.polynomial.chebyshev as cheb
8
+ from numpy.polynomial.polynomial import polyval
9
+ from numpy.testing import (
10
+ assert_almost_equal, assert_raises, assert_equal, assert_,
11
+ )
12
+
13
+
14
+ def trim(x):
15
+ return cheb.chebtrim(x, tol=1e-6)
16
+
17
+ T0 = [1]
18
+ T1 = [0, 1]
19
+ T2 = [-1, 0, 2]
20
+ T3 = [0, -3, 0, 4]
21
+ T4 = [1, 0, -8, 0, 8]
22
+ T5 = [0, 5, 0, -20, 0, 16]
23
+ T6 = [-1, 0, 18, 0, -48, 0, 32]
24
+ T7 = [0, -7, 0, 56, 0, -112, 0, 64]
25
+ T8 = [1, 0, -32, 0, 160, 0, -256, 0, 128]
26
+ T9 = [0, 9, 0, -120, 0, 432, 0, -576, 0, 256]
27
+
28
+ Tlist = [T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]
29
+
30
+
31
+ class TestPrivate:
32
+
33
+ def test__cseries_to_zseries(self):
34
+ for i in range(5):
35
+ inp = np.array([2] + [1]*i, np.double)
36
+ tgt = np.array([.5]*i + [2] + [.5]*i, np.double)
37
+ res = cheb._cseries_to_zseries(inp)
38
+ assert_equal(res, tgt)
39
+
40
+ def test__zseries_to_cseries(self):
41
+ for i in range(5):
42
+ inp = np.array([.5]*i + [2] + [.5]*i, np.double)
43
+ tgt = np.array([2] + [1]*i, np.double)
44
+ res = cheb._zseries_to_cseries(inp)
45
+ assert_equal(res, tgt)
46
+
47
+
48
+ class TestConstants:
49
+
50
+ def test_chebdomain(self):
51
+ assert_equal(cheb.chebdomain, [-1, 1])
52
+
53
+ def test_chebzero(self):
54
+ assert_equal(cheb.chebzero, [0])
55
+
56
+ def test_chebone(self):
57
+ assert_equal(cheb.chebone, [1])
58
+
59
+ def test_chebx(self):
60
+ assert_equal(cheb.chebx, [0, 1])
61
+
62
+
63
+ class TestArithmetic:
64
+
65
+ def test_chebadd(self):
66
+ for i in range(5):
67
+ for j in range(5):
68
+ msg = f"At i={i}, j={j}"
69
+ tgt = np.zeros(max(i, j) + 1)
70
+ tgt[i] += 1
71
+ tgt[j] += 1
72
+ res = cheb.chebadd([0]*i + [1], [0]*j + [1])
73
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
74
+
75
+ def test_chebsub(self):
76
+ for i in range(5):
77
+ for j in range(5):
78
+ msg = f"At i={i}, j={j}"
79
+ tgt = np.zeros(max(i, j) + 1)
80
+ tgt[i] += 1
81
+ tgt[j] -= 1
82
+ res = cheb.chebsub([0]*i + [1], [0]*j + [1])
83
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
84
+
85
+ def test_chebmulx(self):
86
+ assert_equal(cheb.chebmulx([0]), [0])
87
+ assert_equal(cheb.chebmulx([1]), [0, 1])
88
+ for i in range(1, 5):
89
+ ser = [0]*i + [1]
90
+ tgt = [0]*(i - 1) + [.5, 0, .5]
91
+ assert_equal(cheb.chebmulx(ser), tgt)
92
+
93
+ def test_chebmul(self):
94
+ for i in range(5):
95
+ for j in range(5):
96
+ msg = f"At i={i}, j={j}"
97
+ tgt = np.zeros(i + j + 1)
98
+ tgt[i + j] += .5
99
+ tgt[abs(i - j)] += .5
100
+ res = cheb.chebmul([0]*i + [1], [0]*j + [1])
101
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
102
+
103
+ def test_chebdiv(self):
104
+ for i in range(5):
105
+ for j in range(5):
106
+ msg = f"At i={i}, j={j}"
107
+ ci = [0]*i + [1]
108
+ cj = [0]*j + [1]
109
+ tgt = cheb.chebadd(ci, cj)
110
+ quo, rem = cheb.chebdiv(tgt, ci)
111
+ res = cheb.chebadd(cheb.chebmul(quo, ci), rem)
112
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
113
+
114
+ def test_chebpow(self):
115
+ for i in range(5):
116
+ for j in range(5):
117
+ msg = f"At i={i}, j={j}"
118
+ c = np.arange(i + 1)
119
+ tgt = reduce(cheb.chebmul, [c]*j, np.array([1]))
120
+ res = cheb.chebpow(c, j)
121
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
122
+
123
+
124
+ class TestEvaluation:
125
+ # coefficients of 1 + 2*x + 3*x**2
126
+ c1d = np.array([2.5, 2., 1.5])
127
+ c2d = np.einsum('i,j->ij', c1d, c1d)
128
+ c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d)
129
+
130
+ # some random values in [-1, 1)
131
+ x = np.random.random((3, 5))*2 - 1
132
+ y = polyval(x, [1., 2., 3.])
133
+
134
+ def test_chebval(self):
135
+ #check empty input
136
+ assert_equal(cheb.chebval([], [1]).size, 0)
137
+
138
+ #check normal input)
139
+ x = np.linspace(-1, 1)
140
+ y = [polyval(x, c) for c in Tlist]
141
+ for i in range(10):
142
+ msg = f"At i={i}"
143
+ tgt = y[i]
144
+ res = cheb.chebval(x, [0]*i + [1])
145
+ assert_almost_equal(res, tgt, err_msg=msg)
146
+
147
+ #check that shape is preserved
148
+ for i in range(3):
149
+ dims = [2]*i
150
+ x = np.zeros(dims)
151
+ assert_equal(cheb.chebval(x, [1]).shape, dims)
152
+ assert_equal(cheb.chebval(x, [1, 0]).shape, dims)
153
+ assert_equal(cheb.chebval(x, [1, 0, 0]).shape, dims)
154
+
155
+ def test_chebval2d(self):
156
+ x1, x2, x3 = self.x
157
+ y1, y2, y3 = self.y
158
+
159
+ #test exceptions
160
+ assert_raises(ValueError, cheb.chebval2d, x1, x2[:2], self.c2d)
161
+
162
+ #test values
163
+ tgt = y1*y2
164
+ res = cheb.chebval2d(x1, x2, self.c2d)
165
+ assert_almost_equal(res, tgt)
166
+
167
+ #test shape
168
+ z = np.ones((2, 3))
169
+ res = cheb.chebval2d(z, z, self.c2d)
170
+ assert_(res.shape == (2, 3))
171
+
172
+ def test_chebval3d(self):
173
+ x1, x2, x3 = self.x
174
+ y1, y2, y3 = self.y
175
+
176
+ #test exceptions
177
+ assert_raises(ValueError, cheb.chebval3d, x1, x2, x3[:2], self.c3d)
178
+
179
+ #test values
180
+ tgt = y1*y2*y3
181
+ res = cheb.chebval3d(x1, x2, x3, self.c3d)
182
+ assert_almost_equal(res, tgt)
183
+
184
+ #test shape
185
+ z = np.ones((2, 3))
186
+ res = cheb.chebval3d(z, z, z, self.c3d)
187
+ assert_(res.shape == (2, 3))
188
+
189
+ def test_chebgrid2d(self):
190
+ x1, x2, x3 = self.x
191
+ y1, y2, y3 = self.y
192
+
193
+ #test values
194
+ tgt = np.einsum('i,j->ij', y1, y2)
195
+ res = cheb.chebgrid2d(x1, x2, self.c2d)
196
+ assert_almost_equal(res, tgt)
197
+
198
+ #test shape
199
+ z = np.ones((2, 3))
200
+ res = cheb.chebgrid2d(z, z, self.c2d)
201
+ assert_(res.shape == (2, 3)*2)
202
+
203
+ def test_chebgrid3d(self):
204
+ x1, x2, x3 = self.x
205
+ y1, y2, y3 = self.y
206
+
207
+ #test values
208
+ tgt = np.einsum('i,j,k->ijk', y1, y2, y3)
209
+ res = cheb.chebgrid3d(x1, x2, x3, self.c3d)
210
+ assert_almost_equal(res, tgt)
211
+
212
+ #test shape
213
+ z = np.ones((2, 3))
214
+ res = cheb.chebgrid3d(z, z, z, self.c3d)
215
+ assert_(res.shape == (2, 3)*3)
216
+
217
+
218
+ class TestIntegral:
219
+
220
+ def test_chebint(self):
221
+ # check exceptions
222
+ assert_raises(TypeError, cheb.chebint, [0], .5)
223
+ assert_raises(ValueError, cheb.chebint, [0], -1)
224
+ assert_raises(ValueError, cheb.chebint, [0], 1, [0, 0])
225
+ assert_raises(ValueError, cheb.chebint, [0], lbnd=[0])
226
+ assert_raises(ValueError, cheb.chebint, [0], scl=[0])
227
+ assert_raises(TypeError, cheb.chebint, [0], axis=.5)
228
+
229
+ # test integration of zero polynomial
230
+ for i in range(2, 5):
231
+ k = [0]*(i - 2) + [1]
232
+ res = cheb.chebint([0], m=i, k=k)
233
+ assert_almost_equal(res, [0, 1])
234
+
235
+ # check single integration with integration constant
236
+ for i in range(5):
237
+ scl = i + 1
238
+ pol = [0]*i + [1]
239
+ tgt = [i] + [0]*i + [1/scl]
240
+ chebpol = cheb.poly2cheb(pol)
241
+ chebint = cheb.chebint(chebpol, m=1, k=[i])
242
+ res = cheb.cheb2poly(chebint)
243
+ assert_almost_equal(trim(res), trim(tgt))
244
+
245
+ # check single integration with integration constant and lbnd
246
+ for i in range(5):
247
+ scl = i + 1
248
+ pol = [0]*i + [1]
249
+ chebpol = cheb.poly2cheb(pol)
250
+ chebint = cheb.chebint(chebpol, m=1, k=[i], lbnd=-1)
251
+ assert_almost_equal(cheb.chebval(-1, chebint), i)
252
+
253
+ # check single integration with integration constant and scaling
254
+ for i in range(5):
255
+ scl = i + 1
256
+ pol = [0]*i + [1]
257
+ tgt = [i] + [0]*i + [2/scl]
258
+ chebpol = cheb.poly2cheb(pol)
259
+ chebint = cheb.chebint(chebpol, m=1, k=[i], scl=2)
260
+ res = cheb.cheb2poly(chebint)
261
+ assert_almost_equal(trim(res), trim(tgt))
262
+
263
+ # check multiple integrations with default k
264
+ for i in range(5):
265
+ for j in range(2, 5):
266
+ pol = [0]*i + [1]
267
+ tgt = pol[:]
268
+ for k in range(j):
269
+ tgt = cheb.chebint(tgt, m=1)
270
+ res = cheb.chebint(pol, m=j)
271
+ assert_almost_equal(trim(res), trim(tgt))
272
+
273
+ # check multiple integrations with defined k
274
+ for i in range(5):
275
+ for j in range(2, 5):
276
+ pol = [0]*i + [1]
277
+ tgt = pol[:]
278
+ for k in range(j):
279
+ tgt = cheb.chebint(tgt, m=1, k=[k])
280
+ res = cheb.chebint(pol, m=j, k=list(range(j)))
281
+ assert_almost_equal(trim(res), trim(tgt))
282
+
283
+ # check multiple integrations with lbnd
284
+ for i in range(5):
285
+ for j in range(2, 5):
286
+ pol = [0]*i + [1]
287
+ tgt = pol[:]
288
+ for k in range(j):
289
+ tgt = cheb.chebint(tgt, m=1, k=[k], lbnd=-1)
290
+ res = cheb.chebint(pol, m=j, k=list(range(j)), lbnd=-1)
291
+ assert_almost_equal(trim(res), trim(tgt))
292
+
293
+ # check multiple integrations with scaling
294
+ for i in range(5):
295
+ for j in range(2, 5):
296
+ pol = [0]*i + [1]
297
+ tgt = pol[:]
298
+ for k in range(j):
299
+ tgt = cheb.chebint(tgt, m=1, k=[k], scl=2)
300
+ res = cheb.chebint(pol, m=j, k=list(range(j)), scl=2)
301
+ assert_almost_equal(trim(res), trim(tgt))
302
+
303
+ def test_chebint_axis(self):
304
+ # check that axis keyword works
305
+ c2d = np.random.random((3, 4))
306
+
307
+ tgt = np.vstack([cheb.chebint(c) for c in c2d.T]).T
308
+ res = cheb.chebint(c2d, axis=0)
309
+ assert_almost_equal(res, tgt)
310
+
311
+ tgt = np.vstack([cheb.chebint(c) for c in c2d])
312
+ res = cheb.chebint(c2d, axis=1)
313
+ assert_almost_equal(res, tgt)
314
+
315
+ tgt = np.vstack([cheb.chebint(c, k=3) for c in c2d])
316
+ res = cheb.chebint(c2d, k=3, axis=1)
317
+ assert_almost_equal(res, tgt)
318
+
319
+
320
+ class TestDerivative:
321
+
322
+ def test_chebder(self):
323
+ # check exceptions
324
+ assert_raises(TypeError, cheb.chebder, [0], .5)
325
+ assert_raises(ValueError, cheb.chebder, [0], -1)
326
+
327
+ # check that zeroth derivative does nothing
328
+ for i in range(5):
329
+ tgt = [0]*i + [1]
330
+ res = cheb.chebder(tgt, m=0)
331
+ assert_equal(trim(res), trim(tgt))
332
+
333
+ # check that derivation is the inverse of integration
334
+ for i in range(5):
335
+ for j in range(2, 5):
336
+ tgt = [0]*i + [1]
337
+ res = cheb.chebder(cheb.chebint(tgt, m=j), m=j)
338
+ assert_almost_equal(trim(res), trim(tgt))
339
+
340
+ # check derivation with scaling
341
+ for i in range(5):
342
+ for j in range(2, 5):
343
+ tgt = [0]*i + [1]
344
+ res = cheb.chebder(cheb.chebint(tgt, m=j, scl=2), m=j, scl=.5)
345
+ assert_almost_equal(trim(res), trim(tgt))
346
+
347
+ def test_chebder_axis(self):
348
+ # check that axis keyword works
349
+ c2d = np.random.random((3, 4))
350
+
351
+ tgt = np.vstack([cheb.chebder(c) for c in c2d.T]).T
352
+ res = cheb.chebder(c2d, axis=0)
353
+ assert_almost_equal(res, tgt)
354
+
355
+ tgt = np.vstack([cheb.chebder(c) for c in c2d])
356
+ res = cheb.chebder(c2d, axis=1)
357
+ assert_almost_equal(res, tgt)
358
+
359
+
360
+ class TestVander:
361
+ # some random values in [-1, 1)
362
+ x = np.random.random((3, 5))*2 - 1
363
+
364
+ def test_chebvander(self):
365
+ # check for 1d x
366
+ x = np.arange(3)
367
+ v = cheb.chebvander(x, 3)
368
+ assert_(v.shape == (3, 4))
369
+ for i in range(4):
370
+ coef = [0]*i + [1]
371
+ assert_almost_equal(v[..., i], cheb.chebval(x, coef))
372
+
373
+ # check for 2d x
374
+ x = np.array([[1, 2], [3, 4], [5, 6]])
375
+ v = cheb.chebvander(x, 3)
376
+ assert_(v.shape == (3, 2, 4))
377
+ for i in range(4):
378
+ coef = [0]*i + [1]
379
+ assert_almost_equal(v[..., i], cheb.chebval(x, coef))
380
+
381
+ def test_chebvander2d(self):
382
+ # also tests chebval2d for non-square coefficient array
383
+ x1, x2, x3 = self.x
384
+ c = np.random.random((2, 3))
385
+ van = cheb.chebvander2d(x1, x2, [1, 2])
386
+ tgt = cheb.chebval2d(x1, x2, c)
387
+ res = np.dot(van, c.flat)
388
+ assert_almost_equal(res, tgt)
389
+
390
+ # check shape
391
+ van = cheb.chebvander2d([x1], [x2], [1, 2])
392
+ assert_(van.shape == (1, 5, 6))
393
+
394
+ def test_chebvander3d(self):
395
+ # also tests chebval3d for non-square coefficient array
396
+ x1, x2, x3 = self.x
397
+ c = np.random.random((2, 3, 4))
398
+ van = cheb.chebvander3d(x1, x2, x3, [1, 2, 3])
399
+ tgt = cheb.chebval3d(x1, x2, x3, c)
400
+ res = np.dot(van, c.flat)
401
+ assert_almost_equal(res, tgt)
402
+
403
+ # check shape
404
+ van = cheb.chebvander3d([x1], [x2], [x3], [1, 2, 3])
405
+ assert_(van.shape == (1, 5, 24))
406
+
407
+
408
+ class TestFitting:
409
+
410
+ def test_chebfit(self):
411
+ def f(x):
412
+ return x*(x - 1)*(x - 2)
413
+
414
+ def f2(x):
415
+ return x**4 + x**2 + 1
416
+
417
+ # Test exceptions
418
+ assert_raises(ValueError, cheb.chebfit, [1], [1], -1)
419
+ assert_raises(TypeError, cheb.chebfit, [[1]], [1], 0)
420
+ assert_raises(TypeError, cheb.chebfit, [], [1], 0)
421
+ assert_raises(TypeError, cheb.chebfit, [1], [[[1]]], 0)
422
+ assert_raises(TypeError, cheb.chebfit, [1, 2], [1], 0)
423
+ assert_raises(TypeError, cheb.chebfit, [1], [1, 2], 0)
424
+ assert_raises(TypeError, cheb.chebfit, [1], [1], 0, w=[[1]])
425
+ assert_raises(TypeError, cheb.chebfit, [1], [1], 0, w=[1, 1])
426
+ assert_raises(ValueError, cheb.chebfit, [1], [1], [-1,])
427
+ assert_raises(ValueError, cheb.chebfit, [1], [1], [2, -1, 6])
428
+ assert_raises(TypeError, cheb.chebfit, [1], [1], [])
429
+
430
+ # Test fit
431
+ x = np.linspace(0, 2)
432
+ y = f(x)
433
+ #
434
+ coef3 = cheb.chebfit(x, y, 3)
435
+ assert_equal(len(coef3), 4)
436
+ assert_almost_equal(cheb.chebval(x, coef3), y)
437
+ coef3 = cheb.chebfit(x, y, [0, 1, 2, 3])
438
+ assert_equal(len(coef3), 4)
439
+ assert_almost_equal(cheb.chebval(x, coef3), y)
440
+ #
441
+ coef4 = cheb.chebfit(x, y, 4)
442
+ assert_equal(len(coef4), 5)
443
+ assert_almost_equal(cheb.chebval(x, coef4), y)
444
+ coef4 = cheb.chebfit(x, y, [0, 1, 2, 3, 4])
445
+ assert_equal(len(coef4), 5)
446
+ assert_almost_equal(cheb.chebval(x, coef4), y)
447
+ # check things still work if deg is not in strict increasing
448
+ coef4 = cheb.chebfit(x, y, [2, 3, 4, 1, 0])
449
+ assert_equal(len(coef4), 5)
450
+ assert_almost_equal(cheb.chebval(x, coef4), y)
451
+ #
452
+ coef2d = cheb.chebfit(x, np.array([y, y]).T, 3)
453
+ assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
454
+ coef2d = cheb.chebfit(x, np.array([y, y]).T, [0, 1, 2, 3])
455
+ assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
456
+ # test weighting
457
+ w = np.zeros_like(x)
458
+ yw = y.copy()
459
+ w[1::2] = 1
460
+ y[0::2] = 0
461
+ wcoef3 = cheb.chebfit(x, yw, 3, w=w)
462
+ assert_almost_equal(wcoef3, coef3)
463
+ wcoef3 = cheb.chebfit(x, yw, [0, 1, 2, 3], w=w)
464
+ assert_almost_equal(wcoef3, coef3)
465
+ #
466
+ wcoef2d = cheb.chebfit(x, np.array([yw, yw]).T, 3, w=w)
467
+ assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
468
+ wcoef2d = cheb.chebfit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w)
469
+ assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
470
+ # test scaling with complex values x points whose square
471
+ # is zero when summed.
472
+ x = [1, 1j, -1, -1j]
473
+ assert_almost_equal(cheb.chebfit(x, x, 1), [0, 1])
474
+ assert_almost_equal(cheb.chebfit(x, x, [0, 1]), [0, 1])
475
+ # test fitting only even polynomials
476
+ x = np.linspace(-1, 1)
477
+ y = f2(x)
478
+ coef1 = cheb.chebfit(x, y, 4)
479
+ assert_almost_equal(cheb.chebval(x, coef1), y)
480
+ coef2 = cheb.chebfit(x, y, [0, 2, 4])
481
+ assert_almost_equal(cheb.chebval(x, coef2), y)
482
+ assert_almost_equal(coef1, coef2)
483
+
484
+
485
+ class TestInterpolate:
486
+
487
+ def f(self, x):
488
+ return x * (x - 1) * (x - 2)
489
+
490
+ def test_raises(self):
491
+ assert_raises(ValueError, cheb.chebinterpolate, self.f, -1)
492
+ assert_raises(TypeError, cheb.chebinterpolate, self.f, 10.)
493
+
494
+ def test_dimensions(self):
495
+ for deg in range(1, 5):
496
+ assert_(cheb.chebinterpolate(self.f, deg).shape == (deg + 1,))
497
+
498
+ def test_approximation(self):
499
+
500
+ def powx(x, p):
501
+ return x**p
502
+
503
+ x = np.linspace(-1, 1, 10)
504
+ for deg in range(0, 10):
505
+ for p in range(0, deg + 1):
506
+ c = cheb.chebinterpolate(powx, deg, (p,))
507
+ assert_almost_equal(cheb.chebval(x, c), powx(x, p), decimal=12)
508
+
509
+
510
+ class TestCompanion:
511
+
512
+ def test_raises(self):
513
+ assert_raises(ValueError, cheb.chebcompanion, [])
514
+ assert_raises(ValueError, cheb.chebcompanion, [1])
515
+
516
+ def test_dimensions(self):
517
+ for i in range(1, 5):
518
+ coef = [0]*i + [1]
519
+ assert_(cheb.chebcompanion(coef).shape == (i, i))
520
+
521
+ def test_linear_root(self):
522
+ assert_(cheb.chebcompanion([1, 2])[0, 0] == -.5)
523
+
524
+
525
+ class TestGauss:
526
+
527
+ def test_100(self):
528
+ x, w = cheb.chebgauss(100)
529
+
530
+ # test orthogonality. Note that the results need to be normalized,
531
+ # otherwise the huge values that can arise from fast growing
532
+ # functions like Laguerre can be very confusing.
533
+ v = cheb.chebvander(x, 99)
534
+ vv = np.dot(v.T * w, v)
535
+ vd = 1/np.sqrt(vv.diagonal())
536
+ vv = vd[:, None] * vv * vd
537
+ assert_almost_equal(vv, np.eye(100))
538
+
539
+ # check that the integral of 1 is correct
540
+ tgt = np.pi
541
+ assert_almost_equal(w.sum(), tgt)
542
+
543
+
544
+ class TestMisc:
545
+
546
+ def test_chebfromroots(self):
547
+ res = cheb.chebfromroots([])
548
+ assert_almost_equal(trim(res), [1])
549
+ for i in range(1, 5):
550
+ roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2])
551
+ tgt = [0]*i + [1]
552
+ res = cheb.chebfromroots(roots)*2**(i-1)
553
+ assert_almost_equal(trim(res), trim(tgt))
554
+
555
+ def test_chebroots(self):
556
+ assert_almost_equal(cheb.chebroots([1]), [])
557
+ assert_almost_equal(cheb.chebroots([1, 2]), [-.5])
558
+ for i in range(2, 5):
559
+ tgt = np.linspace(-1, 1, i)
560
+ res = cheb.chebroots(cheb.chebfromroots(tgt))
561
+ assert_almost_equal(trim(res), trim(tgt))
562
+
563
+ def test_chebtrim(self):
564
+ coef = [2, -1, 1, 0]
565
+
566
+ # Test exceptions
567
+ assert_raises(ValueError, cheb.chebtrim, coef, -1)
568
+
569
+ # Test results
570
+ assert_equal(cheb.chebtrim(coef), coef[:-1])
571
+ assert_equal(cheb.chebtrim(coef, 1), coef[:-3])
572
+ assert_equal(cheb.chebtrim(coef, 2), [0])
573
+
574
+ def test_chebline(self):
575
+ assert_equal(cheb.chebline(3, 4), [3, 4])
576
+
577
+ def test_cheb2poly(self):
578
+ for i in range(10):
579
+ assert_almost_equal(cheb.cheb2poly([0]*i + [1]), Tlist[i])
580
+
581
+ def test_poly2cheb(self):
582
+ for i in range(10):
583
+ assert_almost_equal(cheb.poly2cheb(Tlist[i]), [0]*i + [1])
584
+
585
+ def test_weight(self):
586
+ x = np.linspace(-1, 1, 11)[1:-1]
587
+ tgt = 1./(np.sqrt(1 + x) * np.sqrt(1 - x))
588
+ res = cheb.chebweight(x)
589
+ assert_almost_equal(res, tgt)
590
+
591
+ def test_chebpts1(self):
592
+ #test exceptions
593
+ assert_raises(ValueError, cheb.chebpts1, 1.5)
594
+ assert_raises(ValueError, cheb.chebpts1, 0)
595
+
596
+ #test points
597
+ tgt = [0]
598
+ assert_almost_equal(cheb.chebpts1(1), tgt)
599
+ tgt = [-0.70710678118654746, 0.70710678118654746]
600
+ assert_almost_equal(cheb.chebpts1(2), tgt)
601
+ tgt = [-0.86602540378443871, 0, 0.86602540378443871]
602
+ assert_almost_equal(cheb.chebpts1(3), tgt)
603
+ tgt = [-0.9238795325, -0.3826834323, 0.3826834323, 0.9238795325]
604
+ assert_almost_equal(cheb.chebpts1(4), tgt)
605
+
606
+ def test_chebpts2(self):
607
+ #test exceptions
608
+ assert_raises(ValueError, cheb.chebpts2, 1.5)
609
+ assert_raises(ValueError, cheb.chebpts2, 1)
610
+
611
+ #test points
612
+ tgt = [-1, 1]
613
+ assert_almost_equal(cheb.chebpts2(2), tgt)
614
+ tgt = [-1, 0, 1]
615
+ assert_almost_equal(cheb.chebpts2(3), tgt)
616
+ tgt = [-1, -0.5, .5, 1]
617
+ assert_almost_equal(cheb.chebpts2(4), tgt)
618
+ tgt = [-1.0, -0.707106781187, 0, 0.707106781187, 1.0]
619
+ assert_almost_equal(cheb.chebpts2(5), tgt)
.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_hermite_e.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for hermite_e module.
2
+
3
+ """
4
+ from functools import reduce
5
+
6
+ import numpy as np
7
+ import numpy.polynomial.hermite_e as herme
8
+ from numpy.polynomial.polynomial import polyval
9
+ from numpy.testing import (
10
+ assert_almost_equal, assert_raises, assert_equal, assert_,
11
+ )
12
+
13
+ He0 = np.array([1])
14
+ He1 = np.array([0, 1])
15
+ He2 = np.array([-1, 0, 1])
16
+ He3 = np.array([0, -3, 0, 1])
17
+ He4 = np.array([3, 0, -6, 0, 1])
18
+ He5 = np.array([0, 15, 0, -10, 0, 1])
19
+ He6 = np.array([-15, 0, 45, 0, -15, 0, 1])
20
+ He7 = np.array([0, -105, 0, 105, 0, -21, 0, 1])
21
+ He8 = np.array([105, 0, -420, 0, 210, 0, -28, 0, 1])
22
+ He9 = np.array([0, 945, 0, -1260, 0, 378, 0, -36, 0, 1])
23
+
24
+ Helist = [He0, He1, He2, He3, He4, He5, He6, He7, He8, He9]
25
+
26
+
27
+ def trim(x):
28
+ return herme.hermetrim(x, tol=1e-6)
29
+
30
+
31
+ class TestConstants:
32
+
33
+ def test_hermedomain(self):
34
+ assert_equal(herme.hermedomain, [-1, 1])
35
+
36
+ def test_hermezero(self):
37
+ assert_equal(herme.hermezero, [0])
38
+
39
+ def test_hermeone(self):
40
+ assert_equal(herme.hermeone, [1])
41
+
42
+ def test_hermex(self):
43
+ assert_equal(herme.hermex, [0, 1])
44
+
45
+
46
+ class TestArithmetic:
47
+ x = np.linspace(-3, 3, 100)
48
+
49
+ def test_hermeadd(self):
50
+ for i in range(5):
51
+ for j in range(5):
52
+ msg = f"At i={i}, j={j}"
53
+ tgt = np.zeros(max(i, j) + 1)
54
+ tgt[i] += 1
55
+ tgt[j] += 1
56
+ res = herme.hermeadd([0]*i + [1], [0]*j + [1])
57
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
58
+
59
+ def test_hermesub(self):
60
+ for i in range(5):
61
+ for j in range(5):
62
+ msg = f"At i={i}, j={j}"
63
+ tgt = np.zeros(max(i, j) + 1)
64
+ tgt[i] += 1
65
+ tgt[j] -= 1
66
+ res = herme.hermesub([0]*i + [1], [0]*j + [1])
67
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
68
+
69
+ def test_hermemulx(self):
70
+ assert_equal(herme.hermemulx([0]), [0])
71
+ assert_equal(herme.hermemulx([1]), [0, 1])
72
+ for i in range(1, 5):
73
+ ser = [0]*i + [1]
74
+ tgt = [0]*(i - 1) + [i, 0, 1]
75
+ assert_equal(herme.hermemulx(ser), tgt)
76
+
77
+ def test_hermemul(self):
78
+ # check values of result
79
+ for i in range(5):
80
+ pol1 = [0]*i + [1]
81
+ val1 = herme.hermeval(self.x, pol1)
82
+ for j in range(5):
83
+ msg = f"At i={i}, j={j}"
84
+ pol2 = [0]*j + [1]
85
+ val2 = herme.hermeval(self.x, pol2)
86
+ pol3 = herme.hermemul(pol1, pol2)
87
+ val3 = herme.hermeval(self.x, pol3)
88
+ assert_(len(pol3) == i + j + 1, msg)
89
+ assert_almost_equal(val3, val1*val2, err_msg=msg)
90
+
91
+ def test_hermediv(self):
92
+ for i in range(5):
93
+ for j in range(5):
94
+ msg = f"At i={i}, j={j}"
95
+ ci = [0]*i + [1]
96
+ cj = [0]*j + [1]
97
+ tgt = herme.hermeadd(ci, cj)
98
+ quo, rem = herme.hermediv(tgt, ci)
99
+ res = herme.hermeadd(herme.hermemul(quo, ci), rem)
100
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
101
+
102
+ def test_hermepow(self):
103
+ for i in range(5):
104
+ for j in range(5):
105
+ msg = f"At i={i}, j={j}"
106
+ c = np.arange(i + 1)
107
+ tgt = reduce(herme.hermemul, [c]*j, np.array([1]))
108
+ res = herme.hermepow(c, j)
109
+ assert_equal(trim(res), trim(tgt), err_msg=msg)
110
+
111
+
112
+ class TestEvaluation:
113
+ # coefficients of 1 + 2*x + 3*x**2
114
+ c1d = np.array([4., 2., 3.])
115
+ c2d = np.einsum('i,j->ij', c1d, c1d)
116
+ c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d)
117
+
118
+ # some random values in [-1, 1)
119
+ x = np.random.random((3, 5))*2 - 1
120
+ y = polyval(x, [1., 2., 3.])
121
+
122
+ def test_hermeval(self):
123
+ #check empty input
124
+ assert_equal(herme.hermeval([], [1]).size, 0)
125
+
126
+ #check normal input)
127
+ x = np.linspace(-1, 1)
128
+ y = [polyval(x, c) for c in Helist]
129
+ for i in range(10):
130
+ msg = f"At i={i}"
131
+ tgt = y[i]
132
+ res = herme.hermeval(x, [0]*i + [1])
133
+ assert_almost_equal(res, tgt, err_msg=msg)
134
+
135
+ #check that shape is preserved
136
+ for i in range(3):
137
+ dims = [2]*i
138
+ x = np.zeros(dims)
139
+ assert_equal(herme.hermeval(x, [1]).shape, dims)
140
+ assert_equal(herme.hermeval(x, [1, 0]).shape, dims)
141
+ assert_equal(herme.hermeval(x, [1, 0, 0]).shape, dims)
142
+
143
+ def test_hermeval2d(self):
144
+ x1, x2, x3 = self.x
145
+ y1, y2, y3 = self.y
146
+
147
+ #test exceptions
148
+ assert_raises(ValueError, herme.hermeval2d, x1, x2[:2], self.c2d)
149
+
150
+ #test values
151
+ tgt = y1*y2
152
+ res = herme.hermeval2d(x1, x2, self.c2d)
153
+ assert_almost_equal(res, tgt)
154
+
155
+ #test shape
156
+ z = np.ones((2, 3))
157
+ res = herme.hermeval2d(z, z, self.c2d)
158
+ assert_(res.shape == (2, 3))
159
+
160
+ def test_hermeval3d(self):
161
+ x1, x2, x3 = self.x
162
+ y1, y2, y3 = self.y
163
+
164
+ #test exceptions
165
+ assert_raises(ValueError, herme.hermeval3d, x1, x2, x3[:2], self.c3d)
166
+
167
+ #test values
168
+ tgt = y1*y2*y3
169
+ res = herme.hermeval3d(x1, x2, x3, self.c3d)
170
+ assert_almost_equal(res, tgt)
171
+
172
+ #test shape
173
+ z = np.ones((2, 3))
174
+ res = herme.hermeval3d(z, z, z, self.c3d)
175
+ assert_(res.shape == (2, 3))
176
+
177
+ def test_hermegrid2d(self):
178
+ x1, x2, x3 = self.x
179
+ y1, y2, y3 = self.y
180
+
181
+ #test values
182
+ tgt = np.einsum('i,j->ij', y1, y2)
183
+ res = herme.hermegrid2d(x1, x2, self.c2d)
184
+ assert_almost_equal(res, tgt)
185
+
186
+ #test shape
187
+ z = np.ones((2, 3))
188
+ res = herme.hermegrid2d(z, z, self.c2d)
189
+ assert_(res.shape == (2, 3)*2)
190
+
191
+ def test_hermegrid3d(self):
192
+ x1, x2, x3 = self.x
193
+ y1, y2, y3 = self.y
194
+
195
+ #test values
196
+ tgt = np.einsum('i,j,k->ijk', y1, y2, y3)
197
+ res = herme.hermegrid3d(x1, x2, x3, self.c3d)
198
+ assert_almost_equal(res, tgt)
199
+
200
+ #test shape
201
+ z = np.ones((2, 3))
202
+ res = herme.hermegrid3d(z, z, z, self.c3d)
203
+ assert_(res.shape == (2, 3)*3)
204
+
205
+
206
+ class TestIntegral:
207
+
208
+ def test_hermeint(self):
209
+ # check exceptions
210
+ assert_raises(TypeError, herme.hermeint, [0], .5)
211
+ assert_raises(ValueError, herme.hermeint, [0], -1)
212
+ assert_raises(ValueError, herme.hermeint, [0], 1, [0, 0])
213
+ assert_raises(ValueError, herme.hermeint, [0], lbnd=[0])
214
+ assert_raises(ValueError, herme.hermeint, [0], scl=[0])
215
+ assert_raises(TypeError, herme.hermeint, [0], axis=.5)
216
+
217
+ # test integration of zero polynomial
218
+ for i in range(2, 5):
219
+ k = [0]*(i - 2) + [1]
220
+ res = herme.hermeint([0], m=i, k=k)
221
+ assert_almost_equal(res, [0, 1])
222
+
223
+ # check single integration with integration constant
224
+ for i in range(5):
225
+ scl = i + 1
226
+ pol = [0]*i + [1]
227
+ tgt = [i] + [0]*i + [1/scl]
228
+ hermepol = herme.poly2herme(pol)
229
+ hermeint = herme.hermeint(hermepol, m=1, k=[i])
230
+ res = herme.herme2poly(hermeint)
231
+ assert_almost_equal(trim(res), trim(tgt))
232
+
233
+ # check single integration with integration constant and lbnd
234
+ for i in range(5):
235
+ scl = i + 1
236
+ pol = [0]*i + [1]
237
+ hermepol = herme.poly2herme(pol)
238
+ hermeint = herme.hermeint(hermepol, m=1, k=[i], lbnd=-1)
239
+ assert_almost_equal(herme.hermeval(-1, hermeint), i)
240
+
241
+ # check single integration with integration constant and scaling
242
+ for i in range(5):
243
+ scl = i + 1
244
+ pol = [0]*i + [1]
245
+ tgt = [i] + [0]*i + [2/scl]
246
+ hermepol = herme.poly2herme(pol)
247
+ hermeint = herme.hermeint(hermepol, m=1, k=[i], scl=2)
248
+ res = herme.herme2poly(hermeint)
249
+ assert_almost_equal(trim(res), trim(tgt))
250
+
251
+ # check multiple integrations with default k
252
+ for i in range(5):
253
+ for j in range(2, 5):
254
+ pol = [0]*i + [1]
255
+ tgt = pol[:]
256
+ for k in range(j):
257
+ tgt = herme.hermeint(tgt, m=1)
258
+ res = herme.hermeint(pol, m=j)
259
+ assert_almost_equal(trim(res), trim(tgt))
260
+
261
+ # check multiple integrations with defined k
262
+ for i in range(5):
263
+ for j in range(2, 5):
264
+ pol = [0]*i + [1]
265
+ tgt = pol[:]
266
+ for k in range(j):
267
+ tgt = herme.hermeint(tgt, m=1, k=[k])
268
+ res = herme.hermeint(pol, m=j, k=list(range(j)))
269
+ assert_almost_equal(trim(res), trim(tgt))
270
+
271
+ # check multiple integrations with lbnd
272
+ for i in range(5):
273
+ for j in range(2, 5):
274
+ pol = [0]*i + [1]
275
+ tgt = pol[:]
276
+ for k in range(j):
277
+ tgt = herme.hermeint(tgt, m=1, k=[k], lbnd=-1)
278
+ res = herme.hermeint(pol, m=j, k=list(range(j)), lbnd=-1)
279
+ assert_almost_equal(trim(res), trim(tgt))
280
+
281
+ # check multiple integrations with scaling
282
+ for i in range(5):
283
+ for j in range(2, 5):
284
+ pol = [0]*i + [1]
285
+ tgt = pol[:]
286
+ for k in range(j):
287
+ tgt = herme.hermeint(tgt, m=1, k=[k], scl=2)
288
+ res = herme.hermeint(pol, m=j, k=list(range(j)), scl=2)
289
+ assert_almost_equal(trim(res), trim(tgt))
290
+
291
+ def test_hermeint_axis(self):
292
+ # check that axis keyword works
293
+ c2d = np.random.random((3, 4))
294
+
295
+ tgt = np.vstack([herme.hermeint(c) for c in c2d.T]).T
296
+ res = herme.hermeint(c2d, axis=0)
297
+ assert_almost_equal(res, tgt)
298
+
299
+ tgt = np.vstack([herme.hermeint(c) for c in c2d])
300
+ res = herme.hermeint(c2d, axis=1)
301
+ assert_almost_equal(res, tgt)
302
+
303
+ tgt = np.vstack([herme.hermeint(c, k=3) for c in c2d])
304
+ res = herme.hermeint(c2d, k=3, axis=1)
305
+ assert_almost_equal(res, tgt)
306
+
307
+
308
+ class TestDerivative:
309
+
310
+ def test_hermeder(self):
311
+ # check exceptions
312
+ assert_raises(TypeError, herme.hermeder, [0], .5)
313
+ assert_raises(ValueError, herme.hermeder, [0], -1)
314
+
315
+ # check that zeroth derivative does nothing
316
+ for i in range(5):
317
+ tgt = [0]*i + [1]
318
+ res = herme.hermeder(tgt, m=0)
319
+ assert_equal(trim(res), trim(tgt))
320
+
321
+ # check that derivation is the inverse of integration
322
+ for i in range(5):
323
+ for j in range(2, 5):
324
+ tgt = [0]*i + [1]
325
+ res = herme.hermeder(herme.hermeint(tgt, m=j), m=j)
326
+ assert_almost_equal(trim(res), trim(tgt))
327
+
328
+ # check derivation with scaling
329
+ for i in range(5):
330
+ for j in range(2, 5):
331
+ tgt = [0]*i + [1]
332
+ res = herme.hermeder(
333
+ herme.hermeint(tgt, m=j, scl=2), m=j, scl=.5)
334
+ assert_almost_equal(trim(res), trim(tgt))
335
+
336
+ def test_hermeder_axis(self):
337
+ # check that axis keyword works
338
+ c2d = np.random.random((3, 4))
339
+
340
+ tgt = np.vstack([herme.hermeder(c) for c in c2d.T]).T
341
+ res = herme.hermeder(c2d, axis=0)
342
+ assert_almost_equal(res, tgt)
343
+
344
+ tgt = np.vstack([herme.hermeder(c) for c in c2d])
345
+ res = herme.hermeder(c2d, axis=1)
346
+ assert_almost_equal(res, tgt)
347
+
348
+
349
+ class TestVander:
350
+ # some random values in [-1, 1)
351
+ x = np.random.random((3, 5))*2 - 1
352
+
353
+ def test_hermevander(self):
354
+ # check for 1d x
355
+ x = np.arange(3)
356
+ v = herme.hermevander(x, 3)
357
+ assert_(v.shape == (3, 4))
358
+ for i in range(4):
359
+ coef = [0]*i + [1]
360
+ assert_almost_equal(v[..., i], herme.hermeval(x, coef))
361
+
362
+ # check for 2d x
363
+ x = np.array([[1, 2], [3, 4], [5, 6]])
364
+ v = herme.hermevander(x, 3)
365
+ assert_(v.shape == (3, 2, 4))
366
+ for i in range(4):
367
+ coef = [0]*i + [1]
368
+ assert_almost_equal(v[..., i], herme.hermeval(x, coef))
369
+
370
+ def test_hermevander2d(self):
371
+ # also tests hermeval2d for non-square coefficient array
372
+ x1, x2, x3 = self.x
373
+ c = np.random.random((2, 3))
374
+ van = herme.hermevander2d(x1, x2, [1, 2])
375
+ tgt = herme.hermeval2d(x1, x2, c)
376
+ res = np.dot(van, c.flat)
377
+ assert_almost_equal(res, tgt)
378
+
379
+ # check shape
380
+ van = herme.hermevander2d([x1], [x2], [1, 2])
381
+ assert_(van.shape == (1, 5, 6))
382
+
383
+ def test_hermevander3d(self):
384
+ # also tests hermeval3d for non-square coefficient array
385
+ x1, x2, x3 = self.x
386
+ c = np.random.random((2, 3, 4))
387
+ van = herme.hermevander3d(x1, x2, x3, [1, 2, 3])
388
+ tgt = herme.hermeval3d(x1, x2, x3, c)
389
+ res = np.dot(van, c.flat)
390
+ assert_almost_equal(res, tgt)
391
+
392
+ # check shape
393
+ van = herme.hermevander3d([x1], [x2], [x3], [1, 2, 3])
394
+ assert_(van.shape == (1, 5, 24))
395
+
396
+
397
+ class TestFitting:
398
+
399
+ def test_hermefit(self):
400
+ def f(x):
401
+ return x*(x - 1)*(x - 2)
402
+
403
+ def f2(x):
404
+ return x**4 + x**2 + 1
405
+
406
+ # Test exceptions
407
+ assert_raises(ValueError, herme.hermefit, [1], [1], -1)
408
+ assert_raises(TypeError, herme.hermefit, [[1]], [1], 0)
409
+ assert_raises(TypeError, herme.hermefit, [], [1], 0)
410
+ assert_raises(TypeError, herme.hermefit, [1], [[[1]]], 0)
411
+ assert_raises(TypeError, herme.hermefit, [1, 2], [1], 0)
412
+ assert_raises(TypeError, herme.hermefit, [1], [1, 2], 0)
413
+ assert_raises(TypeError, herme.hermefit, [1], [1], 0, w=[[1]])
414
+ assert_raises(TypeError, herme.hermefit, [1], [1], 0, w=[1, 1])
415
+ assert_raises(ValueError, herme.hermefit, [1], [1], [-1,])
416
+ assert_raises(ValueError, herme.hermefit, [1], [1], [2, -1, 6])
417
+ assert_raises(TypeError, herme.hermefit, [1], [1], [])
418
+
419
+ # Test fit
420
+ x = np.linspace(0, 2)
421
+ y = f(x)
422
+ #
423
+ coef3 = herme.hermefit(x, y, 3)
424
+ assert_equal(len(coef3), 4)
425
+ assert_almost_equal(herme.hermeval(x, coef3), y)
426
+ coef3 = herme.hermefit(x, y, [0, 1, 2, 3])
427
+ assert_equal(len(coef3), 4)
428
+ assert_almost_equal(herme.hermeval(x, coef3), y)
429
+ #
430
+ coef4 = herme.hermefit(x, y, 4)
431
+ assert_equal(len(coef4), 5)
432
+ assert_almost_equal(herme.hermeval(x, coef4), y)
433
+ coef4 = herme.hermefit(x, y, [0, 1, 2, 3, 4])
434
+ assert_equal(len(coef4), 5)
435
+ assert_almost_equal(herme.hermeval(x, coef4), y)
436
+ # check things still work if deg is not in strict increasing
437
+ coef4 = herme.hermefit(x, y, [2, 3, 4, 1, 0])
438
+ assert_equal(len(coef4), 5)
439
+ assert_almost_equal(herme.hermeval(x, coef4), y)
440
+ #
441
+ coef2d = herme.hermefit(x, np.array([y, y]).T, 3)
442
+ assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
443
+ coef2d = herme.hermefit(x, np.array([y, y]).T, [0, 1, 2, 3])
444
+ assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
445
+ # test weighting
446
+ w = np.zeros_like(x)
447
+ yw = y.copy()
448
+ w[1::2] = 1
449
+ y[0::2] = 0
450
+ wcoef3 = herme.hermefit(x, yw, 3, w=w)
451
+ assert_almost_equal(wcoef3, coef3)
452
+ wcoef3 = herme.hermefit(x, yw, [0, 1, 2, 3], w=w)
453
+ assert_almost_equal(wcoef3, coef3)
454
+ #
455
+ wcoef2d = herme.hermefit(x, np.array([yw, yw]).T, 3, w=w)
456
+ assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
457
+ wcoef2d = herme.hermefit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w)
458
+ assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
459
+ # test scaling with complex values x points whose square
460
+ # is zero when summed.
461
+ x = [1, 1j, -1, -1j]
462
+ assert_almost_equal(herme.hermefit(x, x, 1), [0, 1])
463
+ assert_almost_equal(herme.hermefit(x, x, [0, 1]), [0, 1])
464
+ # test fitting only even Legendre polynomials
465
+ x = np.linspace(-1, 1)
466
+ y = f2(x)
467
+ coef1 = herme.hermefit(x, y, 4)
468
+ assert_almost_equal(herme.hermeval(x, coef1), y)
469
+ coef2 = herme.hermefit(x, y, [0, 2, 4])
470
+ assert_almost_equal(herme.hermeval(x, coef2), y)
471
+ assert_almost_equal(coef1, coef2)
472
+
473
+
474
+ class TestCompanion:
475
+
476
+ def test_raises(self):
477
+ assert_raises(ValueError, herme.hermecompanion, [])
478
+ assert_raises(ValueError, herme.hermecompanion, [1])
479
+
480
+ def test_dimensions(self):
481
+ for i in range(1, 5):
482
+ coef = [0]*i + [1]
483
+ assert_(herme.hermecompanion(coef).shape == (i, i))
484
+
485
+ def test_linear_root(self):
486
+ assert_(herme.hermecompanion([1, 2])[0, 0] == -.5)
487
+
488
+
489
+ class TestGauss:
490
+
491
+ def test_100(self):
492
+ x, w = herme.hermegauss(100)
493
+
494
+ # test orthogonality. Note that the results need to be normalized,
495
+ # otherwise the huge values that can arise from fast growing
496
+ # functions like Laguerre can be very confusing.
497
+ v = herme.hermevander(x, 99)
498
+ vv = np.dot(v.T * w, v)
499
+ vd = 1/np.sqrt(vv.diagonal())
500
+ vv = vd[:, None] * vv * vd
501
+ assert_almost_equal(vv, np.eye(100))
502
+
503
+ # check that the integral of 1 is correct
504
+ tgt = np.sqrt(2*np.pi)
505
+ assert_almost_equal(w.sum(), tgt)
506
+
507
+
508
+ class TestMisc:
509
+
510
+ def test_hermefromroots(self):
511
+ res = herme.hermefromroots([])
512
+ assert_almost_equal(trim(res), [1])
513
+ for i in range(1, 5):
514
+ roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2])
515
+ pol = herme.hermefromroots(roots)
516
+ res = herme.hermeval(roots, pol)
517
+ tgt = 0
518
+ assert_(len(pol) == i + 1)
519
+ assert_almost_equal(herme.herme2poly(pol)[-1], 1)
520
+ assert_almost_equal(res, tgt)
521
+
522
+ def test_hermeroots(self):
523
+ assert_almost_equal(herme.hermeroots([1]), [])
524
+ assert_almost_equal(herme.hermeroots([1, 1]), [-1])
525
+ for i in range(2, 5):
526
+ tgt = np.linspace(-1, 1, i)
527
+ res = herme.hermeroots(herme.hermefromroots(tgt))
528
+ assert_almost_equal(trim(res), trim(tgt))
529
+
530
+ def test_hermetrim(self):
531
+ coef = [2, -1, 1, 0]
532
+
533
+ # Test exceptions
534
+ assert_raises(ValueError, herme.hermetrim, coef, -1)
535
+
536
+ # Test results
537
+ assert_equal(herme.hermetrim(coef), coef[:-1])
538
+ assert_equal(herme.hermetrim(coef, 1), coef[:-3])
539
+ assert_equal(herme.hermetrim(coef, 2), [0])
540
+
541
+ def test_hermeline(self):
542
+ assert_equal(herme.hermeline(3, 4), [3, 4])
543
+
544
+ def test_herme2poly(self):
545
+ for i in range(10):
546
+ assert_almost_equal(herme.herme2poly([0]*i + [1]), Helist[i])
547
+
548
+ def test_poly2herme(self):
549
+ for i in range(10):
550
+ assert_almost_equal(herme.poly2herme(Helist[i]), [0]*i + [1])
551
+
552
+ def test_weight(self):
553
+ x = np.linspace(-5, 5, 11)
554
+ tgt = np.exp(-.5*x**2)
555
+ res = herme.hermeweight(x)
556
+ assert_almost_equal(res, tgt)
.venv/lib/python3.11/site-packages/numpy/polynomial/tests/test_printing.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import nan, inf
2
+ import pytest
3
+ from numpy.core import array, arange, printoptions
4
+ import numpy.polynomial as poly
5
+ from numpy.testing import assert_equal, assert_
6
+
7
+ # For testing polynomial printing with object arrays
8
+ from fractions import Fraction
9
+ from decimal import Decimal
10
+
11
+
12
+ class TestStrUnicodeSuperSubscripts:
13
+
14
+ @pytest.fixture(scope='class', autouse=True)
15
+ def use_unicode(self):
16
+ poly.set_default_printstyle('unicode')
17
+
18
+ @pytest.mark.parametrize(('inp', 'tgt'), (
19
+ ([1, 2, 3], "1.0 + 2.0·x + 3.0·x²"),
20
+ ([-1, 0, 3, -1], "-1.0 + 0.0·x + 3.0·x² - 1.0·x³"),
21
+ (arange(12), ("0.0 + 1.0·x + 2.0·x² + 3.0·x³ + 4.0·x⁴ + 5.0·x⁵ + "
22
+ "6.0·x⁶ + 7.0·x⁷ +\n8.0·x⁸ + 9.0·x⁹ + 10.0·x¹⁰ + "
23
+ "11.0·x¹¹")),
24
+ ))
25
+ def test_polynomial_str(self, inp, tgt):
26
+ res = str(poly.Polynomial(inp))
27
+ assert_equal(res, tgt)
28
+
29
+ @pytest.mark.parametrize(('inp', 'tgt'), (
30
+ ([1, 2, 3], "1.0 + 2.0·T₁(x) + 3.0·T₂(x)"),
31
+ ([-1, 0, 3, -1], "-1.0 + 0.0·T₁(x) + 3.0·T₂(x) - 1.0·T₃(x)"),
32
+ (arange(12), ("0.0 + 1.0·T₁(x) + 2.0·T₂(x) + 3.0·T₃(x) + 4.0·T₄(x) + "
33
+ "5.0·T₅(x) +\n6.0·T₆(x) + 7.0·T₇(x) + 8.0·T₈(x) + "
34
+ "9.0·T₉(x) + 10.0·T₁₀(x) + 11.0·T₁₁(x)")),
35
+ ))
36
+ def test_chebyshev_str(self, inp, tgt):
37
+ res = str(poly.Chebyshev(inp))
38
+ assert_equal(res, tgt)
39
+
40
+ @pytest.mark.parametrize(('inp', 'tgt'), (
41
+ ([1, 2, 3], "1.0 + 2.0·P₁(x) + 3.0·P₂(x)"),
42
+ ([-1, 0, 3, -1], "-1.0 + 0.0·P₁(x) + 3.0·P₂(x) - 1.0·P₃(x)"),
43
+ (arange(12), ("0.0 + 1.0·P₁(x) + 2.0·P₂(x) + 3.0·P₃(x) + 4.0·P₄(x) + "
44
+ "5.0·P₅(x) +\n6.0·P₆(x) + 7.0·P₇(x) + 8.0·P₈(x) + "
45
+ "9.0·P₉(x) + 10.0·P₁₀(x) + 11.0·P₁₁(x)")),
46
+ ))
47
+ def test_legendre_str(self, inp, tgt):
48
+ res = str(poly.Legendre(inp))
49
+ assert_equal(res, tgt)
50
+
51
+ @pytest.mark.parametrize(('inp', 'tgt'), (
52
+ ([1, 2, 3], "1.0 + 2.0·H₁(x) + 3.0·H₂(x)"),
53
+ ([-1, 0, 3, -1], "-1.0 + 0.0·H₁(x) + 3.0·H₂(x) - 1.0·H₃(x)"),
54
+ (arange(12), ("0.0 + 1.0·H₁(x) + 2.0·H₂(x) + 3.0·H₃(x) + 4.0·H₄(x) + "
55
+ "5.0·H₅(x) +\n6.0·H₆(x) + 7.0·H₇(x) + 8.0·H₈(x) + "
56
+ "9.0·H₉(x) + 10.0·H₁₀(x) + 11.0·H₁₁(x)")),
57
+ ))
58
+ def test_hermite_str(self, inp, tgt):
59
+ res = str(poly.Hermite(inp))
60
+ assert_equal(res, tgt)
61
+
62
+ @pytest.mark.parametrize(('inp', 'tgt'), (
63
+ ([1, 2, 3], "1.0 + 2.0·He₁(x) + 3.0·He₂(x)"),
64
+ ([-1, 0, 3, -1], "-1.0 + 0.0·He₁(x) + 3.0·He₂(x) - 1.0·He₃(x)"),
65
+ (arange(12), ("0.0 + 1.0·He₁(x) + 2.0·He₂(x) + 3.0·He₃(x) + "
66
+ "4.0·He₄(x) + 5.0·He₅(x) +\n6.0·He₆(x) + 7.0·He₇(x) + "
67
+ "8.0·He₈(x) + 9.0·He₉(x) + 10.0·He₁₀(x) +\n"
68
+ "11.0·He₁₁(x)")),
69
+ ))
70
+ def test_hermiteE_str(self, inp, tgt):
71
+ res = str(poly.HermiteE(inp))
72
+ assert_equal(res, tgt)
73
+
74
+ @pytest.mark.parametrize(('inp', 'tgt'), (
75
+ ([1, 2, 3], "1.0 + 2.0·L₁(x) + 3.0·L₂(x)"),
76
+ ([-1, 0, 3, -1], "-1.0 + 0.0·L₁(x) + 3.0·L₂(x) - 1.0·L₃(x)"),
77
+ (arange(12), ("0.0 + 1.0·L₁(x) + 2.0·L₂(x) + 3.0·L₃(x) + 4.0·L₄(x) + "
78
+ "5.0·L₅(x) +\n6.0·L₆(x) + 7.0·L₇(x) + 8.0·L₈(x) + "
79
+ "9.0·L₉(x) + 10.0·L₁₀(x) + 11.0·L₁₁(x)")),
80
+ ))
81
+ def test_laguerre_str(self, inp, tgt):
82
+ res = str(poly.Laguerre(inp))
83
+ assert_equal(res, tgt)
84
+
85
+
86
+ class TestStrAscii:
87
+
88
+ @pytest.fixture(scope='class', autouse=True)
89
+ def use_ascii(self):
90
+ poly.set_default_printstyle('ascii')
91
+
92
+ @pytest.mark.parametrize(('inp', 'tgt'), (
93
+ ([1, 2, 3], "1.0 + 2.0 x + 3.0 x**2"),
94
+ ([-1, 0, 3, -1], "-1.0 + 0.0 x + 3.0 x**2 - 1.0 x**3"),
95
+ (arange(12), ("0.0 + 1.0 x + 2.0 x**2 + 3.0 x**3 + 4.0 x**4 + "
96
+ "5.0 x**5 + 6.0 x**6 +\n7.0 x**7 + 8.0 x**8 + "
97
+ "9.0 x**9 + 10.0 x**10 + 11.0 x**11")),
98
+ ))
99
+ def test_polynomial_str(self, inp, tgt):
100
+ res = str(poly.Polynomial(inp))
101
+ assert_equal(res, tgt)
102
+
103
+ @pytest.mark.parametrize(('inp', 'tgt'), (
104
+ ([1, 2, 3], "1.0 + 2.0 T_1(x) + 3.0 T_2(x)"),
105
+ ([-1, 0, 3, -1], "-1.0 + 0.0 T_1(x) + 3.0 T_2(x) - 1.0 T_3(x)"),
106
+ (arange(12), ("0.0 + 1.0 T_1(x) + 2.0 T_2(x) + 3.0 T_3(x) + "
107
+ "4.0 T_4(x) + 5.0 T_5(x) +\n6.0 T_6(x) + 7.0 T_7(x) + "
108
+ "8.0 T_8(x) + 9.0 T_9(x) + 10.0 T_10(x) +\n"
109
+ "11.0 T_11(x)")),
110
+ ))
111
+ def test_chebyshev_str(self, inp, tgt):
112
+ res = str(poly.Chebyshev(inp))
113
+ assert_equal(res, tgt)
114
+
115
+ @pytest.mark.parametrize(('inp', 'tgt'), (
116
+ ([1, 2, 3], "1.0 + 2.0 P_1(x) + 3.0 P_2(x)"),
117
+ ([-1, 0, 3, -1], "-1.0 + 0.0 P_1(x) + 3.0 P_2(x) - 1.0 P_3(x)"),
118
+ (arange(12), ("0.0 + 1.0 P_1(x) + 2.0 P_2(x) + 3.0 P_3(x) + "
119
+ "4.0 P_4(x) + 5.0 P_5(x) +\n6.0 P_6(x) + 7.0 P_7(x) + "
120
+ "8.0 P_8(x) + 9.0 P_9(x) + 10.0 P_10(x) +\n"
121
+ "11.0 P_11(x)")),
122
+ ))
123
+ def test_legendre_str(self, inp, tgt):
124
+ res = str(poly.Legendre(inp))
125
+ assert_equal(res, tgt)
126
+
127
+ @pytest.mark.parametrize(('inp', 'tgt'), (
128
+ ([1, 2, 3], "1.0 + 2.0 H_1(x) + 3.0 H_2(x)"),
129
+ ([-1, 0, 3, -1], "-1.0 + 0.0 H_1(x) + 3.0 H_2(x) - 1.0 H_3(x)"),
130
+ (arange(12), ("0.0 + 1.0 H_1(x) + 2.0 H_2(x) + 3.0 H_3(x) + "
131
+ "4.0 H_4(x) + 5.0 H_5(x) +\n6.0 H_6(x) + 7.0 H_7(x) + "
132
+ "8.0 H_8(x) + 9.0 H_9(x) + 10.0 H_10(x) +\n"
133
+ "11.0 H_11(x)")),
134
+ ))
135
+ def test_hermite_str(self, inp, tgt):
136
+ res = str(poly.Hermite(inp))
137
+ assert_equal(res, tgt)
138
+
139
+ @pytest.mark.parametrize(('inp', 'tgt'), (
140
+ ([1, 2, 3], "1.0 + 2.0 He_1(x) + 3.0 He_2(x)"),
141
+ ([-1, 0, 3, -1], "-1.0 + 0.0 He_1(x) + 3.0 He_2(x) - 1.0 He_3(x)"),
142
+ (arange(12), ("0.0 + 1.0 He_1(x) + 2.0 He_2(x) + 3.0 He_3(x) + "
143
+ "4.0 He_4(x) +\n5.0 He_5(x) + 6.0 He_6(x) + "
144
+ "7.0 He_7(x) + 8.0 He_8(x) + 9.0 He_9(x) +\n"
145
+ "10.0 He_10(x) + 11.0 He_11(x)")),
146
+ ))
147
+ def test_hermiteE_str(self, inp, tgt):
148
+ res = str(poly.HermiteE(inp))
149
+ assert_equal(res, tgt)
150
+
151
+ @pytest.mark.parametrize(('inp', 'tgt'), (
152
+ ([1, 2, 3], "1.0 + 2.0 L_1(x) + 3.0 L_2(x)"),
153
+ ([-1, 0, 3, -1], "-1.0 + 0.0 L_1(x) + 3.0 L_2(x) - 1.0 L_3(x)"),
154
+ (arange(12), ("0.0 + 1.0 L_1(x) + 2.0 L_2(x) + 3.0 L_3(x) + "
155
+ "4.0 L_4(x) + 5.0 L_5(x) +\n6.0 L_6(x) + 7.0 L_7(x) + "
156
+ "8.0 L_8(x) + 9.0 L_9(x) + 10.0 L_10(x) +\n"
157
+ "11.0 L_11(x)")),
158
+ ))
159
+ def test_laguerre_str(self, inp, tgt):
160
+ res = str(poly.Laguerre(inp))
161
+ assert_equal(res, tgt)
162
+
163
+
164
+ class TestLinebreaking:
165
+
166
+ @pytest.fixture(scope='class', autouse=True)
167
+ def use_ascii(self):
168
+ poly.set_default_printstyle('ascii')
169
+
170
+ def test_single_line_one_less(self):
171
+ # With 'ascii' style, len(str(p)) is default linewidth - 1 (i.e. 74)
172
+ p = poly.Polynomial([12345678, 12345678, 12345678, 12345678, 123])
173
+ assert_equal(len(str(p)), 74)
174
+ assert_equal(str(p), (
175
+ '12345678.0 + 12345678.0 x + 12345678.0 x**2 + '
176
+ '12345678.0 x**3 + 123.0 x**4'
177
+ ))
178
+
179
+ def test_num_chars_is_linewidth(self):
180
+ # len(str(p)) == default linewidth == 75
181
+ p = poly.Polynomial([12345678, 12345678, 12345678, 12345678, 1234])
182
+ assert_equal(len(str(p)), 75)
183
+ assert_equal(str(p), (
184
+ '12345678.0 + 12345678.0 x + 12345678.0 x**2 + '
185
+ '12345678.0 x**3 +\n1234.0 x**4'
186
+ ))
187
+
188
+ def test_first_linebreak_multiline_one_less_than_linewidth(self):
189
+ # Multiline str where len(first_line) + len(next_term) == lw - 1 == 74
190
+ p = poly.Polynomial(
191
+ [12345678, 12345678, 12345678, 12345678, 1, 12345678]
192
+ )
193
+ assert_equal(len(str(p).split('\n')[0]), 74)
194
+ assert_equal(str(p), (
195
+ '12345678.0 + 12345678.0 x + 12345678.0 x**2 + '
196
+ '12345678.0 x**3 + 1.0 x**4 +\n12345678.0 x**5'
197
+ ))
198
+
199
+ def test_first_linebreak_multiline_on_linewidth(self):
200
+ # First line is one character longer than previous test
201
+ p = poly.Polynomial(
202
+ [12345678, 12345678, 12345678, 12345678.12, 1, 12345678]
203
+ )
204
+ assert_equal(str(p), (
205
+ '12345678.0 + 12345678.0 x + 12345678.0 x**2 + '
206
+ '12345678.12 x**3 +\n1.0 x**4 + 12345678.0 x**5'
207
+ ))
208
+
209
+ @pytest.mark.parametrize(('lw', 'tgt'), (
210
+ (75, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 + '
211
+ '500000.0 x**5 +\n600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + '
212
+ '900.0 x**9')),
213
+ (45, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 +\n40000.0 x**4 + '
214
+ '500000.0 x**5 +\n600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 +\n'
215
+ '900.0 x**9')),
216
+ (132, ('0.0 + 10.0 x + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 + '
217
+ '500000.0 x**5 + 600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + '
218
+ '900.0 x**9')),
219
+ ))
220
+ def test_linewidth_printoption(self, lw, tgt):
221
+ p = poly.Polynomial(
222
+ [0, 10, 200, 3000, 40000, 500000, 600000, 70000, 8000, 900]
223
+ )
224
+ with printoptions(linewidth=lw):
225
+ assert_equal(str(p), tgt)
226
+ for line in str(p).split('\n'):
227
+ assert_(len(line) < lw)
228
+
229
+
230
+ def test_set_default_printoptions():
231
+ p = poly.Polynomial([1, 2, 3])
232
+ c = poly.Chebyshev([1, 2, 3])
233
+ poly.set_default_printstyle('ascii')
234
+ assert_equal(str(p), "1.0 + 2.0 x + 3.0 x**2")
235
+ assert_equal(str(c), "1.0 + 2.0 T_1(x) + 3.0 T_2(x)")
236
+ poly.set_default_printstyle('unicode')
237
+ assert_equal(str(p), "1.0 + 2.0·x + 3.0·x²")
238
+ assert_equal(str(c), "1.0 + 2.0·T₁(x) + 3.0·T₂(x)")
239
+ with pytest.raises(ValueError):
240
+ poly.set_default_printstyle('invalid_input')
241
+
242
+
243
+ def test_complex_coefficients():
244
+ """Test both numpy and built-in complex."""
245
+ coefs = [0+1j, 1+1j, -2+2j, 3+0j]
246
+ # numpy complex
247
+ p1 = poly.Polynomial(coefs)
248
+ # Python complex
249
+ p2 = poly.Polynomial(array(coefs, dtype=object))
250
+ poly.set_default_printstyle('unicode')
251
+ assert_equal(str(p1), "1j + (1+1j)·x - (2-2j)·x² + (3+0j)·x³")
252
+ assert_equal(str(p2), "1j + (1+1j)·x + (-2+2j)·x² + (3+0j)·x³")
253
+ poly.set_default_printstyle('ascii')
254
+ assert_equal(str(p1), "1j + (1+1j) x - (2-2j) x**2 + (3+0j) x**3")
255
+ assert_equal(str(p2), "1j + (1+1j) x + (-2+2j) x**2 + (3+0j) x**3")
256
+
257
+
258
+ @pytest.mark.parametrize(('coefs', 'tgt'), (
259
+ (array([Fraction(1, 2), Fraction(3, 4)], dtype=object), (
260
+ "1/2 + 3/4·x"
261
+ )),
262
+ (array([1, 2, Fraction(5, 7)], dtype=object), (
263
+ "1 + 2·x + 5/7·x²"
264
+ )),
265
+ (array([Decimal('1.00'), Decimal('2.2'), 3], dtype=object), (
266
+ "1.00 + 2.2·x + 3·x²"
267
+ )),
268
+ ))
269
+ def test_numeric_object_coefficients(coefs, tgt):
270
+ p = poly.Polynomial(coefs)
271
+ poly.set_default_printstyle('unicode')
272
+ assert_equal(str(p), tgt)
273
+
274
+
275
+ @pytest.mark.parametrize(('coefs', 'tgt'), (
276
+ (array([1, 2, 'f'], dtype=object), '1 + 2·x + f·x²'),
277
+ (array([1, 2, [3, 4]], dtype=object), '1 + 2·x + [3, 4]·x²'),
278
+ ))
279
+ def test_nonnumeric_object_coefficients(coefs, tgt):
280
+ """
281
+ Test coef fallback for object arrays of non-numeric coefficients.
282
+ """
283
+ p = poly.Polynomial(coefs)
284
+ poly.set_default_printstyle('unicode')
285
+ assert_equal(str(p), tgt)
286
+
287
+
288
+ class TestFormat:
289
+ def test_format_unicode(self):
290
+ poly.set_default_printstyle('ascii')
291
+ p = poly.Polynomial([1, 2, 0, -1])
292
+ assert_equal(format(p, 'unicode'), "1.0 + 2.0·x + 0.0·x² - 1.0·x³")
293
+
294
+ def test_format_ascii(self):
295
+ poly.set_default_printstyle('unicode')
296
+ p = poly.Polynomial([1, 2, 0, -1])
297
+ assert_equal(
298
+ format(p, 'ascii'), "1.0 + 2.0 x + 0.0 x**2 - 1.0 x**3"
299
+ )
300
+
301
+ def test_empty_formatstr(self):
302
+ poly.set_default_printstyle('ascii')
303
+ p = poly.Polynomial([1, 2, 3])
304
+ assert_equal(format(p), "1.0 + 2.0 x + 3.0 x**2")
305
+ assert_equal(f"{p}", "1.0 + 2.0 x + 3.0 x**2")
306
+
307
+ def test_bad_formatstr(self):
308
+ p = poly.Polynomial([1, 2, 0, -1])
309
+ with pytest.raises(ValueError):
310
+ format(p, '.2f')
311
+
312
+
313
+ @pytest.mark.parametrize(('poly', 'tgt'), (
314
+ (poly.Polynomial, '1.0 + 2.0·z + 3.0·z²'),
315
+ (poly.Chebyshev, '1.0 + 2.0·T₁(z) + 3.0·T₂(z)'),
316
+ (poly.Hermite, '1.0 + 2.0·H₁(z) + 3.0·H₂(z)'),
317
+ (poly.HermiteE, '1.0 + 2.0·He₁(z) + 3.0·He₂(z)'),
318
+ (poly.Laguerre, '1.0 + 2.0·L₁(z) + 3.0·L₂(z)'),
319
+ (poly.Legendre, '1.0 + 2.0·P₁(z) + 3.0·P₂(z)'),
320
+ ))
321
+ def test_symbol(poly, tgt):
322
+ p = poly([1, 2, 3], symbol='z')
323
+ assert_equal(f"{p:unicode}", tgt)
324
+
325
+
326
+ class TestRepr:
327
+ def test_polynomial_str(self):
328
+ res = repr(poly.Polynomial([0, 1]))
329
+ tgt = (
330
+ "Polynomial([0., 1.], domain=[-1, 1], window=[-1, 1], "
331
+ "symbol='x')"
332
+ )
333
+ assert_equal(res, tgt)
334
+
335
+ def test_chebyshev_str(self):
336
+ res = repr(poly.Chebyshev([0, 1]))
337
+ tgt = (
338
+ "Chebyshev([0., 1.], domain=[-1, 1], window=[-1, 1], "
339
+ "symbol='x')"
340
+ )
341
+ assert_equal(res, tgt)
342
+
343
+ def test_legendre_repr(self):
344
+ res = repr(poly.Legendre([0, 1]))
345
+ tgt = (
346
+ "Legendre([0., 1.], domain=[-1, 1], window=[-1, 1], "
347
+ "symbol='x')"
348
+ )
349
+ assert_equal(res, tgt)
350
+
351
+ def test_hermite_repr(self):
352
+ res = repr(poly.Hermite([0, 1]))
353
+ tgt = (
354
+ "Hermite([0., 1.], domain=[-1, 1], window=[-1, 1], "
355
+ "symbol='x')"
356
+ )
357
+ assert_equal(res, tgt)
358
+
359
+ def test_hermiteE_repr(self):
360
+ res = repr(poly.HermiteE([0, 1]))
361
+ tgt = (
362
+ "HermiteE([0., 1.], domain=[-1, 1], window=[-1, 1], "
363
+ "symbol='x')"
364
+ )
365
+ assert_equal(res, tgt)
366
+
367
+ def test_laguerre_repr(self):
368
+ res = repr(poly.Laguerre([0, 1]))
369
+ tgt = (
370
+ "Laguerre([0., 1.], domain=[0, 1], window=[0, 1], "
371
+ "symbol='x')"
372
+ )
373
+ assert_equal(res, tgt)
374
+
375
+
376
+ class TestLatexRepr:
377
+ """Test the latex repr used by Jupyter"""
378
+
379
+ def as_latex(self, obj):
380
+ # right now we ignore the formatting of scalars in our tests, since
381
+ # it makes them too verbose. Ideally, the formatting of scalars will
382
+ # be fixed such that tests below continue to pass
383
+ obj._repr_latex_scalar = lambda x, parens=False: str(x)
384
+ try:
385
+ return obj._repr_latex_()
386
+ finally:
387
+ del obj._repr_latex_scalar
388
+
389
+ def test_simple_polynomial(self):
390
+ # default input
391
+ p = poly.Polynomial([1, 2, 3])
392
+ assert_equal(self.as_latex(p),
393
+ r'$x \mapsto 1.0 + 2.0\,x + 3.0\,x^{2}$')
394
+
395
+ # translated input
396
+ p = poly.Polynomial([1, 2, 3], domain=[-2, 0])
397
+ assert_equal(self.as_latex(p),
398
+ r'$x \mapsto 1.0 + 2.0\,\left(1.0 + x\right) + 3.0\,\left(1.0 + x\right)^{2}$')
399
+
400
+ # scaled input
401
+ p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5])
402
+ assert_equal(self.as_latex(p),
403
+ r'$x \mapsto 1.0 + 2.0\,\left(2.0x\right) + 3.0\,\left(2.0x\right)^{2}$')
404
+
405
+ # affine input
406
+ p = poly.Polynomial([1, 2, 3], domain=[-1, 0])
407
+ assert_equal(self.as_latex(p),
408
+ r'$x \mapsto 1.0 + 2.0\,\left(1.0 + 2.0x\right) + 3.0\,\left(1.0 + 2.0x\right)^{2}$')
409
+
410
+ def test_basis_func(self):
411
+ p = poly.Chebyshev([1, 2, 3])
412
+ assert_equal(self.as_latex(p),
413
+ r'$x \mapsto 1.0\,{T}_{0}(x) + 2.0\,{T}_{1}(x) + 3.0\,{T}_{2}(x)$')
414
+ # affine input - check no surplus parens are added
415
+ p = poly.Chebyshev([1, 2, 3], domain=[-1, 0])
416
+ assert_equal(self.as_latex(p),
417
+ r'$x \mapsto 1.0\,{T}_{0}(1.0 + 2.0x) + 2.0\,{T}_{1}(1.0 + 2.0x) + 3.0\,{T}_{2}(1.0 + 2.0x)$')
418
+
419
+ def test_multichar_basis_func(self):
420
+ p = poly.HermiteE([1, 2, 3])
421
+ assert_equal(self.as_latex(p),
422
+ r'$x \mapsto 1.0\,{He}_{0}(x) + 2.0\,{He}_{1}(x) + 3.0\,{He}_{2}(x)$')
423
+
424
+ def test_symbol_basic(self):
425
+ # default input
426
+ p = poly.Polynomial([1, 2, 3], symbol='z')
427
+ assert_equal(self.as_latex(p),
428
+ r'$z \mapsto 1.0 + 2.0\,z + 3.0\,z^{2}$')
429
+
430
+ # translated input
431
+ p = poly.Polynomial([1, 2, 3], domain=[-2, 0], symbol='z')
432
+ assert_equal(
433
+ self.as_latex(p),
434
+ (
435
+ r'$z \mapsto 1.0 + 2.0\,\left(1.0 + z\right) + 3.0\,'
436
+ r'\left(1.0 + z\right)^{2}$'
437
+ ),
438
+ )
439
+
440
+ # scaled input
441
+ p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5], symbol='z')
442
+ assert_equal(
443
+ self.as_latex(p),
444
+ (
445
+ r'$z \mapsto 1.0 + 2.0\,\left(2.0z\right) + 3.0\,'
446
+ r'\left(2.0z\right)^{2}$'
447
+ ),
448
+ )
449
+
450
+ # affine input
451
+ p = poly.Polynomial([1, 2, 3], domain=[-1, 0], symbol='z')
452
+ assert_equal(
453
+ self.as_latex(p),
454
+ (
455
+ r'$z \mapsto 1.0 + 2.0\,\left(1.0 + 2.0z\right) + 3.0\,'
456
+ r'\left(1.0 + 2.0z\right)^{2}$'
457
+ ),
458
+ )
459
+
460
+
461
+ SWITCH_TO_EXP = (
462
+ '1.0 + (1.0e-01) x + (1.0e-02) x**2',
463
+ '1.2 + (1.2e-01) x + (1.2e-02) x**2',
464
+ '1.23 + 0.12 x + (1.23e-02) x**2 + (1.23e-03) x**3',
465
+ '1.235 + 0.123 x + (1.235e-02) x**2 + (1.235e-03) x**3',
466
+ '1.2346 + 0.1235 x + 0.0123 x**2 + (1.2346e-03) x**3 + (1.2346e-04) x**4',
467
+ '1.23457 + 0.12346 x + 0.01235 x**2 + (1.23457e-03) x**3 + '
468
+ '(1.23457e-04) x**4',
469
+ '1.234568 + 0.123457 x + 0.012346 x**2 + 0.001235 x**3 + '
470
+ '(1.234568e-04) x**4 + (1.234568e-05) x**5',
471
+ '1.2345679 + 0.1234568 x + 0.0123457 x**2 + 0.0012346 x**3 + '
472
+ '(1.2345679e-04) x**4 + (1.2345679e-05) x**5')
473
+
474
+ class TestPrintOptions:
475
+ """
476
+ Test the output is properly configured via printoptions.
477
+ The exponential notation is enabled automatically when the values
478
+ are too small or too large.
479
+ """
480
+
481
+ @pytest.fixture(scope='class', autouse=True)
482
+ def use_ascii(self):
483
+ poly.set_default_printstyle('ascii')
484
+
485
+ def test_str(self):
486
+ p = poly.Polynomial([1/2, 1/7, 1/7*10**8, 1/7*10**9])
487
+ assert_equal(str(p), '0.5 + 0.14285714 x + 14285714.28571429 x**2 '
488
+ '+ (1.42857143e+08) x**3')
489
+
490
+ with printoptions(precision=3):
491
+ assert_equal(str(p), '0.5 + 0.143 x + 14285714.286 x**2 '
492
+ '+ (1.429e+08) x**3')
493
+
494
+ def test_latex(self):
495
+ p = poly.Polynomial([1/2, 1/7, 1/7*10**8, 1/7*10**9])
496
+ assert_equal(p._repr_latex_(),
497
+ r'$x \mapsto \text{0.5} + \text{0.14285714}\,x + '
498
+ r'\text{14285714.28571429}\,x^{2} + '
499
+ r'\text{(1.42857143e+08)}\,x^{3}$')
500
+
501
+ with printoptions(precision=3):
502
+ assert_equal(p._repr_latex_(),
503
+ r'$x \mapsto \text{0.5} + \text{0.143}\,x + '
504
+ r'\text{14285714.286}\,x^{2} + \text{(1.429e+08)}\,x^{3}$')
505
+
506
+ def test_fixed(self):
507
+ p = poly.Polynomial([1/2])
508
+ assert_equal(str(p), '0.5')
509
+
510
+ with printoptions(floatmode='fixed'):
511
+ assert_equal(str(p), '0.50000000')
512
+
513
+ with printoptions(floatmode='fixed', precision=4):
514
+ assert_equal(str(p), '0.5000')
515
+
516
+ def test_switch_to_exp(self):
517
+ for i, s in enumerate(SWITCH_TO_EXP):
518
+ with printoptions(precision=i):
519
+ p = poly.Polynomial([1.23456789*10**-i
520
+ for i in range(i//2+3)])
521
+ assert str(p).replace('\n', ' ') == s
522
+
523
+ def test_non_finite(self):
524
+ p = poly.Polynomial([nan, inf])
525
+ assert str(p) == 'nan + inf x'
526
+ assert p._repr_latex_() == r'$x \mapsto \text{nan} + \text{inf}\,x$'
527
+ with printoptions(nanstr='NAN', infstr='INF'):
528
+ assert str(p) == 'NAN + INF x'
529
+ assert p._repr_latex_() == \
530
+ r'$x \mapsto \text{NAN} + \text{INF}\,x$'
.venv/lib/python3.11/site-packages/torchgen/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """torchgen
2
+
3
+ This module contains codegeneration utilities for PyTorch. It is used to
4
+ build PyTorch from source, but may also be used for out-of-tree projects
5
+ that extend PyTorch.
6
+
7
+ Note well that we provide no BC guarantees for torchgen. If you're interested
8
+ in using torchgen and want the PyTorch team to be aware, please reach out
9
+ on GitHub.
10
+ """
.venv/lib/python3.11/site-packages/torchgen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (547 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/code_template.cpython-311.pyc ADDED
Binary file (4.96 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/context.cpython-311.pyc ADDED
Binary file (7.25 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-311.pyc ADDED
Binary file (26.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_executorch.cpython-311.pyc ADDED
Binary file (45.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-311.pyc ADDED
Binary file (44.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-311.pyc ADDED
Binary file (6.35 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/local.cpython-311.pyc ADDED
Binary file (2.15 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/native_function_generation.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/utils.cpython-311.pyc ADDED
Binary file (25.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/__pycache__/yaml_utils.cpython-311.pyc ADDED
Binary file (1.63 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/aoti/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (186 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/aoti/fallback_ops.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Be extra careful when you edit this file, because it affects AOTInductor ABI compatbility. See
2
+ # https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436
3
+ # for details.
4
+ #
5
+ # The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py.
6
+ # Generally speaking, it is ok to add a new op to the list, but you need to run
7
+ # `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files.
8
+ # But it is NOT ok to remove an existing fallback op from the list, since that will break
9
+ # some existing AOTInductor-compiled models.
10
+ inductor_fallback_ops = {
11
+ "aten._adaptive_avg_pool2d_backward.default",
12
+ "aten._adaptive_avg_pool2d.default",
13
+ "aten._adaptive_avg_pool3d.default",
14
+ "aten._adaptive_avg_pool3d_backward.default",
15
+ "aten.adaptive_max_pool2d_backward.default",
16
+ "aten.adaptive_max_pool2d.default",
17
+ "aten.adaptive_max_pool3d.default",
18
+ "aten.adaptive_max_pool3d_backward.default",
19
+ "aten.addbmm.default",
20
+ "aten._addmm_activation.default",
21
+ "aten.addmm.out",
22
+ "aten.addmv.default",
23
+ "aten.angle.default",
24
+ "aten.avg_pool2d_backward.default",
25
+ "aten.avg_pool2d.default",
26
+ "aten.avg_pool3d_backward.default",
27
+ "aten.avg_pool3d.default",
28
+ "aten.bernoulli_.float",
29
+ "aten.bernoulli_.Tensor",
30
+ "aten.bmm.out",
31
+ "aten.bucketize.Tensor",
32
+ "aten.cat.default",
33
+ "aten._cdist_backward.default",
34
+ "aten._cdist_forward.default",
35
+ "aten.cholesky_inverse.default",
36
+ "aten.cholesky_solve.default",
37
+ "aten.convolution_backward.default",
38
+ "aten._cudnn_rnn.default",
39
+ "aten._cudnn_rnn_backward.default",
40
+ "aten.convolution.default",
41
+ "aten.cummax.default",
42
+ "aten.cummin.default",
43
+ "aten.cumprod.default",
44
+ "aten.cumsum.default",
45
+ "aten._efficient_attention_backward.default",
46
+ "aten._efficient_attention_forward.default",
47
+ "aten._efficientzerotensor.default",
48
+ "aten._embedding_bag.default",
49
+ "aten._embedding_bag_dense_backward.default",
50
+ "aten._embedding_bag_forward_only.default",
51
+ "aten._embedding_bag_per_sample_weights_backward.default",
52
+ "aten.exponential.default",
53
+ "aten._fft_c2c.default",
54
+ "aten._fft_r2c.default",
55
+ "aten._flash_attention_backward.default",
56
+ "aten._flash_attention_forward.default",
57
+ "aten.fractional_max_pool2d_backward.default",
58
+ "aten.fractional_max_pool2d.default",
59
+ "aten.fractional_max_pool3d.default",
60
+ "aten.fractional_max_pool3d_backward.default",
61
+ "aten._fused_moving_avg_obs_fq_helper.default",
62
+ "aten._fused_moving_avg_obs_fq_helper_functional.default",
63
+ "aten.gcd.default",
64
+ "aten.geqrf.default",
65
+ "aten.grid_sampler_2d_backward.default",
66
+ "aten.histc.default",
67
+ "aten.histogram.bin_ct",
68
+ "aten._histogramdd_bin_edges.default",
69
+ "aten._histogramdd_from_bin_cts.default",
70
+ "aten.index_put.default",
71
+ "aten.index_reduce.default",
72
+ "aten.index.Tensor",
73
+ "aten.kthvalue.default",
74
+ "aten.logcumsumexp.default",
75
+ "aten.lu_unpack.default",
76
+ "aten.masked_scatter.default",
77
+ "aten.masked_scatter_backward.default",
78
+ "aten.max_pool2d_with_indices_backward.default",
79
+ "aten.max_pool2d_with_indices.default",
80
+ "aten.max_pool3d_with_indices.default",
81
+ "aten.max_pool3d_with_indices_backward.default",
82
+ "aten.max_unpool2d.default",
83
+ "aten.max_unpool3d.default",
84
+ "aten.median.default",
85
+ "aten.mm.out",
86
+ "aten.mode.default",
87
+ "aten.mul.Scalar",
88
+ "aten.mul.Tensor",
89
+ "aten.nanmedian.default",
90
+ "aten.native_dropout.default",
91
+ "aten.normal_functional.default",
92
+ "aten.nonzero.default",
93
+ "aten.ormqr.default",
94
+ "aten._pdist_backward.default",
95
+ "aten._pdist_forward.default",
96
+ "aten.polar.default",
97
+ "aten.pow.Scalar",
98
+ "aten.pow.Tensor_Scalar",
99
+ "aten.pow.Tensor_Tensor",
100
+ "aten.rand.default",
101
+ "aten.rand.generator",
102
+ "aten.randint.default",
103
+ "aten.randint.generator",
104
+ "aten.randint.low",
105
+ "aten.randint.low_out",
106
+ "aten.randn.default",
107
+ "aten.randn.generator",
108
+ "aten.randperm.default",
109
+ "aten.repeat_interleave.Tensor",
110
+ "aten.replication_pad1d_backward.default",
111
+ "aten.replication_pad2d_backward.default",
112
+ "aten.reshape.default",
113
+ "aten.resize_.default",
114
+ "aten.resize_as_.default",
115
+ "aten._scaled_dot_product_efficient_attention_backward.default",
116
+ "aten._scaled_dot_product_efficient_attention.default",
117
+ "aten._scaled_dot_product_flash_attention_backward.default",
118
+ "aten._scaled_dot_product_flash_attention.default",
119
+ "aten._scaled_dot_product_cudnn_attention_backward.default",
120
+ "aten._scaled_dot_product_cudnn_attention.default",
121
+ "aten._scaled_dot_product_flash_attention_for_cpu_backward.default",
122
+ "aten._scaled_dot_product_flash_attention_for_cpu.default",
123
+ "aten._scaled_mm.default",
124
+ "aten.scatter_reduce.two_out",
125
+ "aten.scatter.src_out",
126
+ "aten.scatter.value_out",
127
+ "aten.searchsorted.default",
128
+ "aten._segment_reduce_backward.default",
129
+ "aten.segment_reduce.default",
130
+ "aten.slice.Tensor",
131
+ "aten.soft_margin_loss_backward.default",
132
+ "aten.sort.default",
133
+ "aten.sort.stable",
134
+ "aten._sparse_coo_tensor_with_dims_and_tensors.default",
135
+ "aten._thnn_fused_lstm_cell.default",
136
+ "aten.topk.default",
137
+ "aten._to_sparse.default",
138
+ "aten.to_sparse.default",
139
+ "aten.triangular_solve.default",
140
+ "aten._trilinear.default",
141
+ "aten.uniform.default",
142
+ "aten.upsample_bicubic2d_backward.default",
143
+ "aten.upsample_linear1d_backward.default",
144
+ "aten.upsample_trilinear3d_backward.default",
145
+ "aten.view_as_complex.default",
146
+ "aten.view_as_real.default",
147
+ "aten.view.dtype",
148
+ "aten.zeros.names",
149
+ }
.venv/lib/python3.11/site-packages/torchgen/code_template.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Mapping, Sequence
5
+
6
+
7
+ # match $identifier or ${identifier} and replace with value in env
8
+ # If this identifier is at the beginning of whitespace on a line
9
+ # and its value is a list then it is treated as
10
+ # block substitution by indenting to that depth and putting each element
11
+ # of the list on its own line
12
+ # if the identifier is on a line starting with non-whitespace and a list
13
+ # then it is comma separated ${,foo} will insert a comma before the list
14
+ # if this list is not empty and ${foo,} will insert one after.
15
+
16
+
17
+ class CodeTemplate:
18
+ substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
19
+ substitution = re.compile(substitution_str, re.MULTILINE)
20
+
21
+ pattern: str
22
+ filename: str
23
+
24
+ @staticmethod
25
+ def from_file(filename: str) -> CodeTemplate:
26
+ with open(filename) as f:
27
+ return CodeTemplate(f.read(), filename)
28
+
29
+ def __init__(self, pattern: str, filename: str = "") -> None:
30
+ self.pattern = pattern
31
+ self.filename = filename
32
+
33
+ def substitute(
34
+ self, env: Mapping[str, object] | None = None, **kwargs: object
35
+ ) -> str:
36
+ if env is None:
37
+ env = {}
38
+
39
+ def lookup(v: str) -> object:
40
+ assert env is not None
41
+ return kwargs[v] if v in kwargs else env[v]
42
+
43
+ def indent_lines(indent: str, v: Sequence[object]) -> str:
44
+ return "".join(
45
+ [indent + l + "\n" for e in v for l in str(e).splitlines()]
46
+ ).rstrip()
47
+
48
+ def replace(match: re.Match[str]) -> str:
49
+ indent = match.group(1)
50
+ key = match.group(2)
51
+ comma_before = ""
52
+ comma_after = ""
53
+ if key[0] == "{":
54
+ key = key[1:-1]
55
+ if key[0] == ",":
56
+ comma_before = ", "
57
+ key = key[1:]
58
+ if key[-1] == ",":
59
+ comma_after = ", "
60
+ key = key[:-1]
61
+ v = lookup(key)
62
+ if indent is not None:
63
+ if not isinstance(v, list):
64
+ v = [v]
65
+ return indent_lines(indent, v)
66
+ elif isinstance(v, list):
67
+ middle = ", ".join([str(x) for x in v])
68
+ if len(v) == 0:
69
+ return middle
70
+ return comma_before + middle + comma_after
71
+ else:
72
+ return str(v)
73
+
74
+ return self.substitution.sub(replace, self.pattern)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ c = CodeTemplate(
79
+ """\
80
+ int foo($args) {
81
+
82
+ $bar
83
+ $bar
84
+ $a+$b
85
+ }
86
+ int commatest(int a${,stuff})
87
+ int notest(int a${,empty,})
88
+ """
89
+ )
90
+ print(
91
+ c.substitute(
92
+ args=["hi", 8],
93
+ bar=["what", 7],
94
+ a=3,
95
+ b=4,
96
+ stuff=["things...", "others"],
97
+ empty=[],
98
+ )
99
+ )
.venv/lib/python3.11/site-packages/torchgen/context.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import functools
5
+ from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
6
+
7
+ import torchgen.local as local
8
+ from torchgen.model import (
9
+ BackendIndex,
10
+ DispatchKey,
11
+ NativeFunction,
12
+ NativeFunctionsGroup,
13
+ NativeFunctionsViewGroup,
14
+ )
15
+ from torchgen.utils import context, S, T
16
+
17
+
18
+ # Helper functions for defining generators on things in the model
19
+
20
+ F = TypeVar(
21
+ "F",
22
+ NativeFunction,
23
+ NativeFunctionsGroup,
24
+ NativeFunctionsViewGroup,
25
+ Union[NativeFunction, NativeFunctionsGroup],
26
+ Union[NativeFunction, NativeFunctionsViewGroup],
27
+ )
28
+
29
+ F2 = TypeVar(
30
+ "F2",
31
+ NativeFunction,
32
+ NativeFunctionsGroup,
33
+ Optional[NativeFunction],
34
+ bool,
35
+ str,
36
+ )
37
+
38
+ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
39
+
40
+
41
+ @contextlib.contextmanager
42
+ def native_function_manager(
43
+ g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
44
+ ) -> Iterator[None]:
45
+ if isinstance(g, NativeFunctionsGroup):
46
+ # By default, we associate all errors with structured native functions
47
+ # with the out variant. In some cases, it might be better to have
48
+ # a more specific place to hang things; if so, use
49
+ # native_function_manager again on the inside
50
+ f = g.out
51
+ elif isinstance(g, NativeFunctionsViewGroup):
52
+ # We associate errors with the view operator
53
+ f = g.view
54
+ else:
55
+ f = g
56
+ with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
57
+ with local.parametrize(
58
+ use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
59
+ use_ilistref_for_tensor_lists=f.part_of_structured_group,
60
+ ):
61
+ yield
62
+
63
+
64
+ # Given a function that operates on NativeFunction, wrap it into a new function
65
+ # that sets some appropriate context managers for that native function.
66
+ # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
67
+ # (you will get an error if we try to access the local variables without having
68
+ # set them).
69
+ def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
70
+ @functools.wraps(func)
71
+ def wrapper(f: F) -> T:
72
+ with native_function_manager(f):
73
+ return func(f)
74
+
75
+ return wrapper
76
+
77
+
78
+ def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
79
+ @functools.wraps(func)
80
+ def wrapper(f: F, f2: F2) -> T:
81
+ # The first native_function is assumed to be the one with the appropriate context.
82
+ with native_function_manager(f):
83
+ return func(f, f2)
84
+
85
+ return wrapper
86
+
87
+
88
+ def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
89
+ @functools.wraps(func)
90
+ def wrapper(slf: S, f: F) -> T:
91
+ with native_function_manager(f):
92
+ return func(slf, f)
93
+
94
+ return wrapper
95
+
96
+
97
+ def method_with_nested_native_function(
98
+ func: Callable[[S, F3], T]
99
+ ) -> Callable[[S, F3], T]:
100
+ @functools.wraps(func)
101
+ def wrapper(slf: S, f: F3) -> T:
102
+ with native_function_manager(f[0]):
103
+ return func(slf, f)
104
+
105
+ return wrapper
106
+
107
+
108
+ # Convenience decorator for functions that explicitly take in a BackendIndex,
109
+ # instead of indirectly taking one in as a closure
110
+ def with_native_function_and_index(
111
+ func: Callable[[F, BackendIndex], T]
112
+ ) -> Callable[[F, BackendIndex], T]:
113
+ @functools.wraps(func)
114
+ def wrapper(f: F, backend_index: BackendIndex) -> T:
115
+ with native_function_manager(f):
116
+ return func(f, backend_index)
117
+
118
+ return wrapper
119
+
120
+
121
+ # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
122
+ def with_native_function_and_indices(
123
+ func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
124
+ ) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
125
+ @functools.wraps(func)
126
+ def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
127
+ with native_function_manager(f):
128
+ return func(f, backend_indices)
129
+
130
+ return wrapper
.venv/lib/python3.11/site-packages/torchgen/dest/lazy_ts_lowering.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchgen.api.lazy import LazyArgument, LazyIrSchema
2
+ from torchgen.api.types import OptionalCType
3
+
4
+
5
+ def ts_lowering_body(schema: LazyIrSchema) -> str:
6
+ # for now, we just want one IR class decl and soon after also the method defs
7
+ # and we use the functional version not out/inplace.
8
+ emplace_arguments = []
9
+
10
+ def get_value(arg: LazyArgument) -> str:
11
+ if isinstance(arg.lazy_type, OptionalCType):
12
+ return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
13
+ return "loctx->GetOutputOp(operand(i++))"
14
+
15
+ for arg in schema.positional_args:
16
+ if arg.is_lazy_value:
17
+ emplace_arguments.append(get_value(arg))
18
+ continue
19
+ emplace_arguments.append(f'"{arg.name}", {arg.name}')
20
+
21
+ emplace_arguments_str = "\n ".join(
22
+ [f"arguments.emplace_back({a});" for a in emplace_arguments]
23
+ )
24
+ emplace_kwarg_values = [
25
+ f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
26
+ ]
27
+ emplace_kwarg_scalars = [
28
+ f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
29
+ ]
30
+ emplace_kwarguments = "\n ".join(
31
+ [
32
+ f"kwarguments.emplace_back({a});"
33
+ for a in emplace_kwarg_values + emplace_kwarg_scalars
34
+ ]
35
+ )
36
+ return f"""\
37
+ std::vector<torch::jit::NamedValue> arguments;
38
+ std::vector<torch::jit::NamedValue> kwarguments;
39
+ arguments.reserve({len(emplace_arguments)});
40
+ kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
41
+ size_t i = 0;
42
+ {emplace_arguments_str}
43
+ {emplace_kwarguments}
44
+ torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
45
+ TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
46
+
47
+ return {schema.aten_name}_out;
48
+ """
.venv/lib/python3.11/site-packages/torchgen/gen.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torchgen/gen_aoti_c_shim.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import textwrap
4
+ from dataclasses import dataclass
5
+ from typing import Sequence
6
+
7
+ from torchgen.api.types import DispatcherSignature
8
+ from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
9
+ from torchgen.context import method_with_native_function
10
+ from torchgen.model import (
11
+ Argument,
12
+ BackendIndex,
13
+ BaseTy,
14
+ BaseType,
15
+ DispatchKey,
16
+ FunctionSchema,
17
+ ListType,
18
+ NativeFunction,
19
+ NativeFunctionsGroup,
20
+ OperatorName,
21
+ OptionalType,
22
+ Type,
23
+ )
24
+ from torchgen.utils import mapMaybe
25
+
26
+
27
+ base_type_to_c_type = {
28
+ BaseTy.Tensor: "AtenTensorHandle",
29
+ BaseTy.bool: "int32_t", # Use int to pass bool
30
+ BaseTy.int: "int64_t",
31
+ BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
32
+ BaseTy.Scalar: "double", # Use double to pass both integer and floating point
33
+ BaseTy.float: "double", # TODO: how about other floating point types?
34
+ BaseTy.str: "const char*",
35
+ BaseTy.DeviceIndex: "int32_t",
36
+ BaseTy.Layout: "int32_t", # Represent enum as int
37
+ BaseTy.MemoryFormat: "int32_t", # Represent enum as int
38
+ BaseTy.ScalarType: "int32_t", # Represent enum as int
39
+ BaseTy.Generator: "AtenGeneratorHandle",
40
+ }
41
+
42
+ base_type_to_aten_type = {
43
+ BaseTy.Tensor: "at::Tensor",
44
+ BaseTy.bool: "bool",
45
+ BaseTy.int: "int64_t",
46
+ BaseTy.SymInt: "c10::SymInt",
47
+ BaseTy.Scalar: "c10::Scalar",
48
+ BaseTy.float: "double",
49
+ BaseTy.str: "c10::string_view",
50
+ BaseTy.DeviceIndex: "c10::DeviceIndex",
51
+ BaseTy.Layout: "c10::Layout",
52
+ BaseTy.MemoryFormat: "c10::MemoryFormat",
53
+ BaseTy.ScalarType: "c10::ScalarType",
54
+ BaseTy.Generator: "at::Generator",
55
+ }
56
+
57
+ base_type_to_callsite_expr = {
58
+ BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
59
+ BaseTy.bool: "",
60
+ BaseTy.int: "",
61
+ BaseTy.SymInt: "",
62
+ BaseTy.Scalar: "",
63
+ BaseTy.float: "",
64
+ BaseTy.str: "",
65
+ BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
66
+ BaseTy.Layout: "static_cast<c10::Layout>",
67
+ BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
68
+ BaseTy.ScalarType: "static_cast<c10::ScalarType>",
69
+ BaseTy.Generator: "*generator_handle_to_generator_pointer",
70
+ }
71
+
72
+
73
+ # convert args to C types, names in declarations, and expressions in function bodies
74
+ def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
75
+ if isinstance(typ, BaseType):
76
+ if typ.name in base_type_to_c_type:
77
+ return (
78
+ [base_type_to_c_type[typ.name]],
79
+ [name],
80
+ [base_type_to_aten_type[typ.name]],
81
+ [
82
+ f"{base_type_to_callsite_expr[typ.name]}({name})"
83
+ if base_type_to_callsite_expr[typ.name]
84
+ else name
85
+ ],
86
+ )
87
+ elif typ.name == BaseTy.Device:
88
+ return (
89
+ ["int32_t", "int32_t"],
90
+ [name, name + "_index_"],
91
+ ["c10::Device"],
92
+ [
93
+ f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
94
+ ],
95
+ )
96
+ else:
97
+ # TODO: BaseTy.Dimname, etc.
98
+ raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
99
+ elif isinstance(typ, OptionalType):
100
+ c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
101
+ typ.elem, name
102
+ )
103
+ j = 0 # index for names
104
+ new_aten_types = []
105
+ new_callsite_exprs = []
106
+ for aten_type in aten_types:
107
+ # Use pointer to denote optional type
108
+ c_types[j] = c_types[j] + "*"
109
+ if aten_type.startswith("c10::ArrayRef<"):
110
+ # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
111
+ new_aten_types.append(f"::std::optional<{aten_type}>")
112
+ base_type = aten_type[len("c10::ArrayRef<") : -1]
113
+ new_callsite_exprs.append(
114
+ f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
115
+ )
116
+ j += 2
117
+ elif aten_type == "c10::Device":
118
+ # Device is passed as device_type + device_index
119
+ new_aten_types.append("::std::optional<c10::Device>")
120
+ new_callsite_exprs.append(
121
+ f"pointer_to_optional_device({names[j]}, {names[j+1]})"
122
+ )
123
+ j += 2
124
+ else:
125
+ new_aten_types.append(f"::std::optional<{aten_type}>")
126
+ new_callsite_exprs.append(
127
+ f"pointer_to_optional<{aten_type}>({names[j]})"
128
+ )
129
+ j += 1
130
+
131
+ return (
132
+ c_types,
133
+ names,
134
+ new_aten_types,
135
+ new_callsite_exprs,
136
+ )
137
+ elif isinstance(typ, ListType):
138
+ # Need to explictly pass the list as pointer + length
139
+ c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
140
+ assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
141
+
142
+ # The list content should never be modified
143
+ c_types[0] = f"const {c_types[0]}*"
144
+ c_types.append("int64_t")
145
+ name = names[0]
146
+ names.append(name + "_len_")
147
+
148
+ atype = aten_types[0]
149
+ callsite_exprs = []
150
+ if atype == "bool":
151
+ # no converter from std::vector<bool> to c10::ArrayRef<bool>
152
+ # construct std::array<bool, N> instead
153
+ assert typ.size is not None
154
+ callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
155
+ elif atype == "::std::optional<at::Tensor>":
156
+ # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
157
+ callsite_exprs.append(
158
+ f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
159
+ )
160
+ else:
161
+ callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
162
+
163
+ aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
164
+ return (
165
+ c_types,
166
+ names,
167
+ aten_types,
168
+ callsite_exprs,
169
+ )
170
+
171
+
172
+ def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
173
+ return [typ + " " + name for typ, name in zip(types, names)]
174
+
175
+
176
+ # Generate argument declarations and callsite expressions
177
+ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
178
+ types = []
179
+ new_names = []
180
+ callsite_exprs = []
181
+ for arg in flat_arguments:
182
+ new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
183
+ arg.type, arg.name
184
+ )
185
+ types.extend(new_types)
186
+ new_names.extend(names)
187
+ callsite_exprs.extend(new_callsite_exprs)
188
+ return zip_type_and_name(types, new_names), callsite_exprs
189
+
190
+
191
+ # Return values are passed out as pointer arguments because all the C shim functions
192
+ # are expected to return AOTITorchError.
193
+ # Generate returns as declarations and callsite expressions
194
+ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
195
+ types = []
196
+ names = []
197
+ for idx, ret in enumerate(schema.returns):
198
+ names.append(f"ret{idx}")
199
+ if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
200
+ types.append(base_type_to_c_type[ret.type.name] + "*")
201
+ else:
202
+ raise NotImplementedError(
203
+ f"TODO: add support for return type {repr(ret.type)}"
204
+ )
205
+
206
+ def convert_return(typ: BaseType, val: str) -> str:
207
+ if typ.name == BaseTy.Tensor:
208
+ return f"new_tensor_handle(std::move({val}));"
209
+ elif typ.name == BaseTy.SymInt:
210
+ return f"{val}.expect_int()"
211
+ elif typ.name == BaseTy.Scalar:
212
+ return f"{val}.toDouble()"
213
+ else:
214
+ return val
215
+
216
+ ret_pointer_can_be_null = False
217
+ unambiguous_name = schema.name.unambiguous_name()
218
+ for name in [
219
+ "_scaled_dot_product_flash_attention",
220
+ "_scaled_dot_product_efficient_attention",
221
+ "_scaled_dot_product_cudnn_attention",
222
+ "convolution_backward",
223
+ ]:
224
+ if name in unambiguous_name:
225
+ ret_pointer_can_be_null = True
226
+ break
227
+
228
+ callsite_exprs: list[str] = []
229
+ for idx, ret in enumerate(schema.returns):
230
+ tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
231
+ assert isinstance(ret.type, BaseType)
232
+ rval = convert_return(ret.type, tmp)
233
+ if ret_pointer_can_be_null:
234
+ callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
235
+ else:
236
+ callsite_exprs.append(f"*{names[idx]} = {rval};")
237
+
238
+ return zip_type_and_name(types, names), callsite_exprs
239
+
240
+
241
+ # gen.py generates header first and then src, so caching the result here to avoid duplicate work
242
+ declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
243
+
244
+
245
+ def gen_declaration_and_definition(
246
+ schema: FunctionSchema, device: str, backend_call: str
247
+ ) -> tuple[str, str]:
248
+ func_name = schema.name.unambiguous_name()
249
+
250
+ global declaration_definition_cache
251
+ if (func_name, device, backend_call) in declaration_definition_cache:
252
+ return declaration_definition_cache[(func_name, device, backend_call)]
253
+
254
+ if schema.is_out_fn():
255
+ # out_variant has out arguments in the front, and it's ok to ignore return values
256
+ # because C shim functions only return AOTITorchError
257
+ args, callsite_exprs = gen_arguments(
258
+ [*schema.arguments.out, *schema.arguments.flat_non_out]
259
+ )
260
+ ret_assignments: list[str] = []
261
+ else:
262
+ args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
263
+ # ignore return values for inplace ops
264
+ ret_declarations, ret_assignments = (
265
+ ([], []) if schema.name.name.inplace else gen_returns(schema)
266
+ )
267
+ args.extend(ret_declarations)
268
+
269
+ declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
270
+
271
+ tmp_result = "auto tmp_result = " if ret_assignments else ""
272
+ ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
273
+ definition = f"""
274
+ {declaration} {{
275
+ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
276
+ {tmp_result}{backend_call}(
277
+ {textwrap.indent(', '.join(callsite_exprs), " ")}
278
+ );{textwrap.indent(ret_assignments_str, " ")}
279
+ }});
280
+ }}
281
+ """
282
+ declaration_definition_cache[(func_name, device, backend_call)] = (
283
+ declaration,
284
+ definition,
285
+ )
286
+ return declaration, definition
287
+
288
+
289
+ def gen_static_dispatch_backend_call_signature(
290
+ sig: CppSignature | DispatcherSignature,
291
+ f: NativeFunction,
292
+ ) -> CppSignature:
293
+ sig = DispatcherSignature.from_schema(f.func)
294
+ cpp_sigs = CppSignatureGroup.from_native_function(
295
+ f, method=False, fallback_binding=False
296
+ )
297
+ if sig.symint and f.func.has_symint():
298
+ cpp_sig = cpp_sigs.symint_signature
299
+ else:
300
+ cpp_sig = cpp_sigs.signature
301
+ assert cpp_sig is not None
302
+ return cpp_sig
303
+
304
+
305
+ def gen_static_dispatch_backend_call(
306
+ f: NativeFunction,
307
+ backend_index: BackendIndex,
308
+ ) -> str:
309
+ sig = DispatcherSignature.from_schema(f.func)
310
+ cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
311
+ return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
312
+
313
+
314
+ def get_backend_index_for_aoti(
315
+ func: NativeFunction,
316
+ func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
317
+ dispatch_key: DispatchKey,
318
+ backend_indices: dict[DispatchKey, BackendIndex],
319
+ ) -> BackendIndex | None:
320
+ backend_index = None
321
+ if backend_indices[dispatch_key].has_kernel(func) or (
322
+ func.structured_delegate is not None
323
+ and func.structured_delegate in func_group_mapping
324
+ and backend_indices[dispatch_key].has_kernel(
325
+ func_group_mapping[func.structured_delegate]
326
+ )
327
+ ):
328
+ backend_index = backend_indices[dispatch_key]
329
+ elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
330
+ # We need to create C shim wrappers for CompositeExplicitAutograd kernels
331
+ backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
332
+ elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
333
+ func
334
+ ):
335
+ # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
336
+ backend_index = backend_indices[
337
+ DispatchKey.CompositeExplicitAutogradNonFunctional
338
+ ]
339
+ elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
340
+ backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
341
+
342
+ return backend_index
343
+
344
+
345
+ def get_header_for_aoti(
346
+ func: NativeFunction,
347
+ func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
348
+ dispatch_key: DispatchKey,
349
+ backend_indices: dict[DispatchKey, BackendIndex],
350
+ ) -> str | None:
351
+ backend_index = get_backend_index_for_aoti(
352
+ func, func_group_mapping, dispatch_key, backend_indices
353
+ )
354
+ return (
355
+ None
356
+ if backend_index is None
357
+ else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
358
+ )
359
+
360
+
361
+ def get_fallback_op_name(func: NativeFunction) -> str:
362
+ return (
363
+ f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
364
+ if func.func.name.overload_name
365
+ else f"{func.namespace}.{func.func.name.name}.default"
366
+ )
367
+
368
+
369
+ def gen_c_shim(
370
+ func: NativeFunction,
371
+ func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
372
+ dispatch_key: DispatchKey,
373
+ backend_indices: dict[DispatchKey, BackendIndex],
374
+ header: bool,
375
+ ) -> str | None:
376
+ backend_index = get_backend_index_for_aoti(
377
+ func, func_group_mapping, dispatch_key, backend_indices
378
+ )
379
+ if backend_index is None:
380
+ return None
381
+
382
+ schema = func.func
383
+ device = dispatch_key.lower()
384
+ backend_call = gen_static_dispatch_backend_call(
385
+ func,
386
+ backend_index,
387
+ )
388
+
389
+ try:
390
+ if header:
391
+ declaration, _ = gen_declaration_and_definition(
392
+ schema, device, backend_call
393
+ )
394
+ return f"AOTI_TORCH_EXPORT {declaration};"
395
+ else:
396
+ _, definition = gen_declaration_and_definition(schema, device, backend_call)
397
+ return definition
398
+
399
+ except NotImplementedError:
400
+ return None
401
+
402
+
403
+ @dataclass(frozen=True)
404
+ class ShimGenerator:
405
+ func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
406
+ dispatch_key: DispatchKey
407
+ backend_indices: dict[DispatchKey, BackendIndex]
408
+ header: bool # True to generate .h and False to generate .cpp
409
+
410
+ @method_with_native_function
411
+ def __call__(
412
+ self,
413
+ func: NativeFunction,
414
+ ) -> str | None:
415
+ result = gen_c_shim(
416
+ func,
417
+ self.func_group_mapping,
418
+ self.dispatch_key,
419
+ self.backend_indices,
420
+ self.header,
421
+ )
422
+ return result
423
+
424
+
425
+ def gen_aoti_c_shim(
426
+ native_functions: Sequence[NativeFunction],
427
+ func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
428
+ dispatch_key: DispatchKey,
429
+ backend_indices: dict[DispatchKey, BackendIndex],
430
+ header: bool,
431
+ includes: str = "",
432
+ ) -> str:
433
+ body = "\n".join(
434
+ list(
435
+ mapMaybe(
436
+ ShimGenerator(
437
+ func_group_mapping, dispatch_key, backend_indices, header
438
+ ),
439
+ native_functions,
440
+ )
441
+ )
442
+ )
443
+ device = dispatch_key.lower()
444
+
445
+ warning = """
446
+ // WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
447
+ // See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
448
+
449
+ if header:
450
+ return f"""
451
+ {warning}
452
+
453
+ #pragma once
454
+
455
+ #include <torch/csrc/inductor/aoti_torch/c/shim.h>
456
+
457
+ #ifdef __cplusplus
458
+ extern "C" {{
459
+ #endif
460
+
461
+ {body}
462
+
463
+ #ifdef __cplusplus
464
+ }} // extern "C"
465
+ #endif
466
+ """
467
+
468
+ else:
469
+ return f"""
470
+ {warning}
471
+
472
+ #include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
473
+ #include <torch/csrc/inductor/aoti_torch/utils.h>
474
+
475
+ #ifndef AT_PER_OPERATOR_HEADERS
476
+ #include <ATen/{str(dispatch_key)}Functions.h>
477
+ #include <ATen/CompositeExplicitAutogradFunctions.h>
478
+ #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
479
+ #include <ATen/CompositeImplicitAutogradFunctions.h>
480
+ #else
481
+ {includes}
482
+ #endif
483
+
484
+ using namespace torch::aot_inductor;
485
+
486
+ {body}"""
.venv/lib/python3.11/site-packages/torchgen/gen_backend_stubs.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import re
6
+ from collections import Counter, defaultdict, namedtuple
7
+ from pathlib import Path
8
+ from typing import Sequence
9
+
10
+ import yaml
11
+
12
+ import torchgen.api.dispatcher as dispatcher
13
+ import torchgen.dest as dest
14
+ from torchgen.api.types import DispatcherSignature
15
+ from torchgen.code_template import CodeTemplate
16
+ from torchgen.context import native_function_manager
17
+ from torchgen.gen import get_grouped_native_functions, parse_native_yaml
18
+ from torchgen.model import (
19
+ BackendIndex,
20
+ BackendMetadata,
21
+ DispatchKey,
22
+ NativeFunction,
23
+ NativeFunctionsGroup,
24
+ OperatorName,
25
+ )
26
+ from torchgen.selective_build.selector import SelectiveBuilder
27
+ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target
28
+ from torchgen.yaml_utils import YamlLoader
29
+
30
+
31
+ # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
32
+ # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
33
+ ParsedExternalYaml = namedtuple(
34
+ "ParsedExternalYaml",
35
+ ["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"],
36
+ )
37
+
38
+
39
+ def parse_backend_yaml(
40
+ backend_yaml_path: str,
41
+ grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
42
+ backend_indices: dict[DispatchKey, BackendIndex],
43
+ ) -> ParsedExternalYaml:
44
+ native_functions_map: dict[OperatorName, NativeFunction] = {
45
+ f.func.name: f
46
+ for f in concatMap(
47
+ lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
48
+ grouped_native_functions,
49
+ )
50
+ }
51
+
52
+ with open(backend_yaml_path) as f:
53
+ yaml_values = yaml.load(f, Loader=YamlLoader)
54
+ assert isinstance(yaml_values, dict)
55
+
56
+ valid_keys = [
57
+ "backend",
58
+ "class_name",
59
+ "cpp_namespace",
60
+ "extra_headers",
61
+ "supported",
62
+ "autograd",
63
+ "full_codegen",
64
+ "non_native",
65
+ "ir_gen",
66
+ "symint",
67
+ ]
68
+
69
+ backend = yaml_values.pop("backend", None)
70
+ assert backend is not None, 'You must provide a value for "backend"'
71
+
72
+ class_name = yaml_values.pop("class_name", None)
73
+
74
+ cpp_namespace = yaml_values.pop("cpp_namespace", None)
75
+ assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'
76
+
77
+ # Mostly just defaulting to false to stick with LazyTensor convention.
78
+ use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
79
+ assert isinstance(
80
+ use_out_as_primary, bool
81
+ ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
82
+
83
+ use_device_guard = yaml_values.pop("device_guard", False)
84
+ assert isinstance(
85
+ use_device_guard, bool
86
+ ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
87
+
88
+ supported = yaml_values.pop("supported", [])
89
+ if supported is None:
90
+ supported = [] # Allow an empty list of supported ops
91
+ assert isinstance(
92
+ supported, list
93
+ ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
94
+
95
+ symint = yaml_values.pop("symint", [])
96
+ if symint is None:
97
+ symint = [] # Allow an empty list of symint ops
98
+ assert isinstance(
99
+ symint, list
100
+ ), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
101
+ symint_set = set(symint)
102
+
103
+ supported_autograd = yaml_values.pop("autograd", [])
104
+ assert isinstance(
105
+ supported_autograd, list
106
+ ), f'expected "autograd" to be a list, but got: {supported_autograd}'
107
+
108
+ # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
109
+ full_codegen = yaml_values.pop("full_codegen", [])
110
+ supported.extend(full_codegen)
111
+
112
+ # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
113
+ yaml_values.pop("non_native", {})
114
+
115
+ # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
116
+ yaml_values.pop("ir_gen", {})
117
+
118
+ assert (
119
+ len(yaml_values.keys()) == 0
120
+ ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
121
+ Only the following keys are supported: {", ".join(valid_keys)}'
122
+
123
+ def create_backend_index(
124
+ backend_ops: list[str],
125
+ symint_ops: set[str],
126
+ dispatch_key: DispatchKey,
127
+ *,
128
+ use_out_as_primary: bool,
129
+ use_device_guard: bool,
130
+ ) -> BackendIndex:
131
+ metadata: dict[OperatorName, BackendMetadata] = {}
132
+ for op in backend_ops:
133
+ op_name = OperatorName.parse(op)
134
+ assert (
135
+ op_name in native_functions_map
136
+ ), f"Found an invalid operator name: {op_name}"
137
+ # See Note [External Backends Follow Dispatcher API]
138
+ kernel_name = dispatcher.name(native_functions_map[op_name].func)
139
+ if op in symint_ops:
140
+ kernel_name += "_symint"
141
+ # TODO: allow structured external backends later.
142
+ m = BackendMetadata(
143
+ kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
144
+ )
145
+ metadata[op_name] = m
146
+ return BackendIndex(
147
+ dispatch_key=dispatch_key,
148
+ use_out_as_primary=use_out_as_primary,
149
+ external=True,
150
+ device_guard=use_device_guard,
151
+ index=metadata,
152
+ )
153
+
154
+ backend_key: DispatchKey | None = None
155
+ if len(supported) > 0:
156
+ with context(
157
+ lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
158
+ ):
159
+ backend_key = DispatchKey.parse(backend)
160
+
161
+ backend_idx = create_backend_index(
162
+ supported,
163
+ symint_set,
164
+ backend_key,
165
+ use_out_as_primary=use_out_as_primary,
166
+ use_device_guard=use_device_guard,
167
+ )
168
+ assert backend_key not in backend_indices
169
+ backend_indices[backend_key] = backend_idx
170
+
171
+ autograd_key: DispatchKey | None = None
172
+ if len(supported_autograd) > 0:
173
+ with context(
174
+ lambda: f'The "autograd" key was specified, which indicates that you would like to override \
175
+ the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
176
+ ):
177
+ autograd_key = DispatchKey.parse(f"Autograd{backend}")
178
+
179
+ autograd_idx = create_backend_index(
180
+ supported_autograd,
181
+ symint_set,
182
+ autograd_key,
183
+ use_out_as_primary=use_out_as_primary,
184
+ use_device_guard=use_device_guard,
185
+ )
186
+ assert autograd_key not in backend_indices
187
+ backend_indices[autograd_key] = autograd_idx
188
+
189
+ for g in grouped_native_functions:
190
+ if isinstance(g, NativeFunction):
191
+ forward_kernels = (
192
+ []
193
+ if backend_key is None
194
+ else [
195
+ m
196
+ for m in [backend_indices[backend_key].get_kernel(g)]
197
+ if m is not None
198
+ ]
199
+ )
200
+ backward_kernels = (
201
+ []
202
+ if autograd_key is None
203
+ else [
204
+ m
205
+ for m in [backend_indices[autograd_key].get_kernel(g)]
206
+ if m is not None
207
+ ]
208
+ )
209
+ else:
210
+ forward_kernels = (
211
+ []
212
+ if backend_key is None
213
+ else [
214
+ m
215
+ for m in [
216
+ backend_indices[backend_key].get_kernel(f)
217
+ for f in g.functions()
218
+ ]
219
+ if m is not None
220
+ ]
221
+ )
222
+ backward_kernels = (
223
+ []
224
+ if autograd_key is None
225
+ else [
226
+ m
227
+ for m in [
228
+ backend_indices[autograd_key].get_kernel(f)
229
+ for f in g.functions()
230
+ ]
231
+ if m is not None
232
+ ]
233
+ )
234
+
235
+ forward_kernels = [f for f in forward_kernels if f is not None]
236
+ backward_kernels = [f for f in backward_kernels if f is not None]
237
+ assert (
238
+ len(forward_kernels) == 0 or len(backward_kernels) == 0
239
+ ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
240
+ autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
241
+ {forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'
242
+
243
+ return ParsedExternalYaml(
244
+ backend_key, autograd_key, class_name, cpp_namespace, backend_indices
245
+ )
246
+
247
+
248
+ def error_on_missing_kernels(
249
+ native_functions: Sequence[NativeFunction],
250
+ backend_indices: dict[DispatchKey, BackendIndex],
251
+ backend_key: DispatchKey,
252
+ autograd_key: DispatchKey | None,
253
+ class_name: str,
254
+ kernel_defn_file_path: str,
255
+ full_codegen: list[OperatorName] | None = None,
256
+ ) -> None:
257
+ try:
258
+ with open(kernel_defn_file_path) as f:
259
+ backend_defns = f.read()
260
+ except OSError as e:
261
+ raise AssertionError(
262
+ f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
263
+ ) from e
264
+
265
+ if full_codegen is None:
266
+ full_codegen = []
267
+
268
+ indices = [backend_indices[backend_key].index] + (
269
+ [] if autograd_key is None else [backend_indices[autograd_key].index]
270
+ )
271
+ # Quick mapping from each OperatorName used by the external backend
272
+ # to its backend kernel name
273
+ expected_backend_op_names: dict[OperatorName, str] = dict(
274
+ list(
275
+ concatMap(
276
+ lambda index: [
277
+ (op_name, metadata.kernel) for op_name, metadata in index.items()
278
+ ],
279
+ indices,
280
+ )
281
+ )
282
+ )
283
+ expected_backend_native_funcs: list[NativeFunction] = [
284
+ f
285
+ for f in native_functions
286
+ if f.func.name in expected_backend_op_names.keys()
287
+ and f.func.name not in full_codegen
288
+ ]
289
+ expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
290
+ list
291
+ )
292
+ for native_f in expected_backend_native_funcs:
293
+ expected_backend_kernel_name_counts[
294
+ expected_backend_op_names[native_f.func.name]
295
+ ].append(native_f)
296
+
297
+ # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
298
+ # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
299
+ # here, then we get a nicer error message. If we miss it, you get a linker error.
300
+ kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
301
+ actual_backend_kernel_name_counts = Counter(
302
+ # A bit unwieldy (this could probably be moved into regex),
303
+ # but we don't want to include kernel names that come from function calls,
304
+ # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
305
+ # Easy check is to ignore any lines with colons before the class name.
306
+ [
307
+ y
308
+ for (x, y) in re.findall(kernel_defn_regex, backend_defns)
309
+ if not x.endswith(":")
310
+ ]
311
+ )
312
+
313
+ missing_kernels_err_msg = ""
314
+ for expected_name, funcs in expected_backend_kernel_name_counts.items():
315
+ expected_overload_count = len(funcs)
316
+ actual_overload_count = actual_backend_kernel_name_counts[expected_name]
317
+ if expected_overload_count != actual_overload_count:
318
+
319
+ def create_decl(f: NativeFunction) -> str:
320
+ with native_function_manager(f):
321
+ return DispatcherSignature.from_schema(f.func).decl()
322
+
323
+ expected_schemas_str = "\n".join([create_decl(f) for f in funcs])
324
+ missing_kernels_err_msg += f"""
325
+ {class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
326
+ but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
327
+ {expected_schemas_str}
328
+
329
+ """
330
+ assert missing_kernels_err_msg == "", missing_kernels_err_msg
331
+
332
+
333
+ def main() -> None:
334
+ parser = argparse.ArgumentParser(description="Generate backend stub files")
335
+ parser.add_argument(
336
+ "-s",
337
+ "--source-yaml",
338
+ "--source_yaml",
339
+ help="path to source yaml file containing operator external definitions",
340
+ )
341
+ parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
342
+ parser.add_argument(
343
+ "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
344
+ )
345
+ parser.add_argument(
346
+ "--impl-path",
347
+ "--impl_path",
348
+ type=str,
349
+ default=None,
350
+ help="path to the source C++ file containing kernel definitions",
351
+ )
352
+ options = parser.parse_args()
353
+
354
+ run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path)
355
+
356
+
357
+ def gen_dispatchkey_nativefunc_headers(
358
+ fm: FileManager,
359
+ class_name: str,
360
+ cpp_namespace: str,
361
+ backend_indices: dict[DispatchKey, BackendIndex],
362
+ grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
363
+ backend_dispatch_key: DispatchKey,
364
+ autograd_dispatch_key: DispatchKey | None,
365
+ backend_name: str = "",
366
+ ) -> None:
367
+ assert class_name is not None
368
+ generated_comment = (
369
+ "Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
370
+ )
371
+
372
+ # Convert to a set first to remove duplicate kernel names.
373
+ # Backends are allowed to repeat kernel names; only generate the declaration once!
374
+ # Sort for deterministic output.
375
+ backend_declarations = sorted(
376
+ set(
377
+ concatMap(
378
+ lambda f: dest.compute_native_function_declaration(
379
+ f, backend_indices[backend_dispatch_key]
380
+ ),
381
+ grouped_native_functions,
382
+ )
383
+ )
384
+ )
385
+ autograd_declarations = sorted(
386
+ set(
387
+ concatMap(
388
+ lambda f: []
389
+ if autograd_dispatch_key is None
390
+ else dest.compute_native_function_declaration(
391
+ f, backend_indices[autograd_dispatch_key]
392
+ ),
393
+ grouped_native_functions,
394
+ )
395
+ )
396
+ )
397
+
398
+ ns_helper = NamespaceHelper(cpp_namespace)
399
+ fm.write_with_template(
400
+ f"{backend_dispatch_key}NativeFunctions.h",
401
+ "DispatchKeyNativeFunctions.h",
402
+ lambda: {
403
+ "generated_comment": generated_comment,
404
+ "namespace_prologue": ns_helper.prologue,
405
+ "class_name": class_name,
406
+ "namespace_epilogue": ns_helper.epilogue,
407
+ "dispatch_declarations": backend_declarations + autograd_declarations,
408
+ "BackendName": backend_name,
409
+ "DispatchKey": backend_dispatch_key,
410
+ },
411
+ )
412
+
413
+
414
+ def gen_dispatcher_registrations(
415
+ fm: FileManager,
416
+ output_dir: str,
417
+ class_name: str,
418
+ backend_indices: dict[DispatchKey, BackendIndex],
419
+ grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
420
+ backend_dispatch_key: DispatchKey,
421
+ dispatch_key: DispatchKey,
422
+ selector: SelectiveBuilder,
423
+ # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
424
+ build_in_tree: bool = False,
425
+ per_operator_headers: bool = False,
426
+ backend_name: str = "",
427
+ eager_registration: bool = True,
428
+ ) -> None:
429
+ headers = [
430
+ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
431
+ ]
432
+ if build_in_tree:
433
+ external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers)
434
+ else:
435
+ external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers)
436
+
437
+ assert class_name is not None
438
+ backend_index = backend_indices[dispatch_key]
439
+
440
+ dispatch_registrations_body = list(
441
+ concatMap(
442
+ dest.RegisterDispatchKey(
443
+ backend_index,
444
+ Target.REGISTRATION,
445
+ selector,
446
+ rocm=False,
447
+ symint=True,
448
+ class_method_name=f"{class_name}",
449
+ skip_dispatcher_op_registration=False,
450
+ ),
451
+ grouped_native_functions,
452
+ )
453
+ )
454
+ newline = "\n"
455
+ ns_helper = NamespaceHelper(namespace_str="at")
456
+ deferred_dispatch_registrations = ""
457
+ static_init_dispatch_registrations = ""
458
+ if eager_registration:
459
+ static_template = CodeTemplate(
460
+ """\
461
+ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
462
+ $dispatch_registrations_body
463
+ };"""
464
+ )
465
+ static_init_dispatch_registrations = static_template.substitute(
466
+ dispatch_key=dispatch_key,
467
+ dispatch_registrations_body=dispatch_registrations_body,
468
+ )
469
+ else:
470
+ deferred_template = CodeTemplate(
471
+ """\
472
+ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions();
473
+ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
474
+ static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
475
+ $dispatch_registrations_body
476
+ }"""
477
+ )
478
+ deferred_dispatch_registrations = deferred_template.substitute(
479
+ backend_name=backend_name,
480
+ dispatch_key=dispatch_key,
481
+ dispatch_registrations_body=dispatch_registrations_body,
482
+ )
483
+
484
+ fm.write_with_template(
485
+ f"Register{dispatch_key}.cpp",
486
+ "RegisterDispatchKey.cpp",
487
+ lambda: {
488
+ "extra_cuda_headers": "",
489
+ "external_backend_headers": external_backend_headers_str,
490
+ "ops_headers": "#include <ATen/Functions.h>"
491
+ if not per_operator_headers
492
+ else "",
493
+ "DispatchKey": dispatch_key,
494
+ "dispatch_namespace": dispatch_key.lower(),
495
+ "dispatch_headers": dest.gen_registration_headers(
496
+ backend_index, per_operator_headers=per_operator_headers, rocm=False
497
+ ),
498
+ "dispatch_definitions": fm.substitute_with_template(
499
+ "RegisterDispatchDefinitions.ini",
500
+ lambda: {
501
+ "ns_prologue": ns_helper.prologue,
502
+ "ns_epilogue": ns_helper.epilogue,
503
+ "static_init_dispatch_registrations": static_init_dispatch_registrations,
504
+ "deferred_dispatch_registrations": deferred_dispatch_registrations,
505
+ "dispatch_helpers": dest.gen_registration_helpers(backend_index),
506
+ "dispatch_namespace": dispatch_key.lower(),
507
+ "dispatch_namespaced_definitions": "",
508
+ "dispatch_anonymous_definitions": list(
509
+ concatMap(
510
+ dest.RegisterDispatchKey(
511
+ backend_index,
512
+ Target.ANONYMOUS_DEFINITION,
513
+ selector,
514
+ rocm=False,
515
+ symint=True,
516
+ class_method_name=f"{class_name}",
517
+ skip_dispatcher_op_registration=False,
518
+ ),
519
+ grouped_native_functions,
520
+ )
521
+ ),
522
+ },
523
+ ).split(newline),
524
+ },
525
+ )
526
+
527
+
528
+ def run(
529
+ source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
530
+ ) -> None:
531
+ # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
532
+ pytorch_root = Path(__file__).parent.parent.absolute()
533
+ template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
534
+
535
+ def make_file_manager(install_dir: str) -> FileManager:
536
+ return FileManager(
537
+ install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
538
+ )
539
+
540
+ fm = make_file_manager(output_dir)
541
+
542
+ native_yaml_path = os.path.join(
543
+ pytorch_root, "aten/src/ATen/native/native_functions.yaml"
544
+ )
545
+ tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml")
546
+ parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
547
+ native_functions, backend_indices = (
548
+ parsed_yaml.native_functions,
549
+ parsed_yaml.backend_indices,
550
+ )
551
+ grouped_native_functions = get_grouped_native_functions(native_functions)
552
+ parsed_backend_yaml = parse_backend_yaml(
553
+ source_yaml, grouped_native_functions, backend_indices
554
+ )
555
+ backend_key = parsed_backend_yaml.backend_key
556
+ autograd_key = parsed_backend_yaml.autograd_key
557
+ cpp_namespace = parsed_backend_yaml.cpp_namespace
558
+ class_name = parsed_backend_yaml.class_name
559
+ backend_indices = parsed_backend_yaml.backend_indices
560
+
561
+ selector = SelectiveBuilder.get_nop_selector()
562
+
563
+ if backend_key is None:
564
+ # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet.
565
+ return
566
+
567
+ if class_name is None:
568
+ # class_name is an optional argument to backend yaml file.
569
+ # if specified it allows an external backend to override
570
+ # the name of the class that all generated kernel definitions live under.
571
+ # if not specified, its value is given as native_function_class_name.
572
+ class_name = backend_indices[backend_key].native_function_class_name()
573
+ assert class_name is not None
574
+
575
+ if impl_path is not None:
576
+ error_on_missing_kernels(
577
+ native_functions,
578
+ backend_indices,
579
+ backend_key,
580
+ autograd_key,
581
+ class_name,
582
+ impl_path,
583
+ )
584
+
585
+ gen_dispatchkey_nativefunc_headers(
586
+ fm,
587
+ class_name,
588
+ cpp_namespace,
589
+ backend_indices,
590
+ grouped_native_functions,
591
+ backend_key,
592
+ autograd_key,
593
+ )
594
+
595
+ for dispatch_key in (
596
+ [backend_key] if autograd_key is None else [backend_key, autograd_key]
597
+ ):
598
+ gen_dispatcher_registrations(
599
+ fm,
600
+ output_dir,
601
+ class_name,
602
+ backend_indices,
603
+ grouped_native_functions,
604
+ backend_key,
605
+ dispatch_key,
606
+ selector,
607
+ )
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()
.venv/lib/python3.11/site-packages/torchgen/gen_executorch.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
9
+
10
+ import yaml
11
+
12
+ # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
13
+ from torchgen import dest
14
+ from torchgen.api import cpp as aten_cpp
15
+ from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
16
+ from torchgen.context import (
17
+ method_with_native_function,
18
+ method_with_nested_native_function,
19
+ with_native_function_and_index,
20
+ )
21
+ from torchgen.executorch.api import et_cpp
22
+ from torchgen.executorch.api.custom_ops import (
23
+ ComputeNativeFunctionStub,
24
+ gen_custom_ops_registration,
25
+ )
26
+ from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature
27
+ from torchgen.executorch.api.unboxing import Unboxing
28
+ from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml
29
+ from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct
30
+ from torchgen.gen import (
31
+ get_custom_build_selector,
32
+ get_native_function_declarations,
33
+ get_native_function_declarations_from_ns_grouped_kernels,
34
+ get_native_function_schema_registrations,
35
+ LineLoader,
36
+ parse_native_yaml,
37
+ )
38
+ from torchgen.model import (
39
+ BackendIndex,
40
+ BackendMetadata,
41
+ DEFAULT_KERNEL_NAMESPACE,
42
+ DispatchKey,
43
+ FunctionSchema,
44
+ Location,
45
+ NativeFunction,
46
+ NativeFunctionsGroup,
47
+ OperatorName,
48
+ Variant,
49
+ )
50
+ from torchgen.utils import (
51
+ context,
52
+ FileManager,
53
+ make_file_manager,
54
+ mapMaybe,
55
+ NamespaceHelper,
56
+ )
57
+
58
+
59
+ if TYPE_CHECKING:
60
+ from torchgen.selective_build.selector import SelectiveBuilder
61
+
62
+
63
+ def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
64
+ """
65
+ A wrapper function to basically get `sig.decl(include_context=True)`.
66
+ For ATen kernel, the codegen has no idea about ET contextArg, so we
67
+ use this wrapper to add it.
68
+ """
69
+ if isinstance(sig, ExecutorchCppSignature):
70
+ return sig.decl()
71
+
72
+ returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type()
73
+ cpp_args = [a.decl() for a in sig.arguments()]
74
+ cpp_args_str = ", ".join([contextArg.decl()] + cpp_args)
75
+ sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})"
76
+ return sig_decl
77
+
78
+
79
+ def static_dispatch(
80
+ sig: CppSignature | ExecutorchCppSignature,
81
+ f: NativeFunction,
82
+ backend_indices: list[BackendIndex],
83
+ ) -> str:
84
+ """
85
+ For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
86
+ native function exists, error out. A simplified version of register_dispatch_key.py
87
+ Arguments:
88
+ sig: A CppSignature for this native function we want to use.
89
+ f: NativeFunction to generate static dispatch.
90
+ backend_indices: All available backends.
91
+ Return:
92
+ C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
93
+ """
94
+ if len(backend_indices) == 0 or f.manual_kernel_registration:
95
+ return ""
96
+
97
+ backends = [b for b in backend_indices if b.has_kernel(f)]
98
+ static_block = None
99
+ if len(backends) == 1:
100
+ backend_metadata = backends[0].get_kernel(f)
101
+ if backend_metadata:
102
+ args = ", ".join(a.name for a in sig.arguments())
103
+ # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
104
+ static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
105
+ else:
106
+ static_block = f"""
107
+ ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
108
+ """
109
+ return f"""
110
+ // {f.namespace}::{f.func}
111
+ TORCH_API inline {_sig_decl_wrapper(sig)} {{
112
+ {static_block}
113
+ }}
114
+ """
115
+
116
+
117
+ # Generates Functions.h, which provides the functional public C++ API,
118
+ # and the scaffolding to call into the dispatcher from these functions.
119
+ @dataclass(frozen=True)
120
+ class ComputeFunction:
121
+ static_dispatch_backend_indices: list[BackendIndex]
122
+
123
+ selector: SelectiveBuilder
124
+
125
+ use_aten_lib: bool
126
+
127
+ is_custom_op: Callable[[NativeFunction], bool]
128
+
129
+ @method_with_native_function
130
+ def __call__(self, f: NativeFunction) -> str | None:
131
+ is_method_variant = False
132
+ if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
133
+ return None
134
+
135
+ if Variant.function not in f.variants and Variant.method in f.variants:
136
+ is_method_variant = True
137
+
138
+ # only valid remaining case is only function is in f.variants
139
+ elif not (Variant.function in f.variants and Variant.method not in f.variants):
140
+ raise Exception( # noqa: TRY002
141
+ f"Can't handle native function {f.func} with the following variant specification {f.variants}."
142
+ )
143
+
144
+ sig: CppSignature | ExecutorchCppSignature = (
145
+ CppSignatureGroup.from_native_function(
146
+ f, method=False, fallback_binding=f.manual_cpp_binding
147
+ ).most_faithful_signature()
148
+ if self.use_aten_lib
149
+ else ExecutorchCppSignature.from_native_function(f)
150
+ )
151
+ if self.use_aten_lib and not self.is_custom_op(f):
152
+ comma = ", "
153
+
154
+ if is_method_variant:
155
+ return f"""
156
+ // {f.namespace}::{f.func}
157
+ TORCH_API inline {_sig_decl_wrapper(sig)} {{
158
+ return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])});
159
+ }}
160
+ """
161
+ else:
162
+ return f"""
163
+ // {f.namespace}::{f.func}
164
+ TORCH_API inline {_sig_decl_wrapper(sig)} {{
165
+ return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
166
+ }}
167
+ """
168
+
169
+ else:
170
+ return static_dispatch(
171
+ sig,
172
+ f,
173
+ backend_indices=self.static_dispatch_backend_indices,
174
+ )
175
+
176
+
177
+ # Generates RegisterCodegenUnboxedKernels.cpp.
178
+ @dataclass(frozen=True)
179
+ class ComputeCodegenUnboxedKernels:
180
+ selector: SelectiveBuilder
181
+
182
+ use_aten_lib: bool
183
+
184
+ @method_with_nested_native_function
185
+ def __call__(
186
+ self,
187
+ unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
188
+ ) -> str:
189
+ f: NativeFunction = unbox_kernel_entry[0]
190
+ kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
191
+ kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
192
+
193
+ op_name = f"{f.namespace}::{f.func.name}"
194
+ if not self.selector.is_root_operator(op_name):
195
+ return ""
196
+
197
+ if not isinstance(kernel_key, list):
198
+ kernel_key = [kernel_key]
199
+ used_kernel_keys = self.selector.et_get_selected_kernels(
200
+ op_name, [k.to_native_string() for k in kernel_key]
201
+ )
202
+ if not used_kernel_keys:
203
+ return ""
204
+ sig: CppSignature | ExecutorchCppSignature
205
+ argument_type_gen: Callable[..., NamedCType]
206
+ return_type_gen: Callable[..., CType]
207
+ if self.use_aten_lib:
208
+ sig = CppSignatureGroup.from_native_function(
209
+ f, method=False, fallback_binding=f.manual_cpp_binding
210
+ ).most_faithful_signature()
211
+ argument_type_gen = aten_cpp.argumenttype_type
212
+ return_type_gen = aten_cpp.returns_type
213
+ arguments = sig.arguments()
214
+ kernel_call = f"torch::executor::{f.namespace}::{sig.name()}"
215
+ else:
216
+ sig = ExecutorchCppSignature.from_native_function(f)
217
+ argument_type_gen = et_cpp.argumenttype_type
218
+ return_type_gen = et_cpp.returns_type
219
+ arguments = sig.arguments(include_context=False)
220
+ kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}"
221
+ # parse arguments into C++ code
222
+ binding_list, code_list = Unboxing(
223
+ argument_type_gen=argument_type_gen
224
+ ).convert_arguments(arguments)
225
+
226
+ # for each C++ argument, generate the conversion code
227
+ code_connector = "\n\t"
228
+ arg_connector = ", "
229
+
230
+ args_str = f"{arg_connector.join(e.name for e in binding_list)}"
231
+ event_tracer_output_logging = ""
232
+ output_ids = []
233
+
234
+ if len(f.func.returns) == 0:
235
+ if len(f.func.arguments.out) == 0:
236
+ raise Exception( # noqa: TRY002
237
+ f"Can't handle native function {f.func} with no returns and no out yet."
238
+ )
239
+ out = f.func.arguments.out[0]
240
+ return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
241
+ ret_prefix = ""
242
+ output_ids = [len(binding_list)]
243
+ else:
244
+ if len(f.func.arguments.out) == 0:
245
+ return_assignment = (
246
+ f"""*stack[{len(binding_list)}] = EValue(result_);"""
247
+ )
248
+ ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
249
+ output_ids = [len(binding_list)]
250
+ else:
251
+ return_assignment = ""
252
+ ret_prefix = ""
253
+ output_ids = [
254
+ len(binding_list) - (i + 1)
255
+ for i in reversed(range(len(f.func.arguments.out)))
256
+ ]
257
+
258
+ for output_id in output_ids:
259
+ event_tracer_output_logging += (
260
+ f"internal::event_tracer_log_evalue("
261
+ f"context.internal_event_tracer(), "
262
+ f"*stack[{output_id}]);\n"
263
+ )
264
+
265
+ newline = "\n "
266
+ return "\n".join(
267
+ [
268
+ f"""
269
+ Kernel(
270
+ "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''}
271
+ []({contextArg.defn()}, EValue** stack) {{
272
+ {code_connector.join(code_list)}
273
+
274
+ internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_{f.func.name}");
275
+ EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
276
+ {ret_prefix}{kernel_call}(context, {args_str});
277
+ {event_tracer_output_logging}
278
+ {return_assignment}
279
+ }}
280
+ ),
281
+ """
282
+ for k in used_kernel_keys
283
+ ]
284
+ )
285
+
286
+
287
+ def gen_unboxing(
288
+ *,
289
+ native_functions: Sequence[NativeFunction],
290
+ cpu_fm: FileManager,
291
+ selector: SelectiveBuilder,
292
+ use_aten_lib: bool,
293
+ kernel_index: ETKernelIndex,
294
+ manual_registration: bool,
295
+ ) -> None:
296
+ # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
297
+ def key_func(
298
+ item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
299
+ ) -> str:
300
+ return item[0].root_name + ":" + item[1][0].to_native_string()
301
+
302
+ items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
303
+ (native_function, (kernel_key, metadata))
304
+ for native_function in native_functions
305
+ for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
306
+ ]
307
+
308
+ header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"]
309
+ filename = (
310
+ "RegisterKernels.cpp"
311
+ if manual_registration
312
+ else "RegisterCodegenUnboxedKernels.cpp"
313
+ )
314
+ cpu_fm.write_sharded(
315
+ filename,
316
+ items,
317
+ key_fn=key_func,
318
+ env_callable=lambda unbox_kernel_entry: {
319
+ "unboxed_kernels": [
320
+ ComputeCodegenUnboxedKernels(selector, use_aten_lib)(unbox_kernel_entry)
321
+ ],
322
+ "fn_header": header
323
+ if unbox_kernel_entry == items[0]
324
+ else [], # Only write header once
325
+ },
326
+ num_shards=1,
327
+ sharded_keys={"unboxed_kernels", "fn_header"},
328
+ )
329
+
330
+
331
+ @with_native_function_and_index # type: ignore[arg-type]
332
+ def compute_native_function_declaration(
333
+ g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
334
+ ) -> list[str]:
335
+ assert isinstance(g, NativeFunction)
336
+ sig = ExecutorchCppSignature.from_native_function(f=g)
337
+ metadata_list = kernel_index.get_kernels(g).values()
338
+ if metadata_list is None:
339
+ return []
340
+
341
+ # for kernels in lean mode, we declare two versions, one with context and one without.
342
+ # In the end we will cleanup the unused one.
343
+ def gen_decl(metadata: BackendMetadata, include_context: bool) -> str:
344
+ return f"{sig.decl(name=metadata.kernel, include_context=include_context)};"
345
+
346
+ return [
347
+ gen_decl(metadata, include_context)
348
+ for include_context in [False, True]
349
+ for metadata in metadata_list
350
+ ]
351
+
352
+
353
+ def gen_functions_declarations(
354
+ *,
355
+ native_functions: Sequence[NativeFunction],
356
+ kernel_index: ETKernelIndex,
357
+ selector: SelectiveBuilder,
358
+ use_aten_lib: bool,
359
+ custom_ops_native_functions: Sequence[NativeFunction] | None = None,
360
+ ) -> str:
361
+ """
362
+ Generates namespace separated C++ function API inline declaration/definitions.
363
+ Native functions are grouped by namespaces and the generated code is wrapped inside
364
+ namespace blocks.
365
+
366
+ E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
367
+ in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
368
+ the other `custom_2::foo.out` is available.
369
+ """
370
+
371
+ # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
372
+ # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
373
+
374
+ backend_index = kernel_index._to_backend_index()
375
+
376
+ ns_grouped_functions = defaultdict(list)
377
+ for native_function in native_functions:
378
+ ns_grouped_functions[native_function.namespace].append(native_function)
379
+ functions_declarations = ""
380
+ newline = "\n"
381
+ for namespace in ns_grouped_functions:
382
+ ns_helper = NamespaceHelper(
383
+ namespace_str=namespace,
384
+ entity_name="",
385
+ max_level=3,
386
+ )
387
+ declarations = list(
388
+ mapMaybe(
389
+ ComputeFunction(
390
+ static_dispatch_backend_indices=[backend_index],
391
+ selector=selector,
392
+ use_aten_lib=use_aten_lib,
393
+ is_custom_op=lambda f: custom_ops_native_functions is not None
394
+ and f in custom_ops_native_functions,
395
+ ),
396
+ ns_grouped_functions[namespace],
397
+ )
398
+ )
399
+ functions_declarations += f"""
400
+ {ns_helper.prologue}
401
+ {newline.join(declarations)}
402
+ {ns_helper.epilogue}
403
+ """
404
+ return functions_declarations
405
+
406
+
407
+ def get_ns_grouped_kernels(
408
+ *,
409
+ native_functions: Sequence[NativeFunction],
410
+ kernel_index: ETKernelIndex,
411
+ native_function_decl_gen: Callable[
412
+ [
413
+ NativeFunctionsGroup | NativeFunction,
414
+ ETKernelIndex,
415
+ ],
416
+ list[str],
417
+ ],
418
+ ) -> dict[str, list[str]]:
419
+ ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
420
+ for f in native_functions:
421
+ native_function_namespaces = set()
422
+ op_kernels = kernel_index.get_kernels(f)
423
+ for backend_metadata in op_kernels.values():
424
+ if backend_metadata:
425
+ namespace = backend_metadata.cpp_namespace
426
+ native_function_namespaces.add(namespace)
427
+ else:
428
+ namespace = DEFAULT_KERNEL_NAMESPACE
429
+ assert (
430
+ len(native_function_namespaces) <= 1
431
+ ), f"Codegen only supports one namespace per operator, got {native_function_namespaces}"
432
+ ns_grouped_kernels[namespace].extend(
433
+ native_function_decl_gen(f, kernel_index)
434
+ )
435
+ return ns_grouped_kernels
436
+
437
+
438
+ def gen_headers(
439
+ *,
440
+ native_functions: Sequence[NativeFunction],
441
+ gen_custom_ops_header: bool,
442
+ custom_ops_native_functions: Sequence[NativeFunction],
443
+ selector: SelectiveBuilder,
444
+ kernel_index: ETKernelIndex,
445
+ cpu_fm: FileManager,
446
+ use_aten_lib: bool,
447
+ ) -> None:
448
+ """Generate headers.
449
+
450
+ Args:
451
+ native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops.
452
+ gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h
453
+ custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops.
454
+ kernel_index (ETKernelIndex): kernel collection
455
+ cpu_fm (FileManager): file manager manages output stream
456
+ use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types.
457
+ """
458
+ aten_headers = ["#include <ATen/Functions.h>"]
459
+ backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()}
460
+ if gen_custom_ops_header:
461
+ cpu_fm.write_with_template(
462
+ "CustomOpsNativeFunctions.h",
463
+ "NativeFunctions.h",
464
+ lambda: {
465
+ "nativeFunctions_declarations": get_native_function_declarations(
466
+ grouped_native_functions=custom_ops_native_functions,
467
+ backend_indices=backend_indices,
468
+ native_function_decl_gen=dest.compute_native_function_declaration,
469
+ ),
470
+ "headers": [
471
+ "#include <ATen/ATen.h>",
472
+ "#include <torch/torch.h>",
473
+ ],
474
+ },
475
+ )
476
+ aten_headers.append('#include "CustomOpsNativeFunctions.h"')
477
+ cpu_fm.write(
478
+ "Functions.h",
479
+ lambda: {
480
+ "static_dispatch_extra_headers": aten_headers
481
+ if use_aten_lib
482
+ else ['#include "NativeFunctions.h"'],
483
+ "Functions_declarations": gen_functions_declarations(
484
+ native_functions=native_functions,
485
+ kernel_index=kernel_index,
486
+ selector=selector,
487
+ use_aten_lib=use_aten_lib,
488
+ custom_ops_native_functions=custom_ops_native_functions,
489
+ ),
490
+ },
491
+ )
492
+ cpu_fm.write(
493
+ "RegisterKernels.h",
494
+ lambda: {
495
+ "generated_comment": "@" + "generated by torchgen/gen_executorch.py",
496
+ },
497
+ )
498
+ headers = {
499
+ "headers": [
500
+ "#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
501
+ "#include <executorch/runtime/kernel/kernel_runtime_context.h>",
502
+ ],
503
+ }
504
+ if use_aten_lib:
505
+ headers["headers"].append("#include <executorch/codegen/macros.h> // TORCH_API")
506
+ cpu_fm.write(
507
+ "NativeFunctions.h",
508
+ lambda: dict(
509
+ {
510
+ "nativeFunctions_declarations": get_native_function_declarations(
511
+ grouped_native_functions=native_functions,
512
+ backend_indices=backend_indices,
513
+ native_function_decl_gen=dest.compute_native_function_declaration,
514
+ ),
515
+ },
516
+ **headers,
517
+ ),
518
+ )
519
+ else:
520
+ ns_grouped_kernels = get_ns_grouped_kernels(
521
+ native_functions=native_functions,
522
+ kernel_index=kernel_index,
523
+ native_function_decl_gen=compute_native_function_declaration, # type: ignore[arg-type]
524
+ )
525
+ cpu_fm.write(
526
+ "NativeFunctions.h",
527
+ lambda: dict(
528
+ {
529
+ "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
530
+ ns_grouped_kernels=ns_grouped_kernels,
531
+ ),
532
+ },
533
+ **headers,
534
+ ),
535
+ )
536
+
537
+
538
+ def gen_custom_ops(
539
+ *,
540
+ native_functions: Sequence[NativeFunction],
541
+ selector: SelectiveBuilder,
542
+ kernel_index: ETKernelIndex,
543
+ cpu_fm: FileManager,
544
+ rocm: bool,
545
+ ) -> None:
546
+ dispatch_key = DispatchKey.CPU
547
+ (
548
+ anonymous_definition,
549
+ static_init_dispatch_registrations,
550
+ ) = gen_custom_ops_registration(
551
+ native_functions=native_functions,
552
+ selector=selector,
553
+ kernel_index=kernel_index,
554
+ rocm=rocm,
555
+ )
556
+ cpu_fm.write_with_template(
557
+ f"Register{dispatch_key}CustomOps.cpp",
558
+ "RegisterDispatchKeyCustomOps.cpp",
559
+ lambda: {
560
+ "ops_headers": '#include "CustomOpsNativeFunctions.h"',
561
+ "DispatchKey": dispatch_key,
562
+ "dispatch_namespace": dispatch_key.lower(),
563
+ "dispatch_namespaced_definitions": "",
564
+ "dispatch_anonymous_definitions": anonymous_definition,
565
+ "static_init_dispatch_registrations": static_init_dispatch_registrations,
566
+ },
567
+ )
568
+ cpu_fm.write_with_template(
569
+ f"Register{dispatch_key}Stub.cpp",
570
+ "RegisterDispatchKeyCustomOps.cpp",
571
+ lambda: {
572
+ "ops_headers": "",
573
+ "DispatchKey": dispatch_key,
574
+ "dispatch_namespace": dispatch_key.lower(),
575
+ "dispatch_namespaced_definitions": "",
576
+ "dispatch_anonymous_definitions": list(
577
+ mapMaybe(ComputeNativeFunctionStub(), native_functions)
578
+ ),
579
+ "static_init_dispatch_registrations": static_init_dispatch_registrations,
580
+ },
581
+ )
582
+
583
+ (
584
+ aten_schema_registrations,
585
+ schema_registrations,
586
+ ) = get_native_function_schema_registrations(
587
+ native_functions=native_functions,
588
+ schema_selector=selector,
589
+ )
590
+ cpu_fm.write(
591
+ "RegisterSchema.cpp",
592
+ lambda: {
593
+ "schema_registrations": schema_registrations,
594
+ "aten_schema_registrations": aten_schema_registrations,
595
+ },
596
+ )
597
+
598
+
599
+ def translate_native_yaml(
600
+ tags_yaml_path: str,
601
+ aten_yaml_path: str,
602
+ native_yaml_path: str | None,
603
+ use_aten_lib: bool,
604
+ out_file: TextIO,
605
+ ) -> None:
606
+ """Translates Executorch DSL dialect to use the same syntax as
607
+ native_functions.yaml. The major difference is that Executorch DSL dialect
608
+ supports "op" key, where it refers to the operator name in native_functions.yaml.
609
+
610
+ For example, a functions.yaml may have the following entry:
611
+
612
+ - op: add.out
613
+ ...
614
+
615
+ It needs to be translated to the following:
616
+
617
+ - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
618
+ ...
619
+
620
+ We go in aten_yaml_path and find the operator schema for "add.out" and add it
621
+ to the original functions.yaml. We also add required field "variants", where for
622
+ Executorch it will always be "function".
623
+
624
+ For ATen mode we don't have to do the translation because native_yaml_path is
625
+ the same as native_functions.yaml.
626
+
627
+ Args:
628
+ tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
629
+ It is not optional.
630
+ aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
631
+ native_yaml_path: Path to a functions.yaml file to parse.
632
+ If the path does not exist in the filesystem, it is treated as an
633
+ empty file. If `custom_ops_yaml_path` exists, the contents of that
634
+ file are appended to the yaml input to be parsed.
635
+ use_aten_lib: We use this flag to determine if we want to generate native
636
+ functions. In ATen mode we should generate out= variants.
637
+ out_file: The IO object that we are writing into.
638
+ Returns:
639
+ None
640
+ """
641
+ if use_aten_lib:
642
+ with open(aten_yaml_path) as aten_yaml:
643
+ out_file.writelines(aten_yaml.readlines())
644
+ return
645
+
646
+ native_functions, persisted_fields = parse_et_yaml(
647
+ aten_yaml_path,
648
+ tags_yaml_path,
649
+ None,
650
+ skip_native_fns_gen=False,
651
+ )
652
+
653
+ func_to_scoped_name: dict[FunctionSchema, str] = {
654
+ f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
655
+ }
656
+ op_to_scoped_name: dict[OperatorName, str] = {
657
+ func.name: name for func, name in func_to_scoped_name.items()
658
+ }
659
+
660
+ schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
661
+ kernel_persist_dict: dict[str, dict[str, Any]] = {
662
+ op_to_scoped_name[op]: v for op, v in persisted_fields.items()
663
+ }
664
+
665
+ if (
666
+ not native_yaml_path
667
+ or not os.path.exists(native_yaml_path)
668
+ or os.stat(native_yaml_path).st_size == 0
669
+ ):
670
+ return
671
+ with open(native_yaml_path) as native_yaml:
672
+ native_es = yaml.load(native_yaml, Loader=LineLoader)
673
+ if not native_es:
674
+ return
675
+ for e in native_es:
676
+ assert isinstance(e.get("__line__"), int), e
677
+ loc = Location(native_yaml_path, e.pop("__line__"))
678
+ with context(lambda: f"in {loc}:\n "):
679
+ if "variants" not in e:
680
+ e["variants"] = "function"
681
+ if "func" in e:
682
+ continue
683
+ assert isinstance(e.get("op"), str), e
684
+ opname = e.pop("op")
685
+ if "::" not in opname:
686
+ opname = "aten::" + opname
687
+ assert opname in schema_dict
688
+ e["func"] = schema_dict.get(opname)
689
+
690
+ # Write out persisted kernel information
691
+ if opname in kernel_persist_dict:
692
+ for k, v in kernel_persist_dict[opname].items():
693
+ e[k] = v
694
+
695
+ yaml.dump(native_es, out_file, width=1000)
696
+
697
+
698
+ def parse_yaml(
699
+ path: str | None,
700
+ tags_yaml_path: str,
701
+ function_filter: Callable[[NativeFunction], bool],
702
+ skip_native_fns_gen: bool = False,
703
+ ) -> tuple[
704
+ list[NativeFunction],
705
+ dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
706
+ ]:
707
+ if path and os.path.exists(path) and os.stat(path).st_size > 0:
708
+ with open(path) as f:
709
+ es = yaml.load(f, Loader=LineLoader)
710
+
711
+ # Check for kernel index structure
712
+ kernel_index = (
713
+ parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None
714
+ )
715
+
716
+ # Remove ET specific fields from entries for BC compatibility
717
+ for entry in es:
718
+ for field in ET_FIELDS:
719
+ entry.pop(field, None)
720
+
721
+ parsed_yaml = parse_native_yaml(
722
+ path,
723
+ tags_yaml_path,
724
+ None,
725
+ skip_native_fns_gen=skip_native_fns_gen,
726
+ loaded_yaml=es,
727
+ )
728
+ native_functions = list(filter(function_filter, parsed_yaml.native_functions))
729
+ op_names = [f.func.name for f in native_functions]
730
+
731
+ # (1) Return ETKernelIndex if kernel index is present
732
+ if kernel_index is not None:
733
+ filtered_index = {
734
+ op_name: kernel_mapping
735
+ for op_name, kernel_mapping in kernel_index.index.items()
736
+ if op_name in op_names
737
+ }
738
+ return native_functions, ETKernelIndex(index=filtered_index)
739
+
740
+ # (2) Return BackendIndices if kernel index is absent
741
+ def map_index(
742
+ m: dict[OperatorName, BackendMetadata]
743
+ ) -> dict[OperatorName, BackendMetadata]:
744
+ return {op: m[op] for op in m if op in op_names}
745
+
746
+ backend_indices = {
747
+ k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items()
748
+ }
749
+
750
+ return native_functions, backend_indices
751
+ else:
752
+ return [], {}
753
+
754
+
755
+ def parse_yaml_files(
756
+ tags_yaml_path: str,
757
+ aten_yaml_path: str,
758
+ native_yaml_path: str | None,
759
+ custom_ops_yaml_path: str | None,
760
+ selector: SelectiveBuilder,
761
+ use_aten_lib: bool,
762
+ ) -> tuple[ETParsedYaml, ETParsedYaml | None]:
763
+ """Parses functions.yaml and custom_ops.yaml files.
764
+
765
+ Args:
766
+ tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
767
+ It is not optional.
768
+ aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
769
+ native_yaml_path: Path to a functions.yaml file to parse.
770
+ If the path does not exist in the filesystem, it is treated as an
771
+ empty file. If `custom_ops_yaml_path` exists, the contents of that
772
+ file are appended to the yaml input to be parsed.
773
+ custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
774
+ the path does not exist in the filesystem, it is ignored.
775
+ selector: For selective build.
776
+ use_aten_lib: We use this flag to determine if we want to generate native
777
+ functions. In ATen mode we should generate out= variants.
778
+ Returns:
779
+ A tuple with two elements:
780
+ [0]: The parsed results of concatenating the contents of
781
+ `native_yaml_path` and `custom_ops_yaml_path`.
782
+ [1]: The parsed results of the contents of `custom_ops_yaml_path`, if
783
+ present. If not present, None.
784
+ """
785
+ import tempfile
786
+
787
+ # only include selected ops, this is because we want to avoid
788
+ def function_filter(f: NativeFunction) -> bool:
789
+ return selector.is_native_function_selected(f)
790
+
791
+ with tempfile.TemporaryDirectory() as tmpdirname:
792
+ translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
793
+ with open(translated_yaml_path, "w") as translated:
794
+ translate_native_yaml(
795
+ tags_yaml_path,
796
+ aten_yaml_path,
797
+ native_yaml_path,
798
+ use_aten_lib,
799
+ translated,
800
+ )
801
+
802
+ translated_functions, translated_indices = parse_yaml(
803
+ translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
804
+ )
805
+ custom_ops_functions, custom_ops_indices = parse_yaml(
806
+ custom_ops_yaml_path, tags_yaml_path, function_filter, True
807
+ )
808
+
809
+ # Convert BackendIndices to ETKernelIndex
810
+ if not isinstance(translated_indices, ETKernelIndex):
811
+ translated_indices = ETKernelIndex.from_backend_indices(translated_indices)
812
+ if not isinstance(custom_ops_indices, ETKernelIndex):
813
+ custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices)
814
+
815
+ combined_functions = translated_functions + custom_ops_functions
816
+ combined_kernel_index = ETKernelIndex.merge_indices(
817
+ translated_indices, custom_ops_indices
818
+ )
819
+ combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index)
820
+ custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices)
821
+
822
+ return combined_yaml, custom_ops_parsed_yaml
823
+
824
+
825
+ def main() -> None:
826
+ parser = argparse.ArgumentParser(description="Generate operator source files")
827
+ # Although we don't refer to --source-path directly, make_file_manager()
828
+ # expects it to point to a directory that contains a templates/ subdirectory
829
+ # containing the file templates.
830
+ parser.add_argument(
831
+ "-s",
832
+ "--source-path",
833
+ help="path to source directory for kernel templates",
834
+ )
835
+ parser.add_argument(
836
+ "--functions-yaml-path",
837
+ "--functions_yaml_path",
838
+ help="path to the functions.yaml file to use. Optional, but at least "
839
+ "one of --functions-yaml-path and --custom-ops-yaml-path must be "
840
+ "specified.",
841
+ )
842
+ parser.add_argument(
843
+ "--custom-ops-yaml-path",
844
+ "--custom_ops_yaml_path",
845
+ help="path to the custom_ops.yaml file to use. Optional, but at least "
846
+ "one of --functions-yaml-path and --custom-ops-yaml-path must be "
847
+ "specified.",
848
+ )
849
+ parser.add_argument(
850
+ "--aten-yaml-path",
851
+ "--aten_yaml_path",
852
+ help="path to native_functions.yaml file.",
853
+ )
854
+ # Note that make_file_manager() also looks at --install-dir.
855
+ parser.add_argument(
856
+ "-d",
857
+ "--install-dir",
858
+ "--install_dir",
859
+ help="output directory",
860
+ default="build/generated",
861
+ )
862
+ parser.add_argument(
863
+ "-o",
864
+ "--output-dependencies",
865
+ help="output a list of dependencies into the given file and exit",
866
+ )
867
+ # Although we don't refer to --dry-run directly, make_file_manager() looks
868
+ # for it.
869
+ parser.add_argument(
870
+ "--dry-run",
871
+ action="store_true",
872
+ help="run without writing any files (still updates outputs)",
873
+ )
874
+ parser.add_argument(
875
+ "--static-dispatch-backend",
876
+ "--static_dispatch_backend",
877
+ nargs="*",
878
+ help="generate static dispatch code for the specific backend (if set)",
879
+ )
880
+ parser.add_argument(
881
+ "--op-registration-whitelist",
882
+ "--op_registration_whitelist",
883
+ nargs="*",
884
+ help="filter op registrations by the whitelist (if set); "
885
+ "each item is `namespace`::`operator name` without overload name; "
886
+ "e.g.: aten::empty aten::conv2d ...",
887
+ )
888
+ parser.add_argument(
889
+ "--op-selection-yaml-path",
890
+ "--op_selection_yaml_path",
891
+ help="Provide a path to the operator selection (for custom build) YAML "
892
+ "that contains the information about the set of selected operators "
893
+ "and their categories (training, ...). Each operator is either a "
894
+ "full operator name with overload or just a bare operator name. "
895
+ "The operator names also contain the namespace prefix (e.g. aten::)",
896
+ )
897
+ parser.add_argument(
898
+ "--tags-path",
899
+ help="Path to tags.yaml. Required by yaml parsing in codegen system.",
900
+ )
901
+ parser.add_argument(
902
+ "--rocm",
903
+ action="store_true",
904
+ help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
905
+ )
906
+ parser.add_argument(
907
+ "--use-aten-lib",
908
+ "--use_aten_lib",
909
+ action="store_true",
910
+ help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
911
+ "operator",
912
+ )
913
+ parser.add_argument(
914
+ "--manual_registration",
915
+ "--manual-registration",
916
+ action="store_true",
917
+ help="a boolean flag to indicate whether we want to manually call"
918
+ "register_kernels() or rely on static init. ",
919
+ )
920
+ parser.add_argument(
921
+ "--generate",
922
+ type=str,
923
+ nargs="*",
924
+ choices=["headers", "sources"],
925
+ default=["headers", "sources"],
926
+ help="Generate only a subset of files",
927
+ )
928
+ options = parser.parse_args()
929
+ assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
930
+
931
+ selector = get_custom_build_selector(
932
+ options.op_registration_whitelist,
933
+ options.op_selection_yaml_path,
934
+ )
935
+
936
+ parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
937
+ aten_yaml_path=options.aten_yaml_path,
938
+ tags_yaml_path=options.tags_path,
939
+ native_yaml_path=options.functions_yaml_path,
940
+ custom_ops_yaml_path=options.custom_ops_yaml_path,
941
+ selector=selector,
942
+ use_aten_lib=options.use_aten_lib,
943
+ )
944
+ native_functions, kernel_index = (
945
+ parsed_yaml.native_functions,
946
+ parsed_yaml.kernel_index,
947
+ )
948
+ custom_ops_native_functions = (
949
+ custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
950
+ )
951
+
952
+ cpu_fm = make_file_manager(options=options)
953
+
954
+ if "headers" in options.generate:
955
+ # generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system.
956
+ gen_headers(
957
+ native_functions=native_functions,
958
+ gen_custom_ops_header=options.custom_ops_yaml_path,
959
+ custom_ops_native_functions=custom_ops_native_functions,
960
+ selector=selector,
961
+ kernel_index=kernel_index,
962
+ cpu_fm=cpu_fm,
963
+ use_aten_lib=options.use_aten_lib,
964
+ )
965
+
966
+ if "sources" in options.generate:
967
+ gen_unboxing(
968
+ native_functions=native_functions,
969
+ cpu_fm=cpu_fm,
970
+ selector=selector,
971
+ use_aten_lib=options.use_aten_lib,
972
+ kernel_index=kernel_index,
973
+ manual_registration=options.manual_registration,
974
+ )
975
+ if custom_ops_native_functions:
976
+ gen_custom_ops(
977
+ native_functions=custom_ops_native_functions,
978
+ selector=selector,
979
+ kernel_index=kernel_index,
980
+ cpu_fm=cpu_fm,
981
+ rocm=options.rocm,
982
+ )
983
+
984
+ if options.output_dependencies:
985
+ depfile_path = Path(options.output_dependencies).resolve()
986
+ depfile_name = depfile_path.name
987
+ depfile_stem = depfile_path.stem
988
+
989
+ for fm, prefix in [
990
+ (cpu_fm, ""),
991
+ ]:
992
+ varname = prefix + depfile_stem
993
+ path = depfile_path.parent / (prefix + depfile_name)
994
+ fm.write_outputs(varname, str(path))
995
+
996
+
997
+ if __name__ == "__main__":
998
+ main()
.venv/lib/python3.11/site-packages/torchgen/gen_functionalization_type.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable, TYPE_CHECKING
5
+
6
+ from torchgen.api import cpp, dispatcher
7
+ from torchgen.api.translate import translate
8
+ from torchgen.api.types import (
9
+ BaseCType,
10
+ Binding,
11
+ CType,
12
+ DispatcherSignature,
13
+ FunctionalizationLambda,
14
+ iTensorListRefT,
15
+ NativeSignature,
16
+ OptionalCType,
17
+ optionalSymIntArrayRefT,
18
+ symIntArrayRefT,
19
+ SymIntT,
20
+ tensorListT,
21
+ tensorT,
22
+ VectorCType,
23
+ ViewInverseSignature,
24
+ )
25
+ from torchgen.context import (
26
+ method_with_native_function,
27
+ native_function_manager,
28
+ with_native_function,
29
+ with_native_function_and,
30
+ )
31
+ from torchgen.model import (
32
+ Argument,
33
+ BackendIndex,
34
+ BaseTy,
35
+ BaseType,
36
+ FunctionSchema,
37
+ ListType,
38
+ NativeFunction,
39
+ NativeFunctionsGroup,
40
+ NativeFunctionsViewGroup,
41
+ Return,
42
+ SchemaKind,
43
+ SelfArgument,
44
+ TensorOptionsArguments,
45
+ )
46
+ from torchgen.native_function_generation import (
47
+ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
48
+ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
49
+ OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
50
+ )
51
+ from torchgen.utils import dataclass_repr
52
+
53
+
54
+ if TYPE_CHECKING:
55
+ from torchgen.selective_build.selector import SelectiveBuilder
56
+
57
+
58
+ # Note: [Mutable Ops Not Using Functionalization]
59
+ # Ops in this list currently do not work with functionalization and should be fixed.
60
+ MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
61
+ OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
62
+ + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
63
+ + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
64
+ + [
65
+ # It will be BC-breaking, but we should fix their schemas.
66
+ # should be inplace?
67
+ "record_stream",
68
+ # See Note [resize_ in Functionalization]
69
+ "resize_",
70
+ "resize_as_",
71
+ # This function is used as for testing purposes only.
72
+ "_fill_mem_eff_dropout_mask_",
73
+ ]
74
+ )
75
+
76
+ # This file contains codegen that relates to the functionalization pass.
77
+ # It includes:
78
+ # - gen_functionalization_definition
79
+ # Generates dispatcher kernel definitions for the functionalization pass.
80
+ # - gen_functionalization_registration
81
+ # Generates dispatcher kernel registrations for the functionalization pass.
82
+ # - gen_functionalization_view_inverse_declaration
83
+ # Generates a declaration for an "inverse view", for every view op
84
+ # that is needed in functionalization. We manually implement their definitions.
85
+ # - gen_composite_view_copy_kernel
86
+ # Generates view_copy() composite kernels for all view_copy operators.
87
+
88
+
89
+ # Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction
90
+ # See Note [view_copy NativeFunctions]
91
+ @dataclass(frozen=True)
92
+ class GenCompositeViewCopyKernel:
93
+ backend_index: BackendIndex
94
+
95
+ @method_with_native_function
96
+ def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
97
+ if g.view_copy is None:
98
+ return None
99
+ elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
100
+ # If the view_copy doesn't match the standard naming scheme of <op>_copy,
101
+ # assume it already exists and doesn't need to be generated.
102
+ # Example: slice_inverse() with the copy variant named slice_scatter()
103
+ # instead of slice_inverse_copy()
104
+ return None
105
+
106
+ metadata = self.backend_index.get_kernel(g.view_copy)
107
+ assert metadata is not None
108
+
109
+ # We can make view_copy work in more cases by using reshape()
110
+ # when a normal view call would ordinarily fail.
111
+ # This also makes LTC more efficient, because they don't need to include
112
+ # clone() calls in their graph (which is normally needed by reshape).
113
+ if str(g.view_copy.func.name) == "view_copy":
114
+ assert metadata.kernel == "view_copy_symint"
115
+ return """\
116
+ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
117
+ c10::SymDimVector shape = infer_size_dv(size, self.sym_numel());
118
+ if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) {
119
+ return self.reshape_symint(size);
120
+ } else {
121
+ auto output = at::_ops::view::call(self, size);
122
+ return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
123
+ }
124
+ }
125
+ """
126
+ # view_copy is a native signature, since we're generating an at::native:: kernel
127
+ # Functionalization always operates on symints though
128
+ view_copy_sig = NativeSignature(
129
+ g.view_copy.func, symint=metadata.supports_symint()
130
+ )
131
+
132
+ # view is a dispatcher signature, since we're calling into the at::_ops API
133
+ view_sig = DispatcherSignature(g.view.func)
134
+
135
+ view_api_name = g.view.func.name.unambiguous_name()
136
+ exprs = ", ".join(
137
+ [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]
138
+ )
139
+
140
+ # view ops today always return either a Tensor or a list of Tensors
141
+ assert len(g.view.func.returns) == 1
142
+ assert g.view.func.returns[0].type == BaseType(
143
+ BaseTy.Tensor
144
+ ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)
145
+
146
+ if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
147
+ return_cloned_output = """\
148
+ return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
149
+ else:
150
+ # If the return type is a list, we need to clone each tensor in the list.
151
+ return_cloned_output = f"""\
152
+ {view_copy_sig.returns_type().cpp_type()} out_clone;
153
+ for (const auto i : c10::irange(output.size())) {{
154
+ out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
155
+ }}
156
+ return out_clone;"""
157
+
158
+ # The default generated composite kernel for {view}_copy() operators just clones
159
+ # the input tensor, and runs the underlying view on the clone.
160
+ return f"""
161
+ {view_copy_sig.defn(name=metadata.kernel)} {{
162
+ auto output = at::_ops::{view_api_name}::call({exprs});
163
+ {return_cloned_output}
164
+ }}
165
+ """
166
+
167
+
168
+ def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
169
+ assert len(rets) == len(names)
170
+ if len(rets) == 0:
171
+ return ""
172
+ elif len(rets) == 1:
173
+ return f"return {names[0]};"
174
+ else:
175
+ return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
176
+
177
+
178
+ def modifies_arguments(f: NativeFunction) -> bool:
179
+ return any(
180
+ a.annotation is not None and a.annotation.is_write
181
+ for a in f.func.arguments.flat_all
182
+ )
183
+
184
+
185
+ def wrapper_name(func: FunctionSchema) -> str:
186
+ if func.name.overload_name:
187
+ return f"{cpp.name(func)}_{func.name.overload_name}"
188
+ else:
189
+ return cpp.name(func)
190
+
191
+
192
+ def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
193
+ return isinstance(a, SelfArgument) or (
194
+ isinstance(a, Argument) and a.type.is_tensor_like()
195
+ )
196
+
197
+
198
+ # We need to wrap / unwrap various arguments from the op in the functionalization kernels.
199
+ # Some op schemas include non-owning types though (like TensorList),
200
+ # and when we unwrap them we expect to get out an owning type!.
201
+ # We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
202
+ def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
203
+ if t == BaseCType(tensorListT):
204
+ return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
205
+ if t == BaseCType(iTensorListRefT):
206
+ return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}"
207
+ # There are technically other non-owning types out there (like IntArrayRef),
208
+ # but functionalization only actually cares about the ones involving tensors.
209
+ return t, lambda x: x
210
+
211
+
212
+ # unwraps all tensor-like arguments, returning:
213
+ # (1) a string containing all of the logic that does the unwrapping
214
+ # (2) a context, to be used by translate(), with all of the relevant bindings.
215
+ def unwrap_tensor_args(
216
+ sig: DispatcherSignature, *, is_view_op: bool
217
+ ) -> tuple[str, list[Binding]]:
218
+ context: list[Binding] = []
219
+ unwrapped_tensor_args: list[str] = []
220
+ for arg in sig.arguments():
221
+ if is_tensor_like(arg.argument):
222
+ # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
223
+ unwrapped_name = f"{arg.name}_"
224
+ # For most ops, the functionalization needs to sync any pending updates on the input tensors
225
+ # before calling the operator, since otherwise the operator will act on stale data.
226
+ # For view ops though, we can continue to defer syncing until the tensor is used by
227
+ # a non-view operator.
228
+ maybe_sync_input = (
229
+ "" if is_view_op else f"at::functionalization::impl::sync({arg.name});"
230
+ )
231
+ unwrapped_type, conversion_fn = get_owning_type(
232
+ arg.nctype.remove_const_ref().type
233
+ )
234
+ unwrapped_tensor_args.append(
235
+ f"""
236
+ {unwrapped_type.cpp_type()} {unwrapped_name};
237
+ if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
238
+ {maybe_sync_input}
239
+ {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
240
+ }} else {{
241
+ {unwrapped_name} = {conversion_fn(arg.name)};
242
+ }}"""
243
+ )
244
+ context.append(arg.with_name(unwrapped_name))
245
+ else:
246
+ # for non-tensor inputs, we want to pass them directly into the redispatch calls.
247
+ context.append(arg)
248
+ unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
249
+ return unwrap_tensor_args_str, context
250
+
251
+
252
+ # converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
253
+ # (1) a string containing all of the logic that does the conversions.
254
+ # (2) a context, to be used by translate(), with all of the relevant bindings.
255
+ def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
256
+ context: list[Binding] = []
257
+ unwrapped_tensor_args: list[str] = []
258
+ for arg in sig.arguments():
259
+ if is_tensor_like(arg.argument):
260
+ # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
261
+ a_ = arg.name
262
+ unwrapped_name = f"{arg.name}_meta"
263
+ unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});")
264
+ context.append(arg.with_name(unwrapped_name))
265
+ else:
266
+ # for non-tensor inputs, we want to pass them directly into the redispatch calls.
267
+ context.append(arg)
268
+ unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
269
+ return unwrap_tensor_args_str, context
270
+
271
+
272
+ # The functionalization codegen currently expects view op schemas to have this form:
273
+ # foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose)
274
+ # foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_)
275
+ def assert_view_op_properties(func: FunctionSchema) -> None:
276
+ def is_alias(a: Argument) -> bool:
277
+ return a.annotation is not None
278
+
279
+ args = func.arguments.flat_non_out
280
+ # The first argument is a tensor with an alias semantics (annotations)
281
+ assert len(args) > 0 and args[0].type == BaseType(
282
+ BaseTy.Tensor
283
+ ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
284
+ but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
285
+ # No other arguments have aliasing semantics
286
+ assert is_alias(args[0]) and not any(
287
+ is_alias(a) for a in args[1:]
288
+ ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
289
+ View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""
290
+
291
+
292
+ # One-liner expression for checking if an expression expr of type type has any
293
+ # symbolic values.
294
+ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
295
+ if type == BaseCType(SymIntT):
296
+ return f"{expr}.is_symbolic()"
297
+
298
+ if isinstance(type, OptionalCType):
299
+ innerexpr = f"(*{expr})"
300
+ return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false"
301
+
302
+ if type == BaseCType(optionalSymIntArrayRefT):
303
+ return emit_expr_has_symbolic_values(
304
+ expr, OptionalCType(BaseCType(symIntArrayRefT))
305
+ )
306
+
307
+ if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))):
308
+ argname = "arg"
309
+ lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT))
310
+ return (
311
+ "std::any_of("
312
+ f"{expr}.begin(), {expr}.end(), "
313
+ f"[=](auto& {argname}) {{ return {lambda_check}; }})"
314
+ )
315
+
316
+ raise ValueError(
317
+ "unsupported type for has_symbolic_values check. "
318
+ "It should be a SymInt or a collection of those. "
319
+ f"Got: {type.cpp_type()}"
320
+ )
321
+
322
+
323
+ # Detects whether any of the SymInt arguments are, in fact, symbolic values.
324
+ # This is used in the constructor of ViewMeta.
325
+ def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
326
+ name = "has_symbolic_inputs"
327
+ statements = [
328
+ f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
329
+ for binding in sig.arguments()
330
+ if (
331
+ isinstance(binding.argument, Argument)
332
+ and binding.argument.type.is_symint_like()
333
+ )
334
+ ]
335
+ body = "\n ".join(statements)
336
+ return (
337
+ name,
338
+ f"""
339
+ bool {name} = false;
340
+ {body}""",
341
+ )
342
+
343
+
344
+ # Generates the Functionalization kernel for:
345
+ # - ops that create aliases (e.g. transpose())
346
+ # - ops that are views AND mutations (e.g. transpose_())
347
+ def emit_view_functionalization_body(
348
+ g: NativeFunctionsViewGroup, *, view_inplace: bool
349
+ ) -> str:
350
+ if view_inplace:
351
+ # This op is both an inplace op AND a view op.
352
+ # See Note [Functionalization Pass - Inplace View Ops] for details.
353
+ # I currently have the view meta call into the out-of-place variant of the view, to avoid
354
+ # having to define an extra ~20 inplace {view}_inverse_ functions.
355
+ # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
356
+ # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
357
+ # with the same name but the trailing underscore removed.
358
+ # This is currently asserted at parse time in gen.py (see error_check_native_functions).
359
+ assert g.view_inplace is not None
360
+ f = g.view_inplace
361
+ else:
362
+ f = g.view
363
+
364
+ assert g.view_copy is not None
365
+ with native_function_manager(f):
366
+ call_sig = DispatcherSignature.from_schema(g.view_copy.func)
367
+
368
+ # the "view_copy" op name that the functionalization kernels need to call
369
+ api_name = g.view_copy.func.name.unambiguous_name()
370
+ # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
371
+ # "no-op"ing in this context is just redispatching to the original op.
372
+ noop_api_name = f.func.name.unambiguous_name()
373
+
374
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
375
+ assert_view_op_properties(f.func)
376
+ view_tensor_name = dispatcher_sig.arguments()[0].name
377
+
378
+ return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()
379
+
380
+ unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
381
+ dispatcher_sig, is_view_op=True
382
+ )
383
+ view_redispatch_args = [
384
+ e.expr
385
+ for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
386
+ ]
387
+
388
+ forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
389
+ reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
390
+
391
+ # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
392
+ meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
393
+ meta_call_args = [
394
+ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
395
+ ]
396
+
397
+ (
398
+ symbolic_inputs_varname,
399
+ symbolic_inputs_check,
400
+ ) = emit_has_symbolic_inputs(call_sig)
401
+
402
+ if "inplace_view" in f.tags:
403
+ # See Note [Functionalization Pass - Inplace View Ops] for more details
404
+ return f"""
405
+ {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
406
+ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
407
+ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
408
+ {unwrap_tensor_args_str}
409
+ at::AutoDispatchSkipFunctionalize guard;
410
+ return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
411
+ }}
412
+ auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
413
+ auto inverse_return_mode = (
414
+ reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
415
+ : at::functionalization::InverseReturnMode::NeverView
416
+ );
417
+ {symbolic_inputs_check}
418
+ at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
419
+ {forward_lambda.decl()} {{
420
+ if (reapply_views) {{
421
+ return {forward_lambda.inner_call(reapply_views=True)}
422
+ }} else {{
423
+ return {forward_lambda.inner_call(reapply_views=False)}
424
+ }}
425
+ }},
426
+ {reverse_lambda.decl()} {{
427
+ return {reverse_lambda.inner_call()}
428
+ }},
429
+ /*has_symbolic_inputs=*/{symbolic_inputs_varname}
430
+ );
431
+ auto compute_reference_meta =
432
+ {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
433
+ {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
434
+ {return_type} reference_tensor_output;
435
+ if (compute_reference_meta) {{
436
+ {meta_conversion_str}
437
+ at::AutoDispatchSkipFunctionalize func_guard;
438
+ c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
439
+ reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
440
+ }}
441
+ // This function adds the above view meta to the current tensor and replays them off the base,
442
+ // mutating the size/stride info of the current FunctionalTensorWrapper.
443
+ // Because of this, we need to make sure to run the reference shape function above,
444
+ // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides)
445
+ at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
446
+ // See Note [Propagating strides in the functionalization pass]
447
+ // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
448
+ // on a reference implementation here (instead of relying on the output from the forward lambda
449
+ // having the correct stride info)
450
+ if (compute_reference_meta) {{
451
+ at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
452
+ }}
453
+ return {view_tensor_name};
454
+ }}
455
+ """
456
+
457
+ else:
458
+ is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
459
+ return f"""
460
+ {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
461
+ {unwrap_tensor_args_str}
462
+ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
463
+ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
464
+ at::AutoDispatchSkipFunctionalize guard;
465
+ return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
466
+ }}
467
+ auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
468
+ auto inverse_return_mode = (
469
+ reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
470
+ : at::functionalization::InverseReturnMode::NeverView
471
+ );
472
+ auto compute_reference_meta =
473
+ {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
474
+ {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
475
+ {return_type} reference_tensor_output;
476
+ if (compute_reference_meta) {{
477
+ {meta_conversion_str}
478
+ at::AutoDispatchSkipFunctionalize func_guard;
479
+ c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
480
+ reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
481
+ }}
482
+ {return_type} tmp_output;
483
+ {{
484
+ at::AutoDispatchSkipFunctionalize guard;
485
+ if (reapply_views) {{
486
+ tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
487
+ }} else {{
488
+ tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)});
489
+ }}
490
+ }}
491
+ {symbolic_inputs_check}
492
+ at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
493
+ {forward_lambda.decl()} {{
494
+ if (reapply_views) {{
495
+ return {forward_lambda.inner_call(reapply_views=True)}
496
+ }} else {{
497
+ return {forward_lambda.inner_call(reapply_views=False)}
498
+ }}
499
+ }},
500
+ {reverse_lambda.decl()} {{
501
+ return {reverse_lambda.inner_call()}
502
+ }},
503
+ /*has_symbolic_inputs=*/{symbolic_inputs_varname},
504
+ /*is_multi_output=*/{str(is_multi_output_view).lower()},
505
+ /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
506
+ );
507
+ auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
508
+ // See Note [Propagating strides in the functionalization pass]
509
+ if (compute_reference_meta) {{
510
+ at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
511
+ }}
512
+ return out;
513
+ }}
514
+ """
515
+
516
+
517
+ def maybe_create_output(f: NativeFunction, var_name: str) -> str:
518
+ if len(f.func.returns) == 0:
519
+ return ""
520
+ return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
521
+ return f"{return_type} {var_name} = "
522
+
523
+
524
+ # Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
525
+ # this returns two lists of names, consisting of:
526
+ # - the names of returns corresponding to the original (mutable) inputs of the outer function
527
+ # - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
528
+ def get_mutable_redispatch_return_names(
529
+ f: NativeFunction, inner_return_var: str
530
+ ) -> tuple[list[str], list[str]]:
531
+ aliased_returns = []
532
+ non_aliased_returns = []
533
+ for i, name in enumerate(f.func.aliased_return_names()):
534
+ if name is not None:
535
+ aliased_returns.append(name)
536
+ else:
537
+ non_aliased_returns.append(
538
+ inner_return_var
539
+ if len(f.func.returns) == 1
540
+ else f"std::get<{i}>({inner_return_var})"
541
+ )
542
+ return aliased_returns, non_aliased_returns
543
+
544
+
545
+ # When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
546
+ # - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
547
+ # - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
548
+ def return_from_mutable_noop_redispatch(
549
+ f: NativeFunction, inner_return_var: str
550
+ ) -> str:
551
+ aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
552
+ # Just get all of the return names, and immediately return them
553
+ return return_str(f.func.returns, aliased + non_aliased)
554
+
555
+
556
+ def wrap_propagate_mutations_and_return(
557
+ f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
558
+ ) -> str:
559
+ mutable_arg_names = f.func.arguments.mutable_arg_names()
560
+ (
561
+ aliased_outer_rets,
562
+ non_aliased_outer_rets,
563
+ ) = get_mutable_redispatch_return_names(f, inner_return_var)
564
+ _, non_aliased_inner_rets = get_mutable_redispatch_return_names(
565
+ functional_op, inner_return_var
566
+ )
567
+ # The outer function may have a mix of aliased and non-aliased outputs,
568
+ # But the inner functional op that we're transforming to should only have non-aliased outputs
569
+ assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
570
+ non_aliased_inner_rets
571
+ )
572
+
573
+ # First, take all of the newly created outputs from the inner call and wrap them into functional tensors
574
+ updates = []
575
+ non_aliased_wrapped_ret_names = []
576
+ for i, inner_ret in enumerate(
577
+ non_aliased_inner_rets[: len(non_aliased_outer_rets)]
578
+ ):
579
+ ret_name = f"output_{i}"
580
+ updates.append(
581
+ f"""\
582
+ auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
583
+ )
584
+ non_aliased_wrapped_ret_names.append(ret_name)
585
+
586
+ # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
587
+ # and propagate the mutations
588
+ for outer_arg, inner_ret in zip(
589
+ mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
590
+ ):
591
+ updates.append(
592
+ f"""\
593
+ auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
594
+ at::functionalization::impl::replace_({outer_arg}, {inner_ret});
595
+ at::functionalization::impl::commit_update({outer_arg});
596
+ at::functionalization::impl::sync({outer_arg});
597
+ auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
598
+ at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
599
+ )
600
+
601
+ # Finally, we return:
602
+ # - Any mutable arguments that also returns
603
+ # - Any immutable returns that were created wrapping the output from the inner call
604
+ returns_str = return_str(
605
+ f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
606
+ )
607
+ updates_str = "\n".join(updates)
608
+ return f"""\
609
+ {updates_str}
610
+ {returns_str}"""
611
+
612
+
613
+ # Generates the Functionalization kernel for:
614
+ # - mutation ops (inplace and out= ops)
615
+ @with_native_function_and
616
+ def emit_inplace_functionalization_body(
617
+ f: NativeFunction, g: NativeFunctionsGroup
618
+ ) -> str:
619
+ # mutation case
620
+ assert modifies_arguments(f)
621
+
622
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
623
+
624
+ unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
625
+ dispatcher_sig, is_view_op=False
626
+ )
627
+
628
+ mutated_names = [
629
+ a.name
630
+ for a in f.func.arguments.flat_all
631
+ if a.type.is_tensor_like() and a.annotation is not None
632
+ ]
633
+ non_mutated_names = [
634
+ a.name
635
+ for a in f.func.arguments.flat_all
636
+ if a.type.is_tensor_like() and a.annotation is None
637
+ ]
638
+ non_mutated_tensor_names = [
639
+ a.name
640
+ for a in f.func.arguments.flat_all
641
+ if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
642
+ ]
643
+ # all mutable inputs must be functional tensors in order to participate in functionalization
644
+ check_all_mutated_args_are_functional = " && ".join(
645
+ ["true"]
646
+ + [
647
+ f"at::functionalization::impl::isFunctionalTensor({a})"
648
+ for a in mutated_names
649
+ ]
650
+ )
651
+ check_any_non_mutated_args_are_functional = " || ".join(
652
+ ["false"]
653
+ + [
654
+ f"at::functionalization::impl::isFunctionalTensor({a})"
655
+ for a in non_mutated_names
656
+ ]
657
+ )
658
+
659
+ check_any_non_mutated_tensors_are_xla = " || ".join(
660
+ ["false"]
661
+ + [
662
+ f"{a}.device().type() == c10::DeviceType::XLA"
663
+ for a in non_mutated_tensor_names
664
+ ]
665
+ )
666
+ # These are used in the cases where we don't functionalize and redispatch to the inplace op
667
+ # case 1: we hit an inplace op that doesn't have an out-of-place equivalent
668
+ # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
669
+ inplace_exprs = [
670
+ e.expr
671
+ for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
672
+ ]
673
+
674
+ # call the out-of-place variant of the op
675
+ return_type = (
676
+ dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
677
+ )
678
+ functional_sig = DispatcherSignature.from_schema(g.functional.func)
679
+ functional_exprs = [
680
+ e.expr
681
+ for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
682
+ ]
683
+
684
+ if f.func.is_out_fn():
685
+ mutable_input_post_processing = "\n".join(
686
+ [
687
+ f"""
688
+ at::functionalization::impl::replace_(
689
+ {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
690
+ at::functionalization::impl::commit_update({a.name});"""
691
+ for (i, a) in enumerate(f.func.arguments.out)
692
+ if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
693
+ ]
694
+ )
695
+ else:
696
+ mutable_input_post_processing = "\n".join(
697
+ [
698
+ f"""
699
+ at::functionalization::impl::replace_({a.name}, tmp_output);
700
+ at::functionalization::impl::commit_update({a.name});"""
701
+ for a in f.func.arguments.flat_all
702
+ if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
703
+ ]
704
+ )
705
+
706
+ meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
707
+ # We don't want to run the inplace meta func for ops like .set_(), because:
708
+ # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
709
+ # where broadcasting will work for the out-of-place case but should fail on the inplace call
710
+ # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
711
+ # into a meta storage
712
+ any_storage_args = any(
713
+ a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
714
+ )
715
+
716
+ return f"""
717
+ {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
718
+ if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
719
+ // Before converting the mutable op to its functional variant, run meta tensors through the original op.
720
+ // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
721
+ // (We can only do this for inplace ops today though, because they technically all support meta tensors).
722
+ {meta_conversion_str}
723
+ at::AutoDispatchSkipFunctionalize func_guard;
724
+ c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
725
+ at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
726
+ }}
727
+ {unwrap_tensor_args_str}
728
+ if (!({check_all_mutated_args_are_functional})) {{
729
+ // We want to disable this check if there are any XLA tensors.
730
+ // cpu_tensor.copy_(xla_tensor) is valid code.
731
+ if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
732
+ // case 1: trying to mutate a non functional tensor with a functional tensor is an error
733
+ TORCH_INTERNAL_ASSERT(false,
734
+ "mutating a non-functional tensor with a functional tensor is not allowed.",
735
+ " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
736
+ }} else {{
737
+ // case 2: arguments are not functional tensors, so we no-op and redispatch.
738
+ at::AutoDispatchSkipFunctionalize guard;
739
+ {maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
740
+ {return_from_mutable_noop_redispatch(f, 'tmp_output')}
741
+ }}
742
+ }} else {{
743
+ {return_type} tmp_output;
744
+ {{
745
+ at::AutoDispatchSkipFunctionalize guard;
746
+ tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
747
+ }}
748
+ {wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
749
+ }}
750
+ }}"""
751
+
752
+
753
+ # The below functions generate RegisterFunctionalization.cpp
754
+ # These files provide the kernels that run the functionalization pass, which can be opted into
755
+ # per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch).
756
+
757
+
758
+ # See Note [Functionalization Pass: View Inverses].
759
+ def gen_functionalization_view_inverse_declaration(
760
+ selector: SelectiveBuilder, g: NativeFunctionsViewGroup
761
+ ) -> str | None:
762
+ # For every (non-composite) view op, we need a corresponding "inverse view" function.
763
+ # This generates the declarations so we get a good compiler error when someone adds a new view.
764
+ @with_native_function
765
+ def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
766
+ if g.view.has_composite_implicit_autograd_kernel:
767
+ return None
768
+ view_inverse_sig = ViewInverseSignature(g)
769
+ return view_inverse_sig.decl()
770
+
771
+ return emit_decl_helper(g)
772
+
773
+
774
+ def gen_functionalization_registration(
775
+ selector: SelectiveBuilder,
776
+ g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
777
+ composite_implicit_autograd_index: BackendIndex,
778
+ ) -> list[str]:
779
+ @with_native_function
780
+ def emit_registration_helper(f: NativeFunction) -> str:
781
+ assert not f.has_composite_implicit_autograd_kernel
782
+ registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
783
+ return f'm.impl("{f.func.name}", {registration_str});'
784
+
785
+ # Don't generate kernels in mobile build
786
+ if not selector.include_all_operators:
787
+ return []
788
+
789
+ if isinstance(g, NativeFunctionsViewGroup):
790
+ # functionalization needs to register kernels for view + view_inplace ops
791
+ # See Note [Functionalization <> torch.Tensor constructor]
792
+ if str(g.view.func.name) == "lift_fresh":
793
+ return []
794
+ view_str = []
795
+ if not g.view.has_composite_implicit_autograd_kernel:
796
+ view_str.append(emit_registration_helper(g.view))
797
+ if (
798
+ g.view_inplace is not None
799
+ and not g.view_inplace.has_composite_implicit_autograd_kernel
800
+ ):
801
+ assert g.view_inplace.is_view_op
802
+ view_str.append(emit_registration_helper(g.view_inplace))
803
+ return view_str
804
+
805
+ elif isinstance(g, NativeFunctionsGroup):
806
+ # Gets a hand-written functionalization kernel
807
+ if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor":
808
+ fns = []
809
+ else:
810
+ fns = list(g.functions())
811
+ else:
812
+ if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
813
+ return []
814
+ fns = [g]
815
+
816
+ registrations = []
817
+ for f in fns:
818
+ if f.has_composite_implicit_autograd_kernel:
819
+ continue
820
+ if str(f.func.name) == "lift":
821
+ # See Note [Functionalization <> torch.Tensor constructor]
822
+ return []
823
+ if str(f.func.name) == "resize_":
824
+ # See Note [resize_ in Functionalization]
825
+ return []
826
+ if str(f.func.name.name) != "set_":
827
+ assert not f.is_view_op
828
+ # functionalization needs to generate and register kernels for inplace ops.
829
+ # We *also* need to directly register CompositeImplicitAUtograd kernels
830
+ # so that they decompose properly before functioanlization.
831
+ if modifies_arguments(f):
832
+ registrations.append(emit_registration_helper(f))
833
+ return registrations
834
+
835
+
836
+ def gen_functionalization_definition(
837
+ selector: SelectiveBuilder,
838
+ # Note: Ideally this code should never have to look at NativeFunction
839
+ # (and instead only need to operate on grouped NativeFunctions).
840
+ # The only reason currently is because we need to emit direct dispatch registrations
841
+ # For CompositeImplicitAutograd operators, which are potentially ungrouped.
842
+ g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
843
+ ) -> list[str]:
844
+ # Don't generate kernels in mobile build
845
+ if not selector.include_all_operators:
846
+ return []
847
+
848
+ if isinstance(g, NativeFunctionsViewGroup):
849
+ # Case 1: emit view -> view_copy kernels for the functionalization pass
850
+ view_defs = []
851
+ if not g.composite:
852
+ # invariant: NativeFunctionsViewGroup's always have a view_copy operator
853
+ # if the view is not composite (implicit autograd)
854
+ assert g.view_copy is not None, dataclass_repr(g, indent=1)
855
+ view_defs.append(emit_view_functionalization_body(g, view_inplace=False))
856
+ if g.view_inplace is not None:
857
+ view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
858
+ return view_defs
859
+ elif isinstance(g, NativeFunction):
860
+ # Invariant: all mutable operators that we need to handle in functionalization
861
+ # should have been properly grouped up.
862
+ # TODO: The below ops all have "problematic" schemas that prevent them from
863
+ # getting functionalized. Instead of bending over backwards to get things to work,
864
+ # I think we should either:
865
+ # (1) fix their schemas (BC-breaking)
866
+ # (2) hand-write their functionalization kernels
867
+ if (
868
+ str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
869
+ and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
870
+ ):
871
+ assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
872
+ return []
873
+ else:
874
+ # Case 2: emit inplace -> out-of-place kernels for the functionalization pass
875
+ mutation_defs = []
876
+ mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
877
+ if g.inplace is not None:
878
+ mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
879
+ if g.mutable is not None:
880
+ mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
881
+ return mutation_defs
882
+ return []
.venv/lib/python3.11/site-packages/torchgen/gen_lazy_tensor.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ from collections import namedtuple
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Iterable, Iterator, Sequence
8
+
9
+ import yaml
10
+
11
+ import torchgen.dest as dest
12
+ from torchgen.api.lazy import setValueT
13
+ from torchgen.api.types import BaseCppType
14
+ from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
15
+ from torchgen.gen import get_grouped_native_functions, parse_native_yaml
16
+ from torchgen.gen_backend_stubs import (
17
+ error_on_missing_kernels,
18
+ gen_dispatcher_registrations,
19
+ gen_dispatchkey_nativefunc_headers,
20
+ parse_backend_yaml,
21
+ )
22
+ from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
23
+ from torchgen.selective_build.selector import SelectiveBuilder
24
+ from torchgen.utils import FileManager, NamespaceHelper
25
+ from torchgen.yaml_utils import YamlLoader
26
+
27
+
28
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
29
+ #
30
+ # Lazy Tensor Codegen
31
+ #
32
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
33
+ # Overview
34
+ # ~~~~~~~~
35
+ #
36
+ # This codegen script builds on existing data models and helpers used
37
+ # by all ATen backends, and adds new functionality specific to lazy
38
+ # tensor backends.
39
+ #
40
+ # Inputs:
41
+ # - <backend>_native_functions.yaml: controls which operators are
42
+ # supported by the backend.
43
+ #
44
+ # Outputs:
45
+ # (for all backends)
46
+ # <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
47
+ # - opt-in: also generate 'lowering' methods for the TorchScript backend only
48
+ # <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
49
+ # - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
50
+ # <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
51
+ # ops
52
+ #
53
+ # Register<DispatchKey>.cpp registers all op implementations with the dispatcher
54
+ # RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
55
+ #
56
+ # Validation Helpers:
57
+ # - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
58
+ # implementations in torch/csrc/lazy/core/shape_inference.*
59
+ # - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
60
+ # (non-codegen) implementation file
61
+ #
62
+ #
63
+ # About the Data Model
64
+ # ~~~~~~~~~~~~~~~~~~~~
65
+ #
66
+ # Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
67
+ # we care about. In this case, the <backend>_native_functions yaml defines a subset of the core operators
68
+ # (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
69
+ # Backends can list ops in two categories:
70
+ # - `supported` ops require hand-implementations but still get codegenned declarations and registrations
71
+ # - `full_codegen` ops get implementations (and IR classes) generated too
72
+ #
73
+ # Each native function is modeled as an object with a schema, and each schema has objects representing their
74
+ # arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor
75
+ # backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
76
+ # types (stringref) with actual string objects, and this is done by manipulating the data model objects.
77
+ # - see api/lazy.py for the lazy data model
78
+ #
79
+ # Once the data model is set up, the rest of this script processes a number of templates for output CPP file
80
+ # and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These
81
+ # helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
82
+ #
83
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
84
+
85
+
86
+ # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
87
+ # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
88
+ ParsedExternalYaml = namedtuple(
89
+ "ParsedExternalYaml",
90
+ ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
91
+ )
92
+
93
+
94
+ def parse_native_functions_keys(
95
+ backend_yaml_path: str,
96
+ grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
97
+ ) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
98
+ with open(backend_yaml_path) as f:
99
+ yaml_values = yaml.load(f, Loader=YamlLoader)
100
+ assert isinstance(yaml_values, dict)
101
+
102
+ full_codegen = yaml_values.pop("full_codegen", [])
103
+ non_native = yaml_values.pop("non_native", [])
104
+ ir_gen = yaml_values.pop("ir_gen", [])
105
+ assert isinstance(full_codegen, list)
106
+ assert isinstance(non_native, list)
107
+ assert isinstance(ir_gen, list)
108
+ full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
109
+ ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
110
+ return full_codegen_opnames, non_native, ir_gen_opnames
111
+
112
+
113
+ def validate_shape_inference_header(
114
+ shape_inference_hdr: str, expected_shape_infr_decls: list[str]
115
+ ) -> None:
116
+ try:
117
+ with open(shape_inference_hdr) as f:
118
+ shape_infr_decls = f.read()
119
+ shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
120
+ except OSError as e:
121
+ raise AssertionError(
122
+ f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
123
+ ) from e
124
+
125
+ # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
126
+
127
+ missing_decls = [
128
+ decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
129
+ ]
130
+ if missing_decls:
131
+ raise Exception( # noqa: TRY002
132
+ f"""Missing shape inference function.\n
133
+ Please add declare this function in {shape_inference_hdr}:\n
134
+ and implement it in the corresponding shape_inference.cpp file.\n
135
+ {os.linesep.join(missing_decls)}"""
136
+ )
137
+
138
+
139
+ # Some helper functions for the codegen.
140
+ def get_ltc_helper_fns() -> str:
141
+ return """\
142
+ at::Tensor to_meta(const at::Tensor& tensor) {
143
+ // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
144
+ if (!tensor.defined()) return tensor;
145
+ auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
146
+ /*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
147
+ /*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
148
+ // needs to handle wrapped numbers, so dtype promotion works properly.
149
+ if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
150
+ out.unsafeGetTensorImpl()->set_wrapped_number(true);
151
+ }
152
+ return out;
153
+ }
154
+ std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
155
+ if (tensor.has_value()) {
156
+ return to_meta(*tensor);
157
+ }
158
+ return std::nullopt;
159
+ }
160
+
161
+ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
162
+ std::vector<at::Tensor> outs;
163
+ outs.reserve(t_list.size());
164
+ for (const auto& tensor : t_list) {
165
+ outs.push_back(to_meta(tensor));
166
+ }
167
+ return outs;
168
+ }
169
+ """
170
+
171
+
172
+ class default_args:
173
+ node_base: str = "Node"
174
+ node_base_hdr: str | None = None
175
+ shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
176
+ tensor_class: str = "torch::lazy::LazyTensor"
177
+ tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
178
+ lazy_ir_generator: type[GenLazyIR] = GenLazyIR
179
+ native_func_definition_generator: type[
180
+ GenLazyNativeFuncDefinition
181
+ ] = GenLazyNativeFuncDefinition
182
+ backend_name: str = "TorchScript"
183
+
184
+
185
+ def main() -> None:
186
+ parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
187
+ parser.add_argument(
188
+ "-s",
189
+ "--source-yaml",
190
+ "--source_yaml",
191
+ help="path to source yaml file containing operator external definitions",
192
+ )
193
+ parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
194
+ parser.add_argument(
195
+ "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
196
+ )
197
+ parser.add_argument(
198
+ "--impl-path",
199
+ "--impl_path",
200
+ type=str,
201
+ default=None,
202
+ help="path to the source C++ file containing kernel definitions",
203
+ )
204
+ parser.add_argument(
205
+ "--gen-ts-lowerings",
206
+ "--gen_ts_lowerings",
207
+ action="store_true",
208
+ help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
209
+ )
210
+ parser.add_argument(
211
+ "--node-base",
212
+ "--node_base",
213
+ type=str,
214
+ default=default_args.node_base,
215
+ help="Name of backend specific custom Lazy IR Node base class",
216
+ )
217
+ parser.add_argument(
218
+ "--node-base-hdr",
219
+ "--node_base_hdr",
220
+ type=str,
221
+ default=default_args.node_base_hdr,
222
+ help="Path to header file defining custom Lazy IR Node base class",
223
+ )
224
+ parser.add_argument(
225
+ "--shape-inference-hdr",
226
+ "--shape_inference_hdr",
227
+ type=str,
228
+ default=default_args.shape_inference_hdr,
229
+ help="Path to header file defining custom Lazy shape inference functions",
230
+ )
231
+ parser.add_argument(
232
+ "--tensor-class",
233
+ "--tensor_class",
234
+ type=str,
235
+ default=default_args.tensor_class,
236
+ help="Name of backend specific custom Lazy Tensor class",
237
+ )
238
+ parser.add_argument(
239
+ "--tensor-class-hdr",
240
+ "--tensor_class_hdr",
241
+ type=str,
242
+ default=default_args.tensor_class_hdr,
243
+ help="Path to header file defining custom Lazy Tensor class",
244
+ )
245
+ parser.add_argument(
246
+ "--backend-name",
247
+ "--backend_name",
248
+ type=str,
249
+ default=default_args.backend_name,
250
+ help="Name of the backend to generate",
251
+ )
252
+ options = parser.parse_args()
253
+
254
+ # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
255
+ torch_root = Path(__file__).parent.parent.parent.absolute()
256
+ aten_path = str(torch_root / "aten" / "src" / "ATen")
257
+ lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
258
+ if options.gen_ts_lowerings:
259
+ lazy_ir_generator = GenTSLazyIR
260
+ native_func_definition_generator: type[
261
+ GenLazyNativeFuncDefinition
262
+ ] = default_args.native_func_definition_generator
263
+
264
+ run_gen_lazy_tensor(
265
+ aten_path,
266
+ options.source_yaml,
267
+ options.output_dir,
268
+ options.dry_run,
269
+ options.impl_path,
270
+ options.node_base,
271
+ options.node_base_hdr,
272
+ options.tensor_class,
273
+ options.tensor_class_hdr,
274
+ options.shape_inference_hdr,
275
+ lazy_ir_generator,
276
+ native_func_definition_generator,
277
+ options.backend_name,
278
+ )
279
+
280
+
281
+ def run_gen_lazy_tensor(
282
+ aten_path: str,
283
+ source_yaml: str,
284
+ output_dir: str,
285
+ dry_run: bool,
286
+ impl_path: str | None,
287
+ node_base: str = default_args.node_base,
288
+ node_base_hdr: str | None = default_args.node_base_hdr,
289
+ tensor_class: str = default_args.tensor_class,
290
+ tensor_class_hdr: str = default_args.tensor_class_hdr,
291
+ shape_inference_hdr: str = default_args.shape_inference_hdr,
292
+ lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
293
+ native_func_definition_generator: type[
294
+ GenLazyNativeFuncDefinition
295
+ ] = default_args.native_func_definition_generator,
296
+ # build_in_tree is true for TS backend and affects include paths
297
+ build_in_tree: bool = False,
298
+ # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
299
+ # it must match how ATen was built
300
+ per_operator_headers: bool = False,
301
+ backend_name: str = default_args.backend_name,
302
+ gen_forced_fallback_code: bool = False,
303
+ use_lazy_shape: bool = True,
304
+ # the following arguments are temporary customization points for xla backend migration.
305
+ # do not rely on them otherwise, they should be removed once migration is complete
306
+ backend_namespace: str = "torch::lazy",
307
+ get_tensorlist: str = "GetTensorList",
308
+ get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
309
+ try_get_tensor: str = "TryGetLtcTensor",
310
+ metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
311
+ create_tensor: str = "LazyTensor::Create",
312
+ create_from_first_tensor: bool = False,
313
+ create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
314
+ tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
315
+ lazy_value_class: str = "torch::lazy::Value",
316
+ lazy_tensor_ptr: str = "LazyTensorPtr",
317
+ get_device_fn: str = "torch::lazy::GetBackendDevice",
318
+ ) -> None:
319
+ lv_tokens = lazy_value_class.split("::")
320
+ lv_class = lv_tokens[-1]
321
+ lv_ns = "::".join(lv_tokens[:-1])
322
+ setValueT(BaseCppType(lv_ns, lv_class))
323
+ template_dir = os.path.join(aten_path, "templates")
324
+
325
+ def make_file_manager(install_dir: str) -> FileManager:
326
+ return FileManager(
327
+ install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
328
+ )
329
+
330
+ fm = make_file_manager(output_dir)
331
+
332
+ native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
333
+ tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
334
+ parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
335
+ native_functions, backend_indices = (
336
+ parsed_yaml.native_functions,
337
+ parsed_yaml.backend_indices,
338
+ )
339
+ grouped_native_functions = get_grouped_native_functions(native_functions)
340
+
341
+ def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
342
+ """
343
+ We sort the native function because of the note in concat_map_codegen.
344
+ TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
345
+ """
346
+ func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
347
+ return str(func.name.name)
348
+
349
+ grouped_native_functions = sorted(
350
+ grouped_native_functions, key=sort_native_function
351
+ )
352
+
353
+ parsed_backend_yaml = parse_backend_yaml(
354
+ source_yaml, grouped_native_functions, backend_indices
355
+ )
356
+ backend_key = parsed_backend_yaml.backend_key
357
+ autograd_key = parsed_backend_yaml.autograd_key
358
+ cpp_namespace = parsed_backend_yaml.cpp_namespace
359
+ backend_indices = parsed_backend_yaml.backend_indices
360
+ # the following 3 keys are all processed differently
361
+ # for full_codegen, we generate IR, kernels, etc
362
+ # for ir_gen, we generate only IR
363
+ # non_native is used to register kernels not declared in
364
+ # native_functions.yaml
365
+ full_codegen, non_native, ir_gen = parse_native_functions_keys(
366
+ source_yaml, grouped_native_functions
367
+ )
368
+
369
+ def concat_map_codegen(
370
+ func: Callable[[NativeFunction], Sequence[str]],
371
+ xs: Iterable[NativeFunctionsGroup | NativeFunction],
372
+ ops_list: list[OperatorName] = full_codegen,
373
+ ) -> Iterator[str]:
374
+ """
375
+ We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
376
+ only code-gen additional entries for the inplace variant for the native functions.
377
+ """
378
+
379
+ for x in xs:
380
+ fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
381
+ for f in fs:
382
+ if f.func.name in ops_list:
383
+ yield from func(f)
384
+
385
+ selector = SelectiveBuilder.get_nop_selector()
386
+
387
+ assert backend_key is not None
388
+ class_name = backend_indices[backend_key].native_function_class_name()
389
+
390
+ if impl_path is not None:
391
+ error_on_missing_kernels(
392
+ native_functions,
393
+ backend_indices,
394
+ backend_key,
395
+ autograd_key,
396
+ class_name,
397
+ impl_path,
398
+ full_codegen,
399
+ )
400
+
401
+ """ Validate Shape Inference Definitions
402
+
403
+ Generated lazy native functions all perform shape inference, by first using a meta:: kernel
404
+ if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
405
+ knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
406
+ so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
407
+ to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
408
+ the expected signature which can be copy-pasted into shape_inference.h.
409
+
410
+ compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
411
+ to structured kernels.
412
+
413
+ See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
414
+ """
415
+ if shape_inference_hdr is not None:
416
+ expected_shape_infr_decls = list(
417
+ concat_map_codegen(
418
+ dest.GenLazyShapeInferenceDefinition(
419
+ backend_indices[backend_key], tensor_class
420
+ ),
421
+ grouped_native_functions,
422
+ )
423
+ )
424
+
425
+ validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
426
+ assert class_name is not None
427
+
428
+ # Generate nativefunction declarations
429
+ # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
430
+ # may want to register their own lazy kernels instead of registering the TS ones.
431
+ # The registration will lazily happen when init_ts_backend is called.
432
+ gen_dispatchkey_nativefunc_headers(
433
+ fm,
434
+ class_name,
435
+ cpp_namespace,
436
+ backend_indices,
437
+ grouped_native_functions,
438
+ backend_key,
439
+ autograd_key,
440
+ backend_name,
441
+ )
442
+
443
+ # Generate Dispatcher registrations which hook up the nativefunctions
444
+ for dispatch_key in (
445
+ [backend_key] if autograd_key is None else [backend_key, autograd_key]
446
+ ):
447
+ gen_dispatcher_registrations(
448
+ fm,
449
+ output_dir,
450
+ class_name,
451
+ backend_indices,
452
+ grouped_native_functions,
453
+ backend_key,
454
+ dispatch_key,
455
+ selector,
456
+ build_in_tree=build_in_tree,
457
+ per_operator_headers=per_operator_headers,
458
+ backend_name=backend_name,
459
+ eager_registration=False,
460
+ )
461
+
462
+ # Generate native function impls that build IR nodes
463
+ ns_helper = NamespaceHelper(cpp_namespace)
464
+ fm.write_with_template(
465
+ f"{backend_key}NativeFunctions.cpp",
466
+ "DispatchKeyNativeFunctions.cpp",
467
+ lambda: {
468
+ "includes": [
469
+ f"#include <{path}>"
470
+ for path in [
471
+ tensor_class_hdr,
472
+ shape_inference_hdr,
473
+ "ATen/Functions.h",
474
+ "ATen/native/TensorConversions.h",
475
+ "ATen/NativeFunctions.h",
476
+ "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
477
+ "ATen/MetaFunctions.h",
478
+ "ATen/Operators.h",
479
+ "ATen/native/CPUFallback.h",
480
+ "torch/csrc/lazy/core/ir_builder.h",
481
+ "torch/csrc/lazy/core/lazy_graph_executor.h",
482
+ "torch/csrc/lazy/core/metrics.h",
483
+ "torch/csrc/lazy/core/shape.h",
484
+ f"{output_dir}/{backend_key}NativeFunctions.h",
485
+ f"{output_dir}/LazyIr.h",
486
+ ]
487
+ + (
488
+ ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
489
+ if gen_forced_fallback_code
490
+ else []
491
+ )
492
+ ],
493
+ "helper_fns": get_ltc_helper_fns(),
494
+ "native_functions_include": "",
495
+ "namespace_prologue": ns_helper.prologue,
496
+ "namespace_epilogue": ns_helper.epilogue,
497
+ "native_function_definitions": list(
498
+ concat_map_codegen(
499
+ native_func_definition_generator(
500
+ f"{backend_key}NativeFunctions",
501
+ backend_indices[backend_key],
502
+ tensor_class,
503
+ gen_forced_fallback_code,
504
+ backend_namespace,
505
+ get_tensorlist,
506
+ get_tensor_or_wrap_number,
507
+ try_get_tensor,
508
+ metrics_counter,
509
+ create_tensor,
510
+ create_from_first_tensor,
511
+ create_aten_from_ltc_tensor,
512
+ tuple_aten_from_ltc_tensors,
513
+ lazy_tensor_ptr,
514
+ get_device_fn,
515
+ ),
516
+ grouped_native_functions,
517
+ )
518
+ ),
519
+ },
520
+ )
521
+ # Generate IR node classes
522
+ lazy_ir_obj = lazy_ir_generator(
523
+ backend_indices[backend_key], backend_name, node_base, use_lazy_shape
524
+ )
525
+
526
+ fm.write_with_template(
527
+ "LazyIr.h",
528
+ "LazyIr.h",
529
+ lambda: {
530
+ "lazy_ir_sysinc": [
531
+ f"#include <{path}>"
532
+ for path in [
533
+ "ATen/core/Formatting.h",
534
+ "c10/core/ScalarType.h",
535
+ "torch/csrc/lazy/core/hash.h",
536
+ "torch/csrc/lazy/core/ir.h",
537
+ "torch/csrc/lazy/core/shape.h",
538
+ "optional",
539
+ "vector",
540
+ ]
541
+ ],
542
+ "lazy_ir_inc": [f'#include "{node_base_hdr}"']
543
+ if node_base_hdr is not None
544
+ else [],
545
+ "ir_declarations": list(
546
+ concat_map_codegen(
547
+ lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
548
+ )
549
+ ),
550
+ "namespace_prologue": ns_helper.prologue,
551
+ "namespace_epilogue": ns_helper.epilogue,
552
+ },
553
+ )
554
+
555
+ # Generate Non Native IR Node classes
556
+ fm.write_with_template(
557
+ "LazyNonNativeIr.h",
558
+ "LazyNonNativeIr.h",
559
+ lambda: {
560
+ "lazy_non_native_ir_inc": [
561
+ f"#include <{path}>"
562
+ for path in [
563
+ "torch/csrc/lazy/core/ir.h",
564
+ "torch/csrc/lazy/core/ir_builder.h",
565
+ "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
566
+ "torch/csrc/lazy/core/shape_inference.h",
567
+ ]
568
+ + ([node_base_hdr] if node_base_hdr else [])
569
+ if path
570
+ ],
571
+ "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
572
+ non_native, lazy_ir_obj
573
+ ),
574
+ "namespace_prologue": ns_helper.prologue,
575
+ "namespace_epilogue": ns_helper.epilogue,
576
+ },
577
+ )
578
+
579
+
580
+ if __name__ == "__main__":
581
+ main()